From 9c283e8c4affb8e74f929e9e510abfa4d6a7d456 Mon Sep 17 00:00:00 2001
From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com>
Date: Wed, 7 Apr 2021 07:46:33 +0200
Subject: [PATCH] Fix embeddings mapping during the evaluation step.

---
 combo/data/api.py | 5 +++--
 combo/predict.py  | 6 +++---
 2 files changed, 6 insertions(+), 5 deletions(-)

diff --git a/combo/data/api.py b/combo/data/api.py
index bfec5ee..308e9e4 100644
--- a/combo/data/api.py
+++ b/combo/data/api.py
@@ -55,6 +55,7 @@ def sentence2conllu(sentence: Sentence, keep_semrel: bool = True) -> conllu.Toke
         # Remove semrel to have default conllu format.
         if not keep_semrel:
             del token_dict["semrel"]
+        del token_dict["embeddings"]
         tokens.append(token_dict)
     # Range tokens must be tuple not list, this is conllu library requirement
     for t in tokens:
@@ -84,10 +85,10 @@ def conllu2sentence(conllu_sentence: conllu.TokenList,
     if sentence_embedding is None:
         sentence_embedding = []
     tokens = []
-    for idx, token in enumerate(conllu_sentence.tokens):
+    for token in conllu_sentence.tokens:
         tokens.append(
             Token(
-                **token, embeddings=embeddings[idx]
+                **token, embeddings=embeddings[token["id"]]
             )
         )
     return Sentence(
diff --git a/combo/predict.py b/combo/predict.py
index f580c01..5945e6d 100644
--- a/combo/predict.py
+++ b/combo/predict.py
@@ -141,7 +141,7 @@ class COMBO(predictor.Predictor):
         tree = instance.fields["metadata"]["input"]
         field_names = instance.fields["metadata"]["field_names"]
         tree_tokens = [t for t in tree if isinstance(t["id"], int)]
-        embeddings = [{} for _ in range(len(tree_tokens))]
+        embeddings = {t["id"]: {} for t in tree}
         for field_name in field_names:
             if field_name not in predictions:
                 continue
@@ -150,7 +150,7 @@ class COMBO(predictor.Predictor):
                 if field_name in {"xpostag", "upostag", "semrel", "deprel"}:
                     value = self.vocab.get_token_from_index(field_predictions[idx], field_name + "_labels")
                     token[field_name] = value
-                    embeddings[idx][field_name] = predictions[f"{field_name}_token_embedding"][idx]
+                    embeddings[token["id"]][field_name] = predictions[f"{field_name}_token_embedding"][idx]
                 elif field_name == "head":
                     token[field_name] = int(field_predictions[idx])
                 elif field_name == "deps":
@@ -176,7 +176,7 @@ class COMBO(predictor.Predictor):
                         field_value = "|".join(np.array(features)[arg_indices].tolist())
 
                     token[field_name] = field_value
-                    embeddings[idx][field_name] = predictions[f"{field_name}_token_embedding"][idx]
+                    embeddings[token["id"]][field_name] = predictions[f"{field_name}_token_embedding"][idx]
                 elif field_name == "lemma":
                     prediction = field_predictions[idx]
                     word_chars = []
-- 
GitLab