Skip to content
Snippets Groups Projects

Enhanced dependency parsing

Merged Mateusz Klimaszewski requested to merge enhanced_dependency_parsing into develop
Viewing commit 23e0c9ce
Show latest version
2 files
+ 25
3
Compare changes
  • Side-by-side
  • Inline
Files
2
+ 23
2
"""Dependency parsing models."""
import math
from typing import Tuple, Dict, Optional, Union, List
import numpy as np
@@ -115,11 +116,13 @@ class DependencyRelationModel(base.Predictor):
"""Dependency relation parsing model."""
def __init__(self,
root_idx: int,
head_predictor: HeadPredictionModel,
head_projection_layer: base.Linear,
dependency_projection_layer: base.Linear,
relation_prediction_layer: base.Linear):
super().__init__()
self.root_idx = root_idx
self.head_predictor = head_predictor
self.head_projection_layer = head_projection_layer
self.dependency_projection_layer = dependency_projection_layer
@@ -130,6 +133,7 @@ class DependencyRelationModel(base.Predictor):
mask: Optional[torch.BoolTensor] = None,
labels: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
sample_weights: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]:
device = x.device
if mask is not None:
mask = mask[:, 1:]
relations_labels, head_labels = None, None
@@ -151,7 +155,23 @@ class DependencyRelationModel(base.Predictor):
relation_prediction = self.relation_prediction_layer(dep_rel_pred)
output = head_output
output["prediction"] = (relation_prediction.argmax(-1)[:, 1:], head_output["prediction"])
if self.training:
output["prediction"] = (relation_prediction.argmax(-1)[:, 1:], head_output["prediction"])
else:
# Mask root label whenever head is not 0.
relation_prediction_output = relation_prediction[:, 1:]
mask = (head_output["prediction"] == 0)
vocab_size = relation_prediction_output.size(-1)
root_idx = torch.tensor([self.root_idx], device=device)
relation_prediction_output[mask] = (relation_prediction_output
.masked_select(mask.unsqueeze(-1))
.reshape(-1, vocab_size)
.index_fill(-1, root_idx, 10e10))
relation_prediction_output[~mask] = (relation_prediction_output
.masked_select(~(mask.unsqueeze(-1)))
.reshape(-1, vocab_size)
.index_fill(-1, root_idx, -10e10))
output["prediction"] = (relation_prediction_output.argmax(-1), head_output["prediction"])
if labels is not None and labels[0] is not None:
if sample_weights is None:
@@ -195,5 +215,6 @@ class DependencyRelationModel(base.Predictor):
head_predictor=head_predictor,
head_projection_layer=head_projection_layer,
dependency_projection_layer=dependency_projection_layer,
relation_prediction_layer=relation_prediction_layer
relation_prediction_layer=relation_prediction_layer,
root_idx=vocab.get_token_index("root", vocab_namespace)
)