From 23e0c9ce45e63dbcf031ba8965f85590f331a92d Mon Sep 17 00:00:00 2001 From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com> Date: Wed, 23 Dec 2020 09:14:09 +0100 Subject: [PATCH] Sort deps when uncollapsing nodes, mask root label possibility when root isn't head of a token. --- combo/models/parser.py | 25 +++++++++++++++++++++++-- combo/utils/graph.py | 3 ++- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/combo/models/parser.py b/combo/models/parser.py index 486b248..4b5b126 100644 --- a/combo/models/parser.py +++ b/combo/models/parser.py @@ -1,4 +1,5 @@ """Dependency parsing models.""" +import math from typing import Tuple, Dict, Optional, Union, List import numpy as np @@ -115,11 +116,13 @@ class DependencyRelationModel(base.Predictor): """Dependency relation parsing model.""" def __init__(self, + root_idx: int, head_predictor: HeadPredictionModel, head_projection_layer: base.Linear, dependency_projection_layer: base.Linear, relation_prediction_layer: base.Linear): super().__init__() + self.root_idx = root_idx self.head_predictor = head_predictor self.head_projection_layer = head_projection_layer self.dependency_projection_layer = dependency_projection_layer @@ -130,6 +133,7 @@ class DependencyRelationModel(base.Predictor): mask: Optional[torch.BoolTensor] = None, labels: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, sample_weights: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]: + device = x.device if mask is not None: mask = mask[:, 1:] relations_labels, head_labels = None, None @@ -151,7 +155,23 @@ class DependencyRelationModel(base.Predictor): relation_prediction = self.relation_prediction_layer(dep_rel_pred) output = head_output - output["prediction"] = (relation_prediction.argmax(-1)[:, 1:], head_output["prediction"]) + if self.training: + output["prediction"] = (relation_prediction.argmax(-1)[:, 1:], head_output["prediction"]) + else: + # Mask root label whenever head is not 0. + relation_prediction_output = relation_prediction[:, 1:] + mask = (head_output["prediction"] == 0) + vocab_size = relation_prediction_output.size(-1) + root_idx = torch.tensor([self.root_idx], device=device) + relation_prediction_output[mask] = (relation_prediction_output + .masked_select(mask.unsqueeze(-1)) + .reshape(-1, vocab_size) + .index_fill(-1, root_idx, 10e10)) + relation_prediction_output[~mask] = (relation_prediction_output + .masked_select(~(mask.unsqueeze(-1))) + .reshape(-1, vocab_size) + .index_fill(-1, root_idx, -10e10)) + output["prediction"] = (relation_prediction_output.argmax(-1), head_output["prediction"]) if labels is not None and labels[0] is not None: if sample_weights is None: @@ -195,5 +215,6 @@ class DependencyRelationModel(base.Predictor): head_predictor=head_predictor, head_projection_layer=head_projection_layer, dependency_projection_layer=dependency_projection_layer, - relation_prediction_layer=relation_prediction_layer + relation_prediction_layer=relation_prediction_layer, + root_idx=vocab.get_token_index("root", vocab_namespace) ) diff --git a/combo/utils/graph.py b/combo/utils/graph.py index 32e7dd9..651c14a 100644 --- a/combo/utils/graph.py +++ b/combo/utils/graph.py @@ -110,5 +110,6 @@ def restore_collapse_edges(tree_tokens): "deps": f"{head}:{empty_node_relation}" } ) - token["deps"] = "|".join(deps) + deps = sorted([d.split(":", 1) for d in deps], key=lambda x: float(x[0])) + token["deps"] = "|".join([f"{k}:{v}" for k, v in deps]) return empty_tokens -- GitLab