Skip to content
Snippets Groups Projects
Commit fd024dcc authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski Committed by Mateusz Klimaszewski
Browse files

Pass relation probabilities in graph extraction.

parent 50ea5469
Branches
Tags
2 merge requests!9Enhanced dependency parsing develop to master,!8Enhanced dependency parsing
......@@ -148,6 +148,7 @@ class GraphDependencyRelationModel(base.Predictor):
output = head_output
output["prediction"] = (relation_prediction.argmax(-1), head_output["prediction"])
output["rel_probability"] = relation_prediction
if labels is not None and labels[0] is not None:
if sample_weights is None:
......
......@@ -126,6 +126,7 @@ class SemanticMultitaskModel(allen_models.Model):
"deprel": relations_pred,
"enhanced_head": enhanced_head_pred,
"enhanced_deprel": enhanced_relations_pred,
"enhanced_deprel_prob": enhanced_parser_output["rel_probability"],
"sentence_embedding": torch.max(encoder_emb[:, 1:], dim=1)[0],
}
......
......@@ -198,9 +198,9 @@ class SemanticMultitaskPredictor(predictor.Predictor):
if "enhanced_head" in predictions and predictions["enhanced_head"]:
graph.sdp_to_dag_deps(arc_scores=np.array(predictions["enhanced_head"]),
rel_scores=np.array(predictions["enhanced_deprel"]),
rel_scores=np.array(predictions["enhanced_deprel_prob"]),
tree_tokens=tree_tokens,
root_label="ROOT",
root_idx=self.vocab.get_token_index("root", "deprel_labels"),
vocab_index=self.vocab.get_index_to_token_vocabulary("deprel_labels"))
return tree, predictions["sentence_embedding"]
......
......@@ -4,15 +4,15 @@ from typing import List
import numpy as np
def sdp_to_dag_deps(arc_scores, rel_scores, tree_tokens: List, root_label, vocab_index=None) -> None:
def sdp_to_dag_deps(arc_scores, rel_scores, tree_tokens: List, root_idx=0, vocab_index=None) -> None:
# adding ROOT
tree_heads = [0] + [t["head"] for t in tree_tokens]
graph = adjust_root_score_then_add_secondary_arcs(arc_scores, rel_scores, tree_heads,
root_label)
root_idx)
for i, (t, g) in enumerate(zip(tree_heads, graph)):
if not i:
continue
rels = [vocab_index.get(x[1], "ROOT") if vocab_index else x[1] for x in g]
rels = [vocab_index.get(x[1], "root") if vocab_index else x[1] for x in g]
heads = [x[0] for x in g]
head = tree_tokens[i - 1]["head"]
index = heads.index(head)
......@@ -28,22 +28,23 @@ def sdp_to_dag_deps(arc_scores, rel_scores, tree_tokens: List, root_label, vocab
return
def adjust_root_score_then_add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_idx):
def adjust_root_score_then_add_secondary_arcs(arc_scores, rel_scores, tree_heads, root_idx):
if len(arc_scores) != tree_heads:
arc_scores = arc_scores[:len(tree_heads)][:len(tree_heads)]
rel_labels = rel_labels[:len(tree_heads)][:len(tree_heads)]
rel_scores = rel_scores[:len(tree_heads)][:len(tree_heads)]
parse_preds = np.array(arc_scores) > 0
parse_preds[:, 0] = False # set heads to False
# rel_labels[:, :, root_idx] = -float('inf')
return add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_idx, parse_preds)
rel_scores[:, :, root_idx] = -float('inf')
return add_secondary_arcs(arc_scores, rel_scores, tree_heads, root_idx, parse_preds)
def add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_label, parse_preds):
def add_secondary_arcs(arc_scores, rel_scores, tree_heads, root_idx, parse_preds):
if not isinstance(tree_heads, np.ndarray):
tree_heads = np.array(tree_heads)
dh = np.argwhere(parse_preds)
sdh = sorted([(arc_scores[x[0]][x[1]], list(x)) for x in dh], reverse=True)
graph = [[] for _ in range(len(tree_heads))]
rel_pred = np.argmax(rel_scores, axis=-1)
for d, h in enumerate(tree_heads):
if d:
graph[h].append(d)
......@@ -59,9 +60,9 @@ def add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_label, parse_pre
num_root = 0
for h in range(len(tree_heads)):
for d in graph[h]:
rel = rel_labels[d][h]
rel = rel_pred[d][h]
if h == 0:
rel = root_label
rel = root_idx
assert num_root == 0
num_root += 1
parse_graph[d].append((h, rel))
......
......@@ -10,48 +10,40 @@ class GraphTest(unittest.TestCase):
def test_adding_empty_graph_with_the_same_labels(self):
tree = conllu.TokenList(
tokens=[
{"head": 2, "deprel": "ROOT", "form": "word1"},
{"head": 0, "deprel": "root", "form": "word1"},
{"head": 3, "deprel": "yes", "form": "word2"},
{"head": 1, "deprel": "yes", "form": "word3"},
]
)
vocab_index = {0: "root", 1: "yes", 2: "yes", 3: "yes"}
empty_graph = np.zeros((4, 4))
graph_labels = np.array([
["no", "no", "no", "no"],
["no", "no", "ROOT", "no"],
["no", "no", "no", "yes"],
["no", "yes", "no", "no"],
])
root_label = "ROOT"
expected_deps = ["2:ROOT", "3:yes", "1:yes"]
graph_labels = np.zeros((4, 4, 4))
expected_deps = ["0:root", "3:yes", "1:yes"]
# when
graph.sdp_to_dag_deps(empty_graph, graph_labels, tree.tokens, root_label)
graph.sdp_to_dag_deps(empty_graph, graph_labels, tree.tokens, root_idx=0, vocab_index=vocab_index)
actual_deps = [t["deps"] for t in tree.tokens]
# then
self.assertEqual(actual_deps, expected_deps)
self.assertEqual(expected_deps, actual_deps)
def test_adding_empty_graph_with_different_labels(self):
tree = conllu.TokenList(
tokens=[
{"head": 2, "deprel": "ROOT", "form": "word1"},
{"head": 0, "deprel": "root", "form": "word1"},
{"head": 3, "deprel": "tree_label", "form": "word2"},
{"head": 1, "deprel": "tree_label", "form": "word3"},
]
)
vocab_index = {0: "root", 1: "tree_label", 2: "graph_label"}
empty_graph = np.zeros((4, 4))
graph_labels = np.array([
["no", "no", "no", "no"],
["no", "no", "ROOT", "no"],
["no", "no", "no", "graph_label"],
["no", "graph_label", "no", "no"],
])
root_label = "ROOT"
expected_deps = ["2:ROOT", "3:graph_label", "1:graph_label"]
graph_labels = np.zeros((4, 4, 3))
graph_labels[2][3][2] = 10e10
graph_labels[3][1][2] = 10e10
expected_deps = ["0:root", "3:graph_label", "1:graph_label"]
# when
graph.sdp_to_dag_deps(empty_graph, graph_labels, tree.tokens, root_label)
graph.sdp_to_dag_deps(empty_graph, graph_labels, tree.tokens, root_idx=0, vocab_index=vocab_index)
actual_deps = [t["deps"] for t in tree.tokens]
# then
......@@ -61,28 +53,24 @@ class GraphTest(unittest.TestCase):
# given
tree = conllu.TokenList(
tokens=[
{"head": 0, "deprel": "ROOT", "form": "word1"},
{"head": 0, "deprel": "root", "form": "word1"},
{"head": 1, "deprel": "tree_label", "form": "word2"},
{"head": 2, "deprel": "tree_label", "form": "word3"},
]
)
vocab_index = {0: "root", 1: "tree_label", 2: "graph_label"}
arc_scores = np.array([
[0, 0, 0, 0],
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 1, 1, 0],
])
graph_labels = np.array([
["no", "no", "no", "no"],
["ROOT", "no", "no", "no"],
["no", "tree_label", "no", "no"],
["no", "graph_label", "tree_label", "no"],
])
root_label = "ROOT"
expected_deps = ["0:ROOT", "1:tree_label", "1:graph_label"]
graph_labels = np.zeros((4, 4, 3))
graph_labels[3][1][2] = 10e10
expected_deps = ["0:root", "1:tree_label", "1:graph_label"]
# when
graph.sdp_to_dag_deps(arc_scores, graph_labels, tree.tokens, root_label)
graph.sdp_to_dag_deps(arc_scores, graph_labels, tree.tokens, root_idx=0, vocab_index=vocab_index)
actual_deps = [t["deps"] for t in tree.tokens]
# then
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment