Skip to content
Snippets Groups Projects
Commit 422d12c6 authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski Committed by Mateusz Klimaszewski
Browse files

Working graph decoding.

parent 6d373257
Branches
Tags
2 merge requests!9Enhanced dependency parsing develop to master,!8Enhanced dependency parsing
......@@ -54,6 +54,8 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
field_parsers = parser.DEFAULT_FIELD_PARSERS
# Do not make it nullable
field_parsers.pop("xpostag", None)
# Ignore parsing misc
field_parsers.pop("misc", None)
if self.use_sem:
fields = list(fields)
fields.append("semrel")
......
......@@ -3,6 +3,7 @@ import os
from typing import List, Union, Tuple
import conllu
import numpy as np
from allennlp import data as allen_data, common, models
from allennlp.common import util
from allennlp.data import tokenizers
......@@ -195,10 +196,11 @@ class SemanticMultitaskPredictor(predictor.Predictor):
if "enhanced_head" in predictions and predictions["enhanced_head"]:
import combo.utils.graph as graph
tree = graph.sdp_to_dag_deps(arc_scores=predictions["enhanced_head"],
rel_scores=predictions["enhanced_deprel"],
tree = graph.sdp_to_dag_deps(arc_scores=np.array(predictions["enhanced_head"]),
rel_scores=np.array(predictions["enhanced_deprel"]),
tree=tree,
root_label="ROOT")
root_label="ROOT",
vocab_index=self.vocab.get_index_to_token_vocabulary("deprel_labels"))
return tree, predictions["sentence_embedding"]
......
......@@ -3,7 +3,7 @@ import numpy as np
from conllu import TokenList
def sdp_to_dag_deps(arc_scores, rel_scores, tree: TokenList, root_label):
def sdp_to_dag_deps(arc_scores, rel_scores, tree: TokenList, root_label, vocab_index=None):
# adding ROOT
tree_tokens = tree.tokens
tree_heads = [0] + [t["head"] for t in tree_tokens]
......@@ -12,7 +12,7 @@ def sdp_to_dag_deps(arc_scores, rel_scores, tree: TokenList, root_label):
for i, (t, g) in enumerate(zip(tree_heads, graph)):
if not i:
continue
rels = [x[1] for x in g]
rels = [vocab_index.get(x[1], "ROOT") if vocab_index else x[1] for x in g]
heads = [x[0] for x in g]
head = tree_tokens[i - 1]["head"]
index = heads.index(head)
......@@ -30,9 +30,9 @@ def sdp_to_dag_deps(arc_scores, rel_scores, tree: TokenList, root_label):
def adjust_root_score_then_add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_idx):
if len(arc_scores) != tree_heads:
arc_scores = arc_scores[:len(tree_heads), :len(tree_heads)]
rel_labels = rel_labels[:len(tree_heads), :len(tree_heads)]
parse_preds = arc_scores > 0
arc_scores = arc_scores[:len(tree_heads)][:len(tree_heads)]
rel_labels = rel_labels[:len(tree_heads)][:len(tree_heads)]
parse_preds = np.array(arc_scores) > 0
parse_preds[:, 0] = False # set heads to False
# rel_labels[:, :, root_idx] = -float('inf')
return add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_idx, parse_preds)
......@@ -42,7 +42,7 @@ def add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_label, parse_pre
if not isinstance(tree_heads, np.ndarray):
tree_heads = np.array(tree_heads)
dh = np.argwhere(parse_preds)
sdh = sorted([(arc_scores[x[0], x[1]], list(x)) for x in dh], reverse=True)
sdh = sorted([(arc_scores[x[0]][x[1]], list(x)) for x in dh], reverse=True)
graph = [[] for _ in range(len(tree_heads))]
for d, h in enumerate(tree_heads):
if d:
......@@ -59,7 +59,7 @@ def add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_label, parse_pre
num_root = 0
for h in range(len(tree_heads)):
for d in graph[h]:
rel = rel_labels[d, h]
rel = rel_labels[d][h]
if h == 0:
rel = root_label
assert num_root == 0
......
......@@ -114,8 +114,10 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
use_sem: if in_targets("semrel") then true else false,
token_indexers: {
token: if use_transformer then {
type: "pretrained_transformer_mismatched",
type: "pretrained_transformer_mismatched_fixed",
model_name: pretrained_transformer_name,
tokenizer_kwargs: if std.startsWith(pretrained_transformer_name, "allegro/herbert")
then {use_fast: false} else {},
} else {
# SingleIdTokenIndexer, token as single int
type: "single_id",
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment