diff --git a/combo/predict.py b/combo/predict.py index 6d657f6ae38d2716fccdedf489f2a30de4f78de7..42e8bed149595e06bee8979610e95fc38686ffbf 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -12,7 +12,7 @@ from overrides import overrides from combo import data from combo.data import sentence2conllu, tokens2conllu, conllu2sentence -from combo.utils import download +from combo.utils import download, graph logger = logging.getLogger(__name__) @@ -195,12 +195,11 @@ class SemanticMultitaskPredictor(predictor.Predictor): raise NotImplementedError(f"Unknown field name {field_name}!") if "enhanced_head" in predictions and predictions["enhanced_head"]: - import combo.utils.graph as graph - tree = graph.sdp_to_dag_deps(arc_scores=np.array(predictions["enhanced_head"]), - rel_scores=np.array(predictions["enhanced_deprel"]), - tree=tree, - root_label="ROOT", - vocab_index=self.vocab.get_index_to_token_vocabulary("deprel_labels")) + graph.sdp_to_dag_deps(arc_scores=np.array(predictions["enhanced_head"]), + rel_scores=np.array(predictions["enhanced_deprel"]), + tree_tokens=tree_tokens, + root_label="ROOT", + vocab_index=self.vocab.get_index_to_token_vocabulary("deprel_labels")) return tree, predictions["sentence_embedding"] diff --git a/combo/utils/graph.py b/combo/utils/graph.py index fd59e97b38ebaf247f74c801ac09609c7dee6582..814341b5fb473d1ec33dc93ca8b4cef5cff33b3d 100644 --- a/combo/utils/graph.py +++ b/combo/utils/graph.py @@ -1,11 +1,11 @@ """Based on https://github.com/emorynlp/iwpt-shared-task-2020.""" +from typing import List + import numpy as np -from conllu import TokenList -def sdp_to_dag_deps(arc_scores, rel_scores, tree: TokenList, root_label, vocab_index=None): +def sdp_to_dag_deps(arc_scores, rel_scores, tree_tokens: List, root_label, vocab_index=None) -> None: # adding ROOT - tree_tokens = tree.tokens tree_heads = [0] + [t["head"] for t in tree_tokens] graph = adjust_root_score_then_add_secondary_arcs(arc_scores, rel_scores, tree_heads, root_label) @@ -25,7 +25,7 @@ def sdp_to_dag_deps(arc_scores, rel_scores, tree: TokenList, root_label, vocab_i deps = '|'.join(f'{h}:{r}' for h, r in zip(heads, rels)) tree_tokens[i - 1]["deps"] = deps tree_tokens[i - 1]["deprel"] = deprel - return tree + return def adjust_root_score_then_add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_idx): diff --git a/tests/utils/test_graph.py b/tests/utils/test_graph.py index 4a65182000d3595c71531d8835ff346fac41e451..4c5c7d33fbd2c11736f902892c72e83065866466 100644 --- a/tests/utils/test_graph.py +++ b/tests/utils/test_graph.py @@ -26,7 +26,7 @@ class GraphTest(unittest.TestCase): expected_deps = ["2:ROOT", "3:yes", "1:yes"] # when - tree = graph.sdp_to_dag_deps(empty_graph, graph_labels, tree, root_label) + graph.sdp_to_dag_deps(empty_graph, graph_labels, tree.tokens, root_label) actual_deps = [t["deps"] for t in tree.tokens] # then @@ -51,7 +51,7 @@ class GraphTest(unittest.TestCase): expected_deps = ["2:ROOT", "3:graph_label", "1:graph_label"] # when - tree = graph.sdp_to_dag_deps(empty_graph, graph_labels, tree, root_label) + graph.sdp_to_dag_deps(empty_graph, graph_labels, tree.tokens, root_label) actual_deps = [t["deps"] for t in tree.tokens] # then @@ -82,8 +82,8 @@ class GraphTest(unittest.TestCase): expected_deps = ["0:ROOT", "1:tree_label", "1:graph_label"] # when - tree = graph.sdp_to_dag_deps(arc_scores, graph_labels, tree, root_label) + graph.sdp_to_dag_deps(arc_scores, graph_labels, tree.tokens, root_label) actual_deps = [t["deps"] for t in tree.tokens] # then - self.assertEqual(actual_deps, expected_deps) \ No newline at end of file + self.assertEqual(actual_deps, expected_deps)