Skip to content
Snippets Groups Projects
Commit 44c44d1f authored by Jarema Radom's avatar Jarema Radom
Browse files

Long sequences fix

parent fb0a1335
No related tags found
1 merge request!11Resolve "Handle long sequences"
Pipeline #2504 passed
...@@ -131,13 +131,12 @@ class PolDeepNer2: ...@@ -131,13 +131,12 @@ class PolDeepNer2:
if batch[1] < self.max_seq_length: if batch[1] < self.max_seq_length:
return self.process_sequence(batch[0], label_map) return self.process_sequence(batch[0], label_map)
else: else:
tokens = list(map(lambda x: x.text_a, batch[0])) tokens = list(map(lambda x: x.text_a, batch[0]))[0].split(' ')
sequences = split_subwords_into_sequences(tokens, self.max_seq_length) sequences = split_subwords_into_sequences(tokens, self.max_seq_length)
y_preds = [] y_preds = []
examples = []
for i, sequence in enumerate(sequences): for i, sequence in enumerate(sequences):
label = ["O"] * len(sequence) label = ["O"] * len(sequence)
examples.append(InputExample(guid=i, text_a=sequence, text_b=None, label=label)) examples = [InputExample(guid=i, text_a=' '.join(sequence), text_b=None, label=label)]
y_preds.append(self.process_sequence(examples, label_map)) y_preds.append(self.process_sequence(examples, label_map))
return merge_overlapping_sequences(y_preds, self.max_seq_length) return merge_overlapping_sequences(y_preds, self.max_seq_length)
......
import poldeepner2.models import poldeepner2.models
from poldeepner2.utils.data_utils import wrap_annotations from poldeepner2.utils.data_utils import wrap_annotations
ner = poldeepner2.models.load("kpwr-n82-base", device="cuda:0", resources_path="/tmp") ner = poldeepner2.models.load("kpwr-n82-base", device="cpu", resources_path="/tmp")
sentences = ["Marek Nowak z Politechniki Wrocławskiej mieszka przy ul . Sądeckiej .", sentences = ["Marek Nowak z Politechniki Wrocławskiej mieszka przy ul . Sądeckiej .",
"#PoselAdamNowak Co Pan myśli na temat fuzji Orlenu i Lotosu ?"] "#PoselAdamNowak Co Pan myśli na temat fuzji Orlenu i Lotosu ?"]
...@@ -10,6 +10,7 @@ for n in [64, 128]: ...@@ -10,6 +10,7 @@ for n in [64, 128]:
for sentence in sentences: for sentence in sentences:
tokens = ["."] * n tokens = ["."] * n
#tokens = []
tokens.extend(sentence.split(" ")) tokens.extend(sentence.split(" "))
print("-"*20) print("-"*20)
...@@ -17,7 +18,6 @@ for n in [64, 128]: ...@@ -17,7 +18,6 @@ for n in [64, 128]:
labels = ner.process_tokenized([tokens]) labels = ner.process_tokenized([tokens])
names = wrap_annotations(labels) names = wrap_annotations(labels)
for name in names: for name in names:
name_range = "%d:%d" % (min(name.token_ids), max(name.token_ids)) name_range = "%d:%d" % (min(name.token_ids), max(name.token_ids))
text = " ".join([tokens[idx] for idx in name.token_ids]) text = " ".join([tokens[idx] for idx in name.token_ids])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment