Skip to content
Snippets Groups Projects
Commit 9c283e8c authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski
Browse files

Fix embeddings mapping during the evaluation step.

parent 06efc043
No related branches found
No related tags found
2 merge requests!37Release 1.0.4.,!36Release 1.0.4
Pipeline #3640 passed
......@@ -55,6 +55,7 @@ def sentence2conllu(sentence: Sentence, keep_semrel: bool = True) -> conllu.Toke
# Remove semrel to have default conllu format.
if not keep_semrel:
del token_dict["semrel"]
del token_dict["embeddings"]
tokens.append(token_dict)
# Range tokens must be tuple not list, this is conllu library requirement
for t in tokens:
......@@ -84,10 +85,10 @@ def conllu2sentence(conllu_sentence: conllu.TokenList,
if sentence_embedding is None:
sentence_embedding = []
tokens = []
for idx, token in enumerate(conllu_sentence.tokens):
for token in conllu_sentence.tokens:
tokens.append(
Token(
**token, embeddings=embeddings[idx]
**token, embeddings=embeddings[token["id"]]
)
)
return Sentence(
......
......@@ -141,7 +141,7 @@ class COMBO(predictor.Predictor):
tree = instance.fields["metadata"]["input"]
field_names = instance.fields["metadata"]["field_names"]
tree_tokens = [t for t in tree if isinstance(t["id"], int)]
embeddings = [{} for _ in range(len(tree_tokens))]
embeddings = {t["id"]: {} for t in tree}
for field_name in field_names:
if field_name not in predictions:
continue
......@@ -150,7 +150,7 @@ class COMBO(predictor.Predictor):
if field_name in {"xpostag", "upostag", "semrel", "deprel"}:
value = self.vocab.get_token_from_index(field_predictions[idx], field_name + "_labels")
token[field_name] = value
embeddings[idx][field_name] = predictions[f"{field_name}_token_embedding"][idx]
embeddings[token["id"]][field_name] = predictions[f"{field_name}_token_embedding"][idx]
elif field_name == "head":
token[field_name] = int(field_predictions[idx])
elif field_name == "deps":
......@@ -176,7 +176,7 @@ class COMBO(predictor.Predictor):
field_value = "|".join(np.array(features)[arg_indices].tolist())
token[field_name] = field_value
embeddings[idx][field_name] = predictions[f"{field_name}_token_embedding"][idx]
embeddings[token["id"]][field_name] = predictions[f"{field_name}_token_embedding"][idx]
elif field_name == "lemma":
prediction = field_predictions[idx]
word_chars = []
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment