From 826e57a756c8ba96b26c06406c03ac57807b6e19 Mon Sep 17 00:00:00 2001
From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com>
Date: Fri, 11 Dec 2020 10:19:12 +0100
Subject: [PATCH] Hotfix off by one in enhanced graphs.

---
 combo/models/graph_parser.py | 6 +++---
 combo/predict.py             | 9 +++++++--
 2 files changed, 10 insertions(+), 5 deletions(-)

diff --git a/combo/models/graph_parser.py b/combo/models/graph_parser.py
index 6799d49..2dc02dc 100644
--- a/combo/models/graph_parser.py
+++ b/combo/models/graph_parser.py
@@ -164,9 +164,9 @@ class GraphDependencyRelationModel(base.Predictor):
               heads_true: torch.Tensor,
               mask: torch.BoolTensor,
               sample_weights: torch.Tensor) -> torch.Tensor:
-
-        true = true[true.long() > 0]
-        pred = pred[heads_true.long() == 1]
+        correct_heads_mask = heads_true.long() == 1
+        true = true[correct_heads_mask]
+        pred = pred[correct_heads_mask]
         loss = F.cross_entropy(pred, true.long())
         return loss.sum() / pred.size(0)
 
diff --git a/combo/predict.py b/combo/predict.py
index 070975f..e52b42e 100644
--- a/combo/predict.py
+++ b/combo/predict.py
@@ -197,8 +197,13 @@ class SemanticMultitaskPredictor(predictor.Predictor):
                         raise NotImplementedError(f"Unknown field name {field_name}!")
 
         if "enhanced_head" in predictions and predictions["enhanced_head"]:
-            graph.sdp_to_dag_deps(arc_scores=np.array(predictions["enhanced_head"]),
-                                  rel_scores=np.array(predictions["enhanced_deprel_prob"]),
+            # TODO off-by-one hotfix, refactor
+            h = np.array(predictions["enhanced_head"])
+            h = np.concatenate((h[-1:], h[:-1]))
+            r = np.array(predictions["enhanced_deprel_prob"])
+            r = np.concatenate((r[-1:], r[:-1]))
+            graph.sdp_to_dag_deps(arc_scores=h,
+                                  rel_scores=r,
                                   tree_tokens=tree_tokens,
                                   root_idx=self.vocab.get_token_index("root", "deprel_labels"),
                                   vocab_index=self.vocab.get_index_to_token_vocabulary("deprel_labels"))
-- 
GitLab