From 44c44d1f382492dbf1d96cb9a17109830b9b1b21 Mon Sep 17 00:00:00 2001 From: Jarema Radom <jaremaradom@gmail.com> Date: Mon, 15 Feb 2021 16:02:09 +0100 Subject: [PATCH] Long sequences fix --- poldeepner2/models.py | 5 ++--- poldeepner2/utils/data_utils.py | 2 +- sample_tokenized.py | 4 ++-- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/poldeepner2/models.py b/poldeepner2/models.py index 1f261c6..cf6480d 100644 --- a/poldeepner2/models.py +++ b/poldeepner2/models.py @@ -131,13 +131,12 @@ class PolDeepNer2: if batch[1] < self.max_seq_length: return self.process_sequence(batch[0], label_map) else: - tokens = list(map(lambda x: x.text_a, batch[0])) + tokens = list(map(lambda x: x.text_a, batch[0]))[0].split(' ') sequences = split_subwords_into_sequences(tokens, self.max_seq_length) y_preds = [] - examples = [] for i, sequence in enumerate(sequences): label = ["O"] * len(sequence) - examples.append(InputExample(guid=i, text_a=sequence, text_b=None, label=label)) + examples = [InputExample(guid=i, text_a=' '.join(sequence), text_b=None, label=label)] y_preds.append(self.process_sequence(examples, label_map)) return merge_overlapping_sequences(y_preds, self.max_seq_length) diff --git a/poldeepner2/utils/data_utils.py b/poldeepner2/utils/data_utils.py index bdbf9a1..c2a102c 100644 --- a/poldeepner2/utils/data_utils.py +++ b/poldeepner2/utils/data_utils.py @@ -604,7 +604,7 @@ def split_subwords_into_sequences_train(token_ids, labels, valid, label_mask, ma def split_subwords_into_sequences(tokens, max_seq_length): sequences = [] - for i in range (0, round(len(tokens)/max_seq_length)*2): + for i in range(0, round(len(tokens)/max_seq_length)*2): start = int((i*max_seq_length)/2) current_slice = tokens[start:start + max_seq_length] if len(current_slice) < max_seq_length: diff --git a/sample_tokenized.py b/sample_tokenized.py index 38b5e02..b913c80 100644 --- a/sample_tokenized.py +++ b/sample_tokenized.py @@ -1,7 +1,7 @@ import poldeepner2.models from poldeepner2.utils.data_utils import wrap_annotations -ner = poldeepner2.models.load("kpwr-n82-base", device="cuda:0", resources_path="/tmp") +ner = poldeepner2.models.load("kpwr-n82-base", device="cpu", resources_path="/tmp") sentences = ["Marek Nowak z Politechniki Wrocławskiej mieszka przy ul . Sądeckiej .", "#PoselAdamNowak Co Pan myśli na temat fuzji Orlenu i Lotosu ?"] @@ -10,6 +10,7 @@ for n in [64, 128]: for sentence in sentences: tokens = ["."] * n + #tokens = [] tokens.extend(sentence.split(" ")) print("-"*20) @@ -17,7 +18,6 @@ for n in [64, 128]: labels = ner.process_tokenized([tokens]) names = wrap_annotations(labels) - for name in names: name_range = "%d:%d" % (min(name.token_ids), max(name.token_ids)) text = " ".join([tokens[idx] for idx in name.token_ids]) -- GitLab