From 208e57a48dfc1dbc06c84c28b69d2b1abc21752f Mon Sep 17 00:00:00 2001
From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com>
Date: Fri, 30 Apr 2021 09:50:40 +0200
Subject: [PATCH] Split vocabulary for graph and tree parsing.

---
 combo/config.graph.template.jsonnet | 2 +-
 combo/data/dataset.py               | 2 +-
 combo/predict.py                    | 2 ++
 combo/utils/graph.py                | 6 ++++--
 scripts/train_iwpt21.py             | 2 +-
 5 files changed, 9 insertions(+), 5 deletions(-)

diff --git a/combo/config.graph.template.jsonnet b/combo/config.graph.template.jsonnet
index c0c4696..c72a057 100644
--- a/combo/config.graph.template.jsonnet
+++ b/combo/config.graph.template.jsonnet
@@ -303,7 +303,7 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
         },
         enhanced_dependency_relation: if in_targets("deps") then {
             type: "combo_graph_dependency_parsing_from_vocab",
-            vocab_namespace: 'deprel_labels',
+            vocab_namespace: 'enhanced_deprel_labels',
             head_predictor: {
                 local projection_dim = 512,
                 cycle_loss_n: cycle_loss_n,
diff --git a/combo/data/dataset.py b/combo/data/dataset.py
index 29b1f91..bdc8b20 100644
--- a/combo/data/dataset.py
+++ b/combo/data/dataset.py
@@ -149,7 +149,7 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
                             sequence_field=text_field_deps,
                             labels=enhanced_deprels,
                             # Label namespace matches regular tree parsing.
-                            label_namespace="deprel_labels",
+                            label_namespace="enhanced_deprel_labels",
                             padding_value=0,
                         )
                     else:
diff --git a/combo/predict.py b/combo/predict.py
index c299faa..d5f06a2 100644
--- a/combo/predict.py
+++ b/combo/predict.py
@@ -206,6 +206,8 @@ class COMBO(predictor.Predictor):
                 graph_rel_scores=r,
                 idx2label=self.vocab.get_index_to_token_vocabulary("deprel_labels"),
                 label2idx=self.vocab.get_token_to_index_vocabulary("deprel_labels"),
+                graph_idx2label=self.vocab.get_index_to_token_vocabulary("enhanced_deprel_labels"),
+                graph_label2idx=self.vocab.get_token_to_index_vocabulary("enhanced_deprel_labels"),
                 tokens=tree_tokens
             )
 
diff --git a/combo/utils/graph.py b/combo/utils/graph.py
index c9ad07e..86dd98f 100644
--- a/combo/utils/graph.py
+++ b/combo/utils/graph.py
@@ -12,6 +12,8 @@ def graph_and_tree_merge(tree_arc_scores,
                          graph_rel_scores,
                          label2idx,
                          idx2label,
+                         graph_label2idx,
+                         graph_idx2label,
                          tokens):
     graph_arc_scores = np.copy(graph_arc_scores)
     # Exclude self-loops, in-place operation.
@@ -19,7 +21,7 @@ def graph_and_tree_merge(tree_arc_scores,
     # Connection to root will be handled by tree.
     graph_arc_scores[:, 0] = False
     # The same with labels.
-    root_idx = label2idx["root"]
+    root_idx = graph_label2idx["root"]
     graph_rel_scores[:, :, root_idx] = -float('inf')
     graph_rel_pred = graph_rel_scores.argmax(-1)
 
@@ -47,7 +49,7 @@ def graph_and_tree_merge(tree_arc_scores,
             path = next(_dfs(graph, d, h))
         except StopIteration:
             # There is not path from d to h
-            label = idx2label[graph_rel_pred[d][h]]
+            label = graph_idx2label[graph_rel_pred[d][h]]
             if label != _ACL_REL_CL:
                 graph[h].append(d)
                 labeled_graph[h].append((d, label))
diff --git a/scripts/train_iwpt21.py b/scripts/train_iwpt21.py
index c6310ea..8b077c3 100644
--- a/scripts/train_iwpt21.py
+++ b/scripts/train_iwpt21.py
@@ -115,7 +115,7 @@ def run(_):
         --serialization_dir {serialization_dir}
         --cuda_device {FLAGS.cuda_device}
         --word_batch_size 2500
-        --config_path {pathlib.Path.cwd() / 'config.graph.template.jsonnet'}
+        --config_path {pathlib.Path.cwd() / 'combo' / 'config.graph.template.jsonnet'}
         --notensorboard
         """
 
-- 
GitLab