Skip to content
Snippets Groups Projects
Commit c38ef44a authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski
Browse files

Tree and graph merging algorithm.

parent 3365d6ad
Branches
No related merge requests found
...@@ -198,11 +198,22 @@ class COMBO(predictor.Predictor): ...@@ -198,11 +198,22 @@ class COMBO(predictor.Predictor):
h = np.concatenate((h[-1:], h[:-1])) h = np.concatenate((h[-1:], h[:-1]))
r = np.array(predictions["enhanced_deprel_prob"]) r = np.array(predictions["enhanced_deprel_prob"])
r = np.concatenate((r[-1:], r[:-1])) r = np.concatenate((r[-1:], r[:-1]))
graph.sdp_to_dag_deps(arc_scores=h,
rel_scores=r, graph.graph_and_tree_merge(
tree_tokens=tree_tokens, tree_arc_scores=predictions["head"],
root_idx=self.vocab.get_token_index("root", "deprel_labels"), tree_rel_scores=predictions["deprel"],
vocab_index=self.vocab.get_index_to_token_vocabulary("deprel_labels")) graph_arc_scores=h,
graph_rel_scores=r,
idx2label=self.vocab.get_index_to_token_vocabulary("deprel_labels"),
label2idx=self.vocab.get_token_to_index_vocabulary("deprel_labels"),
tokens=tree_tokens
)
# graph.sdp_to_dag_deps(arc_scores=h,
# rel_scores=r,
# tree_tokens=tree_tokens,
# root_idx=self.vocab.get_token_index("root", "deprel_labels"),
# vocab_index=self.vocab.get_index_to_token_vocabulary("deprel_labels"))
empty_tokens = graph.restore_collapse_edges(tree_tokens) empty_tokens = graph.restore_collapse_edges(tree_tokens)
tree.tokens.extend(empty_tokens) tree.tokens.extend(empty_tokens)
......
...@@ -3,6 +3,82 @@ from typing import List ...@@ -3,6 +3,82 @@ from typing import List
import numpy as np import numpy as np
_ACL_REL_CL = "acl:relcl"
def graph_and_tree_merge(tree_arc_scores,
tree_rel_scores,
graph_arc_scores,
graph_rel_scores,
label2idx,
idx2label,
tokens):
graph_arc_scores = np.copy(graph_arc_scores)
# Exclude self-loops, in-place operation.
np.fill_diagonal(graph_arc_scores, 0)
# Connection to root will be handled by tree.
graph_arc_scores[:, 0] = False
# The same with labels.
root_idx = label2idx["root"]
graph_rel_scores[:, :, root_idx] = -float('inf')
graph_rel_pred = graph_rel_scores.argmax(-1)
# Add tree edges to graph
tree_heads = [0] + tree_arc_scores
graph = [[] for _ in range(len(tree_heads))]
labeled_graph = [[] for _ in range(len(tree_heads))]
for d, h in enumerate(tree_heads):
if not d:
continue
label = idx2label[tree_rel_scores[d - 1]]
if label != _ACL_REL_CL:
graph[h].append(d)
labeled_graph[h].append((d, label))
# Debug only
# Extract graph edges
graph_edges = np.argwhere(graph_arc_scores)
# Add graph edges which aren't creating a cycle
for (d, h) in graph_edges:
if not d or not h or d in graph[h]:
continue
try:
path = next(_dfs(graph, d, h))
except StopIteration:
# There is not path from d to h
label = idx2label[graph_rel_pred[d][h]]
if label != _ACL_REL_CL:
graph[h].append(d)
labeled_graph[h].append((d, label))
# Add 'acl:relcl' without checking for cycles.
for d, h in enumerate(tree_heads):
if not d:
continue
label = idx2label[tree_rel_scores[d - 1]]
if label == _ACL_REL_CL:
graph[h].append(d)
labeled_graph[h].append((d, label))
assert len(labeled_graph[0]) == 1
d = graph[0][0]
graph[d].append(0)
labeled_graph[d].append((0, "root"))
parse_graph = [[] for _ in range(len(tree_heads))]
for h in range(len(tree_heads)):
for d, label in labeled_graph[h]:
parse_graph[d].append((h, label))
parse_graph[d] = sorted(parse_graph[d])
for i, g in enumerate(parse_graph):
heads = [x[0] for x in g]
rels = [x[1] for x in g]
deps = '|'.join(f'{h}:{r}' for h, r in zip(heads, rels))
tokens[i - 1]["deps"] = deps
return
def sdp_to_dag_deps(arc_scores, rel_scores, tree_tokens: List, root_idx=0, vocab_index=None) -> None: def sdp_to_dag_deps(arc_scores, rel_scores, tree_tokens: List, root_idx=0, vocab_index=None) -> None:
# adding ROOT # adding ROOT
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment