Skip to content
Snippets Groups Projects

Release 1.0.4

Merged Mateusz Klimaszewski requested to merge candidate_release_1.0.4 into develop
Viewing commit a3774ab6
Show latest version
2 files
+ 92
5
Compare changes
  • Side-by-side
  • Inline
Files
2
+ 76
0
@@ -3,6 +3,82 @@ from typing import List
import numpy as np
_ACL_REL_CL = "acl:relcl"
def graph_and_tree_merge(tree_arc_scores,
tree_rel_scores,
graph_arc_scores,
graph_rel_scores,
label2idx,
idx2label,
tokens):
graph_arc_scores = np.copy(graph_arc_scores)
# Exclude self-loops, in-place operation.
np.fill_diagonal(graph_arc_scores, 0)
# Connection to root will be handled by tree.
graph_arc_scores[:, 0] = False
# The same with labels.
root_idx = label2idx["root"]
graph_rel_scores[:, :, root_idx] = -float('inf')
graph_rel_pred = graph_rel_scores.argmax(-1)
# Add tree edges to graph
tree_heads = [0] + tree_arc_scores
graph = [[] for _ in range(len(tree_heads))]
labeled_graph = [[] for _ in range(len(tree_heads))]
for d, h in enumerate(tree_heads):
if not d:
continue
label = idx2label[tree_rel_scores[d - 1]]
if label != _ACL_REL_CL:
graph[h].append(d)
labeled_graph[h].append((d, label))
# Debug only
# Extract graph edges
graph_edges = np.argwhere(graph_arc_scores)
# Add graph edges which aren't creating a cycle
for (d, h) in graph_edges:
if not d or not h or d in graph[h]:
continue
try:
path = next(_dfs(graph, d, h))
except StopIteration:
# There is not path from d to h
label = idx2label[graph_rel_pred[d][h]]
if label != _ACL_REL_CL:
graph[h].append(d)
labeled_graph[h].append((d, label))
# Add 'acl:relcl' without checking for cycles.
for d, h in enumerate(tree_heads):
if not d:
continue
label = idx2label[tree_rel_scores[d - 1]]
if label == _ACL_REL_CL:
graph[h].append(d)
labeled_graph[h].append((d, label))
assert len(labeled_graph[0]) == 1
d = graph[0][0]
graph[d].append(0)
labeled_graph[d].append((0, "root"))
parse_graph = [[] for _ in range(len(tree_heads))]
for h in range(len(tree_heads)):
for d, label in labeled_graph[h]:
parse_graph[d].append((h, label))
parse_graph[d] = sorted(parse_graph[d])
for i, g in enumerate(parse_graph):
heads = [x[0] for x in g]
rels = [x[1] for x in g]
deps = '|'.join(f'{h}:{r}' for h, r in zip(heads, rels))
tokens[i - 1]["deps"] = deps
return
def sdp_to_dag_deps(arc_scores, rel_scores, tree_tokens: List, root_idx=0, vocab_index=None) -> None:
# adding ROOT