From 74ff65e0e2a3da580fc028281c9ec2d8bff32425 Mon Sep 17 00:00:00 2001
From: Maja Jablonska <majajjablonska@gmail.com>
Date: Sun, 19 Nov 2023 19:22:53 +1100
Subject: [PATCH] Minor fixes

---
 combo/models/base.py                          | 274 ---------
 combo/models/embeddings.py                    | 221 --------
 combo/models/graph_parser.py                  |  21 +-
 combo/models/lemma.py                         | 107 ----
 combo/models/morpho.py                        | 103 ----
 combo/models/time_distributed.py              |  74 +++
 combo/modules/lemma.py                        |   2 +-
 .../basic_text_field_embedder.py              |   2 +-
 .../modules/token_embedders/token_embedder.py |   2 +-
 .../model_training.ipynb                      | 536 ++++++++++--------
 tests/data/tokenizers/test_spacy_tokenizer.py |   6 -
 11 files changed, 387 insertions(+), 961 deletions(-)
 delete mode 100644 combo/models/base.py
 delete mode 100644 combo/models/embeddings.py
 delete mode 100644 combo/models/lemma.py
 delete mode 100644 combo/models/morpho.py
 create mode 100644 combo/models/time_distributed.py
 rename combo/polish_model_training.ipynb => notebooks/model_training.ipynb (64%)

diff --git a/combo/models/base.py b/combo/models/base.py
deleted file mode 100644
index ad2c88c..0000000
--- a/combo/models/base.py
+++ /dev/null
@@ -1,274 +0,0 @@
-from typing import Dict, Optional, List, Union, Tuple
-
-import torch
-import torch.nn as nn
-from overrides import overrides
-
-from combo.nn import Activation
-import combo.utils.checks as checks
-from combo.data.vocabulary import Vocabulary
-from combo.models.utils import masked_cross_entropy
-from combo.predictors.predictor import Predictor
-
-
-class Linear(nn.Linear):
-    def __init__(self,
-                 in_features: int,
-                 out_features: int,
-                 activation: Optional[Activation] = None,
-                 dropout_rate: Optional[float] = 0.0):
-        super().__init__(in_features, out_features)
-        self.activation = activation if activation else self.identity
-        self.dropout = nn.Dropout(p=dropout_rate) if dropout_rate else self.identity
-
-    def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
-        x = super().forward(x)
-        x = self.activation(x)
-        return self.dropout(x)
-
-    def get_output_dim(self) -> int:
-        return self.out_features
-
-    @staticmethod
-    def identity(x):
-        return x
-
-
-class FeedForward(torch.nn.Module):
-    """
-    Modified copy of allennlp.modules.feedforward.FeedForward
-
-    This `Module` is a feed-forward neural network, just a sequence of `Linear` layers with
-    activation functions in between.
-
-    # Parameters
-
-    input_dim : `int`, required
-        The dimensionality of the input.  We assume the input has shape `(batch_size, input_dim)`.
-    num_layers : `int`, required
-        The number of `Linear` layers to apply to the input.
-    hidden_dims : `Union[int, List[int]]`, required
-        The output dimension of each of the `Linear` layers.  If this is a single `int`, we use
-        it for all `Linear` layers.  If it is a `List[int]`, `len(hidden_dims)` must be
-        `num_layers`.
-    activations : `Union[Activation, List[Activation]]`, required
-        The activation function to use after each `Linear` layer.  If this is a single function,
-        we use it after all `Linear` layers.  If it is a `List[Activation]`,
-        `len(activations)` must be `num_layers`. Activation must have torch.nn.Module type.
-    dropout : `Union[float, List[float]]`, optional (default = `0.0`)
-        If given, we will apply this amount of dropout after each layer.  Semantics of `float`
-        versus `List[float]` is the same as with other parameters.
-
-    # Examples
-
-    ```python
-    FeedForward(124, 2, [64, 32], torch.nn.ReLU(), 0.2)
-    #> FeedForward(
-    #>   (_activations): ModuleList(
-    #>     (0): ReLU()
-    #>     (1): ReLU()
-    #>   )
-    #>   (_linear_layers): ModuleList(
-    #>     (0): Linear(in_features=124, out_features=64, bias=True)
-    #>     (1): Linear(in_features=64, out_features=32, bias=True)
-    #>   )
-    #>   (_dropout): ModuleList(
-    #>     (0): Dropout(p=0.2, inplace=False)
-    #>     (1): Dropout(p=0.2, inplace=False)
-    #>   )
-    #> )
-    ```
-    """
-
-    def __init__(
-            self,
-            input_dim: int,
-            num_layers: int,
-            hidden_dims: Union[int, List[int]],
-            activations: Union[Activation, List[Activation]],
-            dropout: Union[float, List[float]] = 0.0,
-    ) -> None:
-
-        super().__init__()
-        if not isinstance(hidden_dims, list):
-            hidden_dims = [hidden_dims] * num_layers  # type: ignore
-        if not isinstance(activations, list):
-            activations = [activations] * num_layers  # type: ignore
-        if not isinstance(dropout, list):
-            dropout = [dropout] * num_layers  # type: ignore
-        if len(hidden_dims) != num_layers:
-            raise checks.ConfigurationError(
-                "len(hidden_dims) (%d) != num_layers (%d)" % (len(hidden_dims), num_layers)
-            )
-        if len(activations) != num_layers:
-            raise checks.ConfigurationError(
-                "len(activations) (%d) != num_layers (%d)" % (len(activations), num_layers)
-            )
-        if len(dropout) != num_layers:
-            raise checks.ConfigurationError(
-                "len(dropout) (%d) != num_layers (%d)" % (len(dropout), num_layers)
-            )
-        self._activations = torch.nn.ModuleList(activations)
-        input_dims = [input_dim] + hidden_dims[:-1]
-        linear_layers = []
-        for layer_input_dim, layer_output_dim in zip(input_dims, hidden_dims):
-            linear_layers.append(torch.nn.Linear(layer_input_dim, layer_output_dim))
-        self._linear_layers = torch.nn.ModuleList(linear_layers)
-        dropout_layers = [torch.nn.Dropout(p=value) for value in dropout]
-        self._dropout = torch.nn.ModuleList(dropout_layers)
-        self._output_dim = hidden_dims[-1]
-        self.input_dim = input_dim
-
-    def get_output_dim(self):
-        return self._output_dim
-
-    def get_input_dim(self):
-        return self.input_dim
-
-    def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
-
-        output = inputs
-        feature_maps = []
-        for layer, activation, dropout in zip(
-                self._linear_layers, self._activations, self._dropout
-        ):
-            feature_maps.append(output)
-            output = dropout(activation(layer(output)))
-        return output, feature_maps
-
-
-class FeedForwardPredictor(Predictor):
-    """Feedforward predictor. Should be used on top of Seq2Seq encoder."""
-
-    def __init__(self, feedforward_network: "FeedForward"):
-        super().__init__()
-        self.feedforward_network = feedforward_network
-
-    def forward(self,
-                x: Union[torch.Tensor, List[torch.Tensor]],
-                mask: Optional[torch.BoolTensor] = None,
-                labels: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
-                sample_weights: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]:
-        if mask is None:
-            mask = x.new_ones(x.size()[:-1])
-
-        x, feature_maps = self.feedforward_network(x)
-        output = {
-            "prediction": x.argmax(-1),
-            "probability": x,
-            "embedding": feature_maps[-1],
-        }
-
-        if labels is not None:
-            if sample_weights is None:
-                sample_weights = labels.new_ones([mask.size(0)])
-            output["loss"] = self._loss(x, labels, mask, sample_weights)
-
-        return output
-
-    def _loss(self,
-              pred: torch.Tensor,
-              true: torch.Tensor,
-              mask: torch.BoolTensor,
-              sample_weights: torch.Tensor) -> torch.Tensor:
-        BATCH_SIZE, _, CLASSES = pred.size()
-        valid_positions = mask.sum()
-        pred = pred.reshape(-1, CLASSES)
-        true = true.reshape(-1)
-        mask = mask.reshape(-1)
-        loss = masked_cross_entropy(pred, true, mask)
-        loss = loss.reshape(BATCH_SIZE, -1) * sample_weights.unsqueeze(-1)
-        return loss.sum() / valid_positions
-
-    @classmethod
-    def from_vocab(cls,
-                   vocab: Vocabulary,
-                   vocab_namespace: str,
-                   input_dim: int,
-                   num_layers: int,
-                   hidden_dims: List[int],
-                   activations: Union[Activation, List[Activation]],
-                   dropout: Union[float, List[float]] = 0.0,
-                   ):
-        if len(hidden_dims) + 1 != num_layers:
-            raise checks.ConfigurationError(
-                f"len(hidden_dims) ({len(hidden_dims):d}) + 1 != num_layers ({num_layers:d})"
-            )
-
-        assert vocab_namespace in vocab.get_namespaces(), \
-            f"There is not {vocab_namespace} in created vocabs, check if this field has any values to predict!"
-        hidden_dims = hidden_dims + [vocab.get_vocab_size(vocab_namespace)]
-
-        return cls(FeedForward(
-            input_dim=input_dim,
-            num_layers=num_layers,
-            hidden_dims=hidden_dims,
-            activations=activations,
-            dropout=dropout))
-
-
-"""
-Adapted from AllenNLP
-"""
-
-
-class TimeDistributed(torch.nn.Module):
-    """
-    Given an input shaped like `(batch_size, time_steps, [rest])` and a `Module` that takes
-    inputs like `(batch_size, [rest])`, `TimeDistributed` reshapes the input to be
-    `(batch_size * time_steps, [rest])`, applies the contained `Module`, then reshapes it back.
-
-    Note that while the above gives shapes with `batch_size` first, this `Module` also works if
-    `batch_size` is second - we always just combine the first two dimensions, then split them.
-
-    It also reshapes keyword arguments unless they are not tensors or their name is specified in
-    the optional `pass_through` iterable.
-    """
-
-    def __init__(self, module):
-        super().__init__()
-        self._module = module
-
-    @overrides
-    def forward(self, *inputs, pass_through: List[str] = None, **kwargs):
-
-        pass_through = pass_through or []
-
-        reshaped_inputs = [self._reshape_tensor(input_tensor) for input_tensor in inputs]
-
-        # Need some input to then get the batch_size and time_steps.
-        some_input = None
-        if inputs:
-            some_input = inputs[-1]
-
-        reshaped_kwargs = {}
-        for key, value in kwargs.items():
-            if isinstance(value, torch.Tensor) and key not in pass_through:
-                if some_input is None:
-                    some_input = value
-
-                value = self._reshape_tensor(value)
-
-            reshaped_kwargs[key] = value
-
-        reshaped_outputs = self._module(*reshaped_inputs, **reshaped_kwargs)
-
-        if some_input is None:
-            raise RuntimeError("No input tensor to time-distribute")
-
-        # Now get the output back into the right shape.
-        # (batch_size, time_steps, **output_size)
-        new_size = some_input.size()[:2] + reshaped_outputs.size()[1:]
-        outputs = reshaped_outputs.contiguous().view(new_size)
-
-        return outputs
-
-    @staticmethod
-    def _reshape_tensor(input_tensor):
-        input_size = input_tensor.size()
-        if len(input_size) <= 2:
-            raise RuntimeError(f"No dimension to distribute: {input_size}")
-        # Squash batch_size and time_steps into a single axis; result has shape
-        # (batch_size * time_steps, **input_size).
-        squashed_shape = [-1] + list(input_size[2:])
-        return input_tensor.contiguous().view(*squashed_shape)
diff --git a/combo/models/embeddings.py b/combo/models/embeddings.py
deleted file mode 100644
index 46bb709..0000000
--- a/combo/models/embeddings.py
+++ /dev/null
@@ -1,221 +0,0 @@
-from typing import Optional
-
-import torch
-from overrides import overrides
-from torch import nn
-from torchtext.vocab import Vectors, GloVe, FastText, CharNGram
-
-from combo.data import Vocabulary
-from combo.models.base import TimeDistributed
-from combo.models.dilated_cnn import DilatedCnnEncoder
-from combo.models.utils import tiny_value_of_dtype
-from combo.utils import ConfigurationError
-
-
-class TokenEmbedder(nn.Module):
-    def __init__(self):
-        super(TokenEmbedder, self).__init__()
-
-    @property
-    def output_dim(self) -> int:
-        raise NotImplementedError()
-
-    def forward(self,
-                x: torch.Tensor,
-                char_mask: Optional[torch.BoolTensor] = None) -> torch.Tensor:
-        raise NotImplementedError()
-
-
-class _TorchEmbedder(TokenEmbedder):
-    def __init__(self,
-                 num_embeddings: int,
-                 embedding_dim: int,
-                 padding_idx: Optional[int] = None,
-                 max_norm: Optional[float] = None,
-                 norm_type: float = 2.,
-                 scale_grad_by_freq: bool = False,
-                 sparse: bool = False,
-                 vocab_namespace: str = "tokens",
-                 vocab: Vocabulary = None,
-                 weight: Optional[torch.Tensor] = None,
-                 trainable: bool = True,
-                 projection_dim: Optional[int] = None):
-        super(_TorchEmbedder, self).__init__()
-        self._embedding_dim = embedding_dim
-        self._embedding = nn.Embedding(num_embeddings=num_embeddings,
-                                       embedding_dim=embedding_dim,
-                                       padding_idx=padding_idx,
-                                       max_norm=max_norm,
-                                       norm_type=norm_type,
-                                       scale_grad_by_freq=scale_grad_by_freq,
-                                       sparse=sparse)
-        self.__vocab_namespace = vocab_namespace
-        self.__vocab = vocab
-
-        if weight is not None:
-            if weight.shape() != (num_embeddings, embedding_dim):
-                raise ConfigurationError(
-                    "Weight matrix must be of shape (num_embeddings, embedding_dim)." +
-                    f"Got: ({weight.shape()})"
-                )
-
-            self.__weight = torch.nn.Parameter(weight, requires_grad=trainable)
-        else:
-            self.__weight = torch.nn.Parameter(torch.FloatTensor(num_embeddings, embedding_dim),
-                                               requires_grad=trainable)
-            torch.nn.init.xavier_uniform_(self.__weight)
-
-        if padding_idx is not None:
-            self.__weight.data[padding_idx].fill_(0)
-
-        if projection_dim:
-            self._projection = torch.nn.Linear(embedding_dim, projection_dim)
-            self._output_dim = projection_dim
-        else:
-            self._projection = None
-            self._output_dim = embedding_dim
-
-    @overrides
-    def output_dim(self) -> int:
-        return self._output_dim
-
-    @overrides
-    def forward(self,
-                x: torch.Tensor,
-                char_mask: Optional[torch.BoolTensor] = None) -> torch.Tensor:
-        embedded = self._embedding(x)
-        if self._projection:
-            projection = self._projection
-            for p in range(embedded.dim()-2):
-                projection = TimeDistributed(p)
-            embedded = projection(embedded)
-        return embedded
-
-
-class _TorchtextVectorsEmbedder(TokenEmbedder):
-    """
-    Torchtext Vectors object wrapper
-    """
-
-    def __init__(self,
-                 torchtext_embedder: Vectors,
-                 lower_case_backup: bool = False):
-        """
-        :param torchtext_embedder: Torchtext Vectors object
-        :param lower_case_backup: whether to look up the token in the
-        lower case. Default: False.
-        """
-        super(_TorchtextVectorsEmbedder, self).__init__()
-        self.__torchtext_embedder = torchtext_embedder
-        self.__lower_case_backup = lower_case_backup
-
-    @overrides
-    def output_dim(self) -> int:
-        return len(self.__torchtext_embedder)
-
-    @overrides
-    def forward(self,
-                x: torch.Tensor,
-                char_mask: Optional[torch.BoolTensor] = None) -> torch.Tensor:
-        return self.__torchtext_embedder.get_vecs_by_tokens(x, self.__lower_case_backup)
-
-
-class GloVe42BEmbedder(_TorchtextVectorsEmbedder):
-    def __init__(self, dim: int = 300):
-        super(GloVe42BEmbedder, self).__init__(GloVe("42B", dim))
-
-
-class GloVe840BEmbedder(_TorchtextVectorsEmbedder):
-    def __init__(self, dim: int = 300):
-        super(GloVe840BEmbedder, self).__init__(GloVe("840B", dim))
-
-
-class GloVeTwitter27BEmbedder(_TorchtextVectorsEmbedder):
-    def __init__(self, dim: int = 300):
-        super(GloVeTwitter27BEmbedder, self).__init__(GloVe("twitter.27B", dim))
-
-
-class GloVe6BEmbedder(_TorchtextVectorsEmbedder):
-    def __init__(self, dim: int = 300):
-        super(GloVe6BEmbedder, self).__init__(GloVe("6B", dim))
-
-
-class FastTextEmbedder(_TorchtextVectorsEmbedder):
-    def __init__(self, language: str = "en"):
-        super(FastTextEmbedder, self).__init__(FastText(language))
-
-
-class CharNGramEmbedder(_TorchtextVectorsEmbedder):
-    def __init__(self):
-        super(CharNGramEmbedder, self).__init__(CharNGram())
-
-
-class CharacterBasedWordEmbedder(TokenEmbedder):
-    def __init__(self,
-                 num_embeddings: int,
-                 embedding_dim: int,
-                 dilated_cnn_encoder: DilatedCnnEncoder):
-        super(CharacterBasedWordEmbedder, self).__init__()
-        self.__embedding_dim = embedding_dim
-        self.__dilated_cnn_encoder = dilated_cnn_encoder
-        self.char_embed = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
-
-    @overrides
-    def output_dim(self) -> int:
-        return self.__embedding_dim
-
-    @overrides
-    def forward(self,
-                x: torch.Tensor,
-                char_mask: Optional[torch.BoolTensor] = None) -> torch.Tensor:
-        if char_mask is None:
-            char_mask = x.new_ones(x.size())
-
-        x = self.char_embed(x)
-        x = x * char_mask.unsqueeze(-1).float()
-        x = self.__dilated_cnn_encoder(x.transpose(2, 3))
-        return torch.max(x, dim=-1)[0]
-
-
-class PretrainedTransformerMismatchedEmbedder(TokenEmbedder):
-    pass
-
-
-class TransformersWordEmbedder(PretrainedTransformerMismatchedEmbedder):
-    pass
-
-
-class FeatsTokenEmbedder(_TorchEmbedder):
-    def __init__(self,
-                 num_embeddings: int,
-                 embedding_dim: int,
-                 padding_idx: Optional[int] = None,
-                 max_norm: Optional[float] = None,
-                 norm_type: float = 2.,
-                 scale_grad_by_freq: bool = False,
-                 sparse: bool = False,
-                 vocab_namespace: str = "feats",
-                 vocab: Vocabulary = None,
-                 weight: Optional[torch.Tensor] = None,
-                 trainable: bool = True):
-        super(FeatsTokenEmbedder, self).__init__(num_embeddings,
-                                                 embedding_dim,
-                                                 padding_idx,
-                                                 max_norm,
-                                                 norm_type,
-                                                 scale_grad_by_freq,
-                                                 sparse,
-                                                 vocab_namespace,
-                                                 vocab,
-                                                 weight,
-                                                 trainable)
-
-    @overrides
-    def forward(self,
-                x: torch.Tensor,
-                char_mask: Optional[torch.BoolTensor] = None) -> torch.Tensor:
-        mask = x.gt(0)
-        x = super().forward(x)
-        return x.sum(dim=-2)/(
-            (mask.sum(dim=-1)+tiny_value_of_dtype(torch.float)).unsqueeze(dim=-1)
-        )
diff --git a/combo/models/graph_parser.py b/combo/models/graph_parser.py
index 6dffef5..7fc4a1d 100644
--- a/combo/models/graph_parser.py
+++ b/combo/models/graph_parser.py
@@ -6,19 +6,20 @@ Author: Mateusz Klimaszewski
 from typing import List, Optional, Union, Tuple, Dict
 
 from combo import data
-from combo.models import base
-from combo.models.base import Predictor
+from combo.predictors import Predictor
 
 import torch
 import torch.nn.functional as F
 
+from combo.nn.base import Linear
+
 
 class GraphHeadPredictionModel(Predictor):
     """Head prediction model."""
 
     def __init__(self,
-                 head_projection_layer: base.Linear,
-                 dependency_projection_layer: base.Linear,
+                 head_projection_layer: Linear,
+                 dependency_projection_layer: Linear,
                  cycle_loss_n: int = 0,
                  graph_weighting: float = 0.2):
         super().__init__()
@@ -107,9 +108,9 @@ class GraphDependencyRelationModel(Predictor):
 
     def __init__(self,
                  head_predictor: GraphHeadPredictionModel,
-                 head_projection_layer: base.Linear,
-                 dependency_projection_layer: base.Linear,
-                 relation_prediction_layer: base.Linear):
+                 head_projection_layer: Linear,
+                 dependency_projection_layer: Linear,
+                 relation_prediction_layer: Linear):
         super().__init__()
         self.head_predictor = head_predictor
         self.head_projection_layer = head_projection_layer
@@ -173,12 +174,12 @@ class GraphDependencyRelationModel(Predictor):
                    vocab: data.Vocabulary,
                    vocab_namespace: str,
                    head_predictor: GraphHeadPredictionModel,
-                   head_projection_layer: base.Linear,
-                   dependency_projection_layer: base.Linear
+                   head_projection_layer: Linear,
+                   dependency_projection_layer: Linear
                    ):
         """Creates parser combining model configuration and vocabulary data."""
         assert vocab_namespace in vocab.get_namespaces()
