Skip to content
Snippets Groups Projects
Commit 422d12c6 authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski Committed by Mateusz Klimaszewski
Browse files

Working graph decoding.

parent 6d373257
Branches
Tags
2 merge requests!9Enhanced dependency parsing develop to master,!8Enhanced dependency parsing
...@@ -54,6 +54,8 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader): ...@@ -54,6 +54,8 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
field_parsers = parser.DEFAULT_FIELD_PARSERS field_parsers = parser.DEFAULT_FIELD_PARSERS
# Do not make it nullable # Do not make it nullable
field_parsers.pop("xpostag", None) field_parsers.pop("xpostag", None)
# Ignore parsing misc
field_parsers.pop("misc", None)
if self.use_sem: if self.use_sem:
fields = list(fields) fields = list(fields)
fields.append("semrel") fields.append("semrel")
......
...@@ -3,6 +3,7 @@ import os ...@@ -3,6 +3,7 @@ import os
from typing import List, Union, Tuple from typing import List, Union, Tuple
import conllu import conllu
import numpy as np
from allennlp import data as allen_data, common, models from allennlp import data as allen_data, common, models
from allennlp.common import util from allennlp.common import util
from allennlp.data import tokenizers from allennlp.data import tokenizers
...@@ -195,10 +196,11 @@ class SemanticMultitaskPredictor(predictor.Predictor): ...@@ -195,10 +196,11 @@ class SemanticMultitaskPredictor(predictor.Predictor):
if "enhanced_head" in predictions and predictions["enhanced_head"]: if "enhanced_head" in predictions and predictions["enhanced_head"]:
import combo.utils.graph as graph import combo.utils.graph as graph
tree = graph.sdp_to_dag_deps(arc_scores=predictions["enhanced_head"], tree = graph.sdp_to_dag_deps(arc_scores=np.array(predictions["enhanced_head"]),
rel_scores=predictions["enhanced_deprel"], rel_scores=np.array(predictions["enhanced_deprel"]),
tree=tree, tree=tree,
root_label="ROOT") root_label="ROOT",
vocab_index=self.vocab.get_index_to_token_vocabulary("deprel_labels"))
return tree, predictions["sentence_embedding"] return tree, predictions["sentence_embedding"]
......
...@@ -3,7 +3,7 @@ import numpy as np ...@@ -3,7 +3,7 @@ import numpy as np
from conllu import TokenList from conllu import TokenList
def sdp_to_dag_deps(arc_scores, rel_scores, tree: TokenList, root_label): def sdp_to_dag_deps(arc_scores, rel_scores, tree: TokenList, root_label, vocab_index=None):
# adding ROOT # adding ROOT
tree_tokens = tree.tokens tree_tokens = tree.tokens
tree_heads = [0] + [t["head"] for t in tree_tokens] tree_heads = [0] + [t["head"] for t in tree_tokens]
...@@ -12,7 +12,7 @@ def sdp_to_dag_deps(arc_scores, rel_scores, tree: TokenList, root_label): ...@@ -12,7 +12,7 @@ def sdp_to_dag_deps(arc_scores, rel_scores, tree: TokenList, root_label):
for i, (t, g) in enumerate(zip(tree_heads, graph)): for i, (t, g) in enumerate(zip(tree_heads, graph)):
if not i: if not i:
continue continue
rels = [x[1] for x in g] rels = [vocab_index.get(x[1], "ROOT") if vocab_index else x[1] for x in g]
heads = [x[0] for x in g] heads = [x[0] for x in g]
head = tree_tokens[i - 1]["head"] head = tree_tokens[i - 1]["head"]
index = heads.index(head) index = heads.index(head)
...@@ -30,9 +30,9 @@ def sdp_to_dag_deps(arc_scores, rel_scores, tree: TokenList, root_label): ...@@ -30,9 +30,9 @@ def sdp_to_dag_deps(arc_scores, rel_scores, tree: TokenList, root_label):
def adjust_root_score_then_add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_idx): def adjust_root_score_then_add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_idx):
if len(arc_scores) != tree_heads: if len(arc_scores) != tree_heads:
arc_scores = arc_scores[:len(tree_heads), :len(tree_heads)] arc_scores = arc_scores[:len(tree_heads)][:len(tree_heads)]
rel_labels = rel_labels[:len(tree_heads), :len(tree_heads)] rel_labels = rel_labels[:len(tree_heads)][:len(tree_heads)]
parse_preds = arc_scores > 0 parse_preds = np.array(arc_scores) > 0
parse_preds[:, 0] = False # set heads to False parse_preds[:, 0] = False # set heads to False
# rel_labels[:, :, root_idx] = -float('inf') # rel_labels[:, :, root_idx] = -float('inf')
return add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_idx, parse_preds) return add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_idx, parse_preds)
...@@ -42,7 +42,7 @@ def add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_label, parse_pre ...@@ -42,7 +42,7 @@ def add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_label, parse_pre
if not isinstance(tree_heads, np.ndarray): if not isinstance(tree_heads, np.ndarray):
tree_heads = np.array(tree_heads) tree_heads = np.array(tree_heads)
dh = np.argwhere(parse_preds) dh = np.argwhere(parse_preds)
sdh = sorted([(arc_scores[x[0], x[1]], list(x)) for x in dh], reverse=True) sdh = sorted([(arc_scores[x[0]][x[1]], list(x)) for x in dh], reverse=True)
graph = [[] for _ in range(len(tree_heads))] graph = [[] for _ in range(len(tree_heads))]
for d, h in enumerate(tree_heads): for d, h in enumerate(tree_heads):
if d: if d:
...@@ -59,7 +59,7 @@ def add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_label, parse_pre ...@@ -59,7 +59,7 @@ def add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_label, parse_pre
num_root = 0 num_root = 0
for h in range(len(tree_heads)): for h in range(len(tree_heads)):
for d in graph[h]: for d in graph[h]:
rel = rel_labels[d, h] rel = rel_labels[d][h]
if h == 0: if h == 0:
rel = root_label rel = root_label
assert num_root == 0 assert num_root == 0
......
...@@ -114,8 +114,10 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't ...@@ -114,8 +114,10 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
use_sem: if in_targets("semrel") then true else false, use_sem: if in_targets("semrel") then true else false,
token_indexers: { token_indexers: {
token: if use_transformer then { token: if use_transformer then {
type: "pretrained_transformer_mismatched", type: "pretrained_transformer_mismatched_fixed",
model_name: pretrained_transformer_name, model_name: pretrained_transformer_name,
tokenizer_kwargs: if std.startsWith(pretrained_transformer_name, "allegro/herbert")
then {use_fast: false} else {},
} else { } else {
# SingleIdTokenIndexer, token as single int # SingleIdTokenIndexer, token as single int
type: "single_id", type: "single_id",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment