diff --git a/combo/models/graph_parser.py b/combo/models/graph_parser.py index a31e6d052dc8a7969bb07bf92a5eb779b7aa24d4..6799d4994222802af3cd1c6b869bd8a760523ffe 100644 --- a/combo/models/graph_parser.py +++ b/combo/models/graph_parser.py @@ -148,6 +148,7 @@ class GraphDependencyRelationModel(base.Predictor): output = head_output output["prediction"] = (relation_prediction.argmax(-1), head_output["prediction"]) + output["rel_probability"] = relation_prediction if labels is not None and labels[0] is not None: if sample_weights is None: diff --git a/combo/models/model.py b/combo/models/model.py index 124f49a19145a32647b1698637115e238ee7bdd9..710f72cf8dae932fc8f2b5c92abbaf7639a52ec2 100644 --- a/combo/models/model.py +++ b/combo/models/model.py @@ -126,6 +126,7 @@ class SemanticMultitaskModel(allen_models.Model): "deprel": relations_pred, "enhanced_head": enhanced_head_pred, "enhanced_deprel": enhanced_relations_pred, + "enhanced_deprel_prob": enhanced_parser_output["rel_probability"], "sentence_embedding": torch.max(encoder_emb[:, 1:], dim=1)[0], } diff --git a/combo/predict.py b/combo/predict.py index e262b7061e8ac9dcc1f970bd456d3a3923055ebd..070975f3ed3c7e3e7a75be91d311683bf91ff5f4 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -198,9 +198,9 @@ class SemanticMultitaskPredictor(predictor.Predictor): if "enhanced_head" in predictions and predictions["enhanced_head"]: graph.sdp_to_dag_deps(arc_scores=np.array(predictions["enhanced_head"]), - rel_scores=np.array(predictions["enhanced_deprel"]), + rel_scores=np.array(predictions["enhanced_deprel_prob"]), tree_tokens=tree_tokens, - root_label="ROOT", + root_idx=self.vocab.get_token_index("root", "deprel_labels"), vocab_index=self.vocab.get_index_to_token_vocabulary("deprel_labels")) return tree, predictions["sentence_embedding"] diff --git a/combo/utils/graph.py b/combo/utils/graph.py index 814341b5fb473d1ec33dc93ca8b4cef5cff33b3d..8a55cb9ae37a5c8985ece5dd2b11758859ce751e 100644 --- a/combo/utils/graph.py +++ b/combo/utils/graph.py @@ -4,15 +4,15 @@ from typing import List import numpy as np -def sdp_to_dag_deps(arc_scores, rel_scores, tree_tokens: List, root_label, vocab_index=None) -> None: +def sdp_to_dag_deps(arc_scores, rel_scores, tree_tokens: List, root_idx=0, vocab_index=None) -> None: # adding ROOT tree_heads = [0] + [t["head"] for t in tree_tokens] graph = adjust_root_score_then_add_secondary_arcs(arc_scores, rel_scores, tree_heads, - root_label) + root_idx) for i, (t, g) in enumerate(zip(tree_heads, graph)): if not i: continue - rels = [vocab_index.get(x[1], "ROOT") if vocab_index else x[1] for x in g] + rels = [vocab_index.get(x[1], "root") if vocab_index else x[1] for x in g] heads = [x[0] for x in g] head = tree_tokens[i - 1]["head"] index = heads.index(head) @@ -28,22 +28,23 @@ def sdp_to_dag_deps(arc_scores, rel_scores, tree_tokens: List, root_label, vocab return -def adjust_root_score_then_add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_idx): +def adjust_root_score_then_add_secondary_arcs(arc_scores, rel_scores, tree_heads, root_idx): if len(arc_scores) != tree_heads: arc_scores = arc_scores[:len(tree_heads)][:len(tree_heads)] - rel_labels = rel_labels[:len(tree_heads)][:len(tree_heads)] + rel_scores = rel_scores[:len(tree_heads)][:len(tree_heads)] parse_preds = np.array(arc_scores) > 0 parse_preds[:, 0] = False # set heads to False - # rel_labels[:, :, root_idx] = -float('inf') - return add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_idx, parse_preds) + rel_scores[:, :, root_idx] = -float('inf') + return add_secondary_arcs(arc_scores, rel_scores, tree_heads, root_idx, parse_preds) -def add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_label, parse_preds): +def add_secondary_arcs(arc_scores, rel_scores, tree_heads, root_idx, parse_preds): if not isinstance(tree_heads, np.ndarray): tree_heads = np.array(tree_heads) dh = np.argwhere(parse_preds) sdh = sorted([(arc_scores[x[0]][x[1]], list(x)) for x in dh], reverse=True) graph = [[] for _ in range(len(tree_heads))] + rel_pred = np.argmax(rel_scores, axis=-1) for d, h in enumerate(tree_heads): if d: graph[h].append(d) @@ -59,9 +60,9 @@ def add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_label, parse_pre num_root = 0 for h in range(len(tree_heads)): for d in graph[h]: - rel = rel_labels[d][h] + rel = rel_pred[d][h] if h == 0: - rel = root_label + rel = root_idx assert num_root == 0 num_root += 1 parse_graph[d].append((h, rel)) diff --git a/tests/utils/test_graph.py b/tests/utils/test_graph.py index 4c5c7d33fbd2c11736f902892c72e83065866466..0a662122019f2f4ada380a685e207856d84a3646 100644 --- a/tests/utils/test_graph.py +++ b/tests/utils/test_graph.py @@ -10,48 +10,40 @@ class GraphTest(unittest.TestCase): def test_adding_empty_graph_with_the_same_labels(self): tree = conllu.TokenList( tokens=[ - {"head": 2, "deprel": "ROOT", "form": "word1"}, + {"head": 0, "deprel": "root", "form": "word1"}, {"head": 3, "deprel": "yes", "form": "word2"}, {"head": 1, "deprel": "yes", "form": "word3"}, ] ) + vocab_index = {0: "root", 1: "yes", 2: "yes", 3: "yes"} empty_graph = np.zeros((4, 4)) - graph_labels = np.array([ - ["no", "no", "no", "no"], - ["no", "no", "ROOT", "no"], - ["no", "no", "no", "yes"], - ["no", "yes", "no", "no"], - ]) - root_label = "ROOT" - expected_deps = ["2:ROOT", "3:yes", "1:yes"] + graph_labels = np.zeros((4, 4, 4)) + expected_deps = ["0:root", "3:yes", "1:yes"] # when - graph.sdp_to_dag_deps(empty_graph, graph_labels, tree.tokens, root_label) + graph.sdp_to_dag_deps(empty_graph, graph_labels, tree.tokens, root_idx=0, vocab_index=vocab_index) actual_deps = [t["deps"] for t in tree.tokens] # then - self.assertEqual(actual_deps, expected_deps) + self.assertEqual(expected_deps, actual_deps) def test_adding_empty_graph_with_different_labels(self): tree = conllu.TokenList( tokens=[ - {"head": 2, "deprel": "ROOT", "form": "word1"}, + {"head": 0, "deprel": "root", "form": "word1"}, {"head": 3, "deprel": "tree_label", "form": "word2"}, {"head": 1, "deprel": "tree_label", "form": "word3"}, ] ) + vocab_index = {0: "root", 1: "tree_label", 2: "graph_label"} empty_graph = np.zeros((4, 4)) - graph_labels = np.array([ - ["no", "no", "no", "no"], - ["no", "no", "ROOT", "no"], - ["no", "no", "no", "graph_label"], - ["no", "graph_label", "no", "no"], - ]) - root_label = "ROOT" - expected_deps = ["2:ROOT", "3:graph_label", "1:graph_label"] + graph_labels = np.zeros((4, 4, 3)) + graph_labels[2][3][2] = 10e10 + graph_labels[3][1][2] = 10e10 + expected_deps = ["0:root", "3:graph_label", "1:graph_label"] # when - graph.sdp_to_dag_deps(empty_graph, graph_labels, tree.tokens, root_label) + graph.sdp_to_dag_deps(empty_graph, graph_labels, tree.tokens, root_idx=0, vocab_index=vocab_index) actual_deps = [t["deps"] for t in tree.tokens] # then @@ -61,28 +53,24 @@ class GraphTest(unittest.TestCase): # given tree = conllu.TokenList( tokens=[ - {"head": 0, "deprel": "ROOT", "form": "word1"}, + {"head": 0, "deprel": "root", "form": "word1"}, {"head": 1, "deprel": "tree_label", "form": "word2"}, {"head": 2, "deprel": "tree_label", "form": "word3"}, ] ) + vocab_index = {0: "root", 1: "tree_label", 2: "graph_label"} arc_scores = np.array([ [0, 0, 0, 0], [1, 0, 0, 0], [0, 1, 0, 0], [0, 1, 1, 0], ]) - graph_labels = np.array([ - ["no", "no", "no", "no"], - ["ROOT", "no", "no", "no"], - ["no", "tree_label", "no", "no"], - ["no", "graph_label", "tree_label", "no"], - ]) - root_label = "ROOT" - expected_deps = ["0:ROOT", "1:tree_label", "1:graph_label"] + graph_labels = np.zeros((4, 4, 3)) + graph_labels[3][1][2] = 10e10 + expected_deps = ["0:root", "1:tree_label", "1:graph_label"] # when - graph.sdp_to_dag_deps(arc_scores, graph_labels, tree.tokens, root_label) + graph.sdp_to_dag_deps(arc_scores, graph_labels, tree.tokens, root_idx=0, vocab_index=vocab_index) actual_deps = [t["deps"] for t in tree.tokens] # then