diff --git a/combo/config.graph.template.jsonnet b/combo/config.graph.template.jsonnet index c0c469674f4a74a6d46953e65d9715e9eabf0a2f..c72a0573625010608bbf4936b794535e8613dc2e 100644 --- a/combo/config.graph.template.jsonnet +++ b/combo/config.graph.template.jsonnet @@ -303,7 +303,7 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't }, enhanced_dependency_relation: if in_targets("deps") then { type: "combo_graph_dependency_parsing_from_vocab", - vocab_namespace: 'deprel_labels', + vocab_namespace: 'enhanced_deprel_labels', head_predictor: { local projection_dim = 512, cycle_loss_n: cycle_loss_n, diff --git a/combo/data/dataset.py b/combo/data/dataset.py index 29b1f918c5de55f8055ebb432d3fc90e97ba2c42..bdc8b20ea42ef9cd25d757cde2829d5e3327efab 100644 --- a/combo/data/dataset.py +++ b/combo/data/dataset.py @@ -149,7 +149,7 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader): sequence_field=text_field_deps, labels=enhanced_deprels, # Label namespace matches regular tree parsing. - label_namespace="deprel_labels", + label_namespace="enhanced_deprel_labels", padding_value=0, ) else: diff --git a/combo/predict.py b/combo/predict.py index 5f1aaed5ab2625bb35be78591b0d5540b2e431c5..18a63a4ddad1feb19e041891ae9d8ee94c85a9f1 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -209,6 +209,8 @@ class COMBO(predictor.Predictor): graph_rel_scores=r, idx2label=self.vocab.get_index_to_token_vocabulary("deprel_labels"), label2idx=self.vocab.get_token_to_index_vocabulary("deprel_labels"), + graph_idx2label=self.vocab.get_index_to_token_vocabulary("enhanced_deprel_labels"), + graph_label2idx=self.vocab.get_token_to_index_vocabulary("enhanced_deprel_labels"), tokens=tree_tokens ) diff --git a/combo/utils/graph.py b/combo/utils/graph.py index c9ad07e731d2baf0c1b2241db369f43b7493bf91..86dd98f659a5c9aebf9da16d0b58a6f7e809a029 100644 --- a/combo/utils/graph.py +++ b/combo/utils/graph.py @@ -12,6 +12,8 @@ def graph_and_tree_merge(tree_arc_scores, graph_rel_scores, label2idx, idx2label, + graph_label2idx, + graph_idx2label, tokens): graph_arc_scores = np.copy(graph_arc_scores) # Exclude self-loops, in-place operation. @@ -19,7 +21,7 @@ def graph_and_tree_merge(tree_arc_scores, # Connection to root will be handled by tree. graph_arc_scores[:, 0] = False # The same with labels. - root_idx = label2idx["root"] + root_idx = graph_label2idx["root"] graph_rel_scores[:, :, root_idx] = -float('inf') graph_rel_pred = graph_rel_scores.argmax(-1) @@ -47,7 +49,7 @@ def graph_and_tree_merge(tree_arc_scores, path = next(_dfs(graph, d, h)) except StopIteration: # There is not path from d to h - label = idx2label[graph_rel_pred[d][h]] + label = graph_idx2label[graph_rel_pred[d][h]] if label != _ACL_REL_CL: graph[h].append(d) labeled_graph[h].append((d, label)) diff --git a/scripts/train_iwpt21.py b/scripts/train_iwpt21.py index c6310eae7621c9745248e368bd0782cd053613b8..8b077c3cdded0115505ad95c703357ef6f8057a1 100644 --- a/scripts/train_iwpt21.py +++ b/scripts/train_iwpt21.py @@ -115,7 +115,7 @@ def run(_): --serialization_dir {serialization_dir} --cuda_device {FLAGS.cuda_device} --word_batch_size 2500 - --config_path {pathlib.Path.cwd() / 'config.graph.template.jsonnet'} + --config_path {pathlib.Path.cwd() / 'combo' / 'config.graph.template.jsonnet'} --notensorboard """