From 9c283e8c4affb8e74f929e9e510abfa4d6a7d456 Mon Sep 17 00:00:00 2001 From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com> Date: Wed, 7 Apr 2021 07:46:33 +0200 Subject: [PATCH] Fix embeddings mapping during the evaluation step. --- combo/data/api.py | 5 +++-- combo/predict.py | 6 +++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/combo/data/api.py b/combo/data/api.py index bfec5ee..308e9e4 100644 --- a/combo/data/api.py +++ b/combo/data/api.py @@ -55,6 +55,7 @@ def sentence2conllu(sentence: Sentence, keep_semrel: bool = True) -> conllu.Toke # Remove semrel to have default conllu format. if not keep_semrel: del token_dict["semrel"] + del token_dict["embeddings"] tokens.append(token_dict) # Range tokens must be tuple not list, this is conllu library requirement for t in tokens: @@ -84,10 +85,10 @@ def conllu2sentence(conllu_sentence: conllu.TokenList, if sentence_embedding is None: sentence_embedding = [] tokens = [] - for idx, token in enumerate(conllu_sentence.tokens): + for token in conllu_sentence.tokens: tokens.append( Token( - **token, embeddings=embeddings[idx] + **token, embeddings=embeddings[token["id"]] ) ) return Sentence( diff --git a/combo/predict.py b/combo/predict.py index f580c01..5945e6d 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -141,7 +141,7 @@ class COMBO(predictor.Predictor): tree = instance.fields["metadata"]["input"] field_names = instance.fields["metadata"]["field_names"] tree_tokens = [t for t in tree if isinstance(t["id"], int)] - embeddings = [{} for _ in range(len(tree_tokens))] + embeddings = {t["id"]: {} for t in tree} for field_name in field_names: if field_name not in predictions: continue @@ -150,7 +150,7 @@ class COMBO(predictor.Predictor): if field_name in {"xpostag", "upostag", "semrel", "deprel"}: value = self.vocab.get_token_from_index(field_predictions[idx], field_name + "_labels") token[field_name] = value - embeddings[idx][field_name] = predictions[f"{field_name}_token_embedding"][idx] + embeddings[token["id"]][field_name] = predictions[f"{field_name}_token_embedding"][idx] elif field_name == "head": token[field_name] = int(field_predictions[idx]) elif field_name == "deps": @@ -176,7 +176,7 @@ class COMBO(predictor.Predictor): field_value = "|".join(np.array(features)[arg_indices].tolist()) token[field_name] = field_value - embeddings[idx][field_name] = predictions[f"{field_name}_token_embedding"][idx] + embeddings[token["id"]][field_name] = predictions[f"{field_name}_token_embedding"][idx] elif field_name == "lemma": prediction = field_predictions[idx] word_chars = [] -- GitLab