Skip to content
Snippets Groups Projects
Commit 208e57a4 authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski
Browse files

Split vocabulary for graph and tree parsing.

parent c38ef44a
No related branches found
No related tags found
No related merge requests found
......@@ -303,7 +303,7 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
},
enhanced_dependency_relation: if in_targets("deps") then {
type: "combo_graph_dependency_parsing_from_vocab",
vocab_namespace: 'deprel_labels',
vocab_namespace: 'enhanced_deprel_labels',
head_predictor: {
local projection_dim = 512,
cycle_loss_n: cycle_loss_n,
......
......@@ -149,7 +149,7 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
sequence_field=text_field_deps,
labels=enhanced_deprels,
# Label namespace matches regular tree parsing.
label_namespace="deprel_labels",
label_namespace="enhanced_deprel_labels",
padding_value=0,
)
else:
......
......@@ -206,6 +206,8 @@ class COMBO(predictor.Predictor):
graph_rel_scores=r,
idx2label=self.vocab.get_index_to_token_vocabulary("deprel_labels"),
label2idx=self.vocab.get_token_to_index_vocabulary("deprel_labels"),
graph_idx2label=self.vocab.get_index_to_token_vocabulary("enhanced_deprel_labels"),
graph_label2idx=self.vocab.get_token_to_index_vocabulary("enhanced_deprel_labels"),
tokens=tree_tokens
)
......
......@@ -12,6 +12,8 @@ def graph_and_tree_merge(tree_arc_scores,
graph_rel_scores,
label2idx,
idx2label,
graph_label2idx,
graph_idx2label,
tokens):
graph_arc_scores = np.copy(graph_arc_scores)
# Exclude self-loops, in-place operation.
......@@ -19,7 +21,7 @@ def graph_and_tree_merge(tree_arc_scores,
# Connection to root will be handled by tree.
graph_arc_scores[:, 0] = False
# The same with labels.
root_idx = label2idx["root"]
root_idx = graph_label2idx["root"]
graph_rel_scores[:, :, root_idx] = -float('inf')
graph_rel_pred = graph_rel_scores.argmax(-1)
......@@ -47,7 +49,7 @@ def graph_and_tree_merge(tree_arc_scores,
path = next(_dfs(graph, d, h))
except StopIteration:
# There is not path from d to h
label = idx2label[graph_rel_pred[d][h]]
label = graph_idx2label[graph_rel_pred[d][h]]
if label != _ACL_REL_CL:
graph[h].append(d)
labeled_graph[h].append((d, label))
......
......@@ -115,7 +115,7 @@ def run(_):
--serialization_dir {serialization_dir}
--cuda_device {FLAGS.cuda_device}
--word_batch_size 2500
--config_path {pathlib.Path.cwd() / 'config.graph.template.jsonnet'}
--config_path {pathlib.Path.cwd() / 'combo' / 'config.graph.template.jsonnet'}
--notensorboard
"""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment