From 422d12c61e50f31d7ad8f075d1b21abe6ef94d04 Mon Sep 17 00:00:00 2001 From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com> Date: Thu, 19 Nov 2020 10:28:44 +0100 Subject: [PATCH] Working graph decoding. --- combo/data/dataset.py | 2 ++ combo/predict.py | 8 +++++--- combo/utils/graph.py | 14 +++++++------- config.graph.template.jsonnet | 4 +++- 4 files changed, 17 insertions(+), 11 deletions(-) diff --git a/combo/data/dataset.py b/combo/data/dataset.py index fb770f6..bb56ac3 100644 --- a/combo/data/dataset.py +++ b/combo/data/dataset.py @@ -54,6 +54,8 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader): field_parsers = parser.DEFAULT_FIELD_PARSERS # Do not make it nullable field_parsers.pop("xpostag", None) + # Ignore parsing misc + field_parsers.pop("misc", None) if self.use_sem: fields = list(fields) fields.append("semrel") diff --git a/combo/predict.py b/combo/predict.py index 55f78c3..6d657f6 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -3,6 +3,7 @@ import os from typing import List, Union, Tuple import conllu +import numpy as np from allennlp import data as allen_data, common, models from allennlp.common import util from allennlp.data import tokenizers @@ -195,10 +196,11 @@ class SemanticMultitaskPredictor(predictor.Predictor): if "enhanced_head" in predictions and predictions["enhanced_head"]: import combo.utils.graph as graph - tree = graph.sdp_to_dag_deps(arc_scores=predictions["enhanced_head"], - rel_scores=predictions["enhanced_deprel"], + tree = graph.sdp_to_dag_deps(arc_scores=np.array(predictions["enhanced_head"]), + rel_scores=np.array(predictions["enhanced_deprel"]), tree=tree, - root_label="ROOT") + root_label="ROOT", + vocab_index=self.vocab.get_index_to_token_vocabulary("deprel_labels")) return tree, predictions["sentence_embedding"] diff --git a/combo/utils/graph.py b/combo/utils/graph.py index 5970b19..fd59e97 100644 --- a/combo/utils/graph.py +++ b/combo/utils/graph.py @@ -3,7 +3,7 @@ import numpy as np from conllu import TokenList -def sdp_to_dag_deps(arc_scores, rel_scores, tree: TokenList, root_label): +def sdp_to_dag_deps(arc_scores, rel_scores, tree: TokenList, root_label, vocab_index=None): # adding ROOT tree_tokens = tree.tokens tree_heads = [0] + [t["head"] for t in tree_tokens] @@ -12,7 +12,7 @@ def sdp_to_dag_deps(arc_scores, rel_scores, tree: TokenList, root_label): for i, (t, g) in enumerate(zip(tree_heads, graph)): if not i: continue - rels = [x[1] for x in g] + rels = [vocab_index.get(x[1], "ROOT") if vocab_index else x[1] for x in g] heads = [x[0] for x in g] head = tree_tokens[i - 1]["head"] index = heads.index(head) @@ -30,9 +30,9 @@ def sdp_to_dag_deps(arc_scores, rel_scores, tree: TokenList, root_label): def adjust_root_score_then_add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_idx): if len(arc_scores) != tree_heads: - arc_scores = arc_scores[:len(tree_heads), :len(tree_heads)] - rel_labels = rel_labels[:len(tree_heads), :len(tree_heads)] - parse_preds = arc_scores > 0 + arc_scores = arc_scores[:len(tree_heads)][:len(tree_heads)] + rel_labels = rel_labels[:len(tree_heads)][:len(tree_heads)] + parse_preds = np.array(arc_scores) > 0 parse_preds[:, 0] = False # set heads to False # rel_labels[:, :, root_idx] = -float('inf') return add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_idx, parse_preds) @@ -42,7 +42,7 @@ def add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_label, parse_pre if not isinstance(tree_heads, np.ndarray): tree_heads = np.array(tree_heads) dh = np.argwhere(parse_preds) - sdh = sorted([(arc_scores[x[0], x[1]], list(x)) for x in dh], reverse=True) + sdh = sorted([(arc_scores[x[0]][x[1]], list(x)) for x in dh], reverse=True) graph = [[] for _ in range(len(tree_heads))] for d, h in enumerate(tree_heads): if d: @@ -59,7 +59,7 @@ def add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_label, parse_pre num_root = 0 for h in range(len(tree_heads)): for d in graph[h]: - rel = rel_labels[d, h] + rel = rel_labels[d][h] if h == 0: rel = root_label assert num_root == 0 diff --git a/config.graph.template.jsonnet b/config.graph.template.jsonnet index bdb6f0b..d55cb89 100644 --- a/config.graph.template.jsonnet +++ b/config.graph.template.jsonnet @@ -114,8 +114,10 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't use_sem: if in_targets("semrel") then true else false, token_indexers: { token: if use_transformer then { - type: "pretrained_transformer_mismatched", + type: "pretrained_transformer_mismatched_fixed", model_name: pretrained_transformer_name, + tokenizer_kwargs: if std.startsWith(pretrained_transformer_name, "allegro/herbert") + then {use_fast: false} else {}, } else { # SingleIdTokenIndexer, token as single int type: "single_id", -- GitLab