Skip to content
Snippets Groups Projects
Commit 23e0c9ce authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski Committed by Mateusz Klimaszewski
Browse files

Sort deps when uncollapsing nodes, mask root label possibility when root isn't head of a token.

parent d0dc576f
Branches
Tags
2 merge requests!9Enhanced dependency parsing develop to master,!8Enhanced dependency parsing
"""Dependency parsing models."""
import math
from typing import Tuple, Dict, Optional, Union, List
import numpy as np
......@@ -115,11 +116,13 @@ class DependencyRelationModel(base.Predictor):
"""Dependency relation parsing model."""
def __init__(self,
root_idx: int,
head_predictor: HeadPredictionModel,
head_projection_layer: base.Linear,
dependency_projection_layer: base.Linear,
relation_prediction_layer: base.Linear):
super().__init__()
self.root_idx = root_idx
self.head_predictor = head_predictor
self.head_projection_layer = head_projection_layer
self.dependency_projection_layer = dependency_projection_layer
......@@ -130,6 +133,7 @@ class DependencyRelationModel(base.Predictor):
mask: Optional[torch.BoolTensor] = None,
labels: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
sample_weights: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]:
device = x.device
if mask is not None:
mask = mask[:, 1:]
relations_labels, head_labels = None, None
......@@ -151,7 +155,23 @@ class DependencyRelationModel(base.Predictor):
relation_prediction = self.relation_prediction_layer(dep_rel_pred)
output = head_output
output["prediction"] = (relation_prediction.argmax(-1)[:, 1:], head_output["prediction"])
if self.training:
output["prediction"] = (relation_prediction.argmax(-1)[:, 1:], head_output["prediction"])
else:
# Mask root label whenever head is not 0.
relation_prediction_output = relation_prediction[:, 1:]
mask = (head_output["prediction"] == 0)
vocab_size = relation_prediction_output.size(-1)
root_idx = torch.tensor([self.root_idx], device=device)
relation_prediction_output[mask] = (relation_prediction_output
.masked_select(mask.unsqueeze(-1))
.reshape(-1, vocab_size)
.index_fill(-1, root_idx, 10e10))
relation_prediction_output[~mask] = (relation_prediction_output
.masked_select(~(mask.unsqueeze(-1)))
.reshape(-1, vocab_size)
.index_fill(-1, root_idx, -10e10))
output["prediction"] = (relation_prediction_output.argmax(-1), head_output["prediction"])
if labels is not None and labels[0] is not None:
if sample_weights is None:
......@@ -195,5 +215,6 @@ class DependencyRelationModel(base.Predictor):
head_predictor=head_predictor,
head_projection_layer=head_projection_layer,
dependency_projection_layer=dependency_projection_layer,
relation_prediction_layer=relation_prediction_layer
relation_prediction_layer=relation_prediction_layer,
root_idx=vocab.get_token_index("root", vocab_namespace)
)
......@@ -110,5 +110,6 @@ def restore_collapse_edges(tree_tokens):
"deps": f"{head}:{empty_node_relation}"
}
)
token["deps"] = "|".join(deps)
deps = sorted([d.split(":", 1) for d in deps], key=lambda x: float(x[0]))
token["deps"] = "|".join([f"{k}:{v}" for k, v in deps])
return empty_tokens
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment