Skip to content
Snippets Groups Projects
Commit e295e563 authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski Committed by Mateusz Klimaszewski
Browse files

Exclude self loops.

parent fd024dcc
Branches
Tags
2 merge requests!9Enhanced dependency parsing develop to master,!8Enhanced dependency parsing
......@@ -18,10 +18,12 @@ def sdp_to_dag_deps(arc_scores, rel_scores, tree_tokens: List, root_idx=0, vocab
index = heads.index(head)
deprel = tree_tokens[i - 1]["deprel"]
deprel = deprel.split('>')[-1]
# TODO is this necessary?
if len(heads) >= 2:
heads.pop(index)
rels.pop(index)
# TODO - Consider if there should be a condition,
# It doesn't seem to make any sense as DEPS should contain DEPREL
# (although sometimes with different/more detailed label)
# if len(heads) >= 2:
# heads.pop(index)
# rels.pop(index)
deps = '|'.join(f'{h}:{r}' for h, r in zip(heads, rels))
tree_tokens[i - 1]["deps"] = deps
tree_tokens[i - 1]["deprel"] = deprel
......@@ -32,6 +34,8 @@ def adjust_root_score_then_add_secondary_arcs(arc_scores, rel_scores, tree_heads
if len(arc_scores) != tree_heads:
arc_scores = arc_scores[:len(tree_heads)][:len(tree_heads)]
rel_scores = rel_scores[:len(tree_heads)][:len(tree_heads)]
# Self-loops aren't allowed, mask with 0. This is an in-place operation.
np.fill_diagonal(arc_scores, 0)
parse_preds = np.array(arc_scores) > 0
parse_preds[:, 0] = False # set heads to False
rel_scores[:, :, root_idx] = -float('inf')
......
......@@ -67,7 +67,7 @@ class GraphTest(unittest.TestCase):
])
graph_labels = np.zeros((4, 4, 3))
graph_labels[3][1][2] = 10e10
expected_deps = ["0:root", "1:tree_label", "1:graph_label"]
expected_deps = ["0:root", "1:tree_label", "1:graph_label|2:tree_label"]
# when
graph.sdp_to_dag_deps(arc_scores, graph_labels, tree.tokens, root_idx=0, vocab_index=vocab_index)
......@@ -75,3 +75,32 @@ class GraphTest(unittest.TestCase):
# then
self.assertEqual(actual_deps, expected_deps)
def test_extending_tree_with_self_loop_edge_shouldnt_add_edge(self):
# given
tree = conllu.TokenList(
tokens=[
{"head": 0, "deprel": "root", "form": "word1"},
{"head": 1, "deprel": "tree_label", "form": "word2"},
{"head": 2, "deprel": "tree_label", "form": "word3"},
]
)
vocab_index = {0: "root", 1: "tree_label", 2: "graph_label"}
arc_scores = np.array([
[0, 0, 0, 0],
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 1],
])
graph_labels = np.zeros((4, 4, 3))
graph_labels[3][3][2] = 10e10
expected_deps = ["0:root", "1:tree_label", "2:tree_label"]
# TODO current actual, adds self-loop
# actual_deps = ["0:root", "1:tree_label", "2:tree_label|3:graph_label"]
# when
graph.sdp_to_dag_deps(arc_scores, graph_labels, tree.tokens, root_idx=0, vocab_index=vocab_index)
actual_deps = [t["deps"] for t in tree.tokens]
# then
self.assertEqual(expected_deps, actual_deps)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment