Commit a3774ab6 authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski

Tree and graph merging algorithm.

parent 3123cced
......@@ -201,11 +201,22 @@ class COMBO(predictor.Predictor):
h = np.concatenate((h[-1:], h[:-1]))
r = np.array(predictions["enhanced_deprel_prob"])
r = np.concatenate((r[-1:], r[:-1]))
graph.sdp_to_dag_deps(arc_scores=h,
rel_scores=r,
tree_tokens=tree_tokens,
root_idx=self.vocab.get_token_index("root", "deprel_labels"),
vocab_index=self.vocab.get_index_to_token_vocabulary("deprel_labels"))
graph.graph_and_tree_merge(
tree_arc_scores=predictions["head"],
tree_rel_scores=predictions["deprel"],
graph_arc_scores=h,
graph_rel_scores=r,
idx2label=self.vocab.get_index_to_token_vocabulary("deprel_labels"),
label2idx=self.vocab.get_token_to_index_vocabulary("deprel_labels"),
tokens=tree_tokens
)
# graph.sdp_to_dag_deps(arc_scores=h,
# rel_scores=r,
# tree_tokens=tree_tokens,
# root_idx=self.vocab.get_token_index("root", "deprel_labels"),
# vocab_index=self.vocab.get_index_to_token_vocabulary("deprel_labels"))
empty_tokens = graph.restore_collapse_edges(tree_tokens)
tree.tokens.extend(empty_tokens)
......
......@@ -3,6 +3,82 @@ from typing import List
import numpy as np
_ACL_REL_CL = "acl:relcl"
def graph_and_tree_merge(tree_arc_scores,
tree_rel_scores,
graph_arc_scores,
graph_rel_scores,
label2idx,
idx2label,
tokens):
graph_arc_scores = np.copy(graph_arc_scores)
# Exclude self-loops, in-place operation.
np.fill_diagonal(graph_arc_scores, 0)
# Connection to root will be handled by tree.
graph_arc_scores[:, 0] = False
# The same with labels.
root_idx = label2idx["root"]
graph_rel_scores[:, :, root_idx] = -float('inf')
graph_rel_pred = graph_rel_scores.argmax(-1)
# Add tree edges to graph
tree_heads = [0] + tree_arc_scores
graph = [[] for _ in range(len(tree_heads))]
labeled_graph = [[] for _ in range(len(tree_heads))]
for d, h in enumerate(tree_heads):
if not d:
continue
label = idx2label[tree_rel_scores[d - 1]]
if label != _ACL_REL_CL:
graph[h].append(d)
labeled_graph[h].append((d, label))
# Debug only
# Extract graph edges
graph_edges = np.argwhere(graph_arc_scores)
# Add graph edges which aren't creating a cycle
for (d, h) in graph_edges:
if not d or not h or d in graph[h]:
continue
try:
path = next(_dfs(graph, d, h))
except StopIteration:
# There is not path from d to h
label = idx2label[graph_rel_pred[d][h]]
if label != _ACL_REL_CL:
graph[h].append(d)
labeled_graph[h].append((d, label))
# Add 'acl:relcl' without checking for cycles.
for d, h in enumerate(tree_heads):
if not d:
continue
label = idx2label[tree_rel_scores[d - 1]]
if label == _ACL_REL_CL:
graph[h].append(d)
labeled_graph[h].append((d, label))
assert len(labeled_graph[0]) == 1
d = graph[0][0]
graph[d].append(0)
labeled_graph[d].append((0, "root"))
parse_graph = [[] for _ in range(len(tree_heads))]
for h in range(len(tree_heads)):
for d, label in labeled_graph[h]:
parse_graph[d].append((h, label))
parse_graph[d] = sorted(parse_graph[d])
for i, g in enumerate(parse_graph):
heads = [x[0] for x in g]
rels = [x[1] for x in g]
deps = '|'.join(f'{h}:{r}' for h, r in zip(heads, rels))
tokens[i - 1]["deps"] = deps
return
def sdp_to_dag_deps(arc_scores, rel_scores, tree_tokens: List, root_idx=0, vocab_index=None) -> None:
# adding ROOT
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment