Skip to content
Snippets Groups Projects
Commit 422d12c6 authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski Committed by Mateusz Klimaszewski
Browse files

Working graph decoding.

parent 6d373257
Branches
Tags
2 merge requests!9Enhanced dependency parsing develop to master,!8Enhanced dependency parsing
This commit is part of merge request !9. Comments created here will be created in the context of that merge request.
...@@ -54,6 +54,8 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader): ...@@ -54,6 +54,8 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
field_parsers = parser.DEFAULT_FIELD_PARSERS field_parsers = parser.DEFAULT_FIELD_PARSERS
# Do not make it nullable # Do not make it nullable
field_parsers.pop("xpostag", None) field_parsers.pop("xpostag", None)
# Ignore parsing misc
field_parsers.pop("misc", None)
if self.use_sem: if self.use_sem:
fields = list(fields) fields = list(fields)
fields.append("semrel") fields.append("semrel")
......
...@@ -3,6 +3,7 @@ import os ...@@ -3,6 +3,7 @@ import os
from typing import List, Union, Tuple from typing import List, Union, Tuple
import conllu import conllu
import numpy as np
from allennlp import data as allen_data, common, models from allennlp import data as allen_data, common, models
from allennlp.common import util from allennlp.common import util
from allennlp.data import tokenizers from allennlp.data import tokenizers
...@@ -195,10 +196,11 @@ class SemanticMultitaskPredictor(predictor.Predictor): ...@@ -195,10 +196,11 @@ class SemanticMultitaskPredictor(predictor.Predictor):
if "enhanced_head" in predictions and predictions["enhanced_head"]: if "enhanced_head" in predictions and predictions["enhanced_head"]:
import combo.utils.graph as graph import combo.utils.graph as graph
tree = graph.sdp_to_dag_deps(arc_scores=predictions["enhanced_head"], tree = graph.sdp_to_dag_deps(arc_scores=np.array(predictions["enhanced_head"]),
rel_scores=predictions["enhanced_deprel"], rel_scores=np.array(predictions["enhanced_deprel"]),
tree=tree, tree=tree,
root_label="ROOT") root_label="ROOT",
vocab_index=self.vocab.get_index_to_token_vocabulary("deprel_labels"))
return tree, predictions["sentence_embedding"] return tree, predictions["sentence_embedding"]
......
...@@ -3,7 +3,7 @@ import numpy as np ...@@ -3,7 +3,7 @@ import numpy as np
from conllu import TokenList from conllu import TokenList
def sdp_to_dag_deps(arc_scores, rel_scores, tree: TokenList, root_label): def sdp_to_dag_deps(arc_scores, rel_scores, tree: TokenList, root_label, vocab_index=None):
# adding ROOT # adding ROOT
tree_tokens = tree.tokens tree_tokens = tree.tokens
tree_heads = [0] + [t["head"] for t in tree_tokens] tree_heads = [0] + [t["head"] for t in tree_tokens]
...@@ -12,7 +12,7 @@ def sdp_to_dag_deps(arc_scores, rel_scores, tree: TokenList, root_label): ...@@ -12,7 +12,7 @@ def sdp_to_dag_deps(arc_scores, rel_scores, tree: TokenList, root_label):
for i, (t, g) in enumerate(zip(tree_heads, graph)): for i, (t, g) in enumerate(zip(tree_heads, graph)):
if not i: if not i:
continue continue
rels = [x[1] for x in g] rels = [vocab_index.get(x[1], "ROOT") if vocab_index else x[1] for x in g]
heads = [x[0] for x in g] heads = [x[0] for x in g]
head = tree_tokens[i - 1]["head"] head = tree_tokens[i - 1]["head"]
index = heads.index(head) index = heads.index(head)
...@@ -30,9 +30,9 @@ def sdp_to_dag_deps(arc_scores, rel_scores, tree: TokenList, root_label): ...@@ -30,9 +30,9 @@ def sdp_to_dag_deps(arc_scores, rel_scores, tree: TokenList, root_label):
def adjust_root_score_then_add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_idx): def adjust_root_score_then_add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_idx):
if len(arc_scores) != tree_heads: if len(arc_scores) != tree_heads:
arc_scores = arc_scores[:len(tree_heads), :len(tree_heads)] arc_scores = arc_scores[:len(tree_heads)][:len(tree_heads)]
rel_labels = rel_labels[:len(tree_heads), :len(tree_heads)] rel_labels = rel_labels[:len(tree_heads)][:len(tree_heads)]
parse_preds = arc_scores > 0 parse_preds = np.array(arc_scores) > 0
parse_preds[:, 0] = False # set heads to False parse_preds[:, 0] = False # set heads to False
# rel_labels[:, :, root_idx] = -float('inf') # rel_labels[:, :, root_idx] = -float('inf')
return add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_idx, parse_preds) return add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_idx, parse_preds)
...@@ -42,7 +42,7 @@ def add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_label, parse_pre ...@@ -42,7 +42,7 @@ def add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_label, parse_pre
if not isinstance(tree_heads, np.ndarray): if not isinstance(tree_heads, np.ndarray):
tree_heads = np.array(tree_heads) tree_heads = np.array(tree_heads)
dh = np.argwhere(parse_preds) dh = np.argwhere(parse_preds)
sdh = sorted([(arc_scores[x[0], x[1]], list(x)) for x in dh], reverse=True) sdh = sorted([(arc_scores[x[0]][x[1]], list(x)) for x in dh], reverse=True)
graph = [[] for _ in range(len(tree_heads))] graph = [[] for _ in range(len(tree_heads))]
for d, h in enumerate(tree_heads): for d, h in enumerate(tree_heads):
if d: if d:
...@@ -59,7 +59,7 @@ def add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_label, parse_pre ...@@ -59,7 +59,7 @@ def add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_label, parse_pre
num_root = 0 num_root = 0
for h in range(len(tree_heads)): for h in range(len(tree_heads)):
for d in graph[h]: for d in graph[h]:
rel = rel_labels[d, h] rel = rel_labels[d][h]
if h == 0: if h == 0:
rel = root_label rel = root_label
assert num_root == 0 assert num_root == 0
......
...@@ -114,8 +114,10 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't ...@@ -114,8 +114,10 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
use_sem: if in_targets("semrel") then true else false, use_sem: if in_targets("semrel") then true else false,
token_indexers: { token_indexers: {
token: if use_transformer then { token: if use_transformer then {
type: "pretrained_transformer_mismatched", type: "pretrained_transformer_mismatched_fixed",
model_name: pretrained_transformer_name, model_name: pretrained_transformer_name,
tokenizer_kwargs: if std.startsWith(pretrained_transformer_name, "allegro/herbert")
then {use_fast: false} else {},
} else { } else {
# SingleIdTokenIndexer, token as single int # SingleIdTokenIndexer, token as single int
type: "single_id", type: "single_id",
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment