diff --git a/poldeepner2/models.py b/poldeepner2/models.py index 1f261c65d92c520fd14a6a7826d0731eaf737610..cf6480de552b04000a395e44b3c900737a053732 100644 --- a/poldeepner2/models.py +++ b/poldeepner2/models.py @@ -131,13 +131,12 @@ class PolDeepNer2: if batch[1] < self.max_seq_length: return self.process_sequence(batch[0], label_map) else: - tokens = list(map(lambda x: x.text_a, batch[0])) + tokens = list(map(lambda x: x.text_a, batch[0]))[0].split(' ') sequences = split_subwords_into_sequences(tokens, self.max_seq_length) y_preds = [] - examples = [] for i, sequence in enumerate(sequences): label = ["O"] * len(sequence) - examples.append(InputExample(guid=i, text_a=sequence, text_b=None, label=label)) + examples = [InputExample(guid=i, text_a=' '.join(sequence), text_b=None, label=label)] y_preds.append(self.process_sequence(examples, label_map)) return merge_overlapping_sequences(y_preds, self.max_seq_length) diff --git a/poldeepner2/utils/data_utils.py b/poldeepner2/utils/data_utils.py index bdbf9a19a47f876e41de84b73f6dc3addc8b9f29..c2a102c55aadfed67aa4942f4b4dc9796e2d068e 100644 --- a/poldeepner2/utils/data_utils.py +++ b/poldeepner2/utils/data_utils.py @@ -604,7 +604,7 @@ def split_subwords_into_sequences_train(token_ids, labels, valid, label_mask, ma def split_subwords_into_sequences(tokens, max_seq_length): sequences = [] - for i in range (0, round(len(tokens)/max_seq_length)*2): + for i in range(0, round(len(tokens)/max_seq_length)*2): start = int((i*max_seq_length)/2) current_slice = tokens[start:start + max_seq_length] if len(current_slice) < max_seq_length: diff --git a/sample_tokenized.py b/sample_tokenized.py index 38b5e023b4975b371f0c29d55f6eb758889fbdd9..b913c80f79b839bb16ea2441e3eafb7b2f1836c4 100644 --- a/sample_tokenized.py +++ b/sample_tokenized.py @@ -1,7 +1,7 @@ import poldeepner2.models from poldeepner2.utils.data_utils import wrap_annotations -ner = poldeepner2.models.load("kpwr-n82-base", device="cuda:0", resources_path="/tmp") +ner = poldeepner2.models.load("kpwr-n82-base", device="cpu", resources_path="/tmp") sentences = ["Marek Nowak z Politechniki Wrocławskiej mieszka przy ul . Sądeckiej .", "#PoselAdamNowak Co Pan myśli na temat fuzji Orlenu i Lotosu ?"] @@ -10,6 +10,7 @@ for n in [64, 128]: for sentence in sentences: tokens = ["."] * n + #tokens = [] tokens.extend(sentence.split(" ")) print("-"*20) @@ -17,7 +18,6 @@ for n in [64, 128]: labels = ner.process_tokenized([tokens]) names = wrap_annotations(labels) - for name in names: name_range = "%d:%d" % (min(name.token_ids), max(name.token_ids)) text = " ".join([tokens[idx] for idx in name.token_ids])