-        relation_prediction_layer = base.Linear(
+        relation_prediction_layer = Linear(
             in_features=head_projection_layer.get_output_dim() + dependency_projection_layer.get_output_dim(),
             out_features=vocab.get_vocab_size(vocab_namespace)
         )
diff --git a/combo/models/lemma.py b/combo/models/lemma.py
deleted file mode 100644
index b1293e0..0000000
--- a/combo/models/lemma.py
+++ /dev/null
@@ -1,107 +0,0 @@
-from typing import Optional, Dict, List, Union
-
-import torch
-import torch.nn as nn
-
-from combo import data
-from combo.models import dilated_cnn, base, utils
-from combo.models.base import Predictor, TimeDistributed
-from combo.nn import Activation
-from combo.utils import ConfigurationError
-
-
-class LemmatizerModel(Predictor):
-    """Lemmatizer model."""
-
-    def __init__(self,
-                 num_embeddings: int,
-                 embedding_dim: int,
-                 dilated_cnn_encoder: dilated_cnn.DilatedCnnEncoder,
-                 input_projection_layer: base.Linear):
-        super().__init__()
-        self.char_embed = nn.Embedding(
-            num_embeddings=num_embeddings,
-            embedding_dim=embedding_dim,
-        )
-        self.dilated_cnn_encoder = TimeDistributed(dilated_cnn_encoder)
-        self.input_projection_layer = input_projection_layer
-
-    def forward(self,
-                x: Union[torch.Tensor, List[torch.Tensor]],
-                mask: Optional[torch.BoolTensor] = None,
-                labels: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
-                sample_weights: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]:
-        encoder_emb, chars = x
-
-        encoder_emb = self.input_projection_layer(encoder_emb)
-        char_embeddings = self.char_embed(chars)
-
-        BATCH_SIZE, _, MAX_WORD_LENGTH, CHAR_EMB = char_embeddings.size()
-        encoder_emb = encoder_emb.unsqueeze(2).repeat(1, 1, MAX_WORD_LENGTH, 1)
-
-        x = torch.cat((char_embeddings, encoder_emb), dim=-1).transpose(2, 3)
-        x = self.dilated_cnn_encoder(x).transpose(2, 3)
-        output = {
-            "prediction": x.argmax(-1),
-            "probability": x
-        }
-
-        if labels is not None:
-            if mask is None:
-                mask = encoder_emb.new_ones(encoder_emb.size()[:-2])
-            if sample_weights is None:
-                sample_weights = labels.new_ones(BATCH_SIZE)
-            mask = mask.unsqueeze(2).repeat(1, 1, MAX_WORD_LENGTH).bool()
-            output["loss"] = self._loss(x, labels, mask, sample_weights)
-
-        return output
-
-    @staticmethod
-    def _loss(pred: torch.Tensor, true: torch.Tensor, mask: torch.BoolTensor,
-              sample_weights: torch.Tensor) -> torch.Tensor:
-        BATCH_SIZE, SENTENCE_LENGTH, MAX_WORD_LENGTH, CHAR_CLASSES = pred.size()
-        pred = pred.reshape(-1, CHAR_CLASSES)
-
-        true = true.reshape(-1)
-        mask = true.gt(0)
-        loss = utils.masked_cross_entropy(pred, true, mask)
-        loss = loss.reshape(BATCH_SIZE, -1) * sample_weights.unsqueeze(-1)
-        valid_positions = mask.sum()
-        return loss.sum() / valid_positions
-
-    @classmethod
-    def from_vocab(cls,
-                   vocab: data.Vocabulary,
-                   char_vocab_namespace: str,
-                   lemma_vocab_namespace: str,
-                   embedding_dim: int,
-                   input_projection_layer: base.Linear,
-                   filters: List[int],
-                   kernel_size: List[int],
-                   stride: List[int],
-                   padding: List[int],
-                   dilation: List[int],
-                   activations: List[Activation],
-                   ):
-        assert char_vocab_namespace in vocab.get_namespaces()
-        assert lemma_vocab_namespace in vocab.get_namespaces()
-
-        if len(filters) + 1 != len(kernel_size):
-            raise ConfigurationError(
-                f"len(filters) ({len(filters):d}) + 1 != kernel_size ({len(kernel_size):d})"
-            )
-        filters = filters + [vocab.get_vocab_size(lemma_vocab_namespace)]
-
-        dilated_cnn_encoder = dilated_cnn.DilatedCnnEncoder(
-            input_dim=embedding_dim + input_projection_layer.get_output_dim(),
-            filters=filters,
-            kernel_size=kernel_size,
-            stride=stride,
-            padding=padding,
-            dilation=dilation,
-            activations=activations,
-        )
-        return cls(num_embeddings=vocab.get_vocab_size(char_vocab_namespace),
-                   embedding_dim=embedding_dim,
-                   dilated_cnn_encoder=dilated_cnn_encoder,
-                   input_projection_layer=input_projection_layer)
diff --git a/combo/models/morpho.py b/combo/models/morpho.py
deleted file mode 100644
index 2415fb5..0000000
--- a/combo/models/morpho.py
+++ /dev/null
@@ -1,103 +0,0 @@
-"""
-Adapted from COMBO
-Author: Mateusz Klimaszewski
-"""
-from typing import Dict, List, Optional, Union
-import torch
-
-from combo import data
-from combo.data import dataset
-from combo.models import base, utils
-from combo.nn import Activation
-from combo.utils import ConfigurationError
-
-
-class MorphologicalFeatures(base.Predictor):
-    """Morphological features predicting model."""
-
-    def __init__(self, feedforward_network: base.FeedForward, slices: Dict[str, List[int]]):
-        super().__init__()
-        self.feedforward_network = feedforward_network
-        self.slices = slices
-
-    def forward(self,
-                x: Union[torch.Tensor, List[torch.Tensor]],
-                mask: Optional[torch.BoolTensor] = None,
-                labels: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
-                sample_weights: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]:
-        if mask is None:
-            mask = x.new_ones(x.size()[:-1])
-
-        x, feature_maps = self.feedforward_network(x)
-
-        prediction = []
-        for _, cat_indices in self.slices.items():
-            prediction.append(x[:, :, cat_indices].argmax(dim=-1))
-
-        output = {
-            "prediction": torch.stack(prediction, dim=-1),
-            "probability": x,
-            "embedding": feature_maps[-1],
-        }
-
-        if labels is not None:
-            if sample_weights is None:
-                sample_weights = labels.new_ones([mask.size(0)])
-            output["loss"] = self._loss(x, labels, mask, sample_weights)
-
-        return output
-
-    def _loss(self, pred: torch.Tensor, true: torch.Tensor, mask: torch.BoolTensor,
-              sample_weights: torch.Tensor) -> torch.Tensor:
-        assert pred.size() == true.size()
-        BATCH_SIZE, _, MORPHOLOGICAL_FEATURES = pred.size()
-
-        valid_positions = mask.sum()
-
-        pred = pred.reshape(-1, MORPHOLOGICAL_FEATURES)
-        true = true.reshape(-1, MORPHOLOGICAL_FEATURES)
-        mask = mask.reshape(-1)
-        loss = None
-        loss_func = utils.masked_cross_entropy
-        for cat, cat_indices in self.slices.items():
-            if cat not in ["__PAD__", "_"]:
-                if loss is None:
-                    loss = loss_func(pred[:, cat_indices],
-                                     true[:, cat_indices].argmax(dim=1),
-                                     mask)
-                else:
-                    loss += loss_func(pred[:, cat_indices],
-                                      true[:, cat_indices].argmax(dim=1),
-                                      mask)
-        loss = loss.reshape(BATCH_SIZE, -1) * sample_weights.unsqueeze(-1)
-        return loss.sum() / valid_positions
-
-    @classmethod
-    def from_vocab(cls,
-                   vocab: data.Vocabulary,
-                   vocab_namespace: str,
-                   input_dim: int,
-                   num_layers: int,
-                   hidden_dims: List[int],
-                   activations: Union[Activation, List[Activation]],
-                   dropout: Union[float, List[float]] = 0.0,
-                   ):
-        if len(hidden_dims) + 1 != num_layers:
-            raise ConfigurationError(
-                f"len(hidden_dims) ({len(hidden_dims):d}) + 1 != num_layers ({num_layers:d})"
-            )
-
-        assert vocab_namespace in vocab.get_namespaces()
-        hidden_dims = hidden_dims + [vocab.get_vocab_size(vocab_namespace)]
-
-        slices = dataset.get_slices_if_not_provided(vocab)
-
-        return cls(
-            feedforward_network=base.FeedForward(
-                input_dim=input_dim,
-                num_layers=num_layers,
-                hidden_dims=hidden_dims,
-                activations=activations,
-                dropout=dropout),
-            slices=slices
-        )
diff --git a/combo/models/time_distributed.py b/combo/models/time_distributed.py
new file mode 100644
index 0000000..d8c20d5
--- /dev/null
+++ b/combo/models/time_distributed.py
@@ -0,0 +1,74 @@
+"""
+Adapted from AllenNLP
+"""
+from typing import List
+
+import torch
+from overrides import overrides
+
+from combo.config.registry import Registry
+from combo.config.from_parameters import FromParameters, register_arguments
+
+
+@Registry.register('time_distributed')
+class TimeDistributed(torch.nn.Module, FromParameters):
+    """
+    Given an input shaped like `(batch_size, time_steps, [rest])` and a `Module` that takes
+    inputs like `(batch_size, [rest])`, `TimeDistributed` reshapes the input to be
+    `(batch_size * time_steps, [rest])`, applies the contained `Module`, then reshapes it back.
+
+    Note that while the above gives shapes with `batch_size` first, this `Module` also works if
+    `batch_size` is second - we always just combine the first two dimensions, then split them.
+
+    It also reshapes keyword arguments unless they are not tensors or their name is specified in
+    the optional `pass_through` iterable.
+    """
+
+    @register_arguments
+    def __init__(self, module):
+        super().__init__()
+        self._module = module
+
+    @overrides
+    def forward(self, *inputs, pass_through: List[str] = None, **kwargs):
+
+        pass_through = pass_through or []
+
+        reshaped_inputs = [self._reshape_tensor(input_tensor) for input_tensor in inputs]
+
+        # Need some input to then get the batch_size and time_steps.
+        some_input = None
+        if inputs:
+            some_input = inputs[-1]
+
+        reshaped_kwargs = {}
+        for key, value in kwargs.items():
+            if isinstance(value, torch.Tensor) and key not in pass_through:
+                if some_input is None:
+                    some_input = value
+
+                value = self._reshape_tensor(value)
+
+            reshaped_kwargs[key] = value
+
+        reshaped_outputs = self._module(*reshaped_inputs, **reshaped_kwargs)
+
+        if some_input is None:
+            raise RuntimeError("No input tensor to time-distribute")
+
+        # Now get the output back into the right shape.
+        # (batch_size, time_steps, **output_size)
+        new_size = some_input.size()[:2] + reshaped_outputs.size()[1:]
+        outputs = reshaped_outputs.contiguous().view(new_size)
+
+        return outputs
+
+    @staticmethod
+    def _reshape_tensor(input_tensor):
+        input_size = input_tensor.size()
+        if len(input_size) <= 2:
+            raise RuntimeError(f"No dimension to distribute: {input_size}")
+        # Squash batch_size and time_steps into a single axis; result has shape
+        # (batch_size * time_steps, **input_size).
+        squashed_shape = [-1] + list(input_size[2:])
+        return input_tensor.contiguous().view(*squashed_shape)
diff --git a/combo/modules/lemma.py b/combo/modules/lemma.py
index 960f382..8fa959b 100644
--- a/combo/modules/lemma.py
+++ b/combo/modules/lemma.py
@@ -12,7 +12,7 @@ from combo.nn import base
 from combo.nn.activations import Activation
 from combo.nn.utils import masked_cross_entropy
 from combo.utils import ConfigurationError
-from combo.models.base import TimeDistributed
+from combo.models.time_distributed import TimeDistributed
 from combo.predictors import Predictor
 
 
diff --git a/combo/modules/text_field_embedders/basic_text_field_embedder.py b/combo/modules/text_field_embedders/basic_text_field_embedder.py
index dea5f8a..545bce3 100644
--- a/combo/modules/text_field_embedders/basic_text_field_embedder.py
+++ b/combo/modules/text_field_embedders/basic_text_field_embedder.py
@@ -15,7 +15,7 @@ from combo.modules.text_field_embedders.text_field_embedder import TextFieldEmbe
 from combo.modules.token_embedders import EmptyEmbedder
 from combo.modules.token_embedders.token_embedder import TokenEmbedder
 from combo.utils import ConfigurationError
-from combo.models.base import TimeDistributed
+from combo.models.time_distributed import TimeDistributed
 
 
 @Registry.register("base_text_field_embedder")
diff --git a/combo/modules/token_embedders/token_embedder.py b/combo/modules/token_embedders/token_embedder.py
index 0417713..0a60351 100644
--- a/combo/modules/token_embedders/token_embedder.py
+++ b/combo/modules/token_embedders/token_embedder.py
@@ -12,7 +12,7 @@ from combo.data import Vocabulary
 from combo.nn.utils import tiny_value_of_dtype, uncombine_initial_dims, combine_initial_dims
 from combo.modules.module import Module
 from combo.utils import ConfigurationError
-from combo.models.base import TimeDistributed
+from models.time_distributed import TimeDistributed
 
 
 class TokenEmbedder(Module, FromParameters):
diff --git a/combo/polish_model_training.ipynb b/notebooks/model_training.ipynb
similarity index 64%
rename from combo/polish_model_training.ipynb
rename to notebooks/model_training.ipynb
index 005fe4b..99c3a7b 100644
--- a/combo/polish_model_training.ipynb
+++ b/notebooks/model_training.ipynb
@@ -1,83 +1,179 @@
 {
  "cells": [
+  {
+   "cell_type": "markdown",
+   "source": [
+    "# Training the model\n",
+    "\n",
+    "Apart from training using the CLI, COMBO can be used as a Python package.\n",
+    "\n",
+    "This notebook will demonstrate how to train a model from scratch."
+   ],
+   "metadata": {
+    "collapsed": false
+   },
+   "id": "1d09bbd85e03d60c"
+  },
   {
    "cell_type": "code",
-   "execution_count": 1,
+   "execution_count": 17,
    "outputs": [],
    "source": [
-    "# The path where the training and validation datasets are stored\n",
+    "# The path where the training and validation datasets are stored.\n",
+    "# The datasets should be in CONLL-u format.\n",
     "TRAINING_DATA_PATH: str = '/Users/majajablonska/Documents/PDBUD/train.conllu'\n",
     "VALIDATION_DATA_PATH: str = '/Users/majajablonska/Documents/PDBUD/val.conllu'\n",
-    "# The path where the model can be saved to\n",
-    "SERIALIZATION_DIR: str = \"/Users/majajablonska/Documents/Workspace/combotest\""
+    "\n",
+    "TESTING_DATA_PATH: str = ''\n",
+    "\n",
+    "# The directory where the model will be archived in.\n",
+    "SERIALIZATION_DIR: str = \"/Users/majajablonska/Documents/combo\""
    ],
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-13T12:15:21.197003Z",
-     "start_time": "2023-11-13T12:15:19.886422Z"
+     "end_time": "2023-11-19T07:40:13.624770Z",
+     "start_time": "2023-11-19T07:40:13.601998Z"
     }
    },
    "id": "b28c7d8bacb08d02"
   },
+  {
+   "cell_type": "markdown",
+   "source": [
+    "# Reading datasets\n",
+    "\n",
+    "Training and validation datasets are read using the ```DataLoader``` class.\n",
+    "```DataLoader``` can be created either from scratch or from a ```DatasetReader```, which is probably easier and the ```DatasetReader``` can be re-used for testing or inspecting the data as well."
+   ],
+   "metadata": {
+    "collapsed": false
+   },
+   "id": "57c3b4e46a63e0d2"
+  },
+  {
+   "cell_type": "markdown",
+   "source": [
+    "## DatasetReader\n",
+    "\n",
+    "The ```DatasetReader``` reads a file and outputs ```Instance``` objects.\n",
+    "It needs to have ```Tokenizer``` and ```TokenIndexer``` classes defined."
+   ],
+   "metadata": {
+    "collapsed": false
+   },
+   "id": "ae6cf89c1131ba71"
+  },
   {
    "cell_type": "code",
-   "execution_count": 2,
+   "execution_count": 18,
    "outputs": [],
    "source": [
-    "from combo.predict import COMBO\n",
-    "from combo.combo_model import ComboModel\n",
-    "from combo.data.vocabulary import Vocabulary\n",
-    "from combo.models.encoder import ComboEncoder, ComboStackedBidirectionalLSTM\n",
-    "from combo.modules.text_field_embedders import BasicTextFieldEmbedder\n",
-    "from combo.nn.base import Linear\n",
-    "from combo.modules.token_embedders import CharacterBasedWordEmbedder, TransformersWordEmbedder\n",
-    "from combo.modules import FeedForwardPredictor\n",
-    "from combo.nn.activations import ReLUActivation, TanhActivation, LinearActivation\n",
-    "from combo.models.dilated_cnn import DilatedCnnEncoder\n",
-    "from combo.data.tokenizers import LamboTokenizer, CharacterTokenizer\n",
-    "from combo.data.token_indexers import PretrainedTransformerIndexer, TokenConstPaddingCharactersIndexer, TokenFeatsIndexer, SingleIdTokenIndexer, PretrainedTransformerFixedMismatchedIndexer\n",
     "from combo.data.dataset_readers import UniversalDependenciesDatasetReader\n",
-    "import torch\n",
-    "from combo.data.dataset_loaders import SimpleDataLoader\n",
-    "from combo.modules.parser import DependencyRelationModel, HeadPredictionModel\n",
-    "from combo.modules.lemma import LemmatizerModel\n",
-    "from combo.modules.morpho import MorphologicalFeatures\n",
-    "from combo.nn.regularizers.regularizers import L2Regularizer\n",
-    "import pytorch_lightning as pl\n",
-    "from combo.training.trainable_combo import TrainableCombo\n",
-    "from itertools import chain"
+    "from combo.data.token_indexers import PretrainedTransformerIndexer, TokenConstPaddingCharactersIndexer, TokenFeatsIndexer, SingleIdTokenIndexer, PretrainedTransformerFixedMismatchedIndexer\n",
+    "from combo.data.tokenizers import LamboTokenizer, CharacterTokenizer"
    ],
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-13T12:15:28.665585Z",
-     "start_time": "2023-11-13T12:15:19.907198Z"
+     "end_time": "2023-11-19T07:51:42.290849Z",
+     "start_time": "2023-11-19T07:51:42.254199Z"
     }
    },
-   "id": "initial_id"
+   "id": "585dbe0fb61123b4"
   },
   {
    "cell_type": "code",
-   "execution_count": 3,
+   "execution_count": 19,
    "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
-      "To disable this warning, you can either:\n",
-      "\t- Avoid using `tokenizers` before the fork if possible\n",
-      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
+      "Using model LAMBO-UD_Polish-PDB\n"
      ]
-    },
+    }
+   ],
+   "source": [
+    "# We are going to use a pretrained transformer indexer class.\n",
+    "# Pretrained transformers are read from huggingface spaces, so we need a valid name.\n",
+    "# For this example, we are going to train a Polish model, therefore the HerBert model is used.\n",
+    "MODEL_NAME = 'allegro/herbert-base-cased'\n",
+    "\n",
+    "# We are using a Universal Dependencies Dataset Reader to read CONLL-u files\n",
+    "dataset_reader = UniversalDependenciesDatasetReader(\n",
+    "    # Features that are going to be indexed - tokens and characters will be independently\n",
+    "    # assigned their indexes\n",
+    "    features=[\"token\", \"char\"],\n",
+    "    # Lambo Tokenizer is the default tokenizer. We have trained Lambo models for various languages\n",
+    "    tokenizer=LamboTokenizer(\"Polish\"),\n",
+    "    lemma_indexers={\n",
+    "        \"char\": TokenConstPaddingCharactersIndexer(\n",
+    "            tokenizer=CharacterTokenizer(end_tokens=[\"__END__\"],\n",
+    "            start_tokens=[\"__START__\"]),\n",
+    "            min_padding_length=32,\n",
+    "            namespace=\"lemma_characters\"\n",
+    "        )\n",
+    "    },\n",
+    "    # Features that are going to be predicted.\n",
+    "    targets=[\"deprel\", \"head\", \"upostag\", \"lemma\", \"feats\", \"xpostag\"],\n",
+    "    token_indexers={\n",
+    "        \"char\": TokenConstPaddingCharactersIndexer(\n",
+    "            tokenizer=CharacterTokenizer(end_tokens=[\"__END__\"],\n",
+    "            start_tokens=[\"__START__\"]),\n",
+    "            min_padding_length=32\n",
+    "        ),\n",
+    "        \"feats\": TokenFeatsIndexer(),\n",
+    "        \"lemma\": TokenConstPaddingCharactersIndexer(\n",
+    "            tokenizer=CharacterTokenizer(end_tokens=[\"__END__\"],\n",
+    "            start_tokens=[\"__START__\"]),\n",
+    "            min_padding_length=32\n",
+    "        ),\n",
+    "        \"token\": PretrainedTransformerFixedMismatchedIndexer(MODEL_NAME),\n",
+    "        \"upostag\": SingleIdTokenIndexer(\n",
+    "            feature_name=\"pos_\",\n",
+    "            namespace=\"upostag\"\n",
+    "        ),\n",
+    "        \"xpostag\": SingleIdTokenIndexer(\n",
+    "            feature_name=\"tag_\",\n",
+    "            namespace=\"xpostag\"\n",
+    "        )\n",
+    "    },\n",
+    "    use_sem=False\n",
+    ")\n"
+   ],
+   "metadata": {
+    "collapsed": false,
+    "ExecuteTime": {
+     "end_time": "2023-11-19T07:51:43.160536Z",
+     "start_time": "2023-11-19T07:51:42.995675Z"
+    }
+   },
+   "id": "d74957f422f0b05b"
+  },
+  {
+   "cell_type": "markdown",
+   "source": [
+    "## DataLoader\n",
+    "\n",
+    "We can create a ```DataLoader``` from our dataset reader. It will let us iterate through batches of ```Instance``` objects."
+   ],
+   "metadata": {
+    "collapsed": false
+   },
+   "id": "af063d408b6b9a99"
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 20,
+   "outputs": [
     {
      "data": {
       "text/plain": "loading instances: 0it [00:00, ?it/s]",
       "application/vnd.jupyter.widget-view+json": {
        "version_major": 2,
        "version_minor": 0,
-       "model_id": "2179b1be2f484a33948a76d087002182"
+       "model_id": "32330372f8f4473bb8692eb5864ea738"
       }
      },
      "metadata": {},
@@ -89,19 +185,62 @@
       "application/vnd.jupyter.widget-view+json": {
        "version_major": 2,
        "version_minor": 0,
-       "model_id": "86762d681ee0467e8501de2b34061aad"
+       "model_id": "b759a042f43442e7a944faefba8c2728"
       }
      },
      "metadata": {},
      "output_type": "display_data"
-    },
+    }
+   ],
+   "source": [
+    "from combo.data import Vocabulary\n",
+    "from combo.data.dataset_loaders import SimpleDataLoader\n",
+    "\n",
+    "data_loader = SimpleDataLoader.from_dataset_reader(dataset_reader,\n",
+    "                                                   data_path=TRAINING_DATA_PATH,\n",
+    "                                                   batch_size=16,\n",
+    "                                                   batches_per_epoch=4,\n",
+    "                                                   shuffle=True)\n",
+    "val_data_loader = SimpleDataLoader.from_dataset_reader(dataset_reader,\n",
+    "                                                       data_path=VALIDATION_DATA_PATH,\n",
+    "                                                       batch_size=16,\n",
+    "                                                       batches_per_epoch=4,\n",
+    "                                                       shuffle=True)"
+   ],
+   "metadata": {
+    "collapsed": false,
+    "ExecuteTime": {
+     "end_time": "2023-11-19T07:53:48.707900Z",
+     "start_time": "2023-11-19T07:53:21.403165Z"
+    }
+   },
+   "id": "4c8f8fa3d94a30c7"
+  },
+  {
+   "cell_type": "markdown",
+   "source": [
+    "## Vocabulary\n",
+    "\n",
+    "A ```Vocabulary``` is a collection mapping tokens to indices. It is used to \"translate\" between the model output and the actual words we are\n",
+    "predicting.\n",
+    "In this case, we are going to build a ```Vocabulary``` from the training dataset."
+   ],
+   "metadata": {
+    "collapsed": false
+   },
+   "id": "8bd38739a23d474f"
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 21,
+   "outputs": [
     {
      "data": {
       "text/plain": "building vocabulary: 0it [00:00, ?it/s]",
       "application/vnd.jupyter.widget-view+json": {
        "version_major": 2,
        "version_minor": 0,
-       "model_id": "b9e631cb77594ea5aae60e6d15809885"
+       "model_id": "fe7e90f3d2de42f3a11ff33cd71f870f"
       }
      },
      "metadata": {},
@@ -109,57 +248,8 @@
     }
    ],
    "source": [
-    "def default_const_character_indexer(namespace = None):\n",
-    "    if namespace:\n",
-    "        return TokenConstPaddingCharactersIndexer(\n",
-    "            tokenizer=CharacterTokenizer(end_tokens=[\"__END__\"],\n",
-    "            start_tokens=[\"__START__\"]),\n",
-    "            min_padding_length=32,\n",
-    "            namespace=namespace\n",
-    "        )\n",
-    "    else:\n",
-    "        return TokenConstPaddingCharactersIndexer(\n",
-    "            tokenizer=CharacterTokenizer(end_tokens=[\"__END__\"],\n",
-    "            start_tokens=[\"__START__\"]),\n",
-    "            min_padding_length=32\n",
-    "        )\n",
-    "\n",
-    "dataset_reader = UniversalDependenciesDatasetReader(\n",
-    "    features=[\"token\", \"char\"],\n",
-    "    lemma_indexers={\n",
-    "        \"char\": default_const_character_indexer(\"lemma_characters\")\n",
-    "    },\n",
-    "    targets=[\"deprel\", \"head\", \"upostag\", \"lemma\", \"feats\", \"xpostag\"],\n",
-    "    token_indexers={\n",
-    "        \"char\": default_const_character_indexer(),\n",
-    "        \"feats\": TokenFeatsIndexer(),\n",
-    "        \"lemma\": default_const_character_indexer(),\n",
-    "        \"token\": PretrainedTransformerFixedMismatchedIndexer(\"bert-base-cased\"),\n",
-    "        \"upostag\": SingleIdTokenIndexer(\n",
-    "            feature_name=\"pos_\",\n",
-    "            namespace=\"upostag\"\n",
-    "        ),\n",
-    "        \"xpostag\": SingleIdTokenIndexer(\n",
-    "            feature_name=\"tag_\",\n",
-    "            namespace=\"xpostag\"\n",
-    "        )\n",
-    "    },\n",
-    "    use_sem=False\n",
-    ")\n",
-    "\n",
-    "data_loader = SimpleDataLoader.from_dataset_reader(dataset_reader,\n",
-    "                                                   data_path=TRAINING_DATA_PATH,\n",
-    "                                                   batch_size=16,\n",
-    "                                                   batches_per_epoch=4,\n",
-    "                                                   shuffle=True)\n",
-    "val_data_loader = SimpleDataLoader.from_dataset_reader(dataset_reader,\n",
-    "                                                   data_path=VALIDATION_DATA_PATH,\n",
-    "                                                   batch_size=16,\n",
-    "                                                    batches_per_epoch=4,\n",
-    "                                                    shuffle=True)\n",
-    "\n",
     "vocabulary = Vocabulary.from_instances_extended(\n",
-    "    chain(data_loader.iter_instances(), val_data_loader.iter_instances()),\n",
+    "    data_loader.iter_instances(),\n",
     "    non_padded_namespaces=['head_labels'],\n",
     "    only_include_pretrained_words=False,\n",
     "    oov_token='_',\n",
@@ -169,17 +259,29 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-13T12:15:51.717065Z",
-     "start_time": "2023-11-13T12:15:28.442131Z"
+     "end_time": "2023-11-19T07:54:45.394172Z",
+     "start_time": "2023-11-19T07:54:37.718909Z"
     }
    },
-   "id": "d74957f422f0b05b"
+   "id": "e9528f51dc64a80f"
+  },
+  {
+   "cell_type": "markdown",
+   "source": [
+    "## Word embeddings"
+   ],
+   "metadata": {
+    "collapsed": false
+   },
+   "id": "45029e32ff539ee7"
   },
   {
    "cell_type": "code",
    "execution_count": 4,
    "outputs": [],
    "source": [
+    "from combo.models import ComboEncoder, ComboStackedBidirectionalLSTM\n",
+    "\n",
     "seq_encoder = ComboEncoder(layer_dropout_probability=0.33,\n",
     "                           stacked_bilstm=ComboStackedBidirectionalLSTM(\n",
     "                               hidden_size=512,\n",
@@ -192,33 +294,21 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-13T12:15:52.574303Z",
-     "start_time": "2023-11-13T12:15:51.724469Z"
+     "end_time": "2023-11-19T05:01:11.804611Z",
+     "start_time": "2023-11-19T05:01:11.160755Z"
     }
    },
    "id": "fa724d362fd6bd23"
   },
   {
    "cell_type": "code",
-   "execution_count": 5,
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "Using model LAMBO-UD_English-EWT\n"
-     ]
-    },
-    {
-     "data": {
-      "text/plain": "<generator object SimpleDataLoader.iter_instances at 0x7fb512dc4f20>"
-     },
-     "execution_count": 5,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
+   "execution_count": 22,
+   "outputs": [],
    "source": [
+    "from combo.nn import LinearActivation, ReLUActivation\n",
+    "from combo.models.dilated_cnn import DilatedCnnEncoder\n",
+    "from combo.modules import CharacterBasedWordEmbedder\n",
+    "\n",
     "char_words_embedder = CharacterBasedWordEmbedder(\n",
     "    dilated_cnn_encoder = DilatedCnnEncoder(\n",
     "        input_dim=64,\n",
@@ -231,60 +321,44 @@
     "    ),\n",
     "    embedding_dim=64,\n",
     "    vocabulary=vocabulary\n",
-    ")\n",
-    "tokenizer = LamboTokenizer()\n",
-    "indexer = PretrainedTransformerIndexer('bert-base-cased')\n",
-    "data_loader.iter_instances()"
+    ")"
    ],
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-13T12:15:52.641199Z",
-     "start_time": "2023-11-13T12:15:52.583194Z"
+     "end_time": "2023-11-19T07:57:35.373107Z",
+     "start_time": "2023-11-19T07:57:35.305579Z"
     }
    },
    "id": "f8a10f9892005fca"
   },
   {
-   "cell_type": "code",
-   "execution_count": 6,
-   "outputs": [
-    {
-     "name": "stderr",
-     "output_type": "stream",
-     "text": [
-      "Directory /Users/majajablonska/PycharmProjects/combo-lightning/tests/fixtures/train_vocabulary is not empty\n"
-     ]
-    }
-   ],
+   "cell_type": "markdown",
    "source": [
-    "vocabulary.save_to_files('/Users/majajablonska/PycharmProjects/combo-lightning/tests/fixtures/train_vocabulary')"
+    "## COMBO model\n",
+    "\n",
+    "This is the main class, which is an actual model predicting the requested features.\n",
+    "\n",
+    "Various parameters can be tweaked here - this is an example architecture."
    ],
    "metadata": {
-    "collapsed": false,
-    "ExecuteTime": {
-     "end_time": "2023-11-13T12:15:52.659289Z",
-     "start_time": "2023-11-13T12:15:52.625700Z"
-    }
+    "collapsed": false
    },
-   "id": "14413692656b68ac"
+   "id": "82726af7ee13dece"
   },
   {
    "cell_type": "code",
-   "execution_count": 7,
-   "outputs": [
-    {
-     "name": "stderr",
-     "output_type": "stream",
-     "text": [
-      "Some weights of the model checkpoint at allegro/herbert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.sso.sso_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.sso.sso_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']\n",
-      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
-      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
-     ]
-    }
-   ],
+   "execution_count": 23,
+   "outputs": [],
    "source": [
-    "from nn import RegularizerApplicator\n",
+    "from combo.modules import TransformersWordEmbedder, FeedForwardPredictor\n",
+    "from combo.modules.text_field_embedders import BasicTextFieldEmbedder\n",
+    "from combo.modules.morpho import MorphologicalFeatures\n",
+    "from combo.modules.lemma import LemmatizerModel\n",
+    "from combo.modules.parser import DependencyRelationModel, HeadPredictionModel\n",
+    "from combo.nn import TanhActivation, L2Regularizer, RegularizerApplicator\n",
+    "from combo.nn.base import Linear\n",
+    "from combo.combo_model import ComboModel\n",
     "\n",
     "model = ComboModel(\n",
     "    vocabulary=vocabulary,\n",
@@ -384,7 +458,7 @@
     "                ),\n",
     "                embedding_dim=64\n",
     "            ),\n",
-    "            \"token\": TransformersWordEmbedder(\"allegro/herbert-base-cased\", projection_dim=100)\n",
+    "            \"token\": TransformersWordEmbedder(MODEL_NAME, projection_dim=100)\n",
     "        }\n",
     "    ),\n",
     "    upos_tagger=FeedForwardPredictor.from_vocab(\n",
@@ -411,50 +485,46 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-13T12:15:56.509687Z",
-     "start_time": "2023-11-13T12:15:52.658879Z"
+     "end_time": "2023-11-19T08:08:15.012836Z",
+     "start_time": "2023-11-19T08:08:12.154480Z"
     }
    },
    "id": "437d12054baaffa1"
   },
   {
-   "cell_type": "code",
-   "execution_count": 8,
-   "outputs": [],
+   "cell_type": "markdown",
    "source": [
-    "data_loader.index_with(vocabulary)\n",
-    "a = 0\n",
-    "for i in data_loader:\n",
-    "    break"
+    "# Training\n",
+    "\n",
+    "Data loaders need to be indexed with vocabularies.\n",
+    "\n",
+    "For training, the ```TrainableCombo``` class can be used. It's a ```PyTorch Lightning``` module and therefore all lightning functionalities can be used with this object - a ```Trainer```, various callbacks, etc."
    ],
    "metadata": {
-    "collapsed": false,
-    "ExecuteTime": {
-     "end_time": "2023-11-13T12:16:30.663344Z",
-     "start_time": "2023-11-13T12:15:56.529656Z"
-    }
+    "collapsed": false
    },
-   "id": "e131e0ec75dc6927"
+   "id": "c9b05beb65935af4"
   },
   {
    "cell_type": "code",
-   "execution_count": 9,
+   "execution_count": 24,
    "outputs": [],
    "source": [
+    "data_loader.index_with(vocabulary)\n",
     "val_data_loader.index_with(vocabulary)"
    ],
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-13T12:16:45.453326Z",
-     "start_time": "2023-11-13T12:16:30.488388Z"
+     "end_time": "2023-11-19T08:08:58.057064Z",
+     "start_time": "2023-11-19T08:08:18.882438Z"
     }
    },
-   "id": "195c71fcf8170ff"
+   "id": "e131e0ec75dc6927"
   },
   {
    "cell_type": "code",
-   "execution_count": 10,
+   "execution_count": 26,
    "outputs": [
     {
      "name": "stderr",
@@ -463,13 +533,15 @@
       "GPU available: False, used: False\n",
       "TPU available: False, using: 0 TPU cores\n",
       "IPU available: False, using: 0 IPUs\n",
-      "HPU available: False, using: 0 HPUs\n",
-      "/Users/majajablonska/miniconda/envs/combo/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:67: UserWarning: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default\n",
-      "  warning_cache.warn(\n"
+      "HPU available: False, using: 0 HPUs\n"
      ]
     }
    ],
    "source": [
+    "import pytorch_lightning as pl\n",
+    "import torch\n",
+    "from training.trainable_combo import TrainableCombo\n",
+    "\n",
     "nlp = TrainableCombo(model, torch.optim.Adam,\n",
     "                     optimizer_kwargs={'betas': [0.9, 0.9], 'lr': 0.002},\n",
     "                     validation_metrics=['EM'])\n",
@@ -481,15 +553,15 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-13T12:16:45.785538Z",
-     "start_time": "2023-11-13T12:16:45.365250Z"
+     "end_time": "2023-11-19T08:11:17.009697Z",
+     "start_time": "2023-11-19T08:11:16.794518Z"
     }
    },
    "id": "cefc5173154d1605"
   },
   {
    "cell_type": "code",
-   "execution_count": 11,
+   "execution_count": 27,
    "outputs": [
     {
      "name": "stderr",
@@ -503,7 +575,7 @@
       "12.1 M    Trainable params\n",
       "124 M     Non-trainable params\n",
       "136 M     Total params\n",
-      "546.106   Total estimated model params size (MB)\n"
+      "546.099   Total estimated model params size (MB)\n"
      ]
     },
     {
@@ -512,29 +584,19 @@
       "application/vnd.jupyter.widget-view+json": {
        "version_major": 2,
        "version_minor": 0,
-       "model_id": "f2dd3228246843428b8fcb8ae932c1f1"
+       "model_id": "bbbe0d3e735b4fd39ab8781bc53aa29f"
       }
      },
      "metadata": {},
      "output_type": "display_data"
     },
-    {
-     "name": "stderr",
-     "output_type": "stream",
-     "text": [
-      "/Users/majajablonska/miniconda/envs/combo/lib/python3.9/site-packages/pytorch_lightning/utilities/data.py:76: UserWarning: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 16. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n",
-      "  warning_cache.warn(\n",
-      "/Users/majajablonska/miniconda/envs/combo/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py:280: PossibleUserWarning: The number of training batches (4) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n",
-      "  rank_zero_warn(\n"
-     ]
-    },
     {
      "data": {
       "text/plain": "Training: 0it [00:00, ?it/s]",
       "application/vnd.jupyter.widget-view+json": {
        "version_major": 2,
        "version_minor": 0,
-       "model_id": "0bcdd388df664784ba19667c6a0593a1"
+       "model_id": "ae7d24da413145839bfe44dd07da1917"
       }
      },
      "metadata": {},
@@ -546,7 +608,7 @@
       "application/vnd.jupyter.widget-view+json": {
        "version_major": 2,
        "version_minor": 0,
-       "model_id": "a8203342bf454c22b292548d64f085a9"
+       "model_id": "ce585ba3c2614e889d385228037ca757"
       }
      },
      "metadata": {},
@@ -566,60 +628,62 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-13T12:17:47.659618Z",
-     "start_time": "2023-11-13T12:16:45.706948Z"
+     "end_time": "2023-11-19T08:11:55.217182Z",
+     "start_time": "2023-11-19T08:11:19.767347Z"
     }
    },
    "id": "e5af131bae4b1a33"
   },
   {
-   "cell_type": "code",
-   "execution_count": 12,
-   "outputs": [],
+   "cell_type": "markdown",
    "source": [
-    "predictor = COMBO(model, dataset_reader)"
+    "# Prediction\n",
+    "\n",
+    "The ```COMBO``` class combines a data parser (```DatasetReader```) and the trained model."
    ],
    "metadata": {
-    "collapsed": false,
-    "ExecuteTime": {
-     "end_time": "2023-11-13T12:17:47.975345Z",
-     "start_time": "2023-11-13T12:17:47.644327Z"
-    }
+    "collapsed": false
    },
-   "id": "3e23413c86063183"
+   "id": "dc11e1be6b4dce0c"
   },
   {
    "cell_type": "code",
-   "execution_count": 13,
+   "execution_count": 28,
    "outputs": [],
    "source": [
-    "a = predictor(\"Cześć, jestem psem.\")"
+    "from predict import COMBO\n",
+    "\n",
+    "predictor = COMBO(model, dataset_reader)"
    ],
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-13T12:17:47.989681Z",
-     "start_time": "2023-11-13T12:17:47.665490Z"
+     "end_time": "2023-11-19T08:12:07.916052Z",
+     "start_time": "2023-11-19T08:12:07.851797Z"
     }
    },
-   "id": "d555d7f0223a624b"
+   "id": "3e23413c86063183"
   },
   {
    "cell_type": "code",
-   "execution_count": 14,
+   "execution_count": 30,
    "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
       "TOKEN           LEMMA           UPOS       HEAD       DEPREL    \n",
-      "Cześć,          ??????          NOUN                0 root      \n",
-      "jestem          ?????a          NOUN                1 punct     \n",
-      "psem.           ?????           NOUN                1 punct     \n"
+      "Cześć           ?????           NOUN                5 punct     \n",
+      ",               ???             NOUN                0 root      \n",
+      "jestem          ??????          NOUN                2 punct     \n",
+      "psem            ????            NOUN                3 punct     \n",
+      ".               ???             NOUN                4 punct     \n"
      ]
     }
    ],
    "source": [
+    "a = predictor(\"Cześć, jestem psem.\")\n",
+    "\n",
     "print(\"{:15} {:15} {:10} {:10} {:10}\".format('TOKEN', 'LEMMA', 'UPOS', 'HEAD', 'DEPREL'))\n",
     "for token in a.tokens:\n",
     "    print(\"{:15} {:15} {:10} {:10} {:10}\".format(token.text, token.lemma, token.upostag, token.head, token.deprel))"
@@ -627,15 +691,29 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-13T12:17:48.005229Z",
-     "start_time": "2023-11-13T12:17:47.923055Z"
+     "end_time": "2023-11-19T08:12:10.492815Z",
+     "start_time": "2023-11-19T08:12:10.442528Z"
     }
    },
    "id": "a68cd3861e1ceb67"
   },
+  {
+   "cell_type": "markdown",
+   "source": [
+    " # Model archivisation\n",
+    " \n",
+    "The model can be archived and used for further training and prediction.\n",
+    "\n",
+    "Dataset Readers and Data Loaders don't have to be archived with the model, but this can greatly simplify using the model later and is therefore recommended."
+   ],
+   "metadata": {
+    "collapsed": false
+   },
+   "id": "21fe424c920956d5"
+  },
   {
    "cell_type": "code",
-   "execution_count": 15,
+   "execution_count": 31,
    "outputs": [],
    "source": [
     "from modules.archival import archive"
@@ -643,8 +721,8 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-13T12:17:48.008545Z",
-     "start_time": "2023-11-13T12:17:47.928808Z"
+     "end_time": "2023-11-19T08:12:13.159955Z",
+     "start_time": "2023-11-19T08:12:13.129800Z"
     }
    },
    "id": "d0f43f4493218b5"
@@ -663,32 +741,16 @@
     }
    ],
    "source": [
-    "archive(model, '/Users/majajablonska/Documents/combo', data_loader, val_data_loader, dataset_reader)"
+    "archive(model, SERIALIZATION_DIR, data_loader, val_data_loader, dataset_reader)"
    ],
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2023-11-13T12:19:17.944519Z",
-     "start_time": "2023-11-13T12:17:47.965095Z"
+     "end_time": "2023-11-19T05:03:47.155352Z",
+     "start_time": "2023-11-19T05:02:17.972017Z"
     }
    },
    "id": "ec92aa5bb5bb3605"
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 16,
-   "outputs": [],
-   "source": [
-    "\n"
-   ],
-   "metadata": {
-    "collapsed": false,
-    "ExecuteTime": {
-     "end_time": "2023-11-13T12:19:17.954324Z",
-     "start_time": "2023-11-13T12:19:17.920401Z"
-    }
-   },
-   "id": "5ad8a827586f65e3"
   }
  ],
  "metadata": {
diff --git a/tests/data/tokenizers/test_spacy_tokenizer.py b/tests/data/tokenizers/test_spacy_tokenizer.py
index 67aaead..e0e1c6e 100644
--- a/tests/data/tokenizers/test_spacy_tokenizer.py
+++ b/tests/data/tokenizers/test_spacy_tokenizer.py
@@ -17,9 +17,3 @@ class SpacyTokenizerTest(unittest.TestCase):
         tokens = self.spacy_tokenizer.tokenize('')
         self.assertEqual(len(tokens), 0)
 
-    # def test_batch_tokenize_sentence(self):
-    #     tokens = self.spacy_tokenizer.batch_tokenize(['First sentence!', 'This is the second sentence.'])
-    #     self.assertListEqual([t.text for t in tokens[0]],
-    #                          ['First', 'sentence', '!'])
-    #     self.assertListEqual([t.text for t in tokens[1]],
-    #                          ['This', 'is', 'the', 'second', 'sentence', '.'])
-- 
GitLab