Skip to content
Snippets Groups Projects
Commit 9c283e8c authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski
Browse files

Fix embeddings mapping during the evaluation step.

parent 06efc043
No related branches found
No related tags found
2 merge requests!37Release 1.0.4.,!36Release 1.0.4
Pipeline #3640 passed
This commit is part of merge request !36. Comments created here will be created in the context of that merge request.
...@@ -55,6 +55,7 @@ def sentence2conllu(sentence: Sentence, keep_semrel: bool = True) -> conllu.Toke ...@@ -55,6 +55,7 @@ def sentence2conllu(sentence: Sentence, keep_semrel: bool = True) -> conllu.Toke
# Remove semrel to have default conllu format. # Remove semrel to have default conllu format.
if not keep_semrel: if not keep_semrel:
del token_dict["semrel"] del token_dict["semrel"]
del token_dict["embeddings"]
tokens.append(token_dict) tokens.append(token_dict)
# Range tokens must be tuple not list, this is conllu library requirement # Range tokens must be tuple not list, this is conllu library requirement
for t in tokens: for t in tokens:
...@@ -84,10 +85,10 @@ def conllu2sentence(conllu_sentence: conllu.TokenList, ...@@ -84,10 +85,10 @@ def conllu2sentence(conllu_sentence: conllu.TokenList,
if sentence_embedding is None: if sentence_embedding is None:
sentence_embedding = [] sentence_embedding = []
tokens = [] tokens = []
for idx, token in enumerate(conllu_sentence.tokens): for token in conllu_sentence.tokens:
tokens.append( tokens.append(
Token( Token(
**token, embeddings=embeddings[idx] **token, embeddings=embeddings[token["id"]]
) )
) )
return Sentence( return Sentence(
......
...@@ -141,7 +141,7 @@ class COMBO(predictor.Predictor): ...@@ -141,7 +141,7 @@ class COMBO(predictor.Predictor):
tree = instance.fields["metadata"]["input"] tree = instance.fields["metadata"]["input"]
field_names = instance.fields["metadata"]["field_names"] field_names = instance.fields["metadata"]["field_names"]
tree_tokens = [t for t in tree if isinstance(t["id"], int)] tree_tokens = [t for t in tree if isinstance(t["id"], int)]
embeddings = [{} for _ in range(len(tree_tokens))] embeddings = {t["id"]: {} for t in tree}
for field_name in field_names: for field_name in field_names:
if field_name not in predictions: if field_name not in predictions:
continue continue
...@@ -150,7 +150,7 @@ class COMBO(predictor.Predictor): ...@@ -150,7 +150,7 @@ class COMBO(predictor.Predictor):
if field_name in {"xpostag", "upostag", "semrel", "deprel"}: if field_name in {"xpostag", "upostag", "semrel", "deprel"}:
value = self.vocab.get_token_from_index(field_predictions[idx], field_name + "_labels") value = self.vocab.get_token_from_index(field_predictions[idx], field_name + "_labels")
token[field_name] = value token[field_name] = value
embeddings[idx][field_name] = predictions[f"{field_name}_token_embedding"][idx] embeddings[token["id"]][field_name] = predictions[f"{field_name}_token_embedding"][idx]
elif field_name == "head": elif field_name == "head":
token[field_name] = int(field_predictions[idx]) token[field_name] = int(field_predictions[idx])
elif field_name == "deps": elif field_name == "deps":
...@@ -176,7 +176,7 @@ class COMBO(predictor.Predictor): ...@@ -176,7 +176,7 @@ class COMBO(predictor.Predictor):
field_value = "|".join(np.array(features)[arg_indices].tolist()) field_value = "|".join(np.array(features)[arg_indices].tolist())
token[field_name] = field_value token[field_name] = field_value
embeddings[idx][field_name] = predictions[f"{field_name}_token_embedding"][idx] embeddings[token["id"]][field_name] = predictions[f"{field_name}_token_embedding"][idx]
elif field_name == "lemma": elif field_name == "lemma":
prediction = field_predictions[idx] prediction = field_predictions[idx]
word_chars = [] word_chars = []
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment