Skip to content
Snippets Groups Projects
Commit 23e0c9ce authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski Committed by Mateusz Klimaszewski
Browse files

Sort deps when uncollapsing nodes, mask root label possibility when root isn't head of a token.

parent d0dc576f
2 merge requests!9Enhanced dependency parsing develop to master,!8Enhanced dependency parsing
This commit is part of merge request !8. Comments created here will be created in the context of that merge request.
"""Dependency parsing models.""" """Dependency parsing models."""
import math
from typing import Tuple, Dict, Optional, Union, List from typing import Tuple, Dict, Optional, Union, List
import numpy as np import numpy as np
...@@ -115,11 +116,13 @@ class DependencyRelationModel(base.Predictor): ...@@ -115,11 +116,13 @@ class DependencyRelationModel(base.Predictor):
"""Dependency relation parsing model.""" """Dependency relation parsing model."""
def __init__(self, def __init__(self,
root_idx: int,
head_predictor: HeadPredictionModel, head_predictor: HeadPredictionModel,
head_projection_layer: base.Linear, head_projection_layer: base.Linear,
dependency_projection_layer: base.Linear, dependency_projection_layer: base.Linear,
relation_prediction_layer: base.Linear): relation_prediction_layer: base.Linear):
super().__init__() super().__init__()
self.root_idx = root_idx
self.head_predictor = head_predictor self.head_predictor = head_predictor
self.head_projection_layer = head_projection_layer self.head_projection_layer = head_projection_layer
self.dependency_projection_layer = dependency_projection_layer self.dependency_projection_layer = dependency_projection_layer
...@@ -130,6 +133,7 @@ class DependencyRelationModel(base.Predictor): ...@@ -130,6 +133,7 @@ class DependencyRelationModel(base.Predictor):
mask: Optional[torch.BoolTensor] = None, mask: Optional[torch.BoolTensor] = None,
labels: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, labels: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
sample_weights: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]: sample_weights: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]:
device = x.device
if mask is not None: if mask is not None:
mask = mask[:, 1:] mask = mask[:, 1:]
relations_labels, head_labels = None, None relations_labels, head_labels = None, None
...@@ -151,7 +155,23 @@ class DependencyRelationModel(base.Predictor): ...@@ -151,7 +155,23 @@ class DependencyRelationModel(base.Predictor):
relation_prediction = self.relation_prediction_layer(dep_rel_pred) relation_prediction = self.relation_prediction_layer(dep_rel_pred)
output = head_output output = head_output
output["prediction"] = (relation_prediction.argmax(-1)[:, 1:], head_output["prediction"]) if self.training:
output["prediction"] = (relation_prediction.argmax(-1)[:, 1:], head_output["prediction"])
else:
# Mask root label whenever head is not 0.
relation_prediction_output = relation_prediction[:, 1:]
mask = (head_output["prediction"] == 0)
vocab_size = relation_prediction_output.size(-1)
root_idx = torch.tensor([self.root_idx], device=device)
relation_prediction_output[mask] = (relation_prediction_output
.masked_select(mask.unsqueeze(-1))
.reshape(-1, vocab_size)
.index_fill(-1, root_idx, 10e10))
relation_prediction_output[~mask] = (relation_prediction_output
.masked_select(~(mask.unsqueeze(-1)))
.reshape(-1, vocab_size)
.index_fill(-1, root_idx, -10e10))
output["prediction"] = (relation_prediction_output.argmax(-1), head_output["prediction"])
if labels is not None and labels[0] is not None: if labels is not None and labels[0] is not None:
if sample_weights is None: if sample_weights is None:
...@@ -195,5 +215,6 @@ class DependencyRelationModel(base.Predictor): ...@@ -195,5 +215,6 @@ class DependencyRelationModel(base.Predictor):
head_predictor=head_predictor, head_predictor=head_predictor,
head_projection_layer=head_projection_layer, head_projection_layer=head_projection_layer,
dependency_projection_layer=dependency_projection_layer, dependency_projection_layer=dependency_projection_layer,
relation_prediction_layer=relation_prediction_layer relation_prediction_layer=relation_prediction_layer,
root_idx=vocab.get_token_index("root", vocab_namespace)
) )
...@@ -110,5 +110,6 @@ def restore_collapse_edges(tree_tokens): ...@@ -110,5 +110,6 @@ def restore_collapse_edges(tree_tokens):
"deps": f"{head}:{empty_node_relation}" "deps": f"{head}:{empty_node_relation}"
} }
) )
token["deps"] = "|".join(deps) deps = sorted([d.split(":", 1) for d in deps], key=lambda x: float(x[0]))
token["deps"] = "|".join([f"{k}:{v}" for k, v in deps])
return empty_tokens return empty_tokens
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment