diff --git a/combo/models/lemma.py b/combo/models/lemma.py
index 77254f6d3dd9ce4828c7ef54d947f17bf5c46d8d..d724a1ecb9c22610fc6ac56493929178d7a6cd5a 100644
--- a/combo/models/lemma.py
+++ b/combo/models/lemma.py
@@ -1,5 +1,107 @@
-from combo.models.base import Predictor
+from typing import Optional, Dict, List, Union
+
+import torch
+import torch.nn as nn
+
+from combo import data
+from combo.models import dilated_cnn, base, utils
+from combo.models.base import Predictor, TimeDistributed
+from combo.models.combo_nn import Activation
+from combo.utils import ConfigurationError
 
 
 class LemmatizerModel(Predictor):
-    pass
+    """Lemmatizer model."""
+
+    def __init__(self,
+                 num_embeddings: int,
+                 embedding_dim: int,
+                 dilated_cnn_encoder: dilated_cnn.DilatedCnnEncoder,
+                 input_projection_layer: base.Linear):
+        super().__init__()
+        self.char_embed = nn.Embedding(
+            num_embeddings=num_embeddings,
+            embedding_dim=embedding_dim,
+        )
+        self.dilated_cnn_encoder = TimeDistributed(dilated_cnn_encoder)
+        self.input_projection_layer = input_projection_layer
+
+    def forward(self,
+                x: Union[torch.Tensor, List[torch.Tensor]],
+                mask: Optional[torch.BoolTensor] = None,
+                labels: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
+                sample_weights: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]:
+        encoder_emb, chars = x
+
+        encoder_emb = self.input_projection_layer(encoder_emb)
+        char_embeddings = self.char_embed(chars)
+
+        BATCH_SIZE, _, MAX_WORD_LENGTH, CHAR_EMB = char_embeddings.size()
+        encoder_emb = encoder_emb.unsqueeze(2).repeat(1, 1, MAX_WORD_LENGTH, 1)
+
+        x = torch.cat((char_embeddings, encoder_emb), dim=-1).transpose(2, 3)
+        x = self.dilated_cnn_encoder(x).transpose(2, 3)
+        output = {
+            "prediction": x.argmax(-1),
+            "probability": x
+        }
+
+        if labels is not None:
+            if mask is None:
+                mask = encoder_emb.new_ones(encoder_emb.size()[:-2])
+            if sample_weights is None:
+                sample_weights = labels.new_ones(BATCH_SIZE)
+            mask = mask.unsqueeze(2).repeat(1, 1, MAX_WORD_LENGTH).bool()
+            output["loss"] = self._loss(x, labels, mask, sample_weights)
+
+        return output
+
+    @staticmethod
+    def _loss(pred: torch.Tensor, true: torch.Tensor, mask: torch.BoolTensor,
+              sample_weights: torch.Tensor) -> torch.Tensor:
+        BATCH_SIZE, SENTENCE_LENGTH, MAX_WORD_LENGTH, CHAR_CLASSES = pred.size()
+        pred = pred.reshape(-1, CHAR_CLASSES)
+
+        true = true.reshape(-1)
+        mask = true.gt(0)
+        loss = utils.masked_cross_entropy(pred, true, mask)
+        loss = loss.reshape(BATCH_SIZE, -1) * sample_weights.unsqueeze(-1)
+        valid_positions = mask.sum()
+        return loss.sum() / valid_positions
+
+    @classmethod
+    def from_vocab(cls,
+                   vocab: data.Vocabulary,
+                   char_vocab_namespace: str,
+                   lemma_vocab_namespace: str,
+                   embedding_dim: int,
+                   input_projection_layer: base.Linear,
+                   filters: List[int],
+                   kernel_size: List[int],
+                   stride: List[int],
+                   padding: List[int],
+                   dilation: List[int],
+                   activations: List[Activation],
+                   ):
+        assert char_vocab_namespace in vocab.get_namespaces()
+        assert lemma_vocab_namespace in vocab.get_namespaces()
+
+        if len(filters) + 1 != len(kernel_size):
+            raise ConfigurationError(
+                f"len(filters) ({len(filters):d}) + 1 != kernel_size ({len(kernel_size):d})"
+            )
+        filters = filters + [vocab.get_vocab_size(lemma_vocab_namespace)]
+
+        dilated_cnn_encoder = dilated_cnn.DilatedCnnEncoder(
+            input_dim=embedding_dim + input_projection_layer.get_output_dim(),
+            filters=filters,
+            kernel_size=kernel_size,
+            stride=stride,
+            padding=padding,
+            dilation=dilation,
+            activations=activations,
+        )
+        return cls(num_embeddings=vocab.get_vocab_size(char_vocab_namespace),
+                   embedding_dim=embedding_dim,
+                   dilated_cnn_encoder=dilated_cnn_encoder,
+                   input_projection_layer=input_projection_layer)