From 208e57a48dfc1dbc06c84c28b69d2b1abc21752f Mon Sep 17 00:00:00 2001 From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com> Date: Fri, 30 Apr 2021 09:50:40 +0200 Subject: [PATCH] Split vocabulary for graph and tree parsing. --- combo/config.graph.template.jsonnet | 2 +- combo/data/dataset.py | 2 +- combo/predict.py | 2 ++ combo/utils/graph.py | 6 ++++-- scripts/train_iwpt21.py | 2 +- 5 files changed, 9 insertions(+), 5 deletions(-) diff --git a/combo/config.graph.template.jsonnet b/combo/config.graph.template.jsonnet index c0c4696..c72a057 100644 --- a/combo/config.graph.template.jsonnet +++ b/combo/config.graph.template.jsonnet @@ -303,7 +303,7 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't }, enhanced_dependency_relation: if in_targets("deps") then { type: "combo_graph_dependency_parsing_from_vocab", - vocab_namespace: 'deprel_labels', + vocab_namespace: 'enhanced_deprel_labels', head_predictor: { local projection_dim = 512, cycle_loss_n: cycle_loss_n, diff --git a/combo/data/dataset.py b/combo/data/dataset.py index 29b1f91..bdc8b20 100644 --- a/combo/data/dataset.py +++ b/combo/data/dataset.py @@ -149,7 +149,7 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader): sequence_field=text_field_deps, labels=enhanced_deprels, # Label namespace matches regular tree parsing. - label_namespace="deprel_labels", + label_namespace="enhanced_deprel_labels", padding_value=0, ) else: diff --git a/combo/predict.py b/combo/predict.py index c299faa..d5f06a2 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -206,6 +206,8 @@ class COMBO(predictor.Predictor): graph_rel_scores=r, idx2label=self.vocab.get_index_to_token_vocabulary("deprel_labels"), label2idx=self.vocab.get_token_to_index_vocabulary("deprel_labels"), + graph_idx2label=self.vocab.get_index_to_token_vocabulary("enhanced_deprel_labels"), + graph_label2idx=self.vocab.get_token_to_index_vocabulary("enhanced_deprel_labels"), tokens=tree_tokens ) diff --git a/combo/utils/graph.py b/combo/utils/graph.py index c9ad07e..86dd98f 100644 --- a/combo/utils/graph.py +++ b/combo/utils/graph.py @@ -12,6 +12,8 @@ def graph_and_tree_merge(tree_arc_scores, graph_rel_scores, label2idx, idx2label, + graph_label2idx, + graph_idx2label, tokens): graph_arc_scores = np.copy(graph_arc_scores) # Exclude self-loops, in-place operation. @@ -19,7 +21,7 @@ def graph_and_tree_merge(tree_arc_scores, # Connection to root will be handled by tree. graph_arc_scores[:, 0] = False # The same with labels. - root_idx = label2idx["root"] + root_idx = graph_label2idx["root"] graph_rel_scores[:, :, root_idx] = -float('inf') graph_rel_pred = graph_rel_scores.argmax(-1) @@ -47,7 +49,7 @@ def graph_and_tree_merge(tree_arc_scores, path = next(_dfs(graph, d, h)) except StopIteration: # There is not path from d to h - label = idx2label[graph_rel_pred[d][h]] + label = graph_idx2label[graph_rel_pred[d][h]] if label != _ACL_REL_CL: graph[h].append(d) labeled_graph[h].append((d, label)) diff --git a/scripts/train_iwpt21.py b/scripts/train_iwpt21.py index c6310ea..8b077c3 100644 --- a/scripts/train_iwpt21.py +++ b/scripts/train_iwpt21.py @@ -115,7 +115,7 @@ def run(_): --serialization_dir {serialization_dir} --cuda_device {FLAGS.cuda_device} --word_batch_size 2500 - --config_path {pathlib.Path.cwd() / 'config.graph.template.jsonnet'} + --config_path {pathlib.Path.cwd() / 'combo' / 'config.graph.template.jsonnet'} --notensorboard """ -- GitLab