diff --git a/combo/config.multitask.template.jsonnet b/combo/config.multitask.template.jsonnet
new file mode 100644
index 0000000000000000000000000000000000000000..5edff14362062be0adccc65474067c1f43a7d791
--- /dev/null
+++ b/combo/config.multitask.template.jsonnet
@@ -0,0 +1,404 @@
+# Configuration file for jointly training a model using CoNNL-U and IOB.
+local shared_config = import "config.shared.libsonnet";
+########################################################################################
+#                                 BASIC configuration                                  #
+########################################################################################
+# Training data path, str
+# Must be in CONNLU format (or it's extended version with semantic relation field).
+# Can accepted multiple paths when concatenated with ',', "path1,path2"
+local training_data_path = std.extVar("training_data_path");
+# Validation data path, str
+# Can accepted multiple paths when concatenated with ',', "path1,path2"
+local validation_data_path = if std.length(std.extVar("validation_data_path")) > 0 then std.extVar("validation_data_path");
+# Path to pretrained tokens, str or null
+local pretrained_tokens = if std.length(std.extVar("pretrained_tokens")) > 0 then std.extVar("pretrained_tokens");
+# Name of pretrained transformer model, str or null
+local pretrained_transformer_name = if std.length(std.extVar("pretrained_transformer_name")) > 0 then std.extVar("pretrained_transformer_name");
+# Learning rate value, float
+local learning_rate = 0.002;
+# Number of epochs, int
+local num_epochs = std.parseInt(std.extVar("num_epochs"));
+# Cuda device id, -1 for cpu, int
+local cuda_device = std.parseInt(std.extVar("cuda_device"));
+# Minimum number of words in batch, int
+local word_batch_size = std.parseInt(std.extVar("word_batch_size"));
+# Features used as input, list of str
+# Choice "upostag", "xpostag", "lemma"
+# Required "token", "char"
+local features = std.split(std.extVar("features"), " ");
+# Targets of the model, list of str
+# Choice "feats", "lemma", "upostag", "xpostag", "semrel". "sent"
+# Required "deprel", "head"
+local targets = std.split(std.extVar("targets"), " ");
+# Word embedding dimension, int
+# If pretrained_tokens is not null must much provided dimensionality
+local embedding_dim = std.parseInt(std.extVar("embedding_dim"));
+# Dropout rate on predictors, float
+# All of the models on top of the encoder uses this dropout
+local predictors_dropout = 0.25;
+# Xpostag embedding dimension, int
+# (discarded if xpostag not in features)
+local xpostag_dim = 32;
+# Upostag embedding dimension, int
+# (discarded if upostag not in features)
+local upostag_dim = 32;
+# Feats embedding dimension, int
+# (discarded if feats not in featres)
+local feats_dim = 32;
+# Lemma embedding dimension, int
+# (discarded if lemma not in features)
+local lemma_char_dim = 64;
+# Character embedding dim, int
+local char_dim = 64;
+# Word embedding projection dim, int
+local projected_embedding_dim = 100;
+# Loss weights, dict[str, int]
+local loss_weights = {
+    xpostag: 0.05,
+    upostag: 0.05,
+    lemma: 0.05,
+    feats: 0.2,
+    deprel: 0.8,
+    head: 0.2,
+    semrel: 0.05,
+};
+# Encoder hidden size, int
+local hidden_size = 512;
+# Number of layers in the encoder, int
+local num_layers = 2;
+# Cycle loss iterations, int
+local cycle_loss_n = 0;
+# Maximum length of the word, int
+# Shorter words are padded, longer - truncated
+local word_length = 30;
+# Whether to use tensorboard, bool
+local use_tensorboard = if std.extVar("use_tensorboard") == "True" then true else false;
+
+# Helper functions
+local in_features(name) = !(std.length(std.find(name, features)) == 0);
+local in_targets(name) = !(std.length(std.find(name, targets)) == 0);
+local use_transformer = pretrained_transformer_name != null;
+
+# Verify some configuration requirements
+assert in_features("token"): "Key 'token' must be in features!";
+assert in_features("char"): "Key 'char' must be in features!";
+
+assert in_targets("deprel"): "Key 'deprel' must be in targets!";
+assert in_targets("head"): "Key 'head' must be in targets!";
+
+assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't use pretrained tokens and pretrained transformer at the same time!";
+
+########################################################################################
+#                              ADVANCED configuration                                  #
+########################################################################################
+
+# Detailed dataset, training, vocabulary and model configuration.
+{
+    # Configuration type (default or finetuning), str
+    type: std.extVar('type'),
+    # Datasets used for vocab creation, list of str
+    # Choice "train", "valid"
+    datasets_for_vocab_creation: ['train'],
+    # Path to training data, str
+    train_data_path: {
+        conllu: training_data_path,
+        # TODO Add configuration
+        iob: "./data/nkjp-nested-simplified-v2.fixed.iob",
+    },
+    # Path to validation data, str
+    validation_data_path: validation_data_path,
+    # Dataset reader configuration (conllu format)
+    dataset_reader: {
+        type: "multitask",
+        readers: {
+            "conllu": {
+                type: "conllu",
+                features: features,
+                targets: targets,
+                # Whether data contains semantic relation field, bool
+                use_sem: if in_targets("semrel") then true else false,
+                token_indexers: {
+                    token: if use_transformer then {
+                        type: "pretrained_transformer_mismatched",
+                        model_name: pretrained_transformer_name,
+                    } else {
+                        # SingleIdTokenIndexer, token as single int
+                        type: "single_id",
+                    },
+                    upostag: {
+                        type: "single_id",
+                        namespace: "upostag",
+                        feature_name: "pos_",
+                    },
+                    xpostag: {
+                        type: "single_id",
+                        namespace: "xpostag",
+                        feature_name: "tag_",
+                    },
+                    lemma: {
+                        type: "characters_const_padding",
+                        character_tokenizer: {
+                            start_tokens: ["__START__"],
+                            end_tokens: ["__END__"],
+                        },
+                        # +2 for start and end token
+                        min_padding_length: word_length + 2,
+                    },
+                    char: {
+                        type: "characters_const_padding",
+                        character_tokenizer: {
+                            start_tokens: ["__START__"],
+                            end_tokens: ["__END__"],
+                        },
+                        # +2 for start and end token
+                        min_padding_length: word_length + 2,
+                    },
+                    feats: {
+                        type: "feats_indexer",
+                    },
+                },
+                lemma_indexers: {
+                    char: {
+                        type: "characters_const_padding",
+                        namespace: "lemma_characters",
+                        character_tokenizer: {
+                            start_tokens: ["__START__"],
+                            end_tokens: ["__END__"],
+                        },
+                        # +2 for start and end token
+                        min_padding_length: word_length + 2,
+                    },
+                },
+            },
+            iob: {
+                type: "iob",
+                tag_label: "ner",
+                label_namespace: "ner_labels",
+                token_indexers: {
+                    token: if use_transformer then {
+                        type: "pretrained_transformer_mismatched",
+                        model_name: pretrained_transformer_name,
+                    } else {
+                        # SingleIdTokenIndexer, token as single int
+                        type: "single_id",
+                    },
+                    char: {
+                        type: "characters_const_padding",
+                        character_tokenizer: {
+                            start_tokens: ["__START__"],
+                            end_tokens: ["__END__"],
+                        },
+                        # +2 for start and end token
+                        min_padding_length: word_length + 2,
+                    },
+                },
+            },
+        },
+    },
+    # Data loader configuration
+    data_loader: {
+        type: "multitask",
+        scheduler: {
+            batch_size: 10
+        },
+        shuffle: true,
+//        batch_sampler: {
+//            type: "token_count",
+//            word_batch_size: word_batch_size,
+//        },
+    },
+    # Vocabulary configuration
+    vocabulary: std.prune({
+        type: "from_instances_extended",
+        only_include_pretrained_words: true,
+        pretrained_files: {
+            tokens: pretrained_tokens,
+        },
+        oov_token: "_",
+        padding_token: "__PAD__",
+        non_padded_namespaces: ["head_labels"],
+    }),
+    model: std.prune({
+        type: "multitask_extended",
+        backbone: {
+            type: "combo_backbone",
+            text_field_embedder: {
+                type: "basic",
+                token_embedders: {
+                    xpostag: if in_features("xpostag") then {
+                        type: "embedding",
+                        padding_index: 0,
+                        embedding_dim: xpostag_dim,
+                        vocab_namespace: "xpostag",
+                    },
+                    upostag: if in_features("upostag") then {
+                        type: "embedding",
+                        padding_index: 0,
+                        embedding_dim: upostag_dim,
+                        vocab_namespace: "upostag",
+                    },
+                    token: if use_transformer then {
+                        type: "transformers_word_embeddings",
+                        model_name: pretrained_transformer_name,
+                        projection_dim: projected_embedding_dim,
+                    } else {
+                        type: "embeddings_projected",
+                        embedding_dim: embedding_dim,
+                        projection_layer: {
+                            in_features: embedding_dim,
+                            out_features: projected_embedding_dim,
+                            dropout_rate: 0.25,
+                            activation: "tanh"
+                        },
+                        vocab_namespace: "tokens",
+                        pretrained_file: pretrained_tokens,
+                        trainable: if pretrained_tokens == null then true else false,
+                    },
+                    char: {
+                        type: "char_embeddings_from_config",
+                        embedding_dim: char_dim,
+                        dilated_cnn_encoder: {
+                            input_dim: char_dim,
+                            filters: [512, 256, char_dim],
+                            kernel_size: [3, 3, 3],
+                            stride: [1, 1, 1],
+                            padding: [1, 2, 4],
+                            dilation: [1, 2, 4],
+                            activations: ["relu", "relu", "linear"],
+                        },
+                    },
+                    lemma: if in_features("lemma") then {
+                        type: "char_embeddings_from_config",
+                        embedding_dim: lemma_char_dim,
+                        dilated_cnn_encoder: {
+                            input_dim: lemma_char_dim,
+                            filters: [512, 256, lemma_char_dim],
+                            kernel_size: [3, 3, 3],
+                            stride: [1, 1, 1],
+                            padding: [1, 2, 4],
+                            dilation: [1, 2, 4],
+                            activations: ["relu", "relu", "linear"],
+                        },
+                    },
+                    feats: if in_features("feats") then {
+                        type: "feats_embedding",
+                        padding_index: 0,
+                        embedding_dim: feats_dim,
+                        vocab_namespace: "feats",
+                    },
+                },
+            },
+            seq_encoder: {
+                type: "combo_encoder",
+                layer_dropout_probability: 0.33,
+                stacked_bilstm: {
+                    input_size:
+                    (char_dim + projected_embedding_dim +
+                    (if in_features('xpostag') then xpostag_dim else 0) +
+                    (if in_features('lemma') then lemma_char_dim else 0) +
+                    (if in_features('upostag') then upostag_dim else 0) +
+                    (if in_features('feats') then feats_dim else 0)),
+                    hidden_size: hidden_size,
+                    num_layers: num_layers,
+                    recurrent_dropout_probability: 0.33,
+                    layer_dropout_probability: 0.33
+                },
+            }
+        },
+        heads: {
+            conllu: {
+                type: "semantic_multitask_head",
+                loss_weights: loss_weights,
+                dependency_relation: {
+                    type: "combo_dependency_parsing_from_vocab",
+                    vocab_namespace: 'deprel_labels',
+                    head_predictor: {
+                        local projection_dim = 512,
+                        cycle_loss_n: cycle_loss_n,
+                        head_projection_layer: {
+                            in_features: hidden_size * 2,
+                            out_features: projection_dim,
+                            activation: "tanh",
+                        },
+                        dependency_projection_layer: {
+                            in_features: hidden_size * 2,
+                            out_features: projection_dim,
+                            activation: "tanh",
+                        },
+                    },
+                    local projection_dim = 128,
+                    head_projection_layer: {
+                        in_features: hidden_size * 2,
+                        out_features: projection_dim,
+                        dropout_rate: predictors_dropout,
+                        activation: "tanh"
+                    },
+                    dependency_projection_layer: {
+                        in_features: hidden_size * 2,
+                        out_features: projection_dim,
+                        dropout_rate: predictors_dropout,
+                        activation: "tanh"
+                    },
+                },
+                morphological_feat: if in_targets("feats") then {
+                    type: "combo_morpho_from_vocab",
+                    vocab_namespace: "feats_labels",
+                    input_dim: hidden_size * 2,
+                    hidden_dims: [128],
+                    activations: ["tanh", "linear"],
+                    dropout: [predictors_dropout, 0.0],
+                    num_layers: 2,
+                },
+                lemmatizer: if in_targets("lemma") then shared_config.Model.lemma(hidden_size, predictors_dropout),
+                upos_tagger: if in_targets("upostag") then {
+                    input_dim: hidden_size * 2,
+                    hidden_dims: [64],
+                    activations: ["tanh", "linear"],
+                    dropout: [predictors_dropout, 0.0],
+                    num_layers: 2,
+                    vocab_namespace: "upostag_labels"
+                },
+                xpos_tagger: if in_targets("xpostag") then {
+                    input_dim: hidden_size * 2,
+                    hidden_dims: [128],
+                    activations: ["tanh", "linear"],
+                    dropout: [predictors_dropout, 0.0],
+                    num_layers: 2,
+                    vocab_namespace: "xpostag_labels"
+                },
+                semantic_relation: if in_targets("semrel") then {
+                    input_dim: hidden_size * 2,
+                    hidden_dims: [64],
+                    activations: ["tanh", "linear"],
+                    dropout: [predictors_dropout, 0.0],
+                    num_layers: 2,
+                    vocab_namespace: "semrel_labels"
+                },
+                regularizer: {
+                    regexes: [
+                        [".*conv1d.*", {type: "l2", alpha: 1e-6}],
+                        [".*forward.*", {type: "l2", alpha: 1e-6}],
+                        [".*backward.*", {type: "l2", alpha: 1e-6}],
+                        [".*char_embed.*", {type: "l2", alpha: 1e-5}],
+                    ],
+                },
+            },
+            iob: {
+                type: "ner_head",
+                feedforward_predictor: {
+                    type: "feedforward_predictor_from_vocab",
+                    input_dim: hidden_size * 2,
+                    hidden_dims: [128],
+                    activations: ["tanh", "linear"],
+                    dropout: [predictors_dropout, 0.0],
+                    num_layers: 2,
+                    vocab_namespace: "ner_labels"
+                }
+            },
+        },
+    }),
+    trainer: shared_config.Trainer(cuda_device, num_epochs, learning_rate, use_tensorboard),
+    random_seed: 8787,
+    pytorch_seed: 8787,
+    numpy_seed: 8787,
+}
diff --git a/combo/config.shared.libsonnet b/combo/config.shared.libsonnet
index 23d3f9a8717dc35bf3a33bd112d13de6daab13dc..bb262f6a64cd55e00df567ef5484930e679dd9c9 100644
--- a/combo/config.shared.libsonnet
+++ b/combo/config.shared.libsonnet
@@ -25,5 +25,27 @@
             },
             validation_metric: "+EM",
         }),
-    Trainer: trainer
+
+    local lemma(hidden_size, dropout) = {
+        type: "combo_lemma_predictor_from_vocab",
+        char_vocab_namespace: "token_characters",
+        lemma_vocab_namespace: "lemma_characters",
+        embedding_dim: 256,
+        input_projection_layer: {
+            in_features: hidden_size * 2,
+            out_features: 32,
+            dropout_rate: dropout,
+            activation: "tanh"
+        },
+        filters: [256, 256, 256],
+        kernel_size: [3, 3, 3, 1],
+        stride: [1, 1, 1, 1],
+        padding: [1, 2, 4, 0],
+        dilation: [1, 2, 4, 1],
+        activations: ["relu", "relu", "relu", "linear"],
+    },
+    Trainer: trainer,
+    Model: {
+        lemma: lemma,
+    }
 }
\ No newline at end of file
diff --git a/combo/data/dataset.py b/combo/data/dataset.py
index 870659fe27857d157e5061f77690f61f166330a7..6dc30b66315956de03495ad942c68ce0a3a33cf9 100644
--- a/combo/data/dataset.py
+++ b/combo/data/dataset.py
@@ -1,4 +1,5 @@
 import copy
+import itertools
 import logging
 import pathlib
 from typing import Union, List, Dict, Iterable, Optional, Any, Tuple
@@ -6,8 +7,9 @@ from typing import Union, List, Dict, Iterable, Optional, Any, Tuple
 import conllu
 import torch
 from allennlp import data as allen_data
-from allennlp.common import checks, util
+from allennlp.common import checks, util, file_utils
 from allennlp.data import fields as allen_fields, vocabulary
+from allennlp.data.dataset_readers import dataset_reader, conll2003
 from conllu import parser
 from dataclasses import dataclass
 from overrides import overrides
@@ -17,6 +19,36 @@ from combo.data import fields
 logger = logging.getLogger(__name__)
 
 
+@allen_data.DatasetReader.register("iob")
+class IOBDatasetReader(conll2003.Conll2003DatasetReader):
+    """Extension of the AllenNLP Conll2003DatasetReader with tab as a separator."""
+
+    def _read(self, file_path: dataset_reader.PathOrStr) -> Iterable[allen_data.Instance]:
+        # if `file_path` is a URL, redirect to the cache
+        file_path = file_utils.cached_path(file_path)
+
+        with open(file_path, "r") as data_file:
+            logger.info("Reading instances from lines in file at: %s", file_path)
+
+            # Group lines into sentence chunks based on the divider.
+            line_chunks = (
+                lines
+                for is_divider, lines in itertools.groupby(data_file, conll2003._is_divider)
+                # Ignore the divider chunks, so that `lines` corresponds to the words
+                # of a single sentence.
+                if not is_divider
+            )
+            for lines in self.shard_iterable(line_chunks):
+                fields = [line.strip().split("\t") for line in lines]
+                # unzipping trick returns tuples, but our Fields need lists
+                fields = [list(field) for field in zip(*fields)]
+                tokens_, pos_tags, chunk_tags, ner_tags = fields
+                # TextField requires `Token` objects
+                tokens = [allen_data.Token(token) for token in tokens_]
+
+                yield self.text_to_instance(tokens, pos_tags, chunk_tags, ner_tags)
+
+
 @allen_data.DatasetReader.register("conllu")
 class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
 
@@ -99,7 +131,7 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
 
         # features
         text_field = allen_fields.TextField(tokens, self._token_indexers)
-        fields_["sentence"] = text_field
+        fields_["tokens"] = text_field
 
         # targets
         if self.generate_labels:
diff --git a/combo/data/samplers/samplers.py b/combo/data/samplers/samplers.py
index a1754985f115d95995f48a2f99d5a4485d5faa40..2dbc4221cf05902274767b6e3337ee0a90856099 100644
--- a/combo/data/samplers/samplers.py
+++ b/combo/data/samplers/samplers.py
@@ -18,7 +18,7 @@ class TokenCountBatchSampler(allen_data.BatchSampler):
         batches = []
         batch = []
         words_count = 0
-        lengths = [len(instance.fields["sentence"].tokens) for instance in dataset]
+        lengths = [len(instance.fields["tokens"].tokens) for instance in dataset]
         argsorted_lengths = np.argsort(lengths)
         for idx in argsorted_lengths:
             words_count += lengths[idx]
diff --git a/combo/models/__init__.py b/combo/models/__init__.py
index ec7a1380e1cfc80b0302806e46cca4e5fc2d3568..04079101024eb7f7031fac63c37cceba2f322388 100644
--- a/combo/models/__init__.py
+++ b/combo/models/__init__.py
@@ -7,3 +7,4 @@ from .encoder import ComboEncoder
 from .lemma import LemmatizerModel
 from .model import ComboModel
 from .morpho import MorphologicalFeatures
+from .multitask import MultiTaskModel
diff --git a/combo/models/base.py b/combo/models/base.py
index a5cb5fe61f85a98f78d143a54695d01948aa8dda..786ef245e1f7bb4860d94288de3dcb49f2b1afb1 100644
--- a/combo/models/base.py
+++ b/combo/models/base.py
@@ -106,7 +106,8 @@ class FeedForwardPredictor(Predictor):
             )
 
         assert vocab_namespace in vocab.get_namespaces(), \
-            f"There is not {vocab_namespace} in created vocabs, check if this field has any values to predict!"
+            f"There is not {vocab_namespace} in created vocabs ({','.join(vocab.get_namespaces())}), " \
+            f"check if this field has any values to predict!"
         hidden_dims = hidden_dims + [vocab.get_vocab_size(vocab_namespace)]
 
         return cls(feedforward.FeedForward(
diff --git a/combo/models/model.py b/combo/models/model.py
index 9866bcb4fba41ed2506b2d33290e6cd0fe237d29..4c9d1bee43c6460c2ed636c3fa015ddeeafecce1 100644
--- a/combo/models/model.py
+++ b/combo/models/model.py
@@ -1,25 +1,79 @@
 """Main COMBO model."""
-from typing import Optional, Dict, Any, List
+from typing import Optional, Dict, Any, List, Union
 
 import torch
-from allennlp import data, modules, models as allen_models, nn as allen_nn
+from allennlp import data, modules, nn as allen_nn
+from allennlp.models import heads
 from allennlp.modules import text_field_embedders
 from allennlp.nn import util
+from allennlp.training import metrics as allen_metrics
 from overrides import overrides
 
 from combo.models import base
 from combo.utils import metrics
 
 
-@allen_models.Model.register("semantic_multitask")
-class ComboModel(allen_models.Model):
+@modules.Backbone.register("combo_backbone")
+class ComboBackbone(modules.Backbone):
+
+    def __init__(self, text_field_embedder: text_field_embedders.TextFieldEmbedder,
+                 seq_encoder: modules.Seq2SeqEncoder):
+        super().__init__()
+        self.text_field_embedder = text_field_embedder
+        self.seq_encoder = seq_encoder
+
+    def forward(self, tokens: Dict[str, Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
+        char_mask = tokens["char"]["token_characters"].gt(0)
+        word_mask = util.get_text_field_mask(tokens)
+        encoder_input = self.text_field_embedder(tokens, char_mask=char_mask)
+        return dict(encoder_emb=self.seq_encoder(encoder_input, word_mask),
+                    word_mask=word_mask,
+                    char_mask=char_mask)
+
+
+@heads.Head.register("ner_head")
+class NERModel(heads.Head):
+
+    def __init__(self, feedforward_predictor: base.Predictor, vocab: data.Vocabulary):
+        super().__init__(vocab)
+        self.feedforward_predictor = feedforward_predictor
+        self._accuracy_metric = allen_metrics.CategoricalAccuracy()
+        # self._f1_metric = allen_metrics.SpanBasedF1Measure(vocab, tag_namespace="ner_labels", label_encoding="IOB1",
+        #                                                    ignore_classes=["_"])
+        self._loss = 0.0
+
+    def forward(self,
+                encoder_emb: Union[torch.Tensor, List[torch.Tensor]],
+                word_mask: Optional[torch.BoolTensor] = None,
+                tags: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]:
+        output = self.feedforward_predictor(
+            x=encoder_emb,
+            mask=word_mask,
+            labels=tags,
+        )
+
+        if tags is not None:
+            self._loss = output["loss"]
+            self._accuracy_metric(output["probability"], tags, word_mask)
+            # self._f1_metric(output["probability"], tags, word_mask)
+
+        return output
+
+    @overrides
+    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
+        return {
+            **{"accuracy": self._accuracy_metric.get_metric(reset), "loss": self._loss},
+            # **self._f1_metric.get_metric(reset)
+        }
+
+
+@heads.Head.register("semantic_multitask_head")
+class ComboModel(heads.Head):
     """Main COMBO model."""
 
     def __init__(self,
                  vocab: data.Vocabulary,
                  loss_weights: Dict[str, float],
-                 text_field_embedder: text_field_embedders.TextFieldEmbedder,
-                 seq_encoder: modules.Seq2SeqEncoder,
                  use_sample_weight: bool = True,
                  lemmatizer: Optional[base.Predictor] = None,
                  upos_tagger: Optional[base.Predictor] = None,
@@ -30,10 +84,8 @@ class ComboModel(allen_models.Model):
                  enhanced_dependency_relation: Optional[base.Predictor] = None,
                  regularizer: allen_nn.RegularizerApplicator = None) -> None:
         super().__init__(vocab, regularizer)
-        self.text_field_embedder = text_field_embedder
         self.loss_weights = loss_weights
         self.use_sample_weight = use_sample_weight
-        self.seq_encoder = seq_encoder
         self.lemmatizer = lemmatizer
         self.upos_tagger = upos_tagger
         self.xpos_tagger = xpos_tagger
@@ -41,13 +93,16 @@ class ComboModel(allen_models.Model):
         self.morphological_feat = morphological_feat
         self.dependency_relation = dependency_relation
         self.enhanced_dependency_relation = enhanced_dependency_relation
-        self._head_sentinel = torch.nn.Parameter(torch.randn([1, 1, self.seq_encoder.get_output_dim()]))
+        self._head_sentinel = torch.nn.Parameter(torch.randn([1, 1, 1024]))
         self.scores = metrics.SemanticMetrics()
         self._partial_losses = None
 
     @overrides
     def forward(self,
-                sentence: Dict[str, Dict[str, torch.Tensor]],
+                tokens: Dict[str, Dict[str, torch.Tensor]],
+                encoder_emb: torch.Tensor,
+                word_mask: torch.Tensor,
+                char_mask: torch.Tensor,
                 metadata: List[Dict[str, Any]],
                 upostag: torch.Tensor = None,
                 xpostag: torch.Tensor = None,
@@ -58,19 +113,11 @@ class ComboModel(allen_models.Model):
                 semrel: torch.Tensor = None,
                 enhanced_heads: torch.Tensor = None,
                 enhanced_deprels: torch.Tensor = None) -> Dict[str, torch.Tensor]:
-
-        # Prepare masks
-        char_mask = sentence["char"]["token_characters"].gt(0)
-        word_mask = util.get_text_field_mask(sentence)
-
-        device = word_mask.device
+        device = encoder_emb.device
 
         # If enabled weight samples loss by log(sentence_length)
         sample_weights = word_mask.sum(-1).float().log() if self.use_sample_weight else None
 
-        encoder_input = self.text_field_embedder(sentence, char_mask=char_mask)
-        encoder_emb = self.seq_encoder(encoder_input, word_mask)
-
         batch_size, _, encoding_dim = encoder_emb.size()
 
         # Concatenate the head sentinel (ROOT) onto the sentence representation.
@@ -99,8 +146,8 @@ class ComboModel(allen_models.Model):
                                        labels=feats,
                                        sample_weights=sample_weights)
         lemma_output = self._optional(self.lemmatizer,
-                                      (encoder_emb, sentence.get("char").get("token_characters")
-                                      if sentence.get("char") else None),
+                                      (encoder_emb, tokens.get("char").get("token_characters")
+                                      if tokens.get("char") else None),
                                       mask=word_mask,
                                       labels=lemma.get("char").get("token_characters") if lemma else None,
                                       sample_weights=sample_weights)
diff --git a/combo/models/multitask.py b/combo/models/multitask.py
new file mode 100644
index 0000000000000000000000000000000000000000..d557591903baec157b6ebf3c6e184c01ba30daed
--- /dev/null
+++ b/combo/models/multitask.py
@@ -0,0 +1,66 @@
+import collections
+from typing import Mapping, List, Dict, Union
+
+import torch
+from allennlp import models
+
+
+@models.Model.register("multitask_extended")
+class MultiTaskModel(models.MultiTaskModel):
+    """Extension of the AllenNLP MultiTaskModel to handle dictionary inputs."""
+
+    def forward(self, **kwargs) -> Dict[str, torch.Tensor]:  # type: ignore
+        if "task" not in kwargs:
+            raise ValueError(
+                "Instances for multitask training need to contain a MetadataField with "
+                "the name 'task' to indicate which task they belong to. Usually the "
+                "MultitaskDataLoader provides this field and you don't have to do anything."
+            )
+
+        task_indices_just_for_mypy: Mapping[str, List[int]] = collections.defaultdict(lambda: [])
+        for i, task in enumerate(kwargs["task"]):
+            task_indices_just_for_mypy[task].append(i)
+        task_indices: Dict[str, torch.LongTensor] = {
+            task: torch.LongTensor(indices) for task, indices in task_indices_just_for_mypy.items()
+        }
+
+        def make_inputs_for_task(task: str, whole_batch_input: Union[torch.Tensor, List, Dict]):
+            if isinstance(whole_batch_input, torch.Tensor):
+                task_indices[task] = task_indices[task].to(whole_batch_input.device)
+                return torch.index_select(whole_batch_input, 0, task_indices[task])
+            if isinstance(whole_batch_input, dict):
+                return {k: make_inputs_for_task(task, v) for k, v in whole_batch_input.items()}
+            else:
+                return [whole_batch_input[i] for i in task_indices[task]]
+
+        backbone_arguments = self._get_arguments(kwargs, "backbone")
+        backbone_outputs = self._backbone(**backbone_arguments)
+        combined_arguments = {**backbone_outputs, **kwargs}
+
+        outputs = {**backbone_outputs}
+        loss = None
+        for head_name in self._heads:
+            if head_name not in task_indices:
+                continue
+
+            head_arguments = self._get_arguments(combined_arguments, head_name)
+            head_arguments = {
+                key: make_inputs_for_task(head_name, value) for key, value in head_arguments.items()
+            }
+
+            head_outputs = self._heads[head_name](**head_arguments)
+            for key in head_outputs:
+                outputs[f"{head_name}_{key}"] = head_outputs[key]
+
+            if "loss" in head_outputs:
+                self._heads_called.add(head_name)
+                head_loss = self._loss_weights[head_name] * head_outputs["loss"]
+                if loss is None:
+                    loss = head_loss
+                else:
+                    loss += head_loss
+
+        if loss is not None:
+            outputs["loss"] = loss
+
+        return outputs
diff --git a/scripts/fix_iob.py b/scripts/fix_iob.py
new file mode 100644
index 0000000000000000000000000000000000000000..976e70472feec9fdb409402122f063b4f42039f2
--- /dev/null
+++ b/scripts/fix_iob.py
@@ -0,0 +1,35 @@
+"""Script which fixes -DOCSTART- misspellings."""
+import pathlib
+
+from absl import app
+from absl import flags
+
+FLAGS = flags.FLAGS
+flags.DEFINE_string(name="input_path", default=None,
+                    help="Path to IOB file.")
+flags.DEFINE_string(name="output_path", default=None,
+                    help="Path to store fixed IOB file.")
+flags.mark_flag_as_required("input_path")
+flags.mark_flag_as_required("output_path")
+
+
+def run(_):
+    input_path = pathlib.Path(FLAGS.input_path)
+    assert input_path.exists() and input_path.is_file(), "Input doesn't exists or is not a file."
+    with input_path.open("r") as input_fh:
+        with pathlib.Path(FLAGS.output_path).open("w") as output_fh:
+            for line in input_fh:
+
+                # Replace -DOCSTART with -DOCSTART-
+                if "-DOCSTART" in line and "-DOCSTART-" not in line:
+                    line = line.replace("-DOCSTART", "-DOCSTART-")
+
+                output_fh.write(line)
+
+
+def main():
+    app.run(run)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/tests/data/fields/test_samplers.py b/tests/data/fields/test_samplers.py
index f13a26449c17da7810b811352afb930615020bcd..6a228b2a0892b67ebc6b5369899bdd291d069882 100644
--- a/tests/data/fields/test_samplers.py
+++ b/tests/data/fields/test_samplers.py
@@ -16,7 +16,7 @@ class TokenCountBatchSamplerTest(unittest.TestCase):
             tokens = [data.Token(t)
                       for t in sentence.split()]
             text_field = fields.TextField(tokens, {})
-            self.dataset.append(data.Instance({"sentence": text_field}))
+            self.dataset.append(data.Instance({"tokens": text_field}))
 
     def test_batches(self):
         # given