From 23e0c9ce45e63dbcf031ba8965f85590f331a92d Mon Sep 17 00:00:00 2001
From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com>
Date: Wed, 23 Dec 2020 09:14:09 +0100
Subject: [PATCH] Sort deps when uncollapsing nodes, mask root label
 possibility when root isn't head of a token.

---
 combo/models/parser.py | 25 +++++++++++++++++++++++--
 combo/utils/graph.py   |  3 ++-
 2 files changed, 25 insertions(+), 3 deletions(-)

diff --git a/combo/models/parser.py b/combo/models/parser.py
index 486b248..4b5b126 100644
--- a/combo/models/parser.py
+++ b/combo/models/parser.py
@@ -1,4 +1,5 @@
 """Dependency parsing models."""
+import math
 from typing import Tuple, Dict, Optional, Union, List
 
 import numpy as np
@@ -115,11 +116,13 @@ class DependencyRelationModel(base.Predictor):
     """Dependency relation parsing model."""
 
     def __init__(self,
+                 root_idx: int,
                  head_predictor: HeadPredictionModel,
                  head_projection_layer: base.Linear,
                  dependency_projection_layer: base.Linear,
                  relation_prediction_layer: base.Linear):
         super().__init__()
+        self.root_idx = root_idx
         self.head_predictor = head_predictor
         self.head_projection_layer = head_projection_layer
         self.dependency_projection_layer = dependency_projection_layer
@@ -130,6 +133,7 @@ class DependencyRelationModel(base.Predictor):
                 mask: Optional[torch.BoolTensor] = None,
                 labels: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
                 sample_weights: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]:
+        device = x.device
         if mask is not None:
             mask = mask[:, 1:]
         relations_labels, head_labels = None, None
@@ -151,7 +155,23 @@ class DependencyRelationModel(base.Predictor):
         relation_prediction = self.relation_prediction_layer(dep_rel_pred)
         output = head_output
 
-        output["prediction"] = (relation_prediction.argmax(-1)[:, 1:], head_output["prediction"])
+        if self.training:
+            output["prediction"] = (relation_prediction.argmax(-1)[:, 1:], head_output["prediction"])
+        else:
+            # Mask root label whenever head is not 0.
+            relation_prediction_output = relation_prediction[:, 1:]
+            mask = (head_output["prediction"] == 0)
+            vocab_size = relation_prediction_output.size(-1)
+            root_idx = torch.tensor([self.root_idx], device=device)
+            relation_prediction_output[mask] = (relation_prediction_output
+                                                .masked_select(mask.unsqueeze(-1))
+                                                .reshape(-1, vocab_size)
+                                                .index_fill(-1, root_idx, 10e10))
+            relation_prediction_output[~mask] = (relation_prediction_output
+                                                 .masked_select(~(mask.unsqueeze(-1)))
+                                                 .reshape(-1, vocab_size)
+                                                 .index_fill(-1, root_idx, -10e10))
+            output["prediction"] = (relation_prediction_output.argmax(-1), head_output["prediction"])
 
         if labels is not None and labels[0] is not None:
             if sample_weights is None:
@@ -195,5 +215,6 @@ class DependencyRelationModel(base.Predictor):
             head_predictor=head_predictor,
             head_projection_layer=head_projection_layer,
             dependency_projection_layer=dependency_projection_layer,
-            relation_prediction_layer=relation_prediction_layer
+            relation_prediction_layer=relation_prediction_layer,
+            root_idx=vocab.get_token_index("root", vocab_namespace)
         )
diff --git a/combo/utils/graph.py b/combo/utils/graph.py
index 32e7dd9..651c14a 100644
--- a/combo/utils/graph.py
+++ b/combo/utils/graph.py
@@ -110,5 +110,6 @@ def restore_collapse_edges(tree_tokens):
                         "deps": f"{head}:{empty_node_relation}"
                     }
                 )
-        token["deps"] = "|".join(deps)
+        deps = sorted([d.split(":", 1) for d in deps], key=lambda x: float(x[0]))
+        token["deps"] = "|".join([f"{k}:{v}" for k, v in deps])
     return empty_tokens
-- 
GitLab