Commit 0f6faf2a authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski

Remove emorynlp merging algorithm.

parent 94220582
......@@ -214,11 +214,6 @@ class COMBO(predictor.Predictor):
tokens=tree_tokens
)
# graph.sdp_to_dag_deps(arc_scores=h,
# rel_scores=r,
# tree_tokens=tree_tokens,
# root_idx=self.vocab.get_token_index("root", "deprel_labels"),
# vocab_index=self.vocab.get_index_to_token_vocabulary("deprel_labels"))
empty_tokens = graph.restore_collapse_edges(tree_tokens)
tree.tokens.extend(empty_tokens)
......
"""Based on https://github.com/emorynlp/iwpt-shared-task-2020."""
from typing import List
import numpy as np
......@@ -82,76 +81,6 @@ def graph_and_tree_merge(tree_arc_scores,
return
def sdp_to_dag_deps(arc_scores, rel_scores, tree_tokens: List, root_idx=0, vocab_index=None) -> None:
# adding ROOT
tree_heads = [0] + [t["head"] for t in tree_tokens]
graph = adjust_root_score_then_add_secondary_arcs(arc_scores, rel_scores, tree_heads,
root_idx)
for i, (t, g) in enumerate(zip(tree_heads, graph)):
if not i:
continue
rels = [vocab_index.get(x[1], "root") if vocab_index else x[1] for x in g]
heads = [x[0] for x in g]
head = tree_tokens[i - 1]["head"]
index = heads.index(head)
deprel = tree_tokens[i - 1]["deprel"]
deprel = deprel.split('>')[-1]
# TODO - Consider if there should be a condition,
# It doesn't seem to make any sense as DEPS should contain DEPREL
# (although sometimes with different/more detailed label)
# if len(heads) >= 2:
# heads.pop(index)
# rels.pop(index)
deps = '|'.join(f'{h}:{r}' for h, r in zip(heads, rels))
tree_tokens[i - 1]["deps"] = deps
tree_tokens[i - 1]["deprel"] = deprel
return
def adjust_root_score_then_add_secondary_arcs(arc_scores, rel_scores, tree_heads, root_idx):
if len(arc_scores) != tree_heads:
arc_scores = arc_scores[:len(tree_heads)][:len(tree_heads)]
rel_scores = rel_scores[:len(tree_heads)][:len(tree_heads)]
# Self-loops aren't allowed, mask with 0. This is an in-place operation.
np.fill_diagonal(arc_scores, 0)
parse_preds = np.array(arc_scores) > 0
parse_preds[:, 0] = False # set heads to False
rel_scores[:, :, root_idx] = -float('inf')
return add_secondary_arcs(arc_scores, rel_scores, tree_heads, root_idx, parse_preds)
def add_secondary_arcs(arc_scores, rel_scores, tree_heads, root_idx, parse_preds):
if not isinstance(tree_heads, np.ndarray):
tree_heads = np.array(tree_heads)
dh = np.argwhere(parse_preds)
sdh = sorted([(arc_scores[x[0]][x[1]], list(x)) for x in dh], reverse=True)
graph = [[] for _ in range(len(tree_heads))]
rel_pred = np.argmax(rel_scores, axis=-1)
for d, h in enumerate(tree_heads):
if d:
graph[h].append(d)
for s, (d, h) in sdh:
if not d or not h or d in graph[h]:
continue
try:
path = next(_dfs(graph, d, h))
except StopIteration:
# no path from d to h
graph[h].append(d)
parse_graph = [[] for _ in range(len(tree_heads))]
num_root = 0
for h in range(len(tree_heads)):
for d in graph[h]:
rel = rel_pred[d][h]
if h == 0:
rel = root_idx
assert num_root == 0
num_root += 1
parse_graph[d].append((h, rel))
parse_graph[d] = sorted(parse_graph[d])
return parse_graph
def _dfs(graph, start, end):
fringe = [(start, [])]
while fringe:
......
import unittest
import combo.utils.graph as graph
import conllu
import numpy as np
class GraphTest(unittest.TestCase):
def test_adding_empty_graph_with_the_same_labels(self):
tree = conllu.TokenList(
tokens=[
{"head": 0, "deprel": "root", "form": "word1"},
{"head": 3, "deprel": "yes", "form": "word2"},
{"head": 1, "deprel": "yes", "form": "word3"},
]
)
vocab_index = {0: "root", 1: "yes", 2: "yes", 3: "yes"}
empty_graph = np.zeros((4, 4))
graph_labels = np.zeros((4, 4, 4))
expected_deps = ["0:root", "3:yes", "1:yes"]
# when
graph.sdp_to_dag_deps(empty_graph, graph_labels, tree.tokens, root_idx=0, vocab_index=vocab_index)
actual_deps = [t["deps"] for t in tree.tokens]
# then
self.assertEqual(expected_deps, actual_deps)
def test_adding_empty_graph_with_different_labels(self):
tree = conllu.TokenList(
tokens=[
{"head": 0, "deprel": "root", "form": "word1"},
{"head": 3, "deprel": "tree_label", "form": "word2"},
{"head": 1, "deprel": "tree_label", "form": "word3"},
]
)
vocab_index = {0: "root", 1: "tree_label", 2: "graph_label"}
empty_graph = np.zeros((4, 4))
graph_labels = np.zeros((4, 4, 3))
graph_labels[2][3][2] = 10e10
graph_labels[3][1][2] = 10e10
expected_deps = ["0:root", "3:graph_label", "1:graph_label"]
# when
graph.sdp_to_dag_deps(empty_graph, graph_labels, tree.tokens, root_idx=0, vocab_index=vocab_index)
actual_deps = [t["deps"] for t in tree.tokens]
# then
self.assertEqual(actual_deps, expected_deps)
def test_extending_tree_with_graph(self):
# given
tree = conllu.TokenList(
tokens=[
{"head": 0, "deprel": "root", "form": "word1"},
{"head": 1, "deprel": "tree_label", "form": "word2"},
{"head": 2, "deprel": "tree_label", "form": "word3"},
]
)
vocab_index = {0: "root", 1: "tree_label", 2: "graph_label"}
arc_scores = np.array([
[0, 0, 0, 0],
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 1, 1, 0],
])
graph_labels = np.zeros((4, 4, 3))
graph_labels[3][1][2] = 10e10
expected_deps = ["0:root", "1:tree_label", "1:graph_label|2:tree_label"]
# when
graph.sdp_to_dag_deps(arc_scores, graph_labels, tree.tokens, root_idx=0, vocab_index=vocab_index)
actual_deps = [t["deps"] for t in tree.tokens]
# then
self.assertEqual(actual_deps, expected_deps)
def test_extending_tree_with_self_loop_edge_shouldnt_add_edge(self):
# given
tree = conllu.TokenList(
tokens=[
{"head": 0, "deprel": "root", "form": "word1"},
{"head": 1, "deprel": "tree_label", "form": "word2"},
{"head": 2, "deprel": "tree_label", "form": "word3"},
]
)
vocab_index = {0: "root", 1: "tree_label", 2: "graph_label"}
arc_scores = np.array([
[0, 0, 0, 0],
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 1],
])
graph_labels = np.zeros((4, 4, 3))
graph_labels[3][3][2] = 10e10
expected_deps = ["0:root", "1:tree_label", "2:tree_label"]
# TODO current actual, adds self-loop
# actual_deps = ["0:root", "1:tree_label", "2:tree_label|3:graph_label"]
# when
graph.sdp_to_dag_deps(arc_scores, graph_labels, tree.tokens, root_idx=0, vocab_index=vocab_index)
actual_deps = [t["deps"] for t in tree.tokens]
# then
self.assertEqual(expected_deps, actual_deps)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment