From 8e09e3360bbcc24face4b076e78cc47b96691cdb Mon Sep 17 00:00:00 2001
From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com>
Date: Fri, 11 Dec 2020 13:12:05 +0100
Subject: [PATCH] Add restoring collapsed edges (gapping).

---
 combo/predict.py     |  2 ++
 combo/utils/graph.py | 27 +++++++++++++++++++++++++++
 2 files changed, 29 insertions(+)

diff --git a/combo/predict.py b/combo/predict.py
index e52b42e..c58db25 100644
--- a/combo/predict.py
+++ b/combo/predict.py
@@ -207,6 +207,8 @@ class SemanticMultitaskPredictor(predictor.Predictor):
                                   tree_tokens=tree_tokens,
                                   root_idx=self.vocab.get_token_index("root", "deprel_labels"),
                                   vocab_index=self.vocab.get_index_to_token_vocabulary("deprel_labels"))
+            empty_tokens = graph.restore_collapse_edges(tree_tokens)
+            tree.tokens.extend(empty_tokens)
 
         return tree, predictions["sentence_embedding"]
 
diff --git a/combo/utils/graph.py b/combo/utils/graph.py
index 1785b4b..32e7dd9 100644
--- a/combo/utils/graph.py
+++ b/combo/utils/graph.py
@@ -85,3 +85,30 @@ def _dfs(graph, start, end):
             if next_state in path:
                 continue
             fringe.append((next_state, path + [next_state]))
+
+
+def restore_collapse_edges(tree_tokens):
+    empty_tokens = []
+    for token in tree_tokens:
+        deps = token["deps"].split("|")
+        for i, d in enumerate(deps):
+            if ">" in d:
+                # {head}:{empty_node_relation}>{current_node_relation}
+                # should map to
+                # For new, empty node:
+                # {head}:{empty_node_relation}
+                # For current node:
+                # {new_empty_node_id}:{current_node_relation}
+                # TODO consider where to put new_empty_node_id (currently at the end)
+                head, relation = d.split(':', 1)
+                ehead = f"{len(tree_tokens)}.{len(empty_tokens) + 1}"
+                empty_node_relation, current_node_relation = relation.split(">", 1)
+                deps[i] = f"{ehead}:{current_node_relation}"
+                empty_tokens.append(
+                    {
+                        "id": ehead,
+                        "deps": f"{head}:{empty_node_relation}"
+                    }
+                )
+        token["deps"] = "|".join(deps)
+    return empty_tokens
-- 
GitLab