diff --git a/combo/models/parser.py b/combo/models/parser.py index 486b2481b96bf17bb19fd8557916f21dcb6c4584..4b5b12606874c91e9303060d1c8bb68b3d3ac016 100644 --- a/combo/models/parser.py +++ b/combo/models/parser.py @@ -1,4 +1,5 @@ """Dependency parsing models.""" +import math from typing import Tuple, Dict, Optional, Union, List import numpy as np @@ -115,11 +116,13 @@ class DependencyRelationModel(base.Predictor): """Dependency relation parsing model.""" def __init__(self, + root_idx: int, head_predictor: HeadPredictionModel, head_projection_layer: base.Linear, dependency_projection_layer: base.Linear, relation_prediction_layer: base.Linear): super().__init__() + self.root_idx = root_idx self.head_predictor = head_predictor self.head_projection_layer = head_projection_layer self.dependency_projection_layer = dependency_projection_layer @@ -130,6 +133,7 @@ class DependencyRelationModel(base.Predictor): mask: Optional[torch.BoolTensor] = None, labels: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, sample_weights: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]: + device = x.device if mask is not None: mask = mask[:, 1:] relations_labels, head_labels = None, None @@ -151,7 +155,23 @@ class DependencyRelationModel(base.Predictor): relation_prediction = self.relation_prediction_layer(dep_rel_pred) output = head_output - output["prediction"] = (relation_prediction.argmax(-1)[:, 1:], head_output["prediction"]) + if self.training: + output["prediction"] = (relation_prediction.argmax(-1)[:, 1:], head_output["prediction"]) + else: + # Mask root label whenever head is not 0. + relation_prediction_output = relation_prediction[:, 1:] + mask = (head_output["prediction"] == 0) + vocab_size = relation_prediction_output.size(-1) + root_idx = torch.tensor([self.root_idx], device=device) + relation_prediction_output[mask] = (relation_prediction_output + .masked_select(mask.unsqueeze(-1)) + .reshape(-1, vocab_size) + .index_fill(-1, root_idx, 10e10)) + relation_prediction_output[~mask] = (relation_prediction_output + .masked_select(~(mask.unsqueeze(-1))) + .reshape(-1, vocab_size) + .index_fill(-1, root_idx, -10e10)) + output["prediction"] = (relation_prediction_output.argmax(-1), head_output["prediction"]) if labels is not None and labels[0] is not None: if sample_weights is None: @@ -195,5 +215,6 @@ class DependencyRelationModel(base.Predictor): head_predictor=head_predictor, head_projection_layer=head_projection_layer, dependency_projection_layer=dependency_projection_layer, - relation_prediction_layer=relation_prediction_layer + relation_prediction_layer=relation_prediction_layer, + root_idx=vocab.get_token_index("root", vocab_namespace) ) diff --git a/combo/utils/graph.py b/combo/utils/graph.py index 32e7dd944999c9c2f7c8e800289eb2f35d14a4bc..651c14a7d79b7ea3c277b9466f5e050435a7a01b 100644 --- a/combo/utils/graph.py +++ b/combo/utils/graph.py @@ -110,5 +110,6 @@ def restore_collapse_edges(tree_tokens): "deps": f"{head}:{empty_node_relation}" } ) - token["deps"] = "|".join(deps) + deps = sorted([d.split(":", 1) for d in deps], key=lambda x: float(x[0])) + token["deps"] = "|".join([f"{k}:{v}" for k, v in deps]) return empty_tokens