diff --git a/combo/config.multitask.template.jsonnet b/combo/config.multitask.template.jsonnet
index a977819eadb339d94f2f59f2421c103c6d82667e..d7699796e2b9f154b8b02e050845cad5ba97e83a 100644
--- a/combo/config.multitask.template.jsonnet
+++ b/combo/config.multitask.template.jsonnet
@@ -199,7 +199,7 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
     data_loader: {
         type: "multitask",
         scheduler: {
-            batch_size: 20
+            batch_size: 100
         },
         shuffle: true,
 //        batch_sampler: {
@@ -384,16 +384,13 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
                 },
             },
             iob: {
-                type: "ner_head",
-                feedforward_predictor: {
-                    type: "feedforward_predictor_from_vocab",
-                    input_dim: hidden_size * 2,
-                    hidden_dims: [128],
-                    activations: ["tanh", "linear"],
-                    dropout: [predictors_dropout, 0.0],
-                    num_layers: 2,
-                    vocab_namespace: "ner_labels"
-                }
+                type: "ner_head_from_vocab",
+                input_dim: hidden_size * 2,
+                hidden_dims: [128],
+                activations: ["tanh", "linear"],
+                dropout: [predictors_dropout, 0.0],
+                num_layers: 2,
+                vocab_namespace: "ner_labels"
             },
         },
     }),
diff --git a/combo/models/model.py b/combo/models/model.py
index dd3629ad803fdb79662a7346340248df85230c1e..a3db919d6f7cc2fd93942ccc4a8df6eb3503bc85 100644
--- a/combo/models/model.py
+++ b/combo/models/model.py
@@ -1,10 +1,11 @@
 """Main COMBO model."""
-from typing import Optional, Dict, Any, List, Union
+from typing import Optional, Dict, Any, List, Union, cast
 
 import torch
 from allennlp import data, modules, nn as allen_nn
+from allennlp.common import checks
 from allennlp.models import heads
-from allennlp.modules import text_field_embedders
+from allennlp.modules import text_field_embedders, conditional_random_field as crf
 from allennlp.nn import util
 from allennlp.training import metrics as allen_metrics
 from overrides import overrides
@@ -31,33 +32,86 @@ class ComboBackbone(modules.Backbone):
                     char_mask=char_mask)
 
 
-@heads.Head.register("ner_head")
+@heads.Head.register("ner_head_from_vocab", constructor="from_vocab")
 class NERModel(heads.Head):
+    """Based on AllenNLP-models CrfTagger."""
 
-    def __init__(self, feedforward_predictor: base.Predictor, vocab: data.Vocabulary):
+    def __init__(self, feedforward_network: modules.FeedForward, vocab: data.Vocabulary,
+                 label_namespace: str = "ner_labels",
+                 include_start_end_transitions: bool = True,
+                 label_encoding: str = "IOB1"):
         super().__init__(vocab)
-        self.feedforward_predictor = feedforward_predictor
+        self._feedforward_network = feedforward_network
+
+        labels = self.vocab.get_index_to_token_vocabulary(label_namespace)
+        constraints = crf.allowed_transitions(label_encoding, labels)
+        self.include_start_end_transitions = include_start_end_transitions
+        self.num_tags = self.vocab.get_vocab_size(label_namespace)
+        self._crf = modules.ConditionalRandomField(
+            self.num_tags, constraints, include_start_end_transitions=include_start_end_transitions
+        )
+
         self._accuracy_metric = allen_metrics.CategoricalAccuracy()
-        self._f1_metric = allen_metrics.SpanBasedF1Measure(vocab, tag_namespace="ner_labels", label_encoding="IOB1")
+        self._f1_metric = allen_metrics.SpanBasedF1Measure(vocab,
+                                                           tag_namespace=label_namespace,
+                                                           label_encoding=label_encoding)
         self._loss = 0.0
 
     def forward(self,
                 encoder_emb: Union[torch.Tensor, List[torch.Tensor]],
                 word_mask: Optional[torch.BoolTensor] = None,
                 tags: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]:
-        output = self.feedforward_predictor(
-            x=encoder_emb,
-            mask=word_mask,
-            labels=tags,
-        )
+        batch_size = encoder_emb.size(0)
+        logits = self._feedforward_network(encoder_emb)
+        best_paths = self._crf.viterbi_tags(logits, word_mask, top_k=1)
+        predicted_tags = cast(List[List[int]], [x[0][0] for x in best_paths])
 
+        output = {"tags": predicted_tags}
         if tags is not None:
+            log_likelihood: torch.Tensor = self._crf(logits, tags, word_mask) / batch_size  # Mean instead of sum.
+            output["loss"] = -log_likelihood
+            class_probabilities = logits * 0.0
+            for i, instance_tags in enumerate(predicted_tags):
+                for j, tag_id in enumerate(instance_tags):
+                    class_probabilities[i, j, tag_id] = 1
+
             self._loss = output["loss"].cpu().item()
-            self._accuracy_metric(output["probability"], tags, word_mask)
-            self._f1_metric(output["probability"], tags, word_mask)
+            self._accuracy_metric(class_probabilities, tags, word_mask)
+            self._f1_metric(class_probabilities, tags, word_mask)
 
         return output
 
+    @classmethod
+    def from_vocab(cls,
+                   vocab: data.Vocabulary,
+                   vocab_namespace: str,
+                   input_dim: int,
+                   num_layers: int,
+                   hidden_dims: List[int],
+                   activations: Union[allen_nn.Activation, List[allen_nn.Activation]],
+                   dropout: Union[float, List[float]] = 0.0,
+                   ):
+        if len(hidden_dims) + 1 != num_layers:
+            raise checks.ConfigurationError(
+                f"len(hidden_dims) ({len(hidden_dims):d}) + 1 != num_layers ({num_layers:d})"
+            )
+
+        assert vocab_namespace in vocab.get_namespaces(), \
+            f"There is not {vocab_namespace} in created vocabs ({','.join(vocab.get_namespaces())}), " \
+            f"check if this field has any values to predict!"
+        hidden_dims = hidden_dims + [vocab.get_vocab_size(vocab_namespace)]
+        return cls(
+            feedforward_network=modules.FeedForward(
+                input_dim=input_dim,
+                num_layers=num_layers,
+                hidden_dims=hidden_dims,
+                activations=activations,
+                dropout=dropout,
+            ),
+            label_namespace=vocab_namespace,
+            vocab=vocab
+        )
+
     @overrides
     def get_metrics(self, reset: bool = False) -> Dict[str, float]:
         metrics_ = {"accuracy": self._accuracy_metric.get_metric(reset), "loss": self._loss}