Commit 9c283e8c authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski

Fix embeddings mapping during the evaluation step.

parent 06efc043
Pipeline #3640 passed with stage
in 3 minutes and 11 seconds
......@@ -55,6 +55,7 @@ def sentence2conllu(sentence: Sentence, keep_semrel: bool = True) -> conllu.Toke
# Remove semrel to have default conllu format.
if not keep_semrel:
del token_dict["semrel"]
del token_dict["embeddings"]
tokens.append(token_dict)
# Range tokens must be tuple not list, this is conllu library requirement
for t in tokens:
......@@ -84,10 +85,10 @@ def conllu2sentence(conllu_sentence: conllu.TokenList,
if sentence_embedding is None:
sentence_embedding = []
tokens = []
for idx, token in enumerate(conllu_sentence.tokens):
for token in conllu_sentence.tokens:
tokens.append(
Token(
**token, embeddings=embeddings[idx]
**token, embeddings=embeddings[token["id"]]
)
)
return Sentence(
......
......@@ -141,7 +141,7 @@ class COMBO(predictor.Predictor):
tree = instance.fields["metadata"]["input"]
field_names = instance.fields["metadata"]["field_names"]
tree_tokens = [t for t in tree if isinstance(t["id"], int)]
embeddings = [{} for _ in range(len(tree_tokens))]
embeddings = {t["id"]: {} for t in tree}
for field_name in field_names:
if field_name not in predictions:
continue
......@@ -150,7 +150,7 @@ class COMBO(predictor.Predictor):
if field_name in {"xpostag", "upostag", "semrel", "deprel"}:
value = self.vocab.get_token_from_index(field_predictions[idx], field_name + "_labels")
token[field_name] = value
embeddings[idx][field_name] = predictions[f"{field_name}_token_embedding"][idx]
embeddings[token["id"]][field_name] = predictions[f"{field_name}_token_embedding"][idx]
elif field_name == "head":
token[field_name] = int(field_predictions[idx])
elif field_name == "deps":
......@@ -176,7 +176,7 @@ class COMBO(predictor.Predictor):
field_value = "|".join(np.array(features)[arg_indices].tolist())
token[field_name] = field_value
embeddings[idx][field_name] = predictions[f"{field_name}_token_embedding"][idx]
embeddings[token["id"]][field_name] = predictions[f"{field_name}_token_embedding"][idx]
elif field_name == "lemma":
prediction = field_predictions[idx]
word_chars = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment