From a2137989f0a28537f75c193bc2e3915b20ff8bf2 Mon Sep 17 00:00:00 2001
From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com>
Date: Mon, 23 Nov 2020 16:12:47 +0100
Subject: [PATCH] Fix deps prediction for MWE expressions.

---
 combo/predict.py          | 13 ++++++-------
 combo/utils/graph.py      |  8 ++++----
 tests/utils/test_graph.py |  8 ++++----
 3 files changed, 14 insertions(+), 15 deletions(-)

diff --git a/combo/predict.py b/combo/predict.py
index 6d657f6..42e8bed 100644
--- a/combo/predict.py
+++ b/combo/predict.py
@@ -12,7 +12,7 @@ from overrides import overrides
 
 from combo import data
 from combo.data import sentence2conllu, tokens2conllu, conllu2sentence
-from combo.utils import download
+from combo.utils import download, graph
 
 logger = logging.getLogger(__name__)
 
@@ -195,12 +195,11 @@ class SemanticMultitaskPredictor(predictor.Predictor):
                         raise NotImplementedError(f"Unknown field name {field_name}!")
 
         if "enhanced_head" in predictions and predictions["enhanced_head"]:
-            import combo.utils.graph as graph
-            tree = graph.sdp_to_dag_deps(arc_scores=np.array(predictions["enhanced_head"]),
-                                         rel_scores=np.array(predictions["enhanced_deprel"]),
-                                         tree=tree,
-                                         root_label="ROOT",
-                                         vocab_index=self.vocab.get_index_to_token_vocabulary("deprel_labels"))
+            graph.sdp_to_dag_deps(arc_scores=np.array(predictions["enhanced_head"]),
+                                  rel_scores=np.array(predictions["enhanced_deprel"]),
+                                  tree_tokens=tree_tokens,
+                                  root_label="ROOT",
+                                  vocab_index=self.vocab.get_index_to_token_vocabulary("deprel_labels"))
 
         return tree, predictions["sentence_embedding"]
 
diff --git a/combo/utils/graph.py b/combo/utils/graph.py
index fd59e97..814341b 100644
--- a/combo/utils/graph.py
+++ b/combo/utils/graph.py
@@ -1,11 +1,11 @@
 """Based on https://github.com/emorynlp/iwpt-shared-task-2020."""
+from typing import List
+
 import numpy as np
-from conllu import TokenList
 
 
-def sdp_to_dag_deps(arc_scores, rel_scores, tree: TokenList, root_label, vocab_index=None):
+def sdp_to_dag_deps(arc_scores, rel_scores, tree_tokens: List, root_label, vocab_index=None) -> None:
     # adding ROOT
-    tree_tokens = tree.tokens
     tree_heads = [0] + [t["head"] for t in tree_tokens]
     graph = adjust_root_score_then_add_secondary_arcs(arc_scores, rel_scores, tree_heads,
                                                       root_label)
@@ -25,7 +25,7 @@ def sdp_to_dag_deps(arc_scores, rel_scores, tree: TokenList, root_label, vocab_i
         deps = '|'.join(f'{h}:{r}' for h, r in zip(heads, rels))
         tree_tokens[i - 1]["deps"] = deps
         tree_tokens[i - 1]["deprel"] = deprel
-    return tree
+    return
 
 
 def adjust_root_score_then_add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_idx):
diff --git a/tests/utils/test_graph.py b/tests/utils/test_graph.py
index 4a65182..4c5c7d3 100644
--- a/tests/utils/test_graph.py
+++ b/tests/utils/test_graph.py
@@ -26,7 +26,7 @@ class GraphTest(unittest.TestCase):
         expected_deps = ["2:ROOT", "3:yes", "1:yes"]
 
         # when
-        tree = graph.sdp_to_dag_deps(empty_graph, graph_labels, tree, root_label)
+        graph.sdp_to_dag_deps(empty_graph, graph_labels, tree.tokens, root_label)
         actual_deps = [t["deps"] for t in tree.tokens]
 
         # then
@@ -51,7 +51,7 @@ class GraphTest(unittest.TestCase):
         expected_deps = ["2:ROOT", "3:graph_label", "1:graph_label"]
 
         # when
-        tree = graph.sdp_to_dag_deps(empty_graph, graph_labels, tree, root_label)
+        graph.sdp_to_dag_deps(empty_graph, graph_labels, tree.tokens, root_label)
         actual_deps = [t["deps"] for t in tree.tokens]
 
         # then
@@ -82,8 +82,8 @@ class GraphTest(unittest.TestCase):
         expected_deps = ["0:ROOT", "1:tree_label", "1:graph_label"]
 
         # when
-        tree = graph.sdp_to_dag_deps(arc_scores, graph_labels, tree, root_label)
+        graph.sdp_to_dag_deps(arc_scores, graph_labels, tree.tokens, root_label)
         actual_deps = [t["deps"] for t in tree.tokens]
 
         # then
-        self.assertEqual(actual_deps, expected_deps)
\ No newline at end of file
+        self.assertEqual(actual_deps, expected_deps)
-- 
GitLab