diff --git a/combo/data/dataset.py b/combo/data/dataset.py index fb770f6aee2ba4ad93763ae49bb030852b65aaf1..bb56ac33478fb36b5fc7736f79c6b6d68b4a2f59 100644 --- a/combo/data/dataset.py +++ b/combo/data/dataset.py @@ -54,6 +54,8 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader): field_parsers = parser.DEFAULT_FIELD_PARSERS # Do not make it nullable field_parsers.pop("xpostag", None) + # Ignore parsing misc + field_parsers.pop("misc", None) if self.use_sem: fields = list(fields) fields.append("semrel") diff --git a/combo/predict.py b/combo/predict.py index 55f78c3beb4266018b222d3f635d43f806ad576a..6d657f6ae38d2716fccdedf489f2a30de4f78de7 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -3,6 +3,7 @@ import os from typing import List, Union, Tuple import conllu +import numpy as np from allennlp import data as allen_data, common, models from allennlp.common import util from allennlp.data import tokenizers @@ -195,10 +196,11 @@ class SemanticMultitaskPredictor(predictor.Predictor): if "enhanced_head" in predictions and predictions["enhanced_head"]: import combo.utils.graph as graph - tree = graph.sdp_to_dag_deps(arc_scores=predictions["enhanced_head"], - rel_scores=predictions["enhanced_deprel"], + tree = graph.sdp_to_dag_deps(arc_scores=np.array(predictions["enhanced_head"]), + rel_scores=np.array(predictions["enhanced_deprel"]), tree=tree, - root_label="ROOT") + root_label="ROOT", + vocab_index=self.vocab.get_index_to_token_vocabulary("deprel_labels")) return tree, predictions["sentence_embedding"] diff --git a/combo/utils/graph.py b/combo/utils/graph.py index 5970b19a5e458f03046c5d6068ffa7eff71d77f9..fd59e97b38ebaf247f74c801ac09609c7dee6582 100644 --- a/combo/utils/graph.py +++ b/combo/utils/graph.py @@ -3,7 +3,7 @@ import numpy as np from conllu import TokenList -def sdp_to_dag_deps(arc_scores, rel_scores, tree: TokenList, root_label): +def sdp_to_dag_deps(arc_scores, rel_scores, tree: TokenList, root_label, vocab_index=None): # adding ROOT tree_tokens = tree.tokens tree_heads = [0] + [t["head"] for t in tree_tokens] @@ -12,7 +12,7 @@ def sdp_to_dag_deps(arc_scores, rel_scores, tree: TokenList, root_label): for i, (t, g) in enumerate(zip(tree_heads, graph)): if not i: continue - rels = [x[1] for x in g] + rels = [vocab_index.get(x[1], "ROOT") if vocab_index else x[1] for x in g] heads = [x[0] for x in g] head = tree_tokens[i - 1]["head"] index = heads.index(head) @@ -30,9 +30,9 @@ def sdp_to_dag_deps(arc_scores, rel_scores, tree: TokenList, root_label): def adjust_root_score_then_add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_idx): if len(arc_scores) != tree_heads: - arc_scores = arc_scores[:len(tree_heads), :len(tree_heads)] - rel_labels = rel_labels[:len(tree_heads), :len(tree_heads)] - parse_preds = arc_scores > 0 + arc_scores = arc_scores[:len(tree_heads)][:len(tree_heads)] + rel_labels = rel_labels[:len(tree_heads)][:len(tree_heads)] + parse_preds = np.array(arc_scores) > 0 parse_preds[:, 0] = False # set heads to False # rel_labels[:, :, root_idx] = -float('inf') return add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_idx, parse_preds) @@ -42,7 +42,7 @@ def add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_label, parse_pre if not isinstance(tree_heads, np.ndarray): tree_heads = np.array(tree_heads) dh = np.argwhere(parse_preds) - sdh = sorted([(arc_scores[x[0], x[1]], list(x)) for x in dh], reverse=True) + sdh = sorted([(arc_scores[x[0]][x[1]], list(x)) for x in dh], reverse=True) graph = [[] for _ in range(len(tree_heads))] for d, h in enumerate(tree_heads): if d: @@ -59,7 +59,7 @@ def add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_label, parse_pre num_root = 0 for h in range(len(tree_heads)): for d in graph[h]: - rel = rel_labels[d, h] + rel = rel_labels[d][h] if h == 0: rel = root_label assert num_root == 0 diff --git a/config.graph.template.jsonnet b/config.graph.template.jsonnet index bdb6f0bde07fe8ccffe568b5543259e94fe5142a..d55cb89d9c590968146f485169bcb44ce0405db4 100644 --- a/config.graph.template.jsonnet +++ b/config.graph.template.jsonnet @@ -114,8 +114,10 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't use_sem: if in_targets("semrel") then true else false, token_indexers: { token: if use_transformer then { - type: "pretrained_transformer_mismatched", + type: "pretrained_transformer_mismatched_fixed", model_name: pretrained_transformer_name, + tokenizer_kwargs: if std.startsWith(pretrained_transformer_name, "allegro/herbert") + then {use_fast: false} else {}, } else { # SingleIdTokenIndexer, token as single int type: "single_id",