diff --git a/combo/models/morpho.py b/combo/models/morpho.py index 4d65686aa2b71e4bca2f9bfcb087b2c4ab032394..5fb9545eeec5bc0049a51abb69ba479817410484 100644 --- a/combo/models/morpho.py +++ b/combo/models/morpho.py @@ -1,5 +1,103 @@ -from combo.models.base import Predictor +""" +Adapted from COMBO +Author: Mateusz Klimaszewski +""" +from typing import Dict, List, Optional, Union +import torch +from combo import data +from combo.data import dataset +from combo.models import base, utils +from combo.models.combo_nn import Activation +from combo.utils import ConfigurationError -class MorphologicalFeatures(Predictor): - pass + +class MorphologicalFeatures(base.Predictor): + """Morphological features predicting model.""" + + def __init__(self, feedforward_network: base.FeedForward, slices: Dict[str, List[int]]): + super().__init__() + self.feedforward_network = feedforward_network + self.slices = slices + + def forward(self, + x: Union[torch.Tensor, List[torch.Tensor]], + mask: Optional[torch.BoolTensor] = None, + labels: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, + sample_weights: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]: + if mask is None: + mask = x.new_ones(x.size()[:-1]) + + x, feature_maps = self.feedforward_network(x) + + prediction = [] + for _, cat_indices in self.slices.items(): + prediction.append(x[:, :, cat_indices].argmax(dim=-1)) + + output = { + "prediction": torch.stack(prediction, dim=-1), + "probability": x, + "embedding": feature_maps[-1], + } + + if labels is not None: + if sample_weights is None: + sample_weights = labels.new_ones([mask.size(0)]) + output["loss"] = self._loss(x, labels, mask, sample_weights) + + return output + + def _loss(self, pred: torch.Tensor, true: torch.Tensor, mask: torch.BoolTensor, + sample_weights: torch.Tensor) -> torch.Tensor: + assert pred.size() == true.size() + BATCH_SIZE, _, MORPHOLOGICAL_FEATURES = pred.size() + + valid_positions = mask.sum() + + pred = pred.reshape(-1, MORPHOLOGICAL_FEATURES) + true = true.reshape(-1, MORPHOLOGICAL_FEATURES) + mask = mask.reshape(-1) + loss = None + loss_func = utils.masked_cross_entropy + for cat, cat_indices in self.slices.items(): + if cat not in ["__PAD__", "_"]: + if loss is None: + loss = loss_func(pred[:, cat_indices], + true[:, cat_indices].argmax(dim=1), + mask) + else: + loss += loss_func(pred[:, cat_indices], + true[:, cat_indices].argmax(dim=1), + mask) + loss = loss.reshape(BATCH_SIZE, -1) * sample_weights.unsqueeze(-1) + return loss.sum() / valid_positions + + @classmethod + def from_vocab(cls, + vocab: data.Vocabulary, + vocab_namespace: str, + input_dim: int, + num_layers: int, + hidden_dims: List[int], + activations: Union[Activation, List[Activation]], + dropout: Union[float, List[float]] = 0.0, + ): + if len(hidden_dims) + 1 != num_layers: + raise ConfigurationError( + f"len(hidden_dims) ({len(hidden_dims):d}) + 1 != num_layers ({num_layers:d})" + ) + + assert vocab_namespace in vocab.get_namespaces() + hidden_dims = hidden_dims + [vocab.get_vocab_size(vocab_namespace)] + + slices = dataset.get_slices_if_not_provided(vocab) + + return cls( + feedforward_network=base.FeedForward( + input_dim=input_dim, + num_layers=num_layers, + hidden_dims=hidden_dims, + activations=activations, + dropout=dropout), + slices=slices + )