From 0b7636f8c0cff345c9def408d51915e219afa796 Mon Sep 17 00:00:00 2001
From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com>
Date: Mon, 8 Jun 2020 15:44:42 +0200
Subject: [PATCH] Handle multi word tokens during prediction.

---
 combo/data/dataset.py | 9 +--------
 combo/predict.py      | 3 ++-
 2 files changed, 3 insertions(+), 9 deletions(-)

diff --git a/combo/data/dataset.py b/combo/data/dataset.py
index 56f5b9a..f4d387a 100644
--- a/combo/data/dataset.py
+++ b/combo/data/dataset.py
@@ -74,13 +74,6 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
         for conllu_file in file_path:
             with open(conllu_file, "r") as file:
                 for annotation in conllu.parse_incr(file, fields=self.fields, field_parsers=self.field_parsers):
-                    # CoNLLU annotations sometimes add back in words that have been elided
-                    # in the original sentence; we remove these, as we're just predicting
-                    # dependencies for the original sentence.
-                    # We filter by integers here as elided words have a non-integer word id,
-                    # as parsed by the conllu python library.
-                    annotation = conllu.TokenList([x for x in annotation if isinstance(x["id"], int)],
-                                                  metadata=annotation.metadata)
                     yield self.text_to_instance(annotation)
 
     @overrides
@@ -91,7 +84,7 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
                          tag_=t.get("xpostag"),
                          lemma_=t.get("lemma"),
                          feats_=t.get("feats"))
-                  for t in tree]
+                  for t in tree if isinstance(t["id"], int)]
 
         # features
         text_field = allen_fields.TextField(tokens, self._token_indexers)
diff --git a/combo/predict.py b/combo/predict.py
index edcb817..261bd6c 100644
--- a/combo/predict.py
+++ b/combo/predict.py
@@ -100,7 +100,8 @@ class SemanticMultitaskPredictor(predictor.Predictor):
     def _predictions_as_tree(self, predictions, instance):
         tree = instance.fields["metadata"]["input"]
         field_names = instance.fields["metadata"]["field_names"]
-        for idx, token in enumerate(tree):
+        tree_tokens = [t for t in tree if isinstance(t["id"], int)]
+        for idx, token in enumerate(tree_tokens):
             for field_name in field_names:
                 if field_name in predictions:
                     if field_name in ["xpostag", "upostag", "semrel", "deprel"]:
-- 
GitLab