Skip to content
Snippets Groups Projects
Commit d0f87449 authored by Maja Jabłońska's avatar Maja Jabłońska Committed by Martyna Wiącek
Browse files

Add MorphologicalFeatures model

parent b6614125
No related branches found
No related tags found
1 merge request!46Merge COMBO 3.0 into master
from combo.models.base import Predictor
"""
Adapted from COMBO
Author: Mateusz Klimaszewski
"""
from typing import Dict, List, Optional, Union
import torch
from combo import data
from combo.data import dataset
from combo.models import base, utils
from combo.models.combo_nn import Activation
from combo.utils import ConfigurationError
class MorphologicalFeatures(Predictor):
pass
class MorphologicalFeatures(base.Predictor):
"""Morphological features predicting model."""
def __init__(self, feedforward_network: base.FeedForward, slices: Dict[str, List[int]]):
super().__init__()
self.feedforward_network = feedforward_network
self.slices = slices
def forward(self,
x: Union[torch.Tensor, List[torch.Tensor]],
mask: Optional[torch.BoolTensor] = None,
labels: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
sample_weights: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]:
if mask is None:
mask = x.new_ones(x.size()[:-1])
x, feature_maps = self.feedforward_network(x)
prediction = []
for _, cat_indices in self.slices.items():
prediction.append(x[:, :, cat_indices].argmax(dim=-1))
output = {
"prediction": torch.stack(prediction, dim=-1),
"probability": x,
"embedding": feature_maps[-1],
}
if labels is not None:
if sample_weights is None:
sample_weights = labels.new_ones([mask.size(0)])
output["loss"] = self._loss(x, labels, mask, sample_weights)
return output
def _loss(self, pred: torch.Tensor, true: torch.Tensor, mask: torch.BoolTensor,
sample_weights: torch.Tensor) -> torch.Tensor:
assert pred.size() == true.size()
BATCH_SIZE, _, MORPHOLOGICAL_FEATURES = pred.size()
valid_positions = mask.sum()
pred = pred.reshape(-1, MORPHOLOGICAL_FEATURES)
true = true.reshape(-1, MORPHOLOGICAL_FEATURES)
mask = mask.reshape(-1)
loss = None
loss_func = utils.masked_cross_entropy
for cat, cat_indices in self.slices.items():
if cat not in ["__PAD__", "_"]:
if loss is None:
loss = loss_func(pred[:, cat_indices],
true[:, cat_indices].argmax(dim=1),
mask)
else:
loss += loss_func(pred[:, cat_indices],
true[:, cat_indices].argmax(dim=1),
mask)
loss = loss.reshape(BATCH_SIZE, -1) * sample_weights.unsqueeze(-1)
return loss.sum() / valid_positions
@classmethod
def from_vocab(cls,
vocab: data.Vocabulary,
vocab_namespace: str,
input_dim: int,
num_layers: int,
hidden_dims: List[int],
activations: Union[Activation, List[Activation]],
dropout: Union[float, List[float]] = 0.0,
):
if len(hidden_dims) + 1 != num_layers:
raise ConfigurationError(
f"len(hidden_dims) ({len(hidden_dims):d}) + 1 != num_layers ({num_layers:d})"
)
assert vocab_namespace in vocab.get_namespaces()
hidden_dims = hidden_dims + [vocab.get_vocab_size(vocab_namespace)]
slices = dataset.get_slices_if_not_provided(vocab)
return cls(
feedforward_network=base.FeedForward(
input_dim=input_dim,
num_layers=num_layers,
hidden_dims=hidden_dims,
activations=activations,
dropout=dropout),
slices=slices
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment