Skip to content
Snippets Groups Projects
Commit a2137989 authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski Committed by Mateusz Klimaszewski
Browse files

Fix deps prediction for MWE expressions.

parent 422d12c6
2 merge requests!9Enhanced dependency parsing develop to master,!8Enhanced dependency parsing
......@@ -12,7 +12,7 @@ from overrides import overrides
from combo import data
from combo.data import sentence2conllu, tokens2conllu, conllu2sentence
from combo.utils import download
from combo.utils import download, graph
logger = logging.getLogger(__name__)
......@@ -195,12 +195,11 @@ class SemanticMultitaskPredictor(predictor.Predictor):
raise NotImplementedError(f"Unknown field name {field_name}!")
if "enhanced_head" in predictions and predictions["enhanced_head"]:
import combo.utils.graph as graph
tree = graph.sdp_to_dag_deps(arc_scores=np.array(predictions["enhanced_head"]),
rel_scores=np.array(predictions["enhanced_deprel"]),
tree=tree,
root_label="ROOT",
vocab_index=self.vocab.get_index_to_token_vocabulary("deprel_labels"))
graph.sdp_to_dag_deps(arc_scores=np.array(predictions["enhanced_head"]),
rel_scores=np.array(predictions["enhanced_deprel"]),
tree_tokens=tree_tokens,
root_label="ROOT",
vocab_index=self.vocab.get_index_to_token_vocabulary("deprel_labels"))
return tree, predictions["sentence_embedding"]
......
"""Based on https://github.com/emorynlp/iwpt-shared-task-2020."""
from typing import List
import numpy as np
from conllu import TokenList
def sdp_to_dag_deps(arc_scores, rel_scores, tree: TokenList, root_label, vocab_index=None):
def sdp_to_dag_deps(arc_scores, rel_scores, tree_tokens: List, root_label, vocab_index=None) -> None:
# adding ROOT
tree_tokens = tree.tokens
tree_heads = [0] + [t["head"] for t in tree_tokens]
graph = adjust_root_score_then_add_secondary_arcs(arc_scores, rel_scores, tree_heads,
root_label)
......@@ -25,7 +25,7 @@ def sdp_to_dag_deps(arc_scores, rel_scores, tree: TokenList, root_label, vocab_i
deps = '|'.join(f'{h}:{r}' for h, r in zip(heads, rels))
tree_tokens[i - 1]["deps"] = deps
tree_tokens[i - 1]["deprel"] = deprel
return tree
return
def adjust_root_score_then_add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_idx):
......
......@@ -26,7 +26,7 @@ class GraphTest(unittest.TestCase):
expected_deps = ["2:ROOT", "3:yes", "1:yes"]
# when
tree = graph.sdp_to_dag_deps(empty_graph, graph_labels, tree, root_label)
graph.sdp_to_dag_deps(empty_graph, graph_labels, tree.tokens, root_label)
actual_deps = [t["deps"] for t in tree.tokens]
# then
......@@ -51,7 +51,7 @@ class GraphTest(unittest.TestCase):
expected_deps = ["2:ROOT", "3:graph_label", "1:graph_label"]
# when
tree = graph.sdp_to_dag_deps(empty_graph, graph_labels, tree, root_label)
graph.sdp_to_dag_deps(empty_graph, graph_labels, tree.tokens, root_label)
actual_deps = [t["deps"] for t in tree.tokens]
# then
......@@ -82,8 +82,8 @@ class GraphTest(unittest.TestCase):
expected_deps = ["0:ROOT", "1:tree_label", "1:graph_label"]
# when
tree = graph.sdp_to_dag_deps(arc_scores, graph_labels, tree, root_label)
graph.sdp_to_dag_deps(arc_scores, graph_labels, tree.tokens, root_label)
actual_deps = [t["deps"] for t in tree.tokens]
# then
self.assertEqual(actual_deps, expected_deps)
\ No newline at end of file
self.assertEqual(actual_deps, expected_deps)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment