Skip to content
Snippets Groups Projects
Commit a2137989 authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski Committed by Mateusz Klimaszewski
Browse files

Fix deps prediction for MWE expressions.

parent 422d12c6
Branches
Tags
2 merge requests!9Enhanced dependency parsing develop to master,!8Enhanced dependency parsing
......@@ -12,7 +12,7 @@ from overrides import overrides
from combo import data
from combo.data import sentence2conllu, tokens2conllu, conllu2sentence
from combo.utils import download
from combo.utils import download, graph
logger = logging.getLogger(__name__)
......@@ -195,12 +195,11 @@ class SemanticMultitaskPredictor(predictor.Predictor):
raise NotImplementedError(f"Unknown field name {field_name}!")
if "enhanced_head" in predictions and predictions["enhanced_head"]:
import combo.utils.graph as graph
tree = graph.sdp_to_dag_deps(arc_scores=np.array(predictions["enhanced_head"]),
rel_scores=np.array(predictions["enhanced_deprel"]),
tree=tree,
root_label="ROOT",
vocab_index=self.vocab.get_index_to_token_vocabulary("deprel_labels"))
graph.sdp_to_dag_deps(arc_scores=np.array(predictions["enhanced_head"]),
rel_scores=np.array(predictions["enhanced_deprel"]),
tree_tokens=tree_tokens,
root_label="ROOT",
vocab_index=self.vocab.get_index_to_token_vocabulary("deprel_labels"))
return tree, predictions["sentence_embedding"]
......
"""Based on https://github.com/emorynlp/iwpt-shared-task-2020."""
from typing import List
import numpy as np
from conllu import TokenList
def sdp_to_dag_deps(arc_scores, rel_scores, tree: TokenList, root_label, vocab_index=None):
def sdp_to_dag_deps(arc_scores, rel_scores, tree_tokens: List, root_label, vocab_index=None) -> None:
# adding ROOT
tree_tokens = tree.tokens
tree_heads = [0] + [t["head"] for t in tree_tokens]
graph = adjust_root_score_then_add_secondary_arcs(arc_scores, rel_scores, tree_heads,
root_label)
......@@ -25,7 +25,7 @@ def sdp_to_dag_deps(arc_scores, rel_scores, tree: TokenList, root_label, vocab_i
deps = '|'.join(f'{h}:{r}' for h, r in zip(heads, rels))
tree_tokens[i - 1]["deps"] = deps
tree_tokens[i - 1]["deprel"] = deprel
return tree
return
def adjust_root_score_then_add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_idx):
......
......@@ -26,7 +26,7 @@ class GraphTest(unittest.TestCase):
expected_deps = ["2:ROOT", "3:yes", "1:yes"]
# when
tree = graph.sdp_to_dag_deps(empty_graph, graph_labels, tree, root_label)
graph.sdp_to_dag_deps(empty_graph, graph_labels, tree.tokens, root_label)
actual_deps = [t["deps"] for t in tree.tokens]
# then
......@@ -51,7 +51,7 @@ class GraphTest(unittest.TestCase):
expected_deps = ["2:ROOT", "3:graph_label", "1:graph_label"]
# when
tree = graph.sdp_to_dag_deps(empty_graph, graph_labels, tree, root_label)
graph.sdp_to_dag_deps(empty_graph, graph_labels, tree.tokens, root_label)
actual_deps = [t["deps"] for t in tree.tokens]
# then
......@@ -82,8 +82,8 @@ class GraphTest(unittest.TestCase):
expected_deps = ["0:ROOT", "1:tree_label", "1:graph_label"]
# when
tree = graph.sdp_to_dag_deps(arc_scores, graph_labels, tree, root_label)
graph.sdp_to_dag_deps(arc_scores, graph_labels, tree.tokens, root_label)
actual_deps = [t["deps"] for t in tree.tokens]
# then
self.assertEqual(actual_deps, expected_deps)
\ No newline at end of file
self.assertEqual(actual_deps, expected_deps)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment