Skip to content
Snippets Groups Projects
Commit a2137989 authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski Committed by Mateusz Klimaszewski
Browse files

Fix deps prediction for MWE expressions.

parent 422d12c6
Branches
Tags
2 merge requests!9Enhanced dependency parsing develop to master,!8Enhanced dependency parsing
This commit is part of merge request !8. Comments created here will be created in the context of that merge request.
...@@ -12,7 +12,7 @@ from overrides import overrides ...@@ -12,7 +12,7 @@ from overrides import overrides
from combo import data from combo import data
from combo.data import sentence2conllu, tokens2conllu, conllu2sentence from combo.data import sentence2conllu, tokens2conllu, conllu2sentence
from combo.utils import download from combo.utils import download, graph
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -195,12 +195,11 @@ class SemanticMultitaskPredictor(predictor.Predictor): ...@@ -195,12 +195,11 @@ class SemanticMultitaskPredictor(predictor.Predictor):
raise NotImplementedError(f"Unknown field name {field_name}!") raise NotImplementedError(f"Unknown field name {field_name}!")
if "enhanced_head" in predictions and predictions["enhanced_head"]: if "enhanced_head" in predictions and predictions["enhanced_head"]:
import combo.utils.graph as graph graph.sdp_to_dag_deps(arc_scores=np.array(predictions["enhanced_head"]),
tree = graph.sdp_to_dag_deps(arc_scores=np.array(predictions["enhanced_head"]), rel_scores=np.array(predictions["enhanced_deprel"]),
rel_scores=np.array(predictions["enhanced_deprel"]), tree_tokens=tree_tokens,
tree=tree, root_label="ROOT",
root_label="ROOT", vocab_index=self.vocab.get_index_to_token_vocabulary("deprel_labels"))
vocab_index=self.vocab.get_index_to_token_vocabulary("deprel_labels"))
return tree, predictions["sentence_embedding"] return tree, predictions["sentence_embedding"]
......
"""Based on https://github.com/emorynlp/iwpt-shared-task-2020.""" """Based on https://github.com/emorynlp/iwpt-shared-task-2020."""
from typing import List
import numpy as np import numpy as np
from conllu import TokenList
def sdp_to_dag_deps(arc_scores, rel_scores, tree: TokenList, root_label, vocab_index=None): def sdp_to_dag_deps(arc_scores, rel_scores, tree_tokens: List, root_label, vocab_index=None) -> None:
# adding ROOT # adding ROOT
tree_tokens = tree.tokens
tree_heads = [0] + [t["head"] for t in tree_tokens] tree_heads = [0] + [t["head"] for t in tree_tokens]
graph = adjust_root_score_then_add_secondary_arcs(arc_scores, rel_scores, tree_heads, graph = adjust_root_score_then_add_secondary_arcs(arc_scores, rel_scores, tree_heads,
root_label) root_label)
...@@ -25,7 +25,7 @@ def sdp_to_dag_deps(arc_scores, rel_scores, tree: TokenList, root_label, vocab_i ...@@ -25,7 +25,7 @@ def sdp_to_dag_deps(arc_scores, rel_scores, tree: TokenList, root_label, vocab_i
deps = '|'.join(f'{h}:{r}' for h, r in zip(heads, rels)) deps = '|'.join(f'{h}:{r}' for h, r in zip(heads, rels))
tree_tokens[i - 1]["deps"] = deps tree_tokens[i - 1]["deps"] = deps
tree_tokens[i - 1]["deprel"] = deprel tree_tokens[i - 1]["deprel"] = deprel
return tree return
def adjust_root_score_then_add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_idx): def adjust_root_score_then_add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_idx):
......
...@@ -26,7 +26,7 @@ class GraphTest(unittest.TestCase): ...@@ -26,7 +26,7 @@ class GraphTest(unittest.TestCase):
expected_deps = ["2:ROOT", "3:yes", "1:yes"] expected_deps = ["2:ROOT", "3:yes", "1:yes"]
# when # when
tree = graph.sdp_to_dag_deps(empty_graph, graph_labels, tree, root_label) graph.sdp_to_dag_deps(empty_graph, graph_labels, tree.tokens, root_label)
actual_deps = [t["deps"] for t in tree.tokens] actual_deps = [t["deps"] for t in tree.tokens]
# then # then
...@@ -51,7 +51,7 @@ class GraphTest(unittest.TestCase): ...@@ -51,7 +51,7 @@ class GraphTest(unittest.TestCase):
expected_deps = ["2:ROOT", "3:graph_label", "1:graph_label"] expected_deps = ["2:ROOT", "3:graph_label", "1:graph_label"]
# when # when
tree = graph.sdp_to_dag_deps(empty_graph, graph_labels, tree, root_label) graph.sdp_to_dag_deps(empty_graph, graph_labels, tree.tokens, root_label)
actual_deps = [t["deps"] for t in tree.tokens] actual_deps = [t["deps"] for t in tree.tokens]
# then # then
...@@ -82,8 +82,8 @@ class GraphTest(unittest.TestCase): ...@@ -82,8 +82,8 @@ class GraphTest(unittest.TestCase):
expected_deps = ["0:ROOT", "1:tree_label", "1:graph_label"] expected_deps = ["0:ROOT", "1:tree_label", "1:graph_label"]
# when # when
tree = graph.sdp_to_dag_deps(arc_scores, graph_labels, tree, root_label) graph.sdp_to_dag_deps(arc_scores, graph_labels, tree.tokens, root_label)
actual_deps = [t["deps"] for t in tree.tokens] actual_deps = [t["deps"] for t in tree.tokens]
# then # then
self.assertEqual(actual_deps, expected_deps) self.assertEqual(actual_deps, expected_deps)
\ No newline at end of file
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment