Skip to content
Snippets Groups Projects
Commit e295e563 authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski Committed by Mateusz Klimaszewski
Browse files

Exclude self loops.

parent fd024dcc
No related branches found
No related tags found
2 merge requests!9Enhanced dependency parsing develop to master,!8Enhanced dependency parsing
......@@ -18,10 +18,12 @@ def sdp_to_dag_deps(arc_scores, rel_scores, tree_tokens: List, root_idx=0, vocab
index = heads.index(head)
deprel = tree_tokens[i - 1]["deprel"]
deprel = deprel.split('>')[-1]
# TODO is this necessary?
if len(heads) >= 2:
heads.pop(index)
rels.pop(index)
# TODO - Consider if there should be a condition,
# It doesn't seem to make any sense as DEPS should contain DEPREL
# (although sometimes with different/more detailed label)
# if len(heads) >= 2:
# heads.pop(index)
# rels.pop(index)
deps = '|'.join(f'{h}:{r}' for h, r in zip(heads, rels))
tree_tokens[i - 1]["deps"] = deps
tree_tokens[i - 1]["deprel"] = deprel
......@@ -32,6 +34,8 @@ def adjust_root_score_then_add_secondary_arcs(arc_scores, rel_scores, tree_heads
if len(arc_scores) != tree_heads:
arc_scores = arc_scores[:len(tree_heads)][:len(tree_heads)]
rel_scores = rel_scores[:len(tree_heads)][:len(tree_heads)]
# Self-loops aren't allowed, mask with 0. This is an in-place operation.
np.fill_diagonal(arc_scores, 0)
parse_preds = np.array(arc_scores) > 0
parse_preds[:, 0] = False # set heads to False
rel_scores[:, :, root_idx] = -float('inf')
......
......@@ -67,7 +67,7 @@ class GraphTest(unittest.TestCase):
])
graph_labels = np.zeros((4, 4, 3))
graph_labels[3][1][2] = 10e10
expected_deps = ["0:root", "1:tree_label", "1:graph_label"]
expected_deps = ["0:root", "1:tree_label", "1:graph_label|2:tree_label"]
# when
graph.sdp_to_dag_deps(arc_scores, graph_labels, tree.tokens, root_idx=0, vocab_index=vocab_index)
......@@ -75,3 +75,32 @@ class GraphTest(unittest.TestCase):
# then
self.assertEqual(actual_deps, expected_deps)
def test_extending_tree_with_self_loop_edge_shouldnt_add_edge(self):
# given
tree = conllu.TokenList(
tokens=[
{"head": 0, "deprel": "root", "form": "word1"},
{"head": 1, "deprel": "tree_label", "form": "word2"},
{"head": 2, "deprel": "tree_label", "form": "word3"},
]
)
vocab_index = {0: "root", 1: "tree_label", 2: "graph_label"}
arc_scores = np.array([
[0, 0, 0, 0],
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 1],
])
graph_labels = np.zeros((4, 4, 3))
graph_labels[3][3][2] = 10e10
expected_deps = ["0:root", "1:tree_label", "2:tree_label"]
# TODO current actual, adds self-loop
# actual_deps = ["0:root", "1:tree_label", "2:tree_label|3:graph_label"]
# when
graph.sdp_to_dag_deps(arc_scores, graph_labels, tree.tokens, root_idx=0, vocab_index=vocab_index)
actual_deps = [t["deps"] for t in tree.tokens]
# then
self.assertEqual(expected_deps, actual_deps)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment