From c38ef44af1ae83f4c43d04d6f2709afcf86910d2 Mon Sep 17 00:00:00 2001
From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com>
Date: Fri, 30 Apr 2021 07:48:20 +0200
Subject: [PATCH] Tree and graph merging algorithm.

---
 combo/predict.py     | 21 +++++++++---
 combo/utils/graph.py | 76 ++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 92 insertions(+), 5 deletions(-)

diff --git a/combo/predict.py b/combo/predict.py
index e528a18..c299faa 100644
--- a/combo/predict.py
+++ b/combo/predict.py
@@ -198,11 +198,22 @@ class COMBO(predictor.Predictor):
             h = np.concatenate((h[-1:], h[:-1]))
             r = np.array(predictions["enhanced_deprel_prob"])
             r = np.concatenate((r[-1:], r[:-1]))
-            graph.sdp_to_dag_deps(arc_scores=h,
-                                  rel_scores=r,
-                                  tree_tokens=tree_tokens,
-                                  root_idx=self.vocab.get_token_index("root", "deprel_labels"),
-                                  vocab_index=self.vocab.get_index_to_token_vocabulary("deprel_labels"))
+
+            graph.graph_and_tree_merge(
+                tree_arc_scores=predictions["head"],
+                tree_rel_scores=predictions["deprel"],
+                graph_arc_scores=h,
+                graph_rel_scores=r,
+                idx2label=self.vocab.get_index_to_token_vocabulary("deprel_labels"),
+                label2idx=self.vocab.get_token_to_index_vocabulary("deprel_labels"),
+                tokens=tree_tokens
+            )
+
+            # graph.sdp_to_dag_deps(arc_scores=h,
+            #                       rel_scores=r,
+            #                       tree_tokens=tree_tokens,
+            #                       root_idx=self.vocab.get_token_index("root", "deprel_labels"),
+            #                       vocab_index=self.vocab.get_index_to_token_vocabulary("deprel_labels"))
             empty_tokens = graph.restore_collapse_edges(tree_tokens)
             tree.tokens.extend(empty_tokens)
 
diff --git a/combo/utils/graph.py b/combo/utils/graph.py
index 3352625..c9ad07e 100644
--- a/combo/utils/graph.py
+++ b/combo/utils/graph.py
@@ -3,6 +3,82 @@ from typing import List
 
 import numpy as np
 
+_ACL_REL_CL = "acl:relcl"
+
+
+def graph_and_tree_merge(tree_arc_scores,
+                         tree_rel_scores,
+                         graph_arc_scores,
+                         graph_rel_scores,
+                         label2idx,
+                         idx2label,
+                         tokens):
+    graph_arc_scores = np.copy(graph_arc_scores)
+    # Exclude self-loops, in-place operation.
+    np.fill_diagonal(graph_arc_scores, 0)
+    # Connection to root will be handled by tree.
+    graph_arc_scores[:, 0] = False
+    # The same with labels.
+    root_idx = label2idx["root"]
+    graph_rel_scores[:, :, root_idx] = -float('inf')
+    graph_rel_pred = graph_rel_scores.argmax(-1)
+
+    # Add tree edges to graph
+    tree_heads = [0] + tree_arc_scores
+    graph = [[] for _ in range(len(tree_heads))]
+    labeled_graph = [[] for _ in range(len(tree_heads))]
+    for d, h in enumerate(tree_heads):
+        if not d:
+            continue
+        label = idx2label[tree_rel_scores[d - 1]]
+        if label != _ACL_REL_CL:
+            graph[h].append(d)
+            labeled_graph[h].append((d, label))
+
+    # Debug only
+    # Extract graph edges
+    graph_edges = np.argwhere(graph_arc_scores)
+
+    # Add graph edges which aren't creating a cycle
+    for (d, h) in graph_edges:
+        if not d or not h or d in graph[h]:
+            continue
+        try:
+            path = next(_dfs(graph, d, h))
+        except StopIteration:
+            # There is not path from d to h
+            label = idx2label[graph_rel_pred[d][h]]
+            if label != _ACL_REL_CL:
+                graph[h].append(d)
+                labeled_graph[h].append((d, label))
+
+    # Add 'acl:relcl' without checking for cycles.
+    for d, h in enumerate(tree_heads):
+        if not d:
+            continue
+        label = idx2label[tree_rel_scores[d - 1]]
+        if label == _ACL_REL_CL:
+            graph[h].append(d)
+            labeled_graph[h].append((d, label))
+
+    assert len(labeled_graph[0]) == 1
+    d = graph[0][0]
+    graph[d].append(0)
+    labeled_graph[d].append((0, "root"))
+
+    parse_graph = [[] for _ in range(len(tree_heads))]
+    for h in range(len(tree_heads)):
+        for d, label in labeled_graph[h]:
+            parse_graph[d].append((h, label))
+        parse_graph[d] = sorted(parse_graph[d])
+
+    for i, g in enumerate(parse_graph):
+        heads = [x[0] for x in g]
+        rels = [x[1] for x in g]
+        deps = '|'.join(f'{h}:{r}' for h, r in zip(heads, rels))
+        tokens[i - 1]["deps"] = deps
+    return
+
 
 def sdp_to_dag_deps(arc_scores, rel_scores, tree_tokens: List, root_idx=0, vocab_index=None) -> None:
     # adding ROOT
-- 
GitLab