Skip to content
Snippets Groups Projects
Commit e295e563 authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski Committed by Mateusz Klimaszewski
Browse files

Exclude self loops.

parent fd024dcc
2 merge requests!9Enhanced dependency parsing develop to master,!8Enhanced dependency parsing
This commit is part of merge request !8. Comments created here will be created in the context of that merge request.
...@@ -18,10 +18,12 @@ def sdp_to_dag_deps(arc_scores, rel_scores, tree_tokens: List, root_idx=0, vocab ...@@ -18,10 +18,12 @@ def sdp_to_dag_deps(arc_scores, rel_scores, tree_tokens: List, root_idx=0, vocab
index = heads.index(head) index = heads.index(head)
deprel = tree_tokens[i - 1]["deprel"] deprel = tree_tokens[i - 1]["deprel"]
deprel = deprel.split('>')[-1] deprel = deprel.split('>')[-1]
# TODO is this necessary? # TODO - Consider if there should be a condition,
if len(heads) >= 2: # It doesn't seem to make any sense as DEPS should contain DEPREL
heads.pop(index) # (although sometimes with different/more detailed label)
rels.pop(index) # if len(heads) >= 2:
# heads.pop(index)
# rels.pop(index)
deps = '|'.join(f'{h}:{r}' for h, r in zip(heads, rels)) deps = '|'.join(f'{h}:{r}' for h, r in zip(heads, rels))
tree_tokens[i - 1]["deps"] = deps tree_tokens[i - 1]["deps"] = deps
tree_tokens[i - 1]["deprel"] = deprel tree_tokens[i - 1]["deprel"] = deprel
...@@ -32,6 +34,8 @@ def adjust_root_score_then_add_secondary_arcs(arc_scores, rel_scores, tree_heads ...@@ -32,6 +34,8 @@ def adjust_root_score_then_add_secondary_arcs(arc_scores, rel_scores, tree_heads
if len(arc_scores) != tree_heads: if len(arc_scores) != tree_heads:
arc_scores = arc_scores[:len(tree_heads)][:len(tree_heads)] arc_scores = arc_scores[:len(tree_heads)][:len(tree_heads)]
rel_scores = rel_scores[:len(tree_heads)][:len(tree_heads)] rel_scores = rel_scores[:len(tree_heads)][:len(tree_heads)]
# Self-loops aren't allowed, mask with 0. This is an in-place operation.
np.fill_diagonal(arc_scores, 0)
parse_preds = np.array(arc_scores) > 0 parse_preds = np.array(arc_scores) > 0
parse_preds[:, 0] = False # set heads to False parse_preds[:, 0] = False # set heads to False
rel_scores[:, :, root_idx] = -float('inf') rel_scores[:, :, root_idx] = -float('inf')
......
...@@ -67,7 +67,7 @@ class GraphTest(unittest.TestCase): ...@@ -67,7 +67,7 @@ class GraphTest(unittest.TestCase):
]) ])
graph_labels = np.zeros((4, 4, 3)) graph_labels = np.zeros((4, 4, 3))
graph_labels[3][1][2] = 10e10 graph_labels[3][1][2] = 10e10
expected_deps = ["0:root", "1:tree_label", "1:graph_label"] expected_deps = ["0:root", "1:tree_label", "1:graph_label|2:tree_label"]
# when # when
graph.sdp_to_dag_deps(arc_scores, graph_labels, tree.tokens, root_idx=0, vocab_index=vocab_index) graph.sdp_to_dag_deps(arc_scores, graph_labels, tree.tokens, root_idx=0, vocab_index=vocab_index)
...@@ -75,3 +75,32 @@ class GraphTest(unittest.TestCase): ...@@ -75,3 +75,32 @@ class GraphTest(unittest.TestCase):
# then # then
self.assertEqual(actual_deps, expected_deps) self.assertEqual(actual_deps, expected_deps)
def test_extending_tree_with_self_loop_edge_shouldnt_add_edge(self):
# given
tree = conllu.TokenList(
tokens=[
{"head": 0, "deprel": "root", "form": "word1"},
{"head": 1, "deprel": "tree_label", "form": "word2"},
{"head": 2, "deprel": "tree_label", "form": "word3"},
]
)
vocab_index = {0: "root", 1: "tree_label", 2: "graph_label"}
arc_scores = np.array([
[0, 0, 0, 0],
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 1],
])
graph_labels = np.zeros((4, 4, 3))
graph_labels[3][3][2] = 10e10
expected_deps = ["0:root", "1:tree_label", "2:tree_label"]
# TODO current actual, adds self-loop
# actual_deps = ["0:root", "1:tree_label", "2:tree_label|3:graph_label"]
# when
graph.sdp_to_dag_deps(arc_scores, graph_labels, tree.tokens, root_idx=0, vocab_index=vocab_index)
actual_deps = [t["deps"] for t in tree.tokens]
# then
self.assertEqual(expected_deps, actual_deps)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment