From e295e563cd154f48a9b1f180b91355c2e6426497 Mon Sep 17 00:00:00 2001 From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com> Date: Mon, 7 Dec 2020 16:00:09 +0100 Subject: [PATCH] Exclude self loops. --- combo/utils/graph.py | 12 ++++++++---- tests/utils/test_graph.py | 31 ++++++++++++++++++++++++++++++- 2 files changed, 38 insertions(+), 5 deletions(-) diff --git a/combo/utils/graph.py b/combo/utils/graph.py index 8a55cb9..1785b4b 100644 --- a/combo/utils/graph.py +++ b/combo/utils/graph.py @@ -18,10 +18,12 @@ def sdp_to_dag_deps(arc_scores, rel_scores, tree_tokens: List, root_idx=0, vocab index = heads.index(head) deprel = tree_tokens[i - 1]["deprel"] deprel = deprel.split('>')[-1] - # TODO is this necessary? - if len(heads) >= 2: - heads.pop(index) - rels.pop(index) + # TODO - Consider if there should be a condition, + # It doesn't seem to make any sense as DEPS should contain DEPREL + # (although sometimes with different/more detailed label) + # if len(heads) >= 2: + # heads.pop(index) + # rels.pop(index) deps = '|'.join(f'{h}:{r}' for h, r in zip(heads, rels)) tree_tokens[i - 1]["deps"] = deps tree_tokens[i - 1]["deprel"] = deprel @@ -32,6 +34,8 @@ def adjust_root_score_then_add_secondary_arcs(arc_scores, rel_scores, tree_heads if len(arc_scores) != tree_heads: arc_scores = arc_scores[:len(tree_heads)][:len(tree_heads)] rel_scores = rel_scores[:len(tree_heads)][:len(tree_heads)] + # Self-loops aren't allowed, mask with 0. This is an in-place operation. + np.fill_diagonal(arc_scores, 0) parse_preds = np.array(arc_scores) > 0 parse_preds[:, 0] = False # set heads to False rel_scores[:, :, root_idx] = -float('inf') diff --git a/tests/utils/test_graph.py b/tests/utils/test_graph.py index 0a66212..74e3744 100644 --- a/tests/utils/test_graph.py +++ b/tests/utils/test_graph.py @@ -67,7 +67,7 @@ class GraphTest(unittest.TestCase): ]) graph_labels = np.zeros((4, 4, 3)) graph_labels[3][1][2] = 10e10 - expected_deps = ["0:root", "1:tree_label", "1:graph_label"] + expected_deps = ["0:root", "1:tree_label", "1:graph_label|2:tree_label"] # when graph.sdp_to_dag_deps(arc_scores, graph_labels, tree.tokens, root_idx=0, vocab_index=vocab_index) @@ -75,3 +75,32 @@ class GraphTest(unittest.TestCase): # then self.assertEqual(actual_deps, expected_deps) + + def test_extending_tree_with_self_loop_edge_shouldnt_add_edge(self): + # given + tree = conllu.TokenList( + tokens=[ + {"head": 0, "deprel": "root", "form": "word1"}, + {"head": 1, "deprel": "tree_label", "form": "word2"}, + {"head": 2, "deprel": "tree_label", "form": "word3"}, + ] + ) + vocab_index = {0: "root", 1: "tree_label", 2: "graph_label"} + arc_scores = np.array([ + [0, 0, 0, 0], + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, 1], + ]) + graph_labels = np.zeros((4, 4, 3)) + graph_labels[3][3][2] = 10e10 + expected_deps = ["0:root", "1:tree_label", "2:tree_label"] + # TODO current actual, adds self-loop + # actual_deps = ["0:root", "1:tree_label", "2:tree_label|3:graph_label"] + + # when + graph.sdp_to_dag_deps(arc_scores, graph_labels, tree.tokens, root_idx=0, vocab_index=vocab_index) + actual_deps = [t["deps"] for t in tree.tokens] + + # then + self.assertEqual(expected_deps, actual_deps) -- GitLab