From c38ef44af1ae83f4c43d04d6f2709afcf86910d2 Mon Sep 17 00:00:00 2001 From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com> Date: Fri, 30 Apr 2021 07:48:20 +0200 Subject: [PATCH] Tree and graph merging algorithm. --- combo/predict.py | 21 +++++++++--- combo/utils/graph.py | 76 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 92 insertions(+), 5 deletions(-) diff --git a/combo/predict.py b/combo/predict.py index e528a18..c299faa 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -198,11 +198,22 @@ class COMBO(predictor.Predictor): h = np.concatenate((h[-1:], h[:-1])) r = np.array(predictions["enhanced_deprel_prob"]) r = np.concatenate((r[-1:], r[:-1])) - graph.sdp_to_dag_deps(arc_scores=h, - rel_scores=r, - tree_tokens=tree_tokens, - root_idx=self.vocab.get_token_index("root", "deprel_labels"), - vocab_index=self.vocab.get_index_to_token_vocabulary("deprel_labels")) + + graph.graph_and_tree_merge( + tree_arc_scores=predictions["head"], + tree_rel_scores=predictions["deprel"], + graph_arc_scores=h, + graph_rel_scores=r, + idx2label=self.vocab.get_index_to_token_vocabulary("deprel_labels"), + label2idx=self.vocab.get_token_to_index_vocabulary("deprel_labels"), + tokens=tree_tokens + ) + + # graph.sdp_to_dag_deps(arc_scores=h, + # rel_scores=r, + # tree_tokens=tree_tokens, + # root_idx=self.vocab.get_token_index("root", "deprel_labels"), + # vocab_index=self.vocab.get_index_to_token_vocabulary("deprel_labels")) empty_tokens = graph.restore_collapse_edges(tree_tokens) tree.tokens.extend(empty_tokens) diff --git a/combo/utils/graph.py b/combo/utils/graph.py index 3352625..c9ad07e 100644 --- a/combo/utils/graph.py +++ b/combo/utils/graph.py @@ -3,6 +3,82 @@ from typing import List import numpy as np +_ACL_REL_CL = "acl:relcl" + + +def graph_and_tree_merge(tree_arc_scores, + tree_rel_scores, + graph_arc_scores, + graph_rel_scores, + label2idx, + idx2label, + tokens): + graph_arc_scores = np.copy(graph_arc_scores) + # Exclude self-loops, in-place operation. + np.fill_diagonal(graph_arc_scores, 0) + # Connection to root will be handled by tree. + graph_arc_scores[:, 0] = False + # The same with labels. + root_idx = label2idx["root"] + graph_rel_scores[:, :, root_idx] = -float('inf') + graph_rel_pred = graph_rel_scores.argmax(-1) + + # Add tree edges to graph + tree_heads = [0] + tree_arc_scores + graph = [[] for _ in range(len(tree_heads))] + labeled_graph = [[] for _ in range(len(tree_heads))] + for d, h in enumerate(tree_heads): + if not d: + continue + label = idx2label[tree_rel_scores[d - 1]] + if label != _ACL_REL_CL: + graph[h].append(d) + labeled_graph[h].append((d, label)) + + # Debug only + # Extract graph edges + graph_edges = np.argwhere(graph_arc_scores) + + # Add graph edges which aren't creating a cycle + for (d, h) in graph_edges: + if not d or not h or d in graph[h]: + continue + try: + path = next(_dfs(graph, d, h)) + except StopIteration: + # There is not path from d to h + label = idx2label[graph_rel_pred[d][h]] + if label != _ACL_REL_CL: + graph[h].append(d) + labeled_graph[h].append((d, label)) + + # Add 'acl:relcl' without checking for cycles. + for d, h in enumerate(tree_heads): + if not d: + continue + label = idx2label[tree_rel_scores[d - 1]] + if label == _ACL_REL_CL: + graph[h].append(d) + labeled_graph[h].append((d, label)) + + assert len(labeled_graph[0]) == 1 + d = graph[0][0] + graph[d].append(0) + labeled_graph[d].append((0, "root")) + + parse_graph = [[] for _ in range(len(tree_heads))] + for h in range(len(tree_heads)): + for d, label in labeled_graph[h]: + parse_graph[d].append((h, label)) + parse_graph[d] = sorted(parse_graph[d]) + + for i, g in enumerate(parse_graph): + heads = [x[0] for x in g] + rels = [x[1] for x in g] + deps = '|'.join(f'{h}:{r}' for h, r in zip(heads, rels)) + tokens[i - 1]["deps"] = deps + return + def sdp_to_dag_deps(arc_scores, rel_scores, tree_tokens: List, root_idx=0, vocab_index=None) -> None: # adding ROOT -- GitLab