Skip to content
Snippets Groups Projects
Commit fe682946 authored by Maja Jabłońska's avatar Maja Jabłońska
Browse files

Add MorphologicalFeatures model

parent bc8bdc77
Branches
Tags
1 merge request!46Merge COMBO 3.0 into master
from combo.models.base import Predictor
"""
Adapted from COMBO
Author: Mateusz Klimaszewski
"""
from typing import Dict, List, Optional, Union
import torch
from combo import data
from combo.data import dataset
from combo.models import base, utils
from combo.models.combo_nn import Activation
from combo.utils import ConfigurationError
class MorphologicalFeatures(Predictor):
pass
class MorphologicalFeatures(base.Predictor):
"""Morphological features predicting model."""
def __init__(self, feedforward_network: base.FeedForward, slices: Dict[str, List[int]]):
super().__init__()
self.feedforward_network = feedforward_network
self.slices = slices
def forward(self,
x: Union[torch.Tensor, List[torch.Tensor]],
mask: Optional[torch.BoolTensor] = None,
labels: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
sample_weights: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]:
if mask is None:
mask = x.new_ones(x.size()[:-1])
x, feature_maps = self.feedforward_network(x)
prediction = []
for _, cat_indices in self.slices.items():
prediction.append(x[:, :, cat_indices].argmax(dim=-1))
output = {
"prediction": torch.stack(prediction, dim=-1),
"probability": x,
"embedding": feature_maps[-1],
}
if labels is not None:
if sample_weights is None:
sample_weights = labels.new_ones([mask.size(0)])
output["loss"] = self._loss(x, labels, mask, sample_weights)
return output
def _loss(self, pred: torch.Tensor, true: torch.Tensor, mask: torch.BoolTensor,
sample_weights: torch.Tensor) -> torch.Tensor:
assert pred.size() == true.size()
BATCH_SIZE, _, MORPHOLOGICAL_FEATURES = pred.size()
valid_positions = mask.sum()
pred = pred.reshape(-1, MORPHOLOGICAL_FEATURES)
true = true.reshape(-1, MORPHOLOGICAL_FEATURES)
mask = mask.reshape(-1)
loss = None
loss_func = utils.masked_cross_entropy
for cat, cat_indices in self.slices.items():
if cat not in ["__PAD__", "_"]:
if loss is None:
loss = loss_func(pred[:, cat_indices],
true[:, cat_indices].argmax(dim=1),
mask)
else:
loss += loss_func(pred[:, cat_indices],
true[:, cat_indices].argmax(dim=1),
mask)
loss = loss.reshape(BATCH_SIZE, -1) * sample_weights.unsqueeze(-1)
return loss.sum() / valid_positions
@classmethod
def from_vocab(cls,
vocab: data.Vocabulary,
vocab_namespace: str,
input_dim: int,
num_layers: int,
hidden_dims: List[int],
activations: Union[Activation, List[Activation]],
dropout: Union[float, List[float]] = 0.0,
):
if len(hidden_dims) + 1 != num_layers:
raise ConfigurationError(
f"len(hidden_dims) ({len(hidden_dims):d}) + 1 != num_layers ({num_layers:d})"
)
assert vocab_namespace in vocab.get_namespaces()
hidden_dims = hidden_dims + [vocab.get_vocab_size(vocab_namespace)]
slices = dataset.get_slices_if_not_provided(vocab)
return cls(
feedforward_network=base.FeedForward(
input_dim=input_dim,
num_layers=num_layers,
hidden_dims=hidden_dims,
activations=activations,
dropout=dropout),
slices=slices
)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment