diff --git a/.pylintrc b/.pylintrc
new file mode 100644
index 0000000000000000000000000000000000000000..88d1c8dbb6d083104d27efd4a7833759b6bd52e0
--- /dev/null
+++ b/.pylintrc
@@ -0,0 +1,588 @@
+[MASTER]
+
+# A comma-separated list of package or module names from where C extensions may
+# be loaded. Extensions are loading into the active Python interpreter and may
+# run arbitrary code.
+extension-pkg-whitelist=
+
+# Add files or directories to the blacklist. They should be base names, not
+# paths.
+ignore=CVS,venv
+
+# Add files or directories matching the regex patterns to the blacklist. The
+# regex matches against base names, not paths.
+ignore-patterns=/data
+
+# Python code to execute, usually for sys.path manipulation such as
+# pygtk.require().
+#init-hook=
+
+# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
+# number of processors available to use.
+jobs=1
+
+# Control the amount of potential inferred values when inferring a single
+# object. This can help the performance when dealing with large functions or
+# complex, nested conditions.
+limit-inference-results=100
+
+# List of plugins (as comma separated values of python module names) to load,
+# usually to register additional checkers.
+load-plugins=pylint_quotes
+
+# Pickle collected data for later comparisons.
+persistent=yes
+
+# Specify a configuration file.
+#rcfile=
+
+# When enabled, pylint would attempt to guess common misconfiguration and emit
+# user-friendly hints instead of false-positive error messages.
+suggestion-mode=yes
+
+# Allow loading of arbitrary C extensions. Extensions are imported into the
+# active Python interpreter and may run arbitrary code.
+unsafe-load-any-extension=no
+
+
+[MESSAGES CONTROL]
+
+# Only show warnings with the listed confidence levels. Leave empty to show
+# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED.
+confidence=
+
+# Disable the message, report, category or checker with the given id(s). You
+# can either give multiple identifiers separated by comma (,) or put this
+# option multiple times (only on the command line, not in the configuration
+# file where it should appear only once). You can also use "--disable=all" to
+# disable everything first and then reenable specific checks. For example, if
+# you want to run only the similarities checker, you can use "--disable=all
+# --enable=similarities". If you want to run only the classes checker, but have
+# no Warning level messages displayed, use "--disable=all --enable=classes
+# --disable=W".
+disable=print-statement,
+        parameter-unpacking,
+        unpacking-in-except,
+        old-raise-syntax,
+        backtick,
+        long-suffix,
+        old-ne-operator,
+        old-octal-literal,
+        import-star-module-level,
+        non-ascii-bytes-literal,
+        raw-checker-failed,
+        bad-inline-option,
+        locally-disabled,
+        file-ignored,
+        suppressed-message,
+        useless-suppression,
+        deprecated-pragma,
+        use-symbolic-message-instead,
+        apply-builtin,
+        basestring-builtin,
+        buffer-builtin,
+        cmp-builtin,
+        coerce-builtin,
+        execfile-builtin,
+        file-builtin,
+        long-builtin,
+        raw_input-builtin,
+        reduce-builtin,
+        standarderror-builtin,
+        unicode-builtin,
+        xrange-builtin,
+        coerce-method,
+        delslice-method,
+        getslice-method,
+        setslice-method,
+        no-absolute-import,
+        old-division,
+        dict-iter-method,
+        dict-view-method,
+        next-method-called,
+        metaclass-assignment,
+        indexing-exception,
+        raising-string,
+        reload-builtin,
+        oct-method,
+        hex-method,
+        nonzero-method,
+        cmp-method,
+        input-builtin,
+        round-builtin,
+        intern-builtin,
+        unichr-builtin,
+        map-builtin-not-iterating,
+        zip-builtin-not-iterating,
+        range-builtin-not-iterating,
+        filter-builtin-not-iterating,
+        using-cmp-argument,
+        eq-without-hash,
+        div-method,
+        idiv-method,
+        rdiv-method,
+        exception-message-attribute,
+        invalid-str-codec,
+        sys-max-int,
+        bad-python3-import,
+        deprecated-string-function,
+        deprecated-str-translate-call,
+        deprecated-itertools-function,
+        deprecated-types-field,
+        next-method-defined,
+        dict-items-not-iterating,
+        dict-keys-not-iterating,
+        dict-values-not-iterating,
+        deprecated-operator-function,
+        deprecated-urllib-function,
+        xreadlines-attribute,
+        deprecated-sys-function,
+        exception-escape,
+        comprehension-escape,
+        arguments-differ,
+
+# Enable the message, report, category or checker with the given id(s). You can
+# either give multiple identifier separated by comma (,) or put this option
+# multiple time (only on the command line, not in the configuration file where
+# it should appear only once). See also the "--disable" option for examples.
+enable=c-extension-no-member
+
+
+[REPORTS]
+
+# Python expression which should return a score less than or equal to 10. You
+# have access to the variables 'error', 'warning', 'refactor', and 'convention'
+# which contain the number of messages in each category, as well as 'statement'
+# which is the total number of statements analyzed. This score is used by the
+# global evaluation report (RP0004).
+evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
+
+# Template used to display messages. This is a python new-style format string
+# used to format the message information. See doc for all details.
+#msg-template=
+
+# Set the output format. Available formats are text, parseable, colorized, json
+# and msvs (visual studio). You can also give a reporter class, e.g.
+# mypackage.mymodule.MyReporterClass.
+output-format=text
+
+# Tells whether to display a full report or only the messages.
+reports=no
+
+# Activate the evaluation score.
+score=yes
+
+
+[REFACTORING]
+
+# Maximum number of nested blocks for function / method body
+max-nested-blocks=5
+
+# Complete name of functions that never returns. When checking for
+# inconsistent-return-statements if a never returning function is called then
+# it will be considered as an explicit return statement and no message will be
+# printed.
+never-returning-functions=sys.exit
+
+
+[BASIC]
+
+# Naming style matching correct argument names.
+argument-naming-style=snake_case
+
+# Regular expression matching correct argument names. Overrides argument-
+# naming-style.
+#argument-rgx=
+
+# Naming style matching correct attribute names.
+attr-naming-style=snake_case
+
+# Regular expression matching correct attribute names. Overrides attr-naming-
+# style.
+#attr-rgx=
+
+# Bad variable names which should always be refused, separated by a comma.
+bad-names=foo,
+          bar,
+          baz,
+          toto,
+          tutu,
+          tata
+
+# Naming style matching correct class attribute names.
+class-attribute-naming-style=any
+
+# Regular expression matching correct class attribute names. Overrides class-
+# attribute-naming-style.
+#class-attribute-rgx=
+
+# Naming style matching correct class names.
+class-naming-style=PascalCase
+
+# Regular expression matching correct class names. Overrides class-naming-
+# style.
+#class-rgx=
+
+# Naming style matching correct constant names.
+const-naming-style=UPPER_CASE
+
+# Regular expression matching correct constant names. Overrides const-naming-
+# style.
+#const-rgx=
+
+# Minimum line length for functions/classes that require docstrings, shorter
+# ones are exempt.
+docstring-min-length=-1
+
+# Naming style matching correct function names.
+function-naming-style=snake_case
+
+# Regular expression matching correct function names. Overrides function-
+# naming-style.
+#function-rgx=
+
+# Good variable names which should always be accepted, separated by a comma.
+good-names=i,
+           j,
+           k,
+           ex,
+           Run,
+           x,
+           _
+
+# Include a hint for the correct naming format with invalid-name.
+include-naming-hint=no
+
+# Naming style matching correct inline iteration names.
+inlinevar-naming-style=any
+
+# Regular expression matching correct inline iteration names. Overrides
+# inlinevar-naming-style.
+#inlinevar-rgx=
+
+# Naming style matching correct method names.
+method-naming-style=snake_case
+
+# Regular expression matching correct method names. Overrides method-naming-
+# style.
+#method-rgx=
+
+# Naming style matching correct module names.
+module-naming-style=snake_case
+
+# Regular expression matching correct module names. Overrides module-naming-
+# style.
+#module-rgx=
+
+# Colon-delimited sets of names that determine each other's naming style when
+# the name regexes allow several styles.
+name-group=
+
+# Regular expression which should only match function or class names that do
+# not require a docstring.
+no-docstring-rgx=^_|test.*|^.*Test
+
+# List of decorators that produce properties, such as abc.abstractproperty. Add
+# to this list to register other decorators that produce valid properties.
+# These decorators are taken in consideration only for invalid-name.
+property-classes=abc.abstractproperty
+
+# Naming style matching correct variable names.
+variable-naming-style=snake_case
+
+# Regular expression matching correct variable names. Overrides variable-
+# naming-style.
+#variable-rgx=
+
+
+[SPELLING]
+
+# Limits count of emitted suggestions for spelling mistakes.
+max-spelling-suggestions=4
+
+# Spelling dictionary name. Available dictionaries: none. To make it work,
+# install the python-enchant package.
+spelling-dict=
+
+# List of comma separated words that should not be checked.
+spelling-ignore-words=
+
+# A path to a file that contains the private dictionary; one word per line.
+spelling-private-dict-file=
+
+# Tells whether to store unknown words to the private dictionary (see the
+# --spelling-private-dict-file option) instead of raising a message.
+spelling-store-unknown-words=no
+
+
+[FORMAT]
+
+# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
+expected-line-ending-format=
+
+# Regexp for a line that is allowed to be longer than the limit.
+ignore-long-lines=^\s*(# )?<?https?://\S+>?$
+
+# Number of spaces of indent required inside a hanging or continued line.
+indent-after-paren=4
+
+# String used as indentation unit. This is usually "    " (4 spaces) or "\t" (1
+# tab).
+indent-string='    '
+
+# Maximum number of characters on a single line.
+max-line-length=120
+
+# Maximum number of lines in a module.
+max-module-lines=1000
+
+# List of optional constructs for which whitespace checking is disabled. `dict-
+# separator` is used to allow tabulation in dicts, etc.: {1  : 1,\n222: 2}.
+# `trailing-comma` allows a space between comma and closing bracket: (a, ).
+# `empty-line` allows space-only lines.
+no-space-check=trailing-comma,
+               dict-separator
+
+# Allow the body of a class to be on the same line as the declaration if body
+# contains single statement.
+single-line-class-stmt=no
+
+# Allow the body of an if to be on the same line as the test if there is no
+# else.
+single-line-if-stmt=no
+
+
+[MISCELLANEOUS]
+
+# List of note tags to take in consideration, separated by a comma.
+notes=FIXME,
+      XXX,
+
+
+[LOGGING]
+
+# Format style used to check logging format string. `old` means using %
+# formatting, `new` is for `{}` formatting,and `fstr` is for f-strings.
+logging-format-style=old
+
+# Logging modules to check that the string format arguments are in logging
+# function parameter format.
+logging-modules=logging
+
+
+[STRING]
+
+# This flag controls whether the implicit-str-concat-in-sequence should
+# generate a warning on implicit string concatenation in sequences defined over
+# several lines.
+check-str-concat-over-line-jumps=no
+
+
+[STRING_QUOTES]
+
+string-quote=single
+triple-quote=double
+docstring-quote=double
+
+[TYPECHECK]
+
+# List of decorators that produce context managers, such as
+# contextlib.contextmanager. Add to this list to register other decorators that
+# produce valid context managers.
+contextmanager-decorators=contextlib.contextmanager
+
+# List of members which are set dynamically and missed by pylint inference
+# system, and so shouldn't trigger E1101 when accessed. Python regular
+# expressions are accepted.
+generated-members=numpy.*, torch.*
+
+# Tells whether missing members accessed in mixin class should be ignored. A
+# mixin class is detected if its name ends with "mixin" (case insensitive).
+ignore-mixin-members=yes
+
+# Tells whether to warn about missing members when the owner of the attribute
+# is inferred to be None.
+ignore-none=yes
+
+# This flag controls whether pylint should warn about no-member and similar
+# checks whenever an opaque object is returned when inferring. The inference
+# can return multiple potential results while evaluating a Python object, but
+# some branches might not be evaluated, which results in partial inference. In
+# that case, it might be useful to still emit no-member and other checks for
+# the rest of the inferred objects.
+ignore-on-opaque-inference=yes
+
+# List of class names for which member attributes should not be checked (useful
+# for classes with dynamically set attributes). This supports the use of
+# qualified names.
+ignored-classes=optparse.Values,thread._local,_thread._local
+
+# List of module names for which member attributes should not be checked
+# (useful for modules/projects where namespaces are manipulated during runtime
+# and thus existing member attributes cannot be deduced by static analysis). It
+# supports qualified module names, as well as Unix pattern matching.
+ignored-modules=
+
+# Show a hint with possible names when a member name was not found. The aspect
+# of finding the hint is based on edit distance.
+missing-member-hint=yes
+
+# The minimum edit distance a name should have in order to be considered a
+# similar match for a missing member name.
+missing-member-hint-distance=1
+
+# The total number of similar names that should be taken in consideration when
+# showing a hint for a missing member.
+missing-member-max-choices=1
+
+# List of decorators that change the signature of a decorated function.
+signature-mutators=
+
+
+[VARIABLES]
+
+# List of additional names supposed to be defined in builtins. Remember that
+# you should avoid defining new builtins when possible.
+additional-builtins=
+
+# Tells whether unused global variables should be treated as a violation.
+allow-global-unused-variables=yes
+
+# List of strings which can identify a callback function by name. A callback
+# name must start or end with one of those strings.
+callbacks=cb_,
+          _cb
+
+# A regular expression matching the name of dummy variables (i.e. expected to
+# not be used).
+dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
+
+# Argument names that match this expression will be ignored. Default to name
+# with leading underscore.
+ignored-argument-names=_.*|^ignored_|^unused_
+
+# Tells whether we should check for unused import in __init__ files.
+init-import=no
+
+# List of qualified module names which can have objects that can redefine
+# builtins.
+redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io
+
+
+[SIMILARITIES]
+
+# Ignore comments when computing similarities.
+ignore-comments=yes
+
+# Ignore docstrings when computing similarities.
+ignore-docstrings=yes
+
+# Ignore imports when computing similarities.
+ignore-imports=no
+
+# Minimum lines number of a similarity.
+min-similarity-lines=4
+
+
+[DESIGN]
+
+# Maximum number of arguments for function / method.
+max-args=7
+
+# Maximum number of attributes for a class (see R0902).
+max-attributes=7
+
+# Maximum number of boolean expressions in an if statement (see R0916).
+max-bool-expr=5
+
+# Maximum number of branch for function / method body.
+max-branches=12
+
+# Maximum number of locals for function / method body.
+max-locals=15
+
+# Maximum number of parents for a class (see R0901).
+max-parents=7
+
+# Maximum number of public methods for a class (see R0904).
+max-public-methods=20
+
+# Maximum number of return / yield for function / method body.
+max-returns=6
+
+# Maximum number of statements in function / method body.
+max-statements=50
+
+# Minimum number of public methods for a class (see R0903).
+min-public-methods=2
+
+
+[IMPORTS]
+
+# List of modules that can be imported at any level, not just the top level
+# one.
+allow-any-import-level=
+
+# Allow wildcard imports from modules that define __all__.
+allow-wildcard-with-all=no
+
+# Analyse import fallback blocks. This can be used to support both Python 2 and
+# 3 compatible code, which means that the block might have code that exists
+# only in one or another interpreter, leading to false positives when analysed.
+analyse-fallback-blocks=no
+
+# Deprecated modules which should not be used, separated by a comma.
+deprecated-modules=optparse,tkinter.tix
+
+# Create a graph of external dependencies in the given file (report RP0402 must
+# not be disabled).
+ext-import-graph=
+
+# Create a graph of every (i.e. internal and external) dependencies in the
+# given file (report RP0402 must not be disabled).
+import-graph=
+
+# Create a graph of internal dependencies in the given file (report RP0402 must
+# not be disabled).
+int-import-graph=
+
+# Force import order to recognize a module as part of the standard
+# compatibility libraries.
+known-standard-library=
+
+# Force import order to recognize a module as part of a third party library.
+known-third-party=enchant
+
+# Couples of modules and preferred modules, separated by a comma.
+preferred-modules=
+
+
+[CLASSES]
+
+# List of method names used to declare (i.e. assign) instance attributes.
+defining-attr-methods=__init__,
+                      __new__,
+                      setUp,
+                      __post_init__
+
+# List of member names, which should be excluded from the protected access
+# warning.
+exclude-protected=_asdict,
+                  _fields,
+                  _replace,
+                  _source,
+                  _make
+
+# List of valid names for the first argument in a class method.
+valid-classmethod-first-arg=cls
+
+# List of valid names for the first argument in a metaclass class method.
+valid-metaclass-classmethod-first-arg=cls
+
+
+[EXCEPTIONS]
+
+# Exceptions that will emit a warning when being caught. Defaults to
+# "BaseException, Exception".
+overgeneral-exceptions=BaseException,
+                       Exception
diff --git a/Makefile b/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..df0c4836a738746e334ca86483674a5049cb7a5d
--- /dev/null
+++ b/Makefile
@@ -0,0 +1,14 @@
+clean:
+	rm -rf COMBO.egg-info
+	rm -rf .eggs
+	rm -rf .pytest_cache
+
+develop:
+	python setup.py develop
+
+install:
+	python setup.py install
+
+test:
+	python setup.py test
+	pylint --rcfile=.pylintrc tests combo
\ No newline at end of file
diff --git a/README.md b/README.md
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0e4c660aba94e8403d5c01ac0363db82892b8476 100644
--- a/README.md
+++ b/README.md
@@ -0,0 +1,83 @@
+## Installation
+
+Clone this repository and run:
+```bash
+python setup.py develop
+```
+
+## Training
+
+Command:
+```bash
+combo --mode train \
+      --training_data_path your_training_path \
+      --validation_data_path your_validation_path
+```
+
+Options:
+```bash
+combo --helpfull
+```
+
+Examples (for clarity without training/validation data paths):
+
+* train on gpu 0
+    ```bash
+    combo --mode train --cuda_davice 0
+    ```
+* use pretrained embeddings:
+    ```bash
+    combo --mode train --pretrained_tokens your_pretrained_embeddings_path --embedding_dim your_embeddings_dim
+    ```
+* use pretrained transformer embeddings:
+    ```bash
+    combo --mode train --pretrained_transformer_name your_choosen_pretrained_transformer
+    ```
+* predict only dependency tree:
+    ```bash
+    combo --mode train --targets head --targets deprel
+    ```
+* use part-of-speech tags for predicting only dependency tree
+    ```bash
+    combo --mode train --targets head --targets deprel --features token --features char --features upostag
+    ```
+Advanced configuration: [Configuration](#configuration)
+
+## Prediction
+
+### ConLLU file prediction:
+Input and output are both in `*.conllu` format.
+```bash
+combo --mode predict --model_path your_model_tar_gz --input_file your_conllu_file --output_file your_output_file --silent
+```
+
+### Console
+Works for models where input was text-based only.
+
+Interactive testing in console (load model and just type sentence in console).
+
+```bash
+combo --mode predict --model_path your_model_tar_gz --input_file "-"
+```
+### Raw text
+Works for models where input was text-based only. 
+
+Input: one sentence per line.
+
+Output: List of token jsons.
+
+```bash
+combo --mode predict --model_path your_model_tar_gz --input_file your_text_file --output_file your_output_file --silent
+```
+#### Advanced
+
+There are 2 tokenizers: whitespace and spacy-based (`en_core_web_sm` model).
+
+Use either `--predictor_name semantic-multitask-predictor` or `--predictor_name semantic-multitask-predictor-spacy`.
+
+## Configuration
+
+### Advanced
+Config template [config.template.jsonnet](config.template.jsonnet) is formed in `allennlp` format so you can freely modify it.
+There is configuration for all the training/model parameters (learning rates, epochs number etc.).
+Some of them use `jsonnet` syntax to get values from configuration flags, however most of them can be modified directly there.
diff --git a/combo/__init__.py b/combo/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/combo/commands/__init__.py b/combo/commands/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..32071442cc22f05bede1d4533fba4480a7db8563
--- /dev/null
+++ b/combo/commands/__init__.py
@@ -0,0 +1 @@
+from .train import FinetuningTrainModel
diff --git a/combo/commands/train.py b/combo/commands/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..44c291596c3dfba29811a35ec316dbaeee5910c0
--- /dev/null
+++ b/combo/commands/train.py
@@ -0,0 +1,171 @@
+import os
+from typing import List
+
+from allennlp import data, models, common, training
+from allennlp.commands import train
+from allennlp.common import checks
+from allennlp.common import util as common_util
+from allennlp.training import util as training_util
+
+
+@train.TrainModel.register("finetuning", constructor="from_partial_objects_finetuning")
+class FinetuningTrainModel(train.TrainModel):
+    """Class made only for finetuning, the only difference is saving vocab from concatenated
+    (archive and current) datasets."""
+
+    @classmethod
+    def from_partial_objects_finetuning(
+        cls,
+        serialization_dir: str,
+        local_rank: int,
+        batch_weight_key: str,
+        dataset_reader: data.DatasetReader,
+        train_data_path: str,
+        model: common.Lazy[models.Model],
+        data_loader: common.Lazy[data.DataLoader],
+        trainer: common.Lazy[training.Trainer],
+        vocabulary: common.Lazy[data.Vocabulary] = None,
+        datasets_for_vocab_creation: List[str] = None,
+        validation_dataset_reader: data.DatasetReader = None,
+        validation_data_path: str = None,
+        validation_data_loader: common.Lazy[data.DataLoader] = None,
+        test_data_path: str = None,
+        evaluate_on_test: bool = False,
+    ) -> "train.TrainModel":
+        """
+        This method is intended for use with our `FromParams` logic, to construct a `TrainModel`
+        object from a config file passed to the `allennlp train` command.  The arguments to this
+        method are the allowed top-level keys in a configuration file (except for the first three,
+        which are obtained separately).
+
+        You *could* use this outside of our `FromParams` logic if you really want to, but there
+        might be easier ways to accomplish your goal than instantiating `Lazy` objects.  If you are
+        writing your own training loop, we recommend that you look at the implementation of this
+        method for inspiration and possibly some utility functions you can call, but you very likely
+        should not use this method directly.
+
+        The `Lazy` type annotations here are a mechanism for building dependencies to an object
+        sequentially - the `TrainModel` object needs data, a model, and a trainer, but the model
+        needs to see the data before it's constructed (to create a vocabulary) and the trainer needs
+        the data and the model before it's constructed.  Objects that have sequential dependencies
+        like this are labeled as `Lazy` in their type annotations, and we pass the missing
+        dependencies when we call their `construct()` method, which you can see in the code below.
+
+        # Parameters
+        serialization_dir: `str`
+            The directory where logs and model archives will be saved.
+        local_rank: `int`
+            The process index that is initialized using the GPU device id.
+        batch_weight_key: `str`
+            The name of metric used to weight the loss on a per-batch basis.
+        dataset_reader: `DatasetReader`
+            The `DatasetReader` that will be used for training and (by default) for validation.
+        train_data_path: `str`
+            The file (or directory) that will be passed to `dataset_reader.read()` to construct the
+            training data.
+        model: `Lazy[Model]`
+            The model that we will train.  This is lazy because it depends on the `Vocabulary`;
+            after constructing the vocabulary we call `model.construct(vocab=vocabulary)`.
+        data_loader: `Lazy[DataLoader]`
+            The data_loader we use to batch instances from the dataset reader at training and (by
+            default) validation time. This is lazy because it takes a dataset in it's constructor.
+        trainer: `Lazy[Trainer]`
+            The `Trainer` that actually implements the training loop.  This is a lazy object because
+            it depends on the model that's going to be trained.
+        vocabulary: `Lazy[Vocabulary]`, optional (default=None)
+            The `Vocabulary` that we will use to convert strings in the data to integer ids (and
+            possibly set sizes of embedding matrices in the `Model`).  By default we construct the
+            vocabulary from the instances that we read.
+        datasets_for_vocab_creation: `List[str]`, optional (default=None)
+            If you pass in more than one dataset but don't want to use all of them to construct a
+            vocabulary, you can pass in this key to limit it.  Valid entries in the list are
+            "train", "validation" and "test".
+        validation_dataset_reader: `DatasetReader`, optional (default=None)
+            If given, we will use this dataset reader for the validation data instead of
+            `dataset_reader`.
+        validation_data_path: `str`, optional (default=None)
+            If given, we will use this data for computing validation metrics and early stopping.
+        validation_data_loader: `Lazy[DataLoader]`, optional (default=None)
+            If given, the data_loader we use to batch instances from the dataset reader at
+            validation and test time. This is lazy because it takes a dataset in it's constructor.
+        test_data_path: `str`, optional (default=None)
+            If given, we will use this as test data.  This makes it available for vocab creation by
+            default, but nothing else.
+        evaluate_on_test: `bool`, optional (default=False)
+            If given, we will evaluate the final model on this data at the end of training.  Note
+            that we do not recommend using this for actual test data in every-day experimentation;
+            you should only very rarely evaluate your model on actual test data.
+        """
+
+        datasets = training_util.read_all_datasets(
+            train_data_path=train_data_path,
+            dataset_reader=dataset_reader,
+            validation_dataset_reader=validation_dataset_reader,
+            validation_data_path=validation_data_path,
+            test_data_path=test_data_path,
+        )
+
+        if datasets_for_vocab_creation:
+            for key in datasets_for_vocab_creation:
+                if key not in datasets:
+                    raise checks.ConfigurationError(f"invalid 'dataset_for_vocab_creation' {key}")
+
+        instance_generator = (
+            instance
+            for key, dataset in datasets.items()
+            if not datasets_for_vocab_creation or key in datasets_for_vocab_creation
+            for instance in dataset
+        )
+
+        vocabulary_ = vocabulary.construct(instances=instance_generator)
+        if not vocabulary_:
+            vocabulary_ = data.Vocabulary.from_instances(instance_generator)
+        model_ = model.construct(vocab=vocabulary_)
+
+        # Initializing the model can have side effect of expanding the vocabulary.
+        # Save the vocab only in the master. In the degenerate non-distributed
+        # case, we're trivially the master.
+        if common_util.is_master():
+            vocabulary_path = os.path.join(serialization_dir, "vocabulary")
+            # Only difference compared to TrainModel!
+            model_.vocab.save_to_files(vocabulary_path)
+
+        for dataset in datasets.values():
+            dataset.index_with(model_.vocab)
+
+        data_loader_ = data_loader.construct(dataset=datasets["train"])
+        validation_data = datasets.get("validation")
+        if validation_data is not None:
+            # Because of the way Lazy[T] works, we can't check it's existence
+            # _before_ we've tried to construct it. It returns None if it is not
+            # present, so we try to construct it first, and then afterward back off
+            # to the data_loader configuration used for training if it returns None.
+            validation_data_loader_ = validation_data_loader.construct(dataset=validation_data)
+            if validation_data_loader_ is None:
+                validation_data_loader_ = data_loader.construct(dataset=validation_data)
+        else:
+            validation_data_loader_ = None
+
+        test_data = datasets.get("test")
+        if test_data is not None:
+            test_data_loader = validation_data_loader.construct(dataset=test_data)
+            if test_data_loader is None:
+                test_data_loader = data_loader.construct(dataset=test_data)
+        else:
+            test_data_loader = None
+
+        # We don't need to pass serialization_dir and local_rank here, because they will have been
+        # passed through the trainer by from_params already, because they were keyword arguments to
+        # construct this class in the first place.
+        trainer_ = trainer.construct(
+            model=model_, data_loader=data_loader_, validation_data_loader=validation_data_loader_,
+        )
+
+        return cls(
+            serialization_dir=serialization_dir,
+            model=model_,
+            trainer=trainer_,
+            evaluation_data_loader=test_data_loader,
+            evaluate_on_test=evaluate_on_test,
+            batch_weight_key=batch_weight_key,
+        )
diff --git a/combo/data/__init__.py b/combo/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7abffa2b3206b29c93e14f86ae76101a6a859cf3
--- /dev/null
+++ b/combo/data/__init__.py
@@ -0,0 +1,2 @@
+from .samplers import TokenCountBatchSampler
+from .token_indexers import TokenCharactersIndexer
diff --git a/combo/data/dataset.py b/combo/data/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..25a4f07a074bca03c452fbf3b3ce682e678bcd83
--- /dev/null
+++ b/combo/data/dataset.py
@@ -0,0 +1,227 @@
+import logging
+from typing import Union, List, Dict, Iterable, Optional
+
+import conllu
+from allennlp import data as allen_data
+from allennlp.common import checks
+from allennlp.data import fields as allen_fields, vocabulary
+from conllu import parser
+from overrides import overrides
+
+from combo.data import fields
+
+logger = logging.getLogger(__name__)
+
+
+@allen_data.DatasetReader.register('conllu')
+class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
+
+    def __init__(
+            self,
+            token_indexers: Dict[str, allen_data.TokenIndexer] = None,
+            lemma_indexers: Dict[str, allen_data.TokenIndexer] = None,
+            features: List[str] = None,
+            targets: List[str] = None,
+            use_sem: bool = False,
+            **kwargs,
+    ) -> None:
+        super().__init__(**kwargs)
+        if features is None:
+            features = ['token', 'char']
+        if targets is None:
+            targets = ['head', 'deprel', 'upostag', 'xpostag', 'lemma', 'feats']
+
+        if 'token' not in features and 'char' not in features:
+            raise checks.ConfigurationError("There must be at least one ('char' or 'token') text-based feature!")
+
+        intersection = set(features).intersection(set(targets))
+        if len(intersection) != 0:
+            raise checks.ConfigurationError(
+                "Features and targets cannot share elements! "
+                "Remove {} from either features or targets.".format(intersection)
+            )
+        self._use_sem = use_sem
+
+        # *.conllu readers configuration
+        fields = list(parser.DEFAULT_FIELDS)
+        fields[1] = 'token'  # use 'token' instead of 'form'
+        field_parsers = parser.DEFAULT_FIELD_PARSERS
+        if self._use_sem:
+            fields = list(fields)
+            fields.append('semrel')
+            field_parsers['semrel'] = lambda line, i: parser.parse_nullable_value(line[i]),
+        self.field_parsers = field_parsers
+        self.fields = tuple(fields)
+
+        self._token_indexers = token_indexers
+        self._lemma_indexers = lemma_indexers
+        self._targets = targets
+        self._features = features
+        self.generate_labels = True
+        # Filter out not required token indexers to avoid
+        # Mismatched token keys ConfigurationError
+        for indexer_name in list(self._token_indexers.keys()):
+            if indexer_name not in self._features:
+                del self._token_indexers[indexer_name]
+
+    @overrides
+    def _read(self, file_path: str) -> Iterable[allen_data.Instance]:
+        file_path = [file_path] if len(file_path.split(':')) == 0 else file_path.split(':')
+
+        for conllu_file in file_path:
+            with open(conllu_file, 'r') as f:
+                for annotation in conllu.parse_incr(f, fields=self.fields, field_parsers=self.field_parsers):
+                    # CoNLLU annotations sometimes add back in words that have been elided
+                    # in the original sentence; we remove these, as we're just predicting
+                    # dependencies for the original sentence.
+                    # We filter by integers here as elided words have a non-integer word id,
+                    # as parsed by the conllu python library.
+                    annotation = conllu.TokenList([x for x in annotation if isinstance(x['id'], int)])
+                    yield self.text_to_instance(annotation)
+
+    @overrides
+    def text_to_instance(self, tree: conllu.TokenList) -> allen_data.Instance:
+        fields_: Dict[str, allen_data.Field] = {}
+        tokens = [allen_data.Token(t['token'],
+                                   pos_=t.get('upostag'),
+                                   tag_=t.get('xpostag'),
+                                   lemma_=t.get('lemma'))
+                  for t in tree]
+
+        # features
+        text_field = allen_fields.TextField(tokens, self._token_indexers)
+        fields_['sentence'] = text_field
+
+        # targets
+        if self.generate_labels:
+            for target_name in self._targets:
+                if target_name != 'sent':
+                    target_values = [t[target_name] for t in tree.tokens]
+                    if target_name == 'lemma':
+                        target_values = [allen_data.Token(v) for v in target_values]
+                        fields_[target_name] = allen_fields.TextField(target_values, self._lemma_indexers)
+                    elif target_name == 'feats':
+                        target_values = self._feat_values(tree)
+                        fields_[target_name] = fields.SequenceMultiLabelField(target_values,
+                                                                              self._feats_to_index_multi_label,
+                                                                              text_field,
+                                                                              label_namespace="feats_labels")
+                    elif target_name == 'head':
+                        target_values = [0 if v == '_' else int(v) for v in target_values]
+                        fields_[target_name] = allen_fields.SequenceLabelField(target_values, text_field,
+                                                                               label_namespace=target_name + "_labels")
+                    else:
+                        fields_[target_name] = allen_fields.SequenceLabelField(target_values, text_field,
+                                                                               label_namespace=target_name + "_labels")
+
+        # metadata
+        fields_["metadata"] = allen_fields.MetadataField({'input': tree, 'field_names': self.fields})
+        return allen_data.Instance(fields_)
+
+    @staticmethod
+    def _feat_values(tree: conllu.TokenList):
+        features = []
+        for token in tree:
+            token_features = []
+            if token['feats'] is not None:
+                for feat, value in token['feats'].items():
+                    if feat in ['_', '__ROOT__']:
+                        pass
+                    else:
+                        token_features.append(feat + '=' + value)
+            features.append(token_features)
+        return features
+
+    @staticmethod
+    def _feats_to_index_multi_label(vocab: allen_data.Vocabulary):
+        label_namespace = "feats_labels"
+        vocab_size = vocab.get_vocab_size(label_namespace)
+        slices = get_slices_if_not_provided(vocab)
+
+        def _m_from_n_ones_encoding(multi_label: List[str]) -> List[int]:
+            one_hot_encoding = [0] * vocab_size
+            for cat, cat_indices in slices.items():
+                if cat not in ['__PAD__', '_']:
+                    label_from_cat = [label for label in multi_label if cat == label.split('=')[0]]
+                    if label_from_cat:
+                        label_from_cat = label_from_cat[0]
+                        index = vocab.get_token_index(label_from_cat, label_namespace)
+                    else:
+                        # Get Cat=None index
+                        index = vocab.get_token_index(cat + "=None", label_namespace)
+                    one_hot_encoding[index] = 1
+            return one_hot_encoding
+
+        return _m_from_n_ones_encoding
+
+
+@allen_data.Vocabulary.register('from_instances_extended', constructor='from_instances_extended')
+class Vocabulary(allen_data.Vocabulary):
+
+    @classmethod
+    def from_instances_extended(
+            cls,
+            instances: Iterable[allen_data.Instance],
+            min_count: Dict[str, int] = None,
+            max_vocab_size: Union[int, Dict[str, int]] = None,
+            non_padded_namespaces: Iterable[str] = vocabulary.DEFAULT_NON_PADDED_NAMESPACES,
+            pretrained_files: Optional[Dict[str, str]] = None,
+            only_include_pretrained_words: bool = False,
+            min_pretrained_embeddings: Dict[str, int] = None,
+            padding_token: Optional[str] = vocabulary.DEFAULT_PADDING_TOKEN,
+            oov_token: Optional[str] = vocabulary.DEFAULT_OOV_TOKEN,
+    ) -> "Vocabulary":
+        """
+        Extension to manually fill gaps in missing 'feats_labels'.
+        """
+        # Load manually tokens from pretrained file (using different strategy
+        # - only words add all embedding file, without checking if were seen
+        # in any dataset.
+        tokens_to_add = None
+        if pretrained_files and 'tokens' in pretrained_files:
+            pretrained_set = set(vocabulary._read_pretrained_tokens(pretrained_files['tokens']))
+            tokens_to_add = {'tokens': list(pretrained_set)}
+            pretrained_files = None
+
+        vocab = super().from_instances(
+            instances=instances,
+            min_count=min_count,
+            max_vocab_size=max_vocab_size,
+            non_padded_namespaces=non_padded_namespaces,
+            pretrained_files=pretrained_files,
+            only_include_pretrained_words=only_include_pretrained_words,
+            tokens_to_add=tokens_to_add,
+            min_pretrained_embeddings=min_pretrained_embeddings,
+            padding_token=padding_token,
+            oov_token=oov_token
+        )
+        # Extending vocab with features that does not show up explicitly.
+        # To know all features we need to read full dataset first.
+        # Adding auxiliary '=None' feature for each category is needed
+        # to perform classification.
+        get_slices_if_not_provided(vocab)
+        return vocab
+
+
+def get_slices_if_not_provided(vocab: allen_data.Vocabulary):
+    if hasattr(vocab, 'slices'):
+        return vocab.slices
+
+    if 'feats_labels' in vocab.get_namespaces():
+        idx2token = vocab.get_index_to_token_vocabulary('feats_labels')
+        for k, v in dict(idx2token).items():
+            if v not in ['_', '__PAD__']:
+                empty_value = v.split("=")[0] + "=None"
+                vocab.add_token_to_namespace(empty_value, 'feats_labels')
+
+        slices = {}
+        for idx, name in vocab.get_index_to_token_vocabulary('feats_labels').items():
+            # There are 2 types features: with (Case=Acc) or without assigment (None).
+            # Here we group their indices by name (before assigment sign).
+            name = name.split('=')[0]
+            if name in slices:
+                slices[name].append(idx)
+            else:
+                slices[name] = [idx]
+        vocab.slices = slices
+        return vocab.slices
diff --git a/combo/data/fields/__init__.py b/combo/data/fields/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f8d3a265f506aab43e0cda78ec931cd5c44c92e
--- /dev/null
+++ b/combo/data/fields/__init__.py
@@ -0,0 +1 @@
+from .sequence_multilabel_field import SequenceMultiLabelField
diff --git a/combo/data/fields/sequence_multilabel_field.py b/combo/data/fields/sequence_multilabel_field.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e98a148aee35e42af0b4828a031368fe0eafc12
--- /dev/null
+++ b/combo/data/fields/sequence_multilabel_field.py
@@ -0,0 +1,138 @@
+"""Sequence multilabel field implementation."""
+import logging
+import textwrap
+from typing import Set, List, Callable, Iterator, Union, Dict
+
+import torch
+from allennlp import data
+from allennlp.common import checks, util
+from allennlp.data import fields
+from overrides import overrides
+
+logger = logging.getLogger(__name__)
+
+
+class SequenceMultiLabelField(data.Field[torch.Tensor]):
+    """
+    A `SequenceMultiLabelField` is an extension of the :class:`MultiLabelField` that allows for multiple labels
+    while keeping sequence dimension.
+
+    This field will get converted into a sequence of vectors of length equal to the vocabulary size with
+    M from N encoding for the labels (all zeros, and ones for the labels).
+
+    # Parameters
+
+    multi_labels : `List[List[str]]`
+    multi_label_indexer : `Callable[[data.Vocabulary], Callable[[List[str]], List[int]]]`
+        Nested callable which based on vocab creates mapper for multilabel field in the sequence from strings
+        to indexed, int values.
+    sequence_field : `SequenceField`
+        A field containing the sequence that this `SequenceMultiLabelField` is labeling.  Most often, this is a
+        `TextField`, for tagging individual tokens in a sentence.
+    label_namespace : `str`, optional (default="labels")
+        The namespace to use for converting label strings into integers.  We map label strings to
+        integers for you (e.g., "entailment" and "contradiction" get converted to 0, 1, ...),
+        and this namespace tells the `Vocabulary` object which mapping from strings to integers
+        to use (so "entailment" as a label doesn't get the same integer id as "entailment" as a
+        word).  If you have multiple different label fields in your data, you should make sure you
+        use different namespaces for each one, always using the suffix "labels" (e.g.,
+        "passage_labels" and "question_labels").
+    """
+    _already_warned_namespaces: Set[str] = set()
+
+    def __init__(
+            self,
+            multi_labels: List[List[str]],
+            multi_label_indexer: Callable[[data.Vocabulary], Callable[[List[str]], List[int]]],
+            sequence_field: fields.SequenceField,
+            label_namespace: str = "labels",
+    ) -> None:
+        self.multi_labels = multi_labels
+        self.sequence_field = sequence_field
+        self.multi_label_indexer = multi_label_indexer
+        self._label_namespace = label_namespace
+        self._indexed_multi_labels = None
+        self._maybe_warn_for_namespace(label_namespace)
+        if len(multi_labels) != sequence_field.sequence_length():
+            raise checks.ConfigurationError(
+                "Label length and sequence length "
+                "don't match: %d and %d" % (len(multi_labels), sequence_field.sequence_length())
+            )
+
+        if not all([isinstance(x, str) for multi_label in multi_labels for x in multi_label]):
+            raise checks.ConfigurationError(
+                "SequenceMultiLabelField must be passed either all "
+                "strings or all ints. Found labels {} with "
+                "types: {}.".format(multi_labels, [type(x) for multi_label in multi_labels for x in multi_label])
+            )
+
+    def _maybe_warn_for_namespace(self, label_namespace: str) -> None:
+        if not (self._label_namespace.endswith("labels") or self._label_namespace.endswith("tags")):
+            if label_namespace not in self._already_warned_namespaces:
+                logger.warning(
+                    "Your label namespace was '%s'. We recommend you use a namespace "
+                    "ending with 'labels' or 'tags', so we don't add UNK and PAD tokens by "
+                    "default to your vocabulary.  See documentation for "
+                    "`non_padded_namespaces` parameter in Vocabulary.",
+                    self._label_namespace,
+                )
+                self._already_warned_namespaces.add(label_namespace)
+
+    # Sequence methods
+    def __iter__(self) -> Iterator[Union[List[str], int]]:
+        return iter(self.multi_labels)
+
+    def __getitem__(self, idx: int) -> Union[List[str], int]:
+        return self.multi_labels[idx]
+
+    def __len__(self) -> int:
+        return len(self.multi_labels)
+
+    @overrides
+    def count_vocab_items(self, counter: Dict[str, Dict[str, int]]):
+        if self._indexed_multi_labels is None:
+            for multi_label in self.multi_labels:
+                for label in multi_label:
+                    counter[self._label_namespace][label] += 1  # type: ignore
+
+    @overrides
+    def index(self, vocab: data.Vocabulary):
+        indexer = self.multi_label_indexer(vocab)
+
+        indexed = []
+        for multi_label in self.multi_labels:
+            indexed.append(indexer(multi_label))
+        self._indexed_multi_labels = indexed
+
+    @overrides
+    def get_padding_lengths(self) -> Dict[str, int]:
+        return {"num_tokens": self.sequence_field.sequence_length()}
+
+    @overrides
+    def as_tensor(self, padding_lengths: Dict[str, int]) -> torch.Tensor:
+        desired_num_tokens = padding_lengths["num_tokens"]
+        assert len(self._indexed_multi_labels) > 0
+        classes_count = len(self._indexed_multi_labels[0])
+        default_value = [0.0] * classes_count
+        padded_tags = util.pad_sequence_to_length(self._indexed_multi_labels, desired_num_tokens, lambda: default_value)
+        tensor = torch.LongTensor(padded_tags)
+        return tensor
+
+    @overrides
+    def empty_field(self) -> "SequenceMultiLabelField":
+        # The empty_list here is needed for mypy
+        empty_list: List[List[str]] = [[]]
+        sequence_label_field = SequenceMultiLabelField(empty_list, lambda x: lambda y: y,
+                                                       self.sequence_field.empty_field())
+        sequence_label_field._indexed_labels = empty_list
+        return sequence_label_field
+
+    def __str__(self) -> str:
+        length = self.sequence_field.sequence_length()
+        formatted_labels = "".join(
+            "\t\t" + labels + "\n" for labels in textwrap.wrap(repr(self.multi_labels), 100)
+        )
+        return (
+            f"SequenceMultiLabelField of length {length} with "
+            f"labels:\n {formatted_labels} \t\tin namespace: '{self._label_namespace}'."
+        )
diff --git a/combo/data/samplers/__init__.py b/combo/data/samplers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab003fe4341da38983553324b0454681a715eda1
--- /dev/null
+++ b/combo/data/samplers/__init__.py
@@ -0,0 +1 @@
+from .samplers import TokenCountBatchSampler
diff --git a/combo/data/samplers/samplers.py b/combo/data/samplers/samplers.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3c86f4953d59efcc4b6c5246ce16242a48f3f77
--- /dev/null
+++ b/combo/data/samplers/samplers.py
@@ -0,0 +1,54 @@
+from typing import List
+
+import numpy as np
+
+from allennlp import data as allen_data
+
+
+@allen_data.BatchSampler.register("token_count")
+class TokenCountBatchSampler(allen_data.BatchSampler):
+
+    def __init__(self, dataset, word_batch_size: int = 2500, shuffle_dataset: bool = True):
+        self._index = 0
+        self.shuffle_dataset = shuffle_dataset
+        self.batch_dataset = self._batchify(dataset, word_batch_size)
+        if shuffle_dataset:
+            self._shuffle()
+
+    @staticmethod
+    def _batchify(dataset, word_batch_size) -> List[List[int]]:
+        dataset = list(dataset)
+        batches = []
+        batch = []
+        words_count = 0
+        lengths = [len(instance.fields['sentence'].tokens) for instance in dataset]
+        argsorted_lengths = np.argsort(lengths)
+        for idx in argsorted_lengths:
+            words_count += lengths[idx]
+            batch.append(idx)
+            if words_count > word_batch_size:
+                batches.append(batch)
+                words_count = 0
+                batch = []
+        return batches
+
+    def __iter__(self):
+        return self
+
+    def __next__(self):
+        if self._index >= len(self.batch_dataset):
+            if self.shuffle_dataset:
+                self._index = 0
+                self._shuffle()
+            raise StopIteration()
+
+        batch = self.batch_dataset[self._index]
+        self._index += 1
+        return batch
+
+    def _shuffle(self):
+        indices = np.random.permutation(range(len(self.batch_dataset)))
+        self.batch_dataset = np.array(self.batch_dataset)[indices].tolist()
+
+    def __len__(self):
+        return len(self.batch_dataset)
diff --git a/combo/data/token_indexers/__init__.py b/combo/data/token_indexers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b2a7b7510477a68af50c60d40cc21eef3d863bc
--- /dev/null
+++ b/combo/data/token_indexers/__init__.py
@@ -0,0 +1 @@
+from .token_characters_indexer import TokenCharactersIndexer
diff --git a/combo/data/token_indexers/token_characters_indexer.py b/combo/data/token_indexers/token_characters_indexer.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea7a3eaab32a79758573fe6f778c59060d9bce6c
--- /dev/null
+++ b/combo/data/token_indexers/token_characters_indexer.py
@@ -0,0 +1,62 @@
+"""Custom character token indexer."""
+import itertools
+from typing import List, Dict
+
+import torch
+from allennlp import data
+from allennlp.common import util
+from allennlp.data import tokenizers
+from allennlp.data.token_indexers import token_characters_indexer
+from overrides import overrides
+
+
+@data.TokenIndexer.register("characters_const_padding")
+class TokenCharactersIndexer(token_characters_indexer.TokenCharactersIndexer):
+    """Wrapper around allennlp token indexer with const padding."""
+
+    def __init__(self,
+                 namespace: str = "token_characters",
+                 character_tokenizer: tokenizers.CharacterTokenizer = tokenizers.CharacterTokenizer(),
+                 start_tokens: List[str] = None,
+                 end_tokens: List[str] = None,
+                 min_padding_length: int = 0,
+                 token_min_padding_length: int = 0):
+        super().__init__(namespace, character_tokenizer, start_tokens, end_tokens, min_padding_length,
+                         token_min_padding_length)
+
+    @overrides
+    def get_padding_lengths(self, indexed_tokens: data.IndexedTokenList) -> Dict[str, int]:
+        padding_lengths = {"token_characters": len(indexed_tokens["token_characters"]),
+                           "num_token_characters": self._min_padding_length}
+        return padding_lengths
+
+    @overrides
+    def as_padded_tensor_dict(
+            self, tokens: data.IndexedTokenList, padding_lengths: Dict[str, int]
+    ) -> Dict[str, torch.Tensor]:
+        # Pad the tokens.
+        padded_tokens = util.pad_sequence_to_length(
+            tokens["token_characters"],
+            padding_lengths["token_characters"],
+            default_value=lambda: [],
+        )
+
+        # Pad the characters within the tokens.
+        desired_token_length = padding_lengths["num_token_characters"]
+        longest_token: List[int] = max(tokens["token_characters"], key=len, default=[])  # type: ignore
+        padding_value = 0
+        if desired_token_length > len(longest_token):
+            # Since we want to pad to greater than the longest token, we add a
+            # "dummy token" so we can take advantage of the fast implementation of itertools.zip_longest.
+            padded_tokens.append([padding_value] * desired_token_length)
+        # pad the list of lists to the longest sublist, appending 0's
+        padded_tokens = list(zip(*itertools.zip_longest(*padded_tokens, fillvalue=padding_value)))
+        if desired_token_length > len(longest_token):
+            # Removes the "dummy token".
+            padded_tokens.pop()
+        # Truncates all the tokens to the desired length, and return the result.
+        return {
+            "token_characters": torch.LongTensor(
+                [list(token[:desired_token_length]) for token in padded_tokens]
+            )
+        }
diff --git a/combo/main.py b/combo/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f37661ac42fb1ea34aea756da1dc966ec662909
--- /dev/null
+++ b/combo/main.py
@@ -0,0 +1,189 @@
+"""Main entry point."""
+import logging
+import os
+import tempfile
+from typing import Dict
+
+import torch
+from absl import app
+from absl import flags
+from allennlp import common, models, predictors
+from allennlp.commands import train, predict as allen_predict
+from allennlp.common import checks as allen_checks, util
+from allennlp.models import archival
+
+from combo import predict
+from combo.data import dataset
+from combo.utils import checks
+
+logger = logging.getLogger(__name__)
+
+FLAGS = flags.FLAGS
+flags.DEFINE_enum(name='mode', default=None, enum_values=['train', 'predict'],
+                  help="Specify COMBO mode: train or precit")
+
+# Common flags
+flags.DEFINE_integer(name='cuda_device', default=-1,
+                     help="Cuda device id (default -1 cpu)")
+
+# Training flags
+flags.DEFINE_string(name='training_data_path', default="./tests/fixtures/example.conllu",
+                    help='Training data path')
+flags.DEFINE_string(name='validation_data_path', default='',
+                    help='Validation data path')
+flags.DEFINE_string(name='pretrained_tokens', default='',
+                    help='Pretrained tokens embeddings path')
+flags.DEFINE_integer(name='embedding_dim', default=300,
+                     help='Embeddings dim')
+flags.DEFINE_string(name='pretrained_transformer_name', default='',
+                    help='Pretrained transformer model name (see transformers from HuggingFace library for list of'
+                         'available models) for transformers based embeddings.')
+flags.DEFINE_multi_enum(name='features', default=['token', 'char'],
+                        enum_values=['token', 'char', 'upostag', 'xpostag', 'lemma'],
+                        help='Features used to train model (required \'token\' and \'char\')')
+flags.DEFINE_multi_enum(name='targets', default=['deprel', 'feats', 'head', 'lemma', 'upostag', 'xpostag'],
+                        enum_values=['deprel', 'feats', 'head', 'lemma', 'upostag', 'xpostag', 'semrel', 'sent'],
+                        help='Targets of the model (required \'deprel\' and \'head\')')
+flags.DEFINE_string(name='serialization_dir', default=None,
+                    help='Model serialization directory (default - system temp dir).')
+
+# Finetune after training flags
+flags.DEFINE_string(name='finetuning_training_data_path', default='',
+                    help='Training data path')
+flags.DEFINE_string(name='finetuning_validation_data_path', default='',
+                    help='Validation data path')
+flags.DEFINE_string(name='config_path', default='config.template.jsonnet',
+                    help='Config file path.')
+
+# Test after training flags
+flags.DEFINE_string(name='result', default='result.conll',
+                    help='Test result path file')
+flags.DEFINE_string(name='test_path', default=None,
+                    help='Test path file.')
+
+# Experimental
+flags.DEFINE_boolean(name='use_pure_config', default=False,
+                     help='Ignore ext flags (experimental).')
+
+# Prediction flags
+flags.DEFINE_string(name='model_path', default=None,
+                    help='Pretrained model path.')
+flags.DEFINE_string(name='input_file', default=None,
+                    help='File to predict path')
+flags.DEFINE_string(name='output_file', default='output.log',
+                    help='Predictions result file.')
+flags.DEFINE_integer(name='batch_size', default=1,
+                     help='Prediction batch size.')
+flags.DEFINE_boolean(name='silent', default=True,
+                     help='Silent prediction to file (without printing to console).')
+flags.DEFINE_enum(name='predictor_name', default='semantic-multitask-predictor-spacy',
+                  enum_values=['semantic-multitask-predictor', 'semantic-multitask-predictor-spacy'],
+                  help='Use predictor with whitespace or spacy tokenizer.')
+
+
+def run(_):
+    """Run model."""
+    # Imports are required to make Registrable modules visible without passing parameter
+    util.import_module_and_submodules('combo.commands')
+    util.import_module_and_submodules('combo.models')
+    util.import_module_and_submodules('combo.training')
+
+    if FLAGS.mode == 'train':
+        checks.file_exists(FLAGS.config_path)
+        params = common.Params.from_file(FLAGS.config_path, ext_vars=_get_ext_vars())
+        model_params = params.get('model').as_ordered_dict()
+        serialization_dir = tempfile.mkdtemp(prefix='allennlp', dir=FLAGS.serialization_dir)
+        model = train.train_model(params, serialization_dir=serialization_dir, file_friendly_logging=True)
+        logger.info(f'Training model stored in: {serialization_dir}')
+
+        if FLAGS.finetuning_training_data_path:
+            checks.file_exists(FLAGS.finetuning_training_data_path)
+
+            # Loading will be performed from stored model.tar.gz
+            del model
+            if torch.cuda.is_available():
+                torch.cuda.empty_cache()
+
+            params = common.Params.from_file(FLAGS.config_path, ext_vars=_get_ext_vars(finetuning=True))
+            # Replace model definition with pretrained archive
+            params['model'] = {
+                'type': 'from_archive',
+                'archive_file': serialization_dir + '/model.tar.gz',
+            }
+            serialization_dir = tempfile.mkdtemp(prefix='allennlp', suffix='-finetuning', dir=FLAGS.serialization_dir)
+            model = train.train_model(params.duplicate(), serialization_dir=serialization_dir,
+                                      file_friendly_logging=True)
+
+            # Make finetuning model serialization independent from training serialization
+            # Storing model definition instead of archive
+            params['model'] = model_params
+            params.to_file(os.path.join(serialization_dir, archival.CONFIG_NAME))
+            archival.archive_model(serialization_dir)
+
+            logger.info(f'Finetuned model stored in: {serialization_dir}')
+
+        if FLAGS.test_path and FLAGS.result:
+            checks.file_exists(FLAGS.test_path)
+            params = common.Params.from_file(FLAGS.config_path)['dataset_reader']
+            params.pop('type')
+            dataset_reader = dataset.UniversalDependenciesDatasetReader.from_params(params)
+            predictor = predict.SemanticMultitaskPredictor(
+                model=model,
+                dataset_reader=dataset_reader
+            )
+            test_path = FLAGS.test_path
+            test_trees = dataset_reader.read(test_path)
+            with open(FLAGS.result, 'w') as f:
+                for tree in test_trees:
+                    f.writelines(predictor.predict_instance_as_tree(tree).serialize())
+    else:
+        use_dataset_reader = ".conllu" in FLAGS.input_file.lower()
+        manager = allen_predict._PredictManager(
+            _get_predictor(),
+            FLAGS.input_file,
+            FLAGS.output_file,
+            FLAGS.batch_size,
+            FLAGS.silent,
+            use_dataset_reader,
+        )
+        manager.run()
+
+
+def _get_predictor() -> predictors.Predictor:
+    allen_checks.check_for_gpu(FLAGS.cuda_device)
+    checks.file_exists(FLAGS.model_path)
+    archive = models.load_archive(
+        FLAGS.model_path,
+        cuda_device=FLAGS.cuda_device,
+    )
+
+    return predictors.Predictor.from_archive(
+        archive, FLAGS.predictor_name
+    )
+
+
+def _get_ext_vars(finetuning: bool = False) -> Dict:
+    if FLAGS.use_pure_config:
+        return {}
+    else:
+        return {
+            'training_data_path': FLAGS.training_data_path if not finetuning else FLAGS.finetuning_training_data_path,
+            'validation_data_path': (
+                FLAGS.validation_data_path if not finetuning else FLAGS.finetuning_validation_data_path),
+            'pretrained_tokens': FLAGS.pretrained_tokens,
+            'pretrained_transformer_name': FLAGS.pretrained_transformer_name,
+            'features': ' '.join(FLAGS.features),
+            'targets': ' '.join(FLAGS.targets),
+            'type': 'finetuning' if finetuning else 'default',
+            'embedding_dim': str(FLAGS.embedding_dim),
+        }
+
+
+def main():
+    """Parse flags."""
+    flags.mark_flag_as_required('mode')
+    app.run(run)
+
+
+if __name__ == '__main__':
+    main()
diff --git a/combo/models/__init__.py b/combo/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5aa7b283c70af76a27d97f63783368bdbe4ffa3f
--- /dev/null
+++ b/combo/models/__init__.py
@@ -0,0 +1,8 @@
+"""Models module."""
+from .base import FeedForwardPredictor
+from .parser import DependencyRelationModel
+from .embeddings import CharacterBasedWordEmbeddings
+from .encoder import ComboEncoder
+from .lemma import LemmatizerModel
+from .model import SemanticMultitaskModel
+from .morpho import MorphologicalFeatures
diff --git a/combo/models/base.py b/combo/models/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc49f90dc02dc0553cb536b0e09dc2a510866069
--- /dev/null
+++ b/combo/models/base.py
@@ -0,0 +1,114 @@
+from typing import Dict, Optional, List, Union
+
+import torch
+import torch.nn as nn
+from allennlp import common, data
+from allennlp import nn as allen_nn
+from allennlp.common import checks
+from allennlp.modules import feedforward
+from allennlp.nn import Activation
+
+from combo.models import utils
+
+
+class Predictor(nn.Module, common.Registrable):
+
+    default_implementation = 'feedforward_predictor_from_vocab'
+
+    def forward(self,
+                x: Union[torch.Tensor, List[torch.Tensor]],
+                mask: Optional[torch.BoolTensor] = None,
+                labels: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
+                sample_weights: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]:
+        raise NotImplementedError()
+
+
+class Linear(nn.Linear, common.FromParams):
+
+    def __init__(self,
+                 in_features: int,
+                 out_features: int,
+                 activation: Optional[allen_nn.Activation] = lambda x: x,
+                 dropout_rate: Optional[float] = 0.0):
+        super().__init__(in_features, out_features)
+        self.activation = activation
+        self.dropout = nn.Dropout(p=dropout_rate) if dropout_rate else lambda x: x
+
+    def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
+        x = super().forward(x)
+        x = self.activation(x)
+        return self.dropout(x)
+
+    def get_output_dim(self) -> int:
+        return self.out_features
+
+
+@Predictor.register('feedforward_predictor')
+@Predictor.register('feedforward_predictor_from_vocab', constructor='from_vocab')
+class FeedForwardPredictor(Predictor):
+    """Feedforward predictor. Should be used on top of Seq2Seq encoder."""
+
+    def __init__(self, feedforward_network: feedforward.FeedForward):
+        super().__init__()
+        self.feedforward_network = feedforward_network
+
+    def forward(self,
+                x: Union[torch.Tensor, List[torch.Tensor]],
+                mask: Optional[torch.BoolTensor] = None,
+                labels: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
+                sample_weights: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]:
+        if mask is None:
+            mask = x.new_ones(x.size()[:-1])
+
+        x = self.feedforward_network(x)
+        output = {
+            'prediction': x.argmax(-1),
+            'probability': x
+        }
+
+        if labels is not None:
+            if sample_weights is None:
+                sample_weights = labels.new_ones([mask.size(0)])
+            output['loss'] = self._loss(x, labels, mask, sample_weights)
+
+        return output
+
+    def _loss(self,
+              pred: torch.Tensor,
+              true: torch.Tensor,
+              mask: torch.BoolTensor,
+              sample_weights: torch.Tensor) -> torch.Tensor:
+        BATCH_SIZE, _, CLASSES = pred.size()
+        valid_positions = mask.sum()
+        pred = pred.reshape(-1, CLASSES)
+        true = true.reshape(-1)
+        mask = mask.reshape(-1)
+        loss = utils.masked_cross_entropy(pred, true, mask) * mask
+        loss = loss.reshape(BATCH_SIZE, -1) * sample_weights.unsqueeze(-1)
+        return loss.sum() / valid_positions
+
+    @classmethod
+    def from_vocab(cls,
+                   vocab: data.Vocabulary,
+                   vocab_namespace: str,
+                   input_dim: int,
+                   num_layers: int,
+                   hidden_dims: List[int],
+                   activations: Union[Activation, List[Activation]],
+                   dropout: Union[float, List[float]] = 0.0,
+                   ):
+        if len(hidden_dims) + 1 != num_layers:
+            raise checks.ConfigurationError(
+                "len(hidden_dims) (%d) + 1 != num_layers (%d)" % (len(hidden_dims), num_layers)
+            )
+
+        assert vocab_namespace in vocab.get_namespaces()
+        hidden_dims = hidden_dims + [vocab.get_vocab_size(vocab_namespace)]
+
+        return cls(feedforward.FeedForward(
+            input_dim=input_dim,
+            num_layers=num_layers,
+            hidden_dims=hidden_dims,
+            activations=activations,
+            dropout=dropout)
+        )
diff --git a/combo/models/dilated_cnn.py b/combo/models/dilated_cnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..59d25f3e3dbd76ec29d4c5f704f851469713c44d
--- /dev/null
+++ b/combo/models/dilated_cnn.py
@@ -0,0 +1,39 @@
+from typing import List
+
+import torch
+import torch.nn as nn
+
+from allennlp import common
+from allennlp import nn as allen_nn
+
+
+class DilatedCnnEncoder(nn.Module, common.FromParams):
+
+    def __init__(self,
+                 input_dim: int,
+                 filters: List[int],
+                 kernel_size: List[int],
+                 stride: List[int],
+                 padding: List[int],
+                 dilation: List[int],
+                 activations: List[allen_nn.Activation]):
+        super().__init__()
+        conv1d_layers = []
+        input_dims = [input_dim] + filters[:-1]
+        output_dims = filters
+        for idx in range(len(activations)):
+            conv1d_layers.append(nn.Conv1d(
+                in_channels=input_dims[idx],
+                out_channels=output_dims[idx],
+                kernel_size=kernel_size[idx],
+                stride=stride[idx],
+                padding=padding[idx],
+                dilation=dilation[idx]))
+        self.conv1d_layers = nn.ModuleList(conv1d_layers)
+        self.activations = activations
+        assert len(self.activations) == len(self.conv1d_layers)
+
+    def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
+        for layer, activation in zip(self.conv1d_layers, self.activations):
+            x = activation(layer(x))
+        return x
diff --git a/combo/models/embeddings.py b/combo/models/embeddings.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5bf26262fe2bf119d4ad89336ef2291508516a0
--- /dev/null
+++ b/combo/models/embeddings.py
@@ -0,0 +1,150 @@
+"""Embeddings."""
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from allennlp import nn as allen_nn, data
+from allennlp.modules import token_embedders
+from overrides import overrides
+from transformers import modeling_auto
+
+from combo.models import base, dilated_cnn
+
+
+@token_embedders.TokenEmbedder.register('char_embeddings')
+@token_embedders.TokenEmbedder.register('char_embeddings_from_config', constructor='from_config')
+class CharacterBasedWordEmbeddings(token_embedders.TokenEmbedder):
+    """Character-based word embeddings."""
+
+    def __init__(self,
+                 num_embeddings: int,
+                 embedding_dim: int,
+                 dilated_cnn_encoder: dilated_cnn.DilatedCnnEncoder):
+        super().__init__()
+        self.char_embed = nn.Embedding(
+            num_embeddings=num_embeddings,
+            embedding_dim=embedding_dim,
+        )
+        self.dilated_cnn_encoder = dilated_cnn_encoder
+        self.output_dim = embedding_dim
+
+    def forward(self,
+                x: torch.Tensor,
+                char_mask: Optional[torch.BoolTensor] = None) -> torch.Tensor:
+        if char_mask is None:
+            char_mask = x.new_ones(x.size())
+
+        x = self.char_embed(x)
+        x = x * char_mask.unsqueeze(-1).float()
+
+        BATCH_SIZE, SENTENCE_LENGTH, MAX_WORD_LENGTH, CHAR_EMB = x.size()
+
+        words = []
+        for i in range(SENTENCE_LENGTH):
+            word = x[:, i, :, :].reshape(BATCH_SIZE, MAX_WORD_LENGTH, CHAR_EMB).transpose(1, 2)
+            word = self.dilated_cnn_encoder(word)
+            word, _ = torch.max(word, dim=2)
+            words.append(word)
+        return torch.stack(words, dim=1)
+
+    @overrides
+    def get_output_dim(self) -> int:
+        return self.output_dim
+
+    @classmethod
+    def from_config(cls,
+                    embedding_dim: int,
+                    vocab: data.Vocabulary,
+                    dilated_cnn_encoder: dilated_cnn.DilatedCnnEncoder,
+                    vocab_namespace: str = 'token_characters'):
+        assert vocab_namespace in vocab.get_namespaces()
+        return cls(
+            embedding_dim=embedding_dim,
+            num_embeddings=vocab.get_vocab_size(vocab_namespace),
+            dilated_cnn_encoder=dilated_cnn_encoder
+        )
+
+
+@token_embedders.TokenEmbedder.register('embeddings_projected')
+class ProjectedWordEmbedder(token_embedders.Embedding):
+    """Word embeddings."""
+
+    def __init__(self,
+                 embedding_dim: int,
+                 num_embeddings: int = None,
+                 weight: torch.FloatTensor = None,
+                 padding_index: int = None,
+                 trainable: bool = True,
+                 max_norm: float = None,
+                 norm_type: float = 2.0,
+                 scale_grad_by_freq: bool = False,
+                 sparse: bool = False,
+                 vocab_namespace: str = "tokens",
+                 pretrained_file: str = None,
+                 vocab: data.Vocabulary = None,
+                 projection_layer: Optional[base.Linear] = None):
+        super().__init__(
+            embedding_dim=embedding_dim,
+            num_embeddings=num_embeddings,
+            weight=weight,
+            padding_index=padding_index,
+            trainable=trainable,
+            max_norm=max_norm,
+            norm_type=norm_type,
+            scale_grad_by_freq=scale_grad_by_freq,
+            sparse=sparse,
+            vocab_namespace=vocab_namespace,
+            pretrained_file=pretrained_file,
+            vocab=vocab
+        )
+        self._projection = projection_layer
+        self.output_dim = embedding_dim if projection_layer is None else projection_layer.out_features
+
+    @overrides
+    def get_output_dim(self) -> int:
+        return self.output_dim
+
+
+@token_embedders.TokenEmbedder.register('transformers_word_embeddings')
+class TransformersWordEmbedder(token_embedders.PretrainedTransformerMismatchedEmbedder):
+    """
+    Transformers word embeddings as last hidden state + optional projection layers.
+
+    Tested with Bert (but should work for other models as well).
+    """
+
+    def __init__(self,
+                 model_name: str,
+                 projection_dim: int,
+                 projection_activation: Optional[allen_nn.Activation] = lambda x: x,
+                 projection_dropout_rate: Optional[float] = 0.0):
+        super().__init__(model_name)
+        self.transformers_encoder = modeling_auto.AutoModel.from_pretrained(model_name)
+        self.output_dim = self.transformers_encoder.config.hidden_size
+        if projection_dim:
+            self.projection_layer = base.Linear(in_features=self.output_dim,
+                                                out_features=projection_dim,
+                                                dropout_rate=projection_dropout_rate,
+                                                activation=projection_activation)
+            self.output_dim = projection_dim
+        else:
+            self.projection_layer = None
+
+    def forward(
+        self,
+        token_ids: torch.LongTensor,
+        mask: torch.BoolTensor,
+        offsets: torch.LongTensor,
+        wordpiece_mask: torch.BoolTensor,
+        type_ids: Optional[torch.LongTensor] = None,
+        segment_concat_mask: Optional[torch.BoolTensor] = None,
+    ) -> torch.Tensor:
+        x = super().forward(token_ids=token_ids, mask=mask, offsets=offsets, wordpiece_mask=wordpiece_mask,
+                            type_ids=type_ids, segment_concat_mask=segment_concat_mask)
+        if self.projection_layer:
+            x = self.projection_layer(x)
+        return x
+
+    @overrides
+    def get_output_dim(self):
+        return self.output_dim
diff --git a/combo/models/encoder.py b/combo/models/encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..d13e7e5dd96bce0719e572e68bb4afc9a144ef3e
--- /dev/null
+++ b/combo/models/encoder.py
@@ -0,0 +1,76 @@
+"""Encoder."""
+from typing import Optional, Tuple
+
+import torch
+import torch.nn.utils.rnn as rnn
+from allennlp import common, modules
+from allennlp.modules import input_variational_dropout, stacked_bidirectional_lstm, seq2seq_encoders
+from overrides import overrides
+
+
+class StackedBiLSTM(stacked_bidirectional_lstm.StackedBidirectionalLstm, common.FromParams):
+
+    def __init__(self, input_size: int, hidden_size: int, num_layers: int, recurrent_dropout_probability: float,
+                 layer_dropout_probability: float, use_highway: bool = False):
+        super().__init__(input_size=input_size,
+                         hidden_size=hidden_size,
+                         num_layers=num_layers,
+                         recurrent_dropout_probability=recurrent_dropout_probability,
+                         layer_dropout_probability=layer_dropout_probability,
+                         use_highway=use_highway)
+
+    @overrides
+    def forward(self,
+                inputs: rnn.PackedSequence,
+                initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
+                ) -> Tuple[rnn.PackedSequence, Tuple[torch.Tensor, torch.Tensor]]:
+        """Changes when compared to stacked_bidirectional_lstm.StackedBidirectionalLstm
+        * dropout also on last layer
+        * accepts BxTxD tensor
+        * state from n-1 layer used as n layer initial state
+
+        :param inputs:
+        :param initial_state:
+        :return:
+        """
+        output_sequence = inputs
+        state_fwd = None
+        state_bwd = None
+        for i in range(self.num_layers):
+            forward_layer = getattr(self, 'forward_layer_{}'.format(i))
+            backward_layer = getattr(self, 'backward_layer_{}'.format(i))
+
+            forward_output, state_fwd = forward_layer(output_sequence, state_fwd)
+            backward_output, state_bwd = backward_layer(output_sequence, state_bwd)
+
+            forward_output, lengths = rnn.pad_packed_sequence(forward_output, batch_first=True)
+            backward_output, _ = rnn.pad_packed_sequence(backward_output, batch_first=True)
+
+            output_sequence = torch.cat([forward_output, backward_output], -1)
+
+            output_sequence = self.layer_dropout(output_sequence)
+            output_sequence = rnn.pack_padded_sequence(output_sequence, lengths, batch_first=True)
+
+        return output_sequence, (state_fwd, state_bwd)
+
+
+@modules.Seq2SeqEncoder.register('combo_encoder')
+class ComboEncoder(seq2seq_encoders.PytorchSeq2SeqWrapper):
+    """COMBO encoder (https://www.aclweb.org/anthology/K18-2004.pdf).
+
+    This implementation uses Variational Dropout on the input and then outputs of each BiLSTM layer
+    (instead of used Gaussian Dropout and Gaussian Noise).
+    """
+
+    def __init__(self, stacked_bilstm: StackedBiLSTM, layer_dropout_probability: float):
+        super().__init__(stacked_bilstm, stateful=False)
+        self.layer_dropout = input_variational_dropout.InputVariationalDropout(p=layer_dropout_probability)
+
+    @overrides
+    def forward(self,
+                inputs: torch.Tensor,
+                mask: torch.BoolTensor,
+                hidden_state: torch.Tensor = None) -> torch.Tensor:
+        x = self.layer_dropout(inputs)
+        x = super().forward(x, mask)
+        return self.layer_dropout(x)
diff --git a/combo/models/lemma.py b/combo/models/lemma.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ab4030245ceff6931a7bd13049ce92cc2a0ccb6
--- /dev/null
+++ b/combo/models/lemma.py
@@ -0,0 +1,116 @@
+"""Lemmatizer models."""
+from typing import Optional, Dict, List, Union
+
+import torch
+import torch.nn as nn
+from allennlp import data, nn as allen_nn
+from allennlp.common import checks
+
+from combo.models import base, dilated_cnn, utils
+
+
+@base.Predictor.register('combo_lemma_predictor_from_vocab', constructor='from_vocab')
+class LemmatizerModel(base.Predictor):
+    """Lemmatizer model."""
+
+    def __init__(self,
+                 num_embeddings: int,
+                 embedding_dim: int,
+                 dilated_cnn_encoder: dilated_cnn.DilatedCnnEncoder,
+                 input_projection_layer: base.Linear):
+        super().__init__()
+        self.char_embed = nn.Embedding(
+            num_embeddings=num_embeddings,
+            embedding_dim=embedding_dim,
+        )
+        self.dilated_cnn_encoder = dilated_cnn_encoder
+        self.input_projection_layer = input_projection_layer
+
+    def forward(self,
+                x: Union[torch.Tensor, List[torch.Tensor]],
+                mask: Optional[torch.BoolTensor] = None,
+                labels: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
+                sample_weights: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]:
+        encoder_emb, chars = x
+
+        encoder_emb = self.input_projection_layer(encoder_emb)
+        char_embeddings = self.char_embed(chars)
+
+        BATCH_SIZE, SENTENCE_LENGTH, WORD_EMB = encoder_emb.size()
+        _, _, MAX_WORD_LENGTH, CHAR_EMB = char_embeddings.size()
+
+
+        encoder_emb = encoder_emb.unsqueeze(2).repeat(1, 1, MAX_WORD_LENGTH, 1)
+
+        pred = []
+        for i in range(SENTENCE_LENGTH):
+            word_emb = (encoder_emb[:, i, :, :].reshape(BATCH_SIZE, MAX_WORD_LENGTH, -1))
+            char_sent_emb = char_embeddings[:, i, :].reshape(BATCH_SIZE, MAX_WORD_LENGTH, CHAR_EMB)
+            x = torch.cat((char_sent_emb, word_emb), -1).transpose(2, 1)
+            x = self.dilated_cnn_encoder(x)
+            pred.append(x)
+        x = torch.stack(pred, dim=1).transpose(2, 3)
+        output = {
+            'prediction': x.argmax(-1),
+            'probability': x
+        }
+
+        if labels is not None:
+            if mask is None:
+                mask = encoder_emb.new_ones(encoder_emb.size()[:-2])
+            if sample_weights is None:
+                sample_weights = labels.new_ones(BATCH_SIZE)
+            mask = mask.unsqueeze(2).repeat(1, 1, MAX_WORD_LENGTH).bool()
+            output['loss'] = self._loss(x, labels, mask, sample_weights)
+
+        return output
+
+    @staticmethod
+    def _loss(pred: torch.Tensor, true: torch.Tensor, mask: torch.BoolTensor,
+              sample_weights: torch.Tensor) -> torch.Tensor:
+        BATCH_SIZE, SENTENCE_LENGTH, MAX_WORD_LENGTH, CHAR_CLASSES = pred.size()
+        pred = pred.reshape(-1, CHAR_CLASSES)
+
+        valid_positions = mask.sum()
+        mask = mask.reshape(-1)
+        true = true.reshape(-1)
+        loss = utils.masked_cross_entropy(pred, true, mask) * mask
+        loss = loss.reshape(BATCH_SIZE, -1) * sample_weights.unsqueeze(-1)
+        return loss.sum() / valid_positions
+
+    @classmethod
+    def from_vocab(cls,
+                   vocab: data.Vocabulary,
+                   char_vocab_namespace: str,
+                   lemma_vocab_namespace: str,
+                   embedding_dim: int,
+                   input_projection_layer: base.Linear,
+                   filters: List[int],
+                   kernel_size: List[int],
+                   stride: List[int],
+                   padding: List[int],
+                   dilation: List[int],
+                   activations: List[allen_nn.Activation],
+                   ):
+        assert char_vocab_namespace in vocab.get_namespaces()
+        assert lemma_vocab_namespace in vocab.get_namespaces()
+
+        if len(filters) + 1 != len(kernel_size):
+            raise checks.ConfigurationError(
+                "len(filters) (%d) + 1 != kernel_size (%d)" % (len(filters), len(kernel_size))
+            )
+        filters = filters + [vocab.get_vocab_size(lemma_vocab_namespace)]
+
+        dilated_cnn_encoder = dilated_cnn.DilatedCnnEncoder(
+            input_dim=embedding_dim + input_projection_layer.get_output_dim(),
+            filters=filters,
+            kernel_size=kernel_size,
+            stride=stride,
+            padding=padding,
+            dilation=dilation,
+            activations=activations,
+        )
+        return cls(num_embeddings=vocab.get_vocab_size(char_vocab_namespace),
+                   embedding_dim=embedding_dim,
+                   dilated_cnn_encoder=dilated_cnn_encoder,
+                   input_projection_layer=input_projection_layer)
diff --git a/combo/models/model.py b/combo/models/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4a6249dd81b3b5d56688ce61dc7028755929ee6
--- /dev/null
+++ b/combo/models/model.py
@@ -0,0 +1,195 @@
+"""Main COMBO model."""
+from typing import Optional, Dict, Any, List
+
+import torch
+from allennlp import data, modules, models as allen_models, nn as allen_nn
+from allennlp.modules import text_field_embedders
+from allennlp.nn import util
+from overrides import overrides
+
+from combo.models import base
+from combo.utils import metrics
+
+
+@allen_models.Model.register('semantic_multitask')
+class SemanticMultitaskModel(allen_models.Model):
+    """Main COMBO model."""
+
+    def __init__(self,
+                 vocab: data.Vocabulary,
+                 loss_weights: Dict[str, float],
+                 text_field_embedder: text_field_embedders.TextFieldEmbedder,
+                 seq_encoder: modules.Seq2SeqEncoder,
+                 use_sample_weight: bool = True,
+                 lemmatizer: Optional[base.Predictor] = None,
+                 upos_tagger: Optional[base.Predictor] = None,
+                 xpos_tagger: Optional[base.Predictor] = None,
+                 semantic_relation: Optional[base.Predictor] = None,
+                 morphological_feat: Optional[base.Predictor] = None,
+                 dependency_relation: Optional[base.Predictor] = None,
+                 regularizer: allen_nn.RegularizerApplicator = None) -> None:
+        super().__init__(vocab, regularizer)
+        self.text_field_embedder = text_field_embedder
+        self.loss_weights = loss_weights
+        self.use_sample_weight = use_sample_weight
+        self.seq_encoder = seq_encoder
+        self.lemmatizer = lemmatizer
+        self.upos_tagger = upos_tagger
+        self.xpos_tagger = xpos_tagger
+        self.semantic_relation = semantic_relation
+        self.morphological_feat = morphological_feat
+        self.dependency_relation = dependency_relation
+        self._head_sentinel = torch.nn.Parameter(torch.randn([1, 1, self.seq_encoder.get_output_dim()]))
+        self.scores = metrics.SemanticMetrics()
+        self._partial_losses = None
+
+    @overrides
+    def forward(self,
+                sentence: Dict[str, Dict[str, torch.Tensor]],
+                metadata: List[Dict[str, Any]],
+                upostag: torch.Tensor = None,
+                xpostag: torch.Tensor = None,
+                lemma: Dict[str, Dict[str, torch.Tensor]] = None,
+                feats: torch.Tensor = None,
+                head: torch.Tensor = None,
+                deprel: torch.Tensor = None,
+                semrel: torch.Tensor = None,) -> Dict[str, torch.Tensor]:
+
+        # Prepare masks
+        char_mask: torch.BoolTensor = sentence['char']['token_characters'] > 0
+        word_mask = util.get_text_field_mask(sentence)
+
+        # If enabled weight samples loss by log(sentence_length)
+        sample_weights = word_mask.sum(-1).float().log() if self.use_sample_weight else None
+
+        encoder_input = self.text_field_embedder(sentence, char_mask=char_mask)
+        encoder_emb = self.seq_encoder(encoder_input, word_mask)
+
+        batch_size, _, encoding_dim = encoder_emb.size()
+
+        # Concatenate the head sentinel (ROOT) onto the sentence representation.
+        head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim)
+        encoder_emb = torch.cat([head_sentinel, encoder_emb], 1)
+        word_mask = torch.cat([word_mask.new_ones((batch_size, 1)), word_mask], 1)
+
+        upos_output = self._optional(self.upos_tagger,
+                                     encoder_emb[:, 1:],
+                                     mask=word_mask[:, 1:],
+                                     labels=upostag,
+                                     sample_weights=sample_weights)
+        xpos_output = self._optional(self.xpos_tagger,
+                                     encoder_emb[:, 1:],
+                                     mask=word_mask[:, 1:],
+                                     labels=xpostag,
+                                     sample_weights=sample_weights)
+        semrel_output = self._optional(self.semantic_relation,
+                                       encoder_emb[:, 1:],
+                                       mask=word_mask[:, 1:],
+                                       labels=semrel,
+                                       sample_weights=sample_weights)
+        morpho_output = self._optional(self.morphological_feat,
+                                       encoder_emb[:, 1:],
+                                       mask=word_mask[:, 1:],
+                                       labels=feats,
+                                       sample_weights=sample_weights)
+        lemma_output = self._optional(self.lemmatizer,
+                                      (encoder_emb[:, 1:], sentence.get('char').get('token_characters')
+                                       if sentence.get('char') else None),
+                                      mask=word_mask[:, 1:],
+                                      labels=lemma.get('char').get('token_characters') if lemma else None,
+                                      sample_weights=sample_weights)
+        parser_output = self._optional(self.dependency_relation,
+                                       encoder_emb,
+                                       returns_tuple=True,
+                                       mask=word_mask,
+                                       labels=(deprel, head),
+                                       sample_weights=sample_weights)
+        relations_pred, head_pred = parser_output['prediction']
+        output = {
+            'upostag': upos_output['prediction'],
+            'xpostag': xpos_output['prediction'],
+            'semrel': semrel_output['prediction'],
+            'feats': morpho_output['prediction'],
+            'lemma': lemma_output['prediction'],
+            'head': head_pred,
+            'deprel': relations_pred,
+            'sentence_embedding': torch.max(encoder_emb[:, 1:], dim=1)[0],
+        }
+
+        if self._has_labels([upostag, xpostag, lemma, feats, head, deprel, semrel]):
+
+            # Feats mapping
+            if self.morphological_feat:
+                mapped_gold_labels = []
+                for cat, cat_indices in self.morphological_feat.slices.items():
+                    mapped_gold_labels.append(feats[:, :, cat_indices].argmax(dim=-1))
+
+                feats = torch.stack(mapped_gold_labels, dim=-1)
+
+            labels = {
+                'upostag': upostag,
+                'xpostag': xpostag,
+                'semrel': semrel,
+                'feats': feats,
+                'lemma': lemma.get('char').get('token_characters') if lemma else None,
+                'head': head,
+                'deprel': deprel,
+            }
+            self.scores(output,  labels, word_mask[:, 1:])
+            relations_loss, head_loss = parser_output['loss']
+            losses = {
+                'upostag_loss': upos_output['loss'],
+                'xpostag_loss': xpos_output['loss'],
+                'semrel_loss': semrel_output['loss'],
+                'feats_loss': morpho_output['loss'],
+                'lemma_loss': lemma_output['loss'],
+                'head_loss': head_loss,
+                'deprel_loss': relations_loss,
+                # Cycle loss is only for the metrics purposes.
+                'cycle_loss': parser_output.get('cycle_loss')
+            }
+            self._partial_losses = losses.copy()
+            losses['loss'] = self._calculate_loss(losses)
+            output.update(losses)
+
+        return self._clean(output)
+
+    @staticmethod
+    def _has_labels(labels):
+        return any(x is not None for x in labels)
+
+    def _calculate_loss(self, output):
+        losses = []
+        for name, value in self.loss_weights.items():
+            if output.get(f'{name}_loss'):
+                losses.append(output[f'{name}_loss'] * value)
+        return torch.stack(losses).sum()
+
+    @staticmethod
+    def _optional(callable_model: Optional[torch.nn.Module],
+                  *args,
+                  returns_tuple: bool = False,
+                  **kwargs):
+        if callable_model:
+            return callable_model(*args, **kwargs)
+        else:
+            if returns_tuple:
+                return {'prediction': (None, None), 'loss': (None, None)}
+            else:
+                return {'prediction': None, 'loss': None}
+
+    @staticmethod
+    def _clean(output):
+        for k, v in dict(output).items():
+            if v is None:
+                del output[k]
+        return output
+
+    @overrides
+    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
+        metrics = self.scores.get_metric(reset)
+        if self._partial_losses:
+            losses = self._clean(self._partial_losses)
+            losses = {f"partial_loss/{k}": v.detach().item() for k, v in losses.items()}
+            metrics.update(losses)
+        return metrics
diff --git a/combo/models/morpho.py b/combo/models/morpho.py
new file mode 100644
index 0000000000000000000000000000000000000000..238d4284ce97d5508ccb6ef663a25c317efc77bf
--- /dev/null
+++ b/combo/models/morpho.py
@@ -0,0 +1,102 @@
+"""Morphological features models."""
+from typing import Dict, List, Optional, Union
+
+import torch
+from allennlp import data
+from allennlp.common import checks
+from allennlp.modules import feedforward
+from allennlp.nn import Activation
+
+from combo.data import dataset
+from combo.models import base, utils
+
+
+@base.Predictor.register('combo_morpho_from_vocab', constructor='from_vocab')
+class MorphologicalFeatures(base.Predictor):
+    """Morphological features predicting model."""
+
+    def __init__(self, feedforward_network: feedforward.FeedForward, slices: Dict[str, List[int]]):
+        super().__init__()
+        self.feedforward_network = feedforward_network
+        self.slices = slices
+
+    def forward(self,
+                x: Union[torch.Tensor, List[torch.Tensor]],
+                mask: Optional[torch.BoolTensor] = None,
+                labels: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
+                sample_weights: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]:
+        if mask is None:
+            mask = x.new_ones(x.size()[:-1])
+
+        x = self.feedforward_network(x)
+
+        prediction = []
+        for cat, cat_indices in self.slices.items():
+            prediction.append(x[:, :, cat_indices].argmax(dim=-1))
+
+        output = {
+            'prediction': torch.stack(prediction, dim=-1),
+            'probability': x
+        }
+
+        if labels is not None:
+            if sample_weights is None:
+                sample_weights = labels.new_ones([mask.size(0)])
+            output['loss'] = self._loss(x, labels, mask, sample_weights)
+
+        return output
+
+    def _loss(self, pred: torch.Tensor, true: torch.Tensor, mask: torch.BoolTensor,
+              sample_weights: torch.Tensor) -> torch.Tensor:
+        assert pred.size() == true.size()
+        BATCH_SIZE, SENTENCE_LENGTH, MORPHOLOGICAL_FEATURES = pred.size()
+
+        valid_positions = mask.sum()
+
+        pred = pred.reshape(-1, MORPHOLOGICAL_FEATURES)
+        true = true.reshape(-1, MORPHOLOGICAL_FEATURES)
+        mask = mask.reshape(-1)
+        loss = None
+        loss_func = utils.masked_cross_entropy
+        for cat, cat_indices in self.slices.items():
+            if cat not in ['__PAD__', '_']:
+                if loss is None:
+                    loss = loss_func(pred[:, cat_indices],
+                                     true[:, cat_indices].argmax(dim=1),
+                                     mask) * mask
+                else:
+                    loss += loss_func(pred[:, cat_indices],
+                                      true[:, cat_indices].argmax(dim=1),
+                                      mask) * mask
+        loss = loss.reshape(BATCH_SIZE, -1) * sample_weights.unsqueeze(-1)
+        return loss.sum() / valid_positions
+
+    @classmethod
+    def from_vocab(cls,
+                   vocab: data.Vocabulary,
+                   vocab_namespace: str,
+                   input_dim: int,
+                   num_layers: int,
+                   hidden_dims: List[int],
+                   activations: Union[Activation, List[Activation]],
+                   dropout: Union[float, List[float]] = 0.0,
+                   ):
+        if len(hidden_dims) + 1 != num_layers:
+            raise checks.ConfigurationError(
+                "len(hidden_dims) (%d) + 1 != num_layers (%d)" % (len(hidden_dims), num_layers)
+            )
+
+        assert vocab_namespace in vocab.get_namespaces()
+        hidden_dims = hidden_dims + [vocab.get_vocab_size(vocab_namespace)]
+
+        slices = dataset.get_slices_if_not_provided(vocab)
+
+        return cls(
+            feedforward_network=feedforward.FeedForward(
+                input_dim=input_dim,
+                num_layers=num_layers,
+                hidden_dims=hidden_dims,
+                activations=activations,
+                dropout=dropout),
+            slices=slices
+        )
diff --git a/combo/models/parser.py b/combo/models/parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..63f4c8af6641c9aca49ca683bbe257bd3a0b1b6f
--- /dev/null
+++ b/combo/models/parser.py
@@ -0,0 +1,187 @@
+"""Dependency parsing models."""
+from typing import Tuple, Dict, Optional, Union, List
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from allennlp import data
+from allennlp.nn import chu_liu_edmonds
+
+from combo.models import base, utils
+
+
+class HeadPredictionModel(base.Predictor):
+
+    def __init__(self,
+                 head_projection_layer: base.Linear,
+                 dependency_projection_layer: base.Linear,
+                 cycle_loss_n: int = 0):
+        super().__init__()
+        self.head_projection_layer = head_projection_layer
+        self.dependency_projection_layer = dependency_projection_layer
+        self.cycle_loss_n = cycle_loss_n
+
+    def forward(self,
+                x: Union[torch.Tensor, List[torch.Tensor]],
+                mask: Optional[torch.BoolTensor] = None,
+                labels: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
+                sample_weights: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]:
+        if mask is None:
+            mask = x.new_ones(x.size()[-1])
+
+        head_arc_emb = self.head_projection_layer(x)
+        dep_arc_emb = self.dependency_projection_layer(x)
+        x = dep_arc_emb.bmm(head_arc_emb.transpose(2, 1))
+
+        if self.training:
+            pred = x.argmax(-1)
+        else:
+            pred = []
+            # Adding non existing in mask ROOT to lengths
+            lengths = mask.data.sum(dim=1).long().cpu().numpy() + 1
+            for idx, length in enumerate(lengths):
+                probs = x[idx, :].softmax(dim=-1).cpu().numpy()
+                probs[:, 0] = 0
+                heads, _ = chu_liu_edmonds.decode_mst(probs.T, length=length, has_labels=False)
+                heads[0] = 0
+                pred.append(heads)
+            pred = torch.from_numpy(np.stack(pred)).to(x.device)
+
+        output = {
+            'prediction': pred[:, 1:],
+            'probability': x
+        }
+
+        if labels is not None:
+            if sample_weights is None:
+                sample_weights = labels.new_ones([mask.size(0)])
+            output['loss'], output['cycle_loss'] = self._loss(x, labels, mask, sample_weights)
+
+        return output
+
+    def _cycle_loss(self, pred: torch.Tensor):
+        BATCH_SIZE, SENTENCE_LENGTH, _ = pred.size()
+        loss = pred.new_zeros(BATCH_SIZE)
+        # 1: as using non __ROOT__ tokens
+        yn = pred[:, 1:, 1:]
+        for i in range(self.cycle_loss_n):
+            loss += self._batch_trace(yn) / BATCH_SIZE
+            yn = yn.bmm(pred[:, 1:, 1:])
+
+        return loss
+
+    def _batch_trace(self, x: torch.Tensor) -> torch.Tensor:
+        assert len(x.size()) == 3
+        BATCH_SIZE, N, M = x.size()
+        assert N == M
+        identity = x.new_tensor(torch.eye(N))
+        identity = identity.reshape((1, N, N))
+        batch_identity = identity.repeat(BATCH_SIZE, 1, 1)
+        return (x * batch_identity).sum()
+
+    def _loss(self, pred: torch.Tensor, true: torch.Tensor, mask: torch.BoolTensor,
+              sample_weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+        BATCH_SIZE, N, M = pred.size()
+        assert N == M
+        SENTENCE_LENGTH = N
+
+        valid_positions = mask.sum()
+
+        result = []
+        # Ignore first pred dimension as it is ROOT token prediction
+        for i in range(SENTENCE_LENGTH - 1):
+            pred_i = pred[:, i + 1, :].reshape(BATCH_SIZE, SENTENCE_LENGTH)
+            true_i = true[:, i].reshape(-1)
+            mask_i = mask[:, i]
+            cross_entropy_loss = utils.masked_cross_entropy(pred_i, true_i, mask_i) * mask_i
+            result.append(cross_entropy_loss)
+        cycle_loss = self._cycle_loss(pred)
+        loss = torch.stack(result).transpose(1, 0) * sample_weights.unsqueeze(-1)
+        return loss.sum() / valid_positions + cycle_loss.mean(), cycle_loss.mean()
+
+
+@base.Predictor.register('combo_dependency_parsing_from_vocab', constructor='from_vocab')
+class DependencyRelationModel(base.Predictor):
+
+    def __init__(self,
+                 head_predictor: HeadPredictionModel,
+                 head_projection_layer: base.Linear,
+                 dependency_projection_layer: base.Linear,
+                 relation_prediction_layer: base.Linear):
+        super().__init__()
+        self.head_predictor = head_predictor
+        self.head_projection_layer = head_projection_layer
+        self.dependency_projection_layer = dependency_projection_layer
+        self.relation_prediction_layer = relation_prediction_layer
+
+    def forward(self,
+                x: Union[torch.Tensor, List[torch.Tensor]],
+                mask: Optional[torch.BoolTensor] = None,
+                labels: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
+                sample_weights: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]:
+        if mask is not None:
+            mask = mask[:, 1:]
+        relations_labels, head_labels = None, None
+        if labels is not None and labels[0] is not None:
+            relations_labels, head_labels = labels
+            if mask is None:
+                mask = head_labels.new_ones(head_labels.size())
+
+        head_output = self.head_predictor(x, mask, head_labels, sample_weights)
+        head_pred = head_output['probability']
+        head_pred_soft = F.softmax(head_pred, dim=-1)
+
+        head_rel_emb = self.head_projection_layer(x)
+
+        dep_rel_emb = self.dependency_projection_layer(x)
+
+        dep_rel_pred = head_pred_soft.bmm(head_rel_emb)
+        dep_rel_pred = torch.cat((dep_rel_pred, dep_rel_emb), dim=-1)
+        relation_prediction = self.relation_prediction_layer(dep_rel_pred)
+        output = head_output
+
+        output['prediction'] = (relation_prediction.argmax(-1)[:, 1:], head_output['prediction'])
+
+        if labels is not None and labels[0] is not None:
+            if sample_weights is None:
+                sample_weights = labels.new_ones([mask.size(0)])
+            loss = self._loss(relation_prediction[:, 1:], relations_labels, mask, sample_weights)
+            output['loss'] = (loss, head_output['loss'])
+
+        return output
+
+    def _loss(self,
+              pred: torch.Tensor,
+              true: torch.Tensor,
+              mask: torch.BoolTensor,
+              sample_weights: torch.Tensor) -> torch.Tensor:
+
+        valid_positions = mask.sum()
+
+        BATCH_SIZE, _, DEPENDENCY_RELATIONS = pred.size()
+        pred = pred.reshape(-1, DEPENDENCY_RELATIONS)
+        true = true.reshape(-1)
+        mask = mask.reshape(-1)
+        loss = utils.masked_cross_entropy(pred, true, mask) * mask
+        loss = loss.reshape(BATCH_SIZE, -1) * sample_weights.unsqueeze(-1)
+        return loss.sum() / valid_positions
+
+    @classmethod
+    def from_vocab(cls,
+                   vocab: data.Vocabulary,
+                   vocab_namespace: str,
+                   head_predictor: HeadPredictionModel,
+                   head_projection_layer: base.Linear,
+                   dependency_projection_layer: base.Linear
+                   ):
+        assert vocab_namespace in vocab.get_namespaces()
+        relation_prediction_layer = base.Linear(
+            in_features=head_projection_layer.get_output_dim() + dependency_projection_layer.get_output_dim(),
+            out_features=vocab.get_vocab_size(vocab_namespace)
+        )
+        return cls(
+            head_predictor=head_predictor,
+            head_projection_layer=head_projection_layer,
+            dependency_projection_layer=dependency_projection_layer,
+            relation_prediction_layer=relation_prediction_layer
+        )
diff --git a/combo/models/utils.py b/combo/models/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f97913303415552f6c594ac857b563cab52b6c3b
--- /dev/null
+++ b/combo/models/utils.py
@@ -0,0 +1,8 @@
+import torch
+import torch.nn.functional as F
+
+
+def masked_cross_entropy(pred: torch.Tensor, true: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
+    mask = mask.float().unsqueeze(-1)
+    pred = pred + (mask + 1e-45).log()
+    return F.cross_entropy(pred, true, reduction='none')
diff --git a/combo/predict.py b/combo/predict.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e4f8326702724efcdcdefb629d567353606b976
--- /dev/null
+++ b/combo/predict.py
@@ -0,0 +1,141 @@
+import collections
+import logging
+import time
+from typing import List
+
+import conllu
+from allennlp import data as allen_data, common, models
+from allennlp.common import util
+from allennlp.data import tokenizers
+from allennlp.predictors import predictor
+from overrides import overrides
+
+logger = logging.getLogger(__name__)
+
+
+@predictor.Predictor.register('semantic-multitask-predictor')
+@predictor.Predictor.register('semantic-multitask-predictor-spacy', constructor='with_spacy_tokenizer')
+class SemanticMultitaskPredictor(predictor.Predictor):
+
+    def __init__(self,
+                 model: models.Model,
+                 dataset_reader: allen_data.DatasetReader,
+                 tokenizer: allen_data.Tokenizer = tokenizers.WhitespaceTokenizer()) -> None:
+        super().__init__(model, dataset_reader)
+        self.vocab = model.vocab
+        self._dataset_reader.generate_labels = False
+        self._tokenizer = tokenizer
+
+    @overrides
+    def _json_to_instance(self, json_dict: common.JsonDict) -> allen_data.Instance:
+        tokens = self._tokenizer.tokenize(json_dict['sentence'])
+        tree = self._sentence_to_tree([t.text for t in tokens])
+        return self._dataset_reader.text_to_instance(tree)
+
+    @overrides
+    def load_line(self, line: str) -> common.JsonDict:
+        return {'sentence': line.replace("\n", " ").strip()}
+
+    @overrides
+    def dump_line(self, outputs: common.JsonDict) -> str:
+        # Check whether serialized (str) tree or token's list
+        # Serialized tree has already separators between lines
+        if type(outputs['tree']) == str:
+            return str(outputs['tree'])
+        else:
+            return str(outputs['tree']) + "\n"
+
+    @overrides
+    def predict_instance(self, instance: allen_data.Instance) -> common.JsonDict:
+        start_time = time.time()
+        tree = self.predict_instance_as_tree(instance)
+        tree_json = util.sanitize(tree.serialize())
+        result = collections.OrderedDict([
+            ('tree', tree_json),
+        ])
+        end_time = time.time()
+        logger.info('Took {} ms'.format((end_time - start_time) * 1000.0))
+        return result
+
+    @overrides
+    def predict_json(self, inputs: common.JsonDict) -> common.JsonDict:
+        start_time = time.time()
+        instance = self._json_to_instance(inputs)
+        tree = self.predict_instance_as_tree(instance)
+        tree_json = util.sanitize(tree)
+        result = collections.OrderedDict([
+            ('tree', tree_json),
+        ])
+        end_time = time.time()
+        logger.info('Took {} ms'.format((end_time - start_time) * 1000.0))
+        return result
+
+    def predict_instance_as_tree(self, instance: allen_data.Instance) -> conllu.TokenList:
+        predictions = super().predict_instance(instance)
+        return self._predictions_as_tree(predictions, instance)
+
+    @staticmethod
+    def _sentence_to_tree(sentence: List[str]):
+        d = collections.OrderedDict
+        return conllu.TokenList(
+            [d({'id': idx, 'token': token}) for
+             idx, token
+             in enumerate(sentence)]
+        )
+
+    def _predictions_as_tree(self, predictions, instance):
+        tree = instance.fields['metadata']['input']
+        field_names = instance.fields['metadata']['field_names']
+        for idx, token in enumerate(tree):
+            for field_name in field_names:
+                if field_name in predictions:
+                    if field_name in ['xpostag', 'upostag', 'semrel', 'deprel']:
+                        value = self.vocab.get_token_from_index(predictions[field_name][idx], field_name + '_labels')
+                        token[field_name] = value
+                    elif field_name in ['head']:
+                        token[field_name] = int(predictions[field_name][idx])
+                    elif field_name in ['feats']:
+                        slices = self._model.morphological_feat.slices
+                        features = []
+                        prediction = predictions[field_name][idx]
+                        for (cat, cat_indices), pred_idx in zip(slices.items(), prediction):
+                            if cat not in ['__PAD__', '_']:
+                                value = self.vocab.get_token_from_index(cat_indices[pred_idx],
+                                                                        field_name + '_labels')
+                                # Exclude auxiliary values
+                                if '=None' not in value:
+                                    features.append(value)
+                        if len(features) == 0:
+                            field_value = '_'
+                        else:
+                            field_value = '|'.join(sorted(features))
+
+                        token[field_name] = field_value
+                    elif field_name == 'head':
+                        pass
+                    elif field_name == 'lemma':
+                        prediction = predictions[field_name][idx]
+                        word_chars = []
+                        for char_idx in prediction[1:-1]:
+                            pred_char = self.vocab.get_token_from_index(char_idx, 'lemma_characters')
+
+                            if pred_char == '__END__':
+                                break
+                            elif pred_char == '__PAD__':
+                                continue
+                            elif '_' in pred_char:
+                                pred_char = '?'
+
+                            word_chars.append(pred_char)
+                        token[field_name] = ''.join(word_chars)
+                    else:
+                        raise NotImplementedError(f'Unknown field name {field_name}!')
+
+        if self._dataset_reader and 'sent' in self._dataset_reader._targets:
+            tree.metadata = {'sentence_embedding': str(predictions['sentence_embedding'])}
+        return tree
+
+    @classmethod
+    def with_spacy_tokenizer(cls, model: models.Model,
+                             dataset_reader: allen_data.DatasetReader):
+        return cls(model, dataset_reader, tokenizers.SpacyTokenizer())
diff --git a/combo/training/__init__.py b/combo/training/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9980d5fda85246e3f8b38b434257defa165413be
--- /dev/null
+++ b/combo/training/__init__.py
@@ -0,0 +1,3 @@
+"""Training tools."""
+from .scheduler import Scheduler
+from .trainer import GradientDescentTrainer
diff --git a/combo/training/scheduler.py b/combo/training/scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..8752d739029f16289db94ee493d00bbd92158ce7
--- /dev/null
+++ b/combo/training/scheduler.py
@@ -0,0 +1,40 @@
+import torch.optim.lr_scheduler as lr_scheduler
+from allennlp.training.learning_rate_schedulers import learning_rate_scheduler
+from overrides import overrides
+
+
+@learning_rate_scheduler.LearningRateScheduler.register("combo_scheduler")
+class Scheduler(learning_rate_scheduler._PyTorchLearningRateSchedulerWrapper):
+
+    def __init__(self, optimizer, patience: int = 6, decreases: int = 2, threshold: float = 1e-3):
+        super().__init__(lr_scheduler.LambdaLR(optimizer, lr_lambda=[self._lr_lambda]))
+        self.threshold = threshold
+        self.decreases = decreases
+        self.patience = patience
+        self.start_patience = patience
+        self.best_score = 0.0
+
+    @staticmethod
+    def _lr_lambda(idx: int) -> float:
+        return 1.0 / (1.0 + idx * 1e-4)
+
+    @overrides
+    def step(self, metric: float = None) -> None:
+        self.lr_scheduler.step()
+
+        if metric is not None:
+            if metric - self.best_score > self.threshold:
+                self.best_score = metric if metric > self.best_score else self.best_score
+                self.patience = self.start_patience
+            else:
+                if self.patience <= 1:
+                    if self.decreases == 0:
+                        # This is condition for Trainer to trigger early stopping
+                        self.patience = 0
+                    else:
+                        self.patience = self.start_patience
+                        self.decreases -= 1
+                        self.threshold /= 2
+                        self.lr_scheduler.base_lrs = [x / 2 for x in self.lr_scheduler.base_lrs]
+                else:
+                    self.patience -= 1
diff --git a/combo/training/trainer.py b/combo/training/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..01773c394199c097214e9887e872e77580cfad54
--- /dev/null
+++ b/combo/training/trainer.py
@@ -0,0 +1,213 @@
+import datetime
+import logging
+import os
+import time
+import traceback
+from typing import Any, Dict, List, Optional
+
+import torch
+import torch.distributed as dist
+import torch.optim as optim
+import torch.optim.lr_scheduler
+import torch.utils.data as data
+from allennlp import training
+from allennlp.common import checks
+from allennlp.common import util as common_util
+from allennlp.models import model
+from allennlp.training import checkpointer
+from allennlp.training import learning_rate_schedulers
+from allennlp.training import momentum_schedulers
+from allennlp.training import moving_average
+from allennlp.training import tensorboard_writer
+from allennlp.training import util as training_util
+from overrides import overrides
+
+logger = logging.getLogger(__name__)
+
+
+@training.EpochCallback.register('transfer_patience')
+class TransferPatienceEpochCallback(training.EpochCallback):
+
+    def __call__(self, trainer: "training.GradientDescentTrainer", metrics: Dict[str, Any], epoch: int) -> None:
+        if trainer._learning_rate_scheduler and trainer._learning_rate_scheduler.patience is not None:
+            trainer._metric_tracker._patience = trainer._learning_rate_scheduler.patience
+            trainer._metric_tracker._epochs_with_no_improvement = 0
+        else:
+            raise checks.ConfigurationError("Learning rate scheduler isn't properly setup!")
+
+
+@training.Trainer.register("gradient_descent_validate_n", constructor="from_partial_objects")
+class GradientDescentTrainer(training.GradientDescentTrainer):
+
+    def __init__(self, model: model.Model, optimizer: optim.Optimizer, data_loader: data.DataLoader,
+                 patience: Optional[int] = None, validation_metric: str = "-loss",
+                 validation_data_loader: data.DataLoader = None, num_epochs: int = 20,
+                 serialization_dir: Optional[str] = None, checkpointer: checkpointer.Checkpointer = None,
+                 cuda_device: int = -1,
+                 grad_norm: Optional[float] = None, grad_clipping: Optional[float] = None,
+                 learning_rate_scheduler: Optional[learning_rate_schedulers.LearningRateScheduler] = None,
+                 momentum_scheduler: Optional[momentum_schedulers.MomentumScheduler] = None,
+                 tensorboard_writer: tensorboard_writer.TensorboardWriter = None,
+                 moving_average: Optional[moving_average.MovingAverage] = None,
+                 batch_callbacks: List[training.BatchCallback] = None,
+                 epoch_callbacks: List[training.EpochCallback] = None, distributed: bool = False, local_rank: int = 0,
+                 world_size: int = 1, num_gradient_accumulation_steps: int = 1,
+                 opt_level: Optional[str] = None) -> None:
+
+        super().__init__(model, optimizer, data_loader, patience, validation_metric, validation_data_loader, num_epochs,
+                         serialization_dir, checkpointer, cuda_device, grad_norm, grad_clipping,
+                         learning_rate_scheduler, momentum_scheduler, tensorboard_writer, moving_average,
+                         batch_callbacks, epoch_callbacks, distributed, local_rank, world_size,
+                         num_gradient_accumulation_steps, opt_level)
+        # TODO extract param to constructor (+ constructor method?)
+        self.validate_every_n = 5
+
+    @overrides
+    def train(self) -> Dict[str, Any]:
+        """
+        Trains the supplied model with the supplied parameters.
+        """
+        try:
+            epoch_counter = self._restore_checkpoint()
+        except RuntimeError:
+            traceback.print_exc()
+            raise checks.ConfigurationError(
+                "Could not recover training from the checkpoint.  Did you mean to output to "
+                "a different serialization directory or delete the existing serialization "
+                "directory?"
+            )
+
+        training_util.enable_gradient_clipping(self.model, self._grad_clipping)
+
+        logger.info("Beginning training.")
+
+        val_metrics: Dict[str, float] = {}
+        this_epoch_val_metric: float = None
+        metrics: Dict[str, Any] = {}
+        epochs_trained = 0
+        training_start_time = time.time()
+
+        metrics["best_epoch"] = self._metric_tracker.best_epoch
+        for key, value in self._metric_tracker.best_epoch_metrics.items():
+            metrics["best_validation_" + key] = value
+
+        for callback in self._epoch_callbacks:
+            callback(self, metrics={}, epoch=-1)
+
+        for epoch in range(epoch_counter, self._num_epochs):
+            epoch_start_time = time.time()
+            train_metrics = self._train_epoch(epoch)
+
+            # get peak of memory usage
+            if "cpu_memory_MB" in train_metrics:
+                metrics["peak_cpu_memory_MB"] = max(
+                    metrics.get("peak_cpu_memory_MB", 0), train_metrics["cpu_memory_MB"]
+                )
+            for key, value in train_metrics.items():
+                if key.startswith("gpu_"):
+                    metrics["peak_" + key] = max(metrics.get("peak_" + key, 0), value)
+
+            if self._validation_data_loader is not None:
+                val_metrics = {}
+                this_epoch_val_metric = None
+                if epoch % self.validate_every_n == 0:
+                    with torch.no_grad():
+                        # We have a validation set, so compute all the metrics on it.
+                        val_loss, val_reg_loss, num_batches = self._validation_loss(epoch)
+
+                        # It is safe again to wait till the validation is done. This is
+                        # important to get the metrics right.
+                        if self._distributed:
+                            dist.barrier()
+
+                        val_metrics = training_util.get_metrics(
+                            self.model,
+                            val_loss,
+                            val_reg_loss,
+                            num_batches,
+                            reset=True,
+                            world_size=self._world_size,
+                            cuda_device=[self.cuda_device],
+                        )
+
+                        # Check validation metric for early stopping
+                        this_epoch_val_metric = val_metrics[self._validation_metric]
+                        # self._metric_tracker.add_metric(this_epoch_val_metric)
+
+                train_metrics['patience'] = self._metric_tracker._patience
+                if self._metric_tracker.should_stop_early():
+                    logger.info("Ran out of patience.  Stopping training.")
+                    break
+
+            if self._master:
+                self._tensorboard.log_metrics(
+                    train_metrics, val_metrics=val_metrics, log_to_console=True, epoch=epoch + 1
+                )  # +1 because tensorboard doesn't like 0
+
+            # Create overall metrics dict
+            training_elapsed_time = time.time() - training_start_time
+            metrics["training_duration"] = str(datetime.timedelta(seconds=training_elapsed_time))
+            metrics["training_start_epoch"] = epoch_counter
+            metrics["training_epochs"] = epochs_trained
+            metrics["epoch"] = epoch
+
+            for key, value in train_metrics.items():
+                metrics["training_" + key] = value
+            for key, value in val_metrics.items():
+                metrics["validation_" + key] = value
+
+            if self._metric_tracker.is_best_so_far():
+                # Update all the best_ metrics.
+                # (Otherwise they just stay the same as they were.)
+                metrics["best_epoch"] = epoch
+                for key, value in val_metrics.items():
+                    metrics["best_validation_" + key] = value
+
+                self._metric_tracker.best_epoch_metrics = val_metrics
+
+            if self._serialization_dir and self._master:
+                common_util.dump_metrics(
+                    os.path.join(self._serialization_dir, f"metrics_epoch_{epoch}.json"), metrics
+                )
+
+            # The Scheduler API is agnostic to whether your schedule requires a validation metric -
+            # if it doesn't, the validation metric passed here is ignored.
+            if self._learning_rate_scheduler:
+                self._learning_rate_scheduler.step(this_epoch_val_metric)
+            if self._momentum_scheduler:
+                self._momentum_scheduler.step(this_epoch_val_metric)
+
+            if self._master:
+                self._checkpointer.save_checkpoint(
+                    epoch, self, is_best_so_far=self._metric_tracker.is_best_so_far()
+                )
+
+            # Wait for the master to finish saving the checkpoint
+            if self._distributed:
+                dist.barrier()
+
+            for callback in self._epoch_callbacks:
+                callback(self, metrics=metrics, epoch=epoch)
+
+            epoch_elapsed_time = time.time() - epoch_start_time
+            logger.info("Epoch duration: %s", datetime.timedelta(seconds=epoch_elapsed_time))
+
+            if epoch < self._num_epochs - 1:
+                training_elapsed_time = time.time() - training_start_time
+                estimated_time_remaining = training_elapsed_time * (
+                        (self._num_epochs - epoch_counter) / float(epoch - epoch_counter + 1) - 1
+                )
+                formatted_time = str(datetime.timedelta(seconds=int(estimated_time_remaining)))
+                logger.info("Estimated training time remaining: %s", formatted_time)
+
+            epochs_trained += 1
+
+        # make sure pending events are flushed to disk and files are closed properly
+        self._tensorboard.close()
+
+        # Load the best model state before returning
+        best_model_state = self._checkpointer.best_model_state()
+        if best_model_state:
+            self.model.load_state_dict(best_model_state)
+
+        return metrics
diff --git a/combo/utils/__init__.py b/combo/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/combo/utils/checks.py b/combo/utils/checks.py
new file mode 100644
index 0000000000000000000000000000000000000000..d322c15192c6e82647cb51fba8ca9b2d78e5ee3a
--- /dev/null
+++ b/combo/utils/checks.py
@@ -0,0 +1,20 @@
+import os
+
+import torch
+from allennlp.common import checks as allen_checks
+
+
+def file_exists(*paths):
+    for path in paths:
+        if path is None:
+            raise allen_checks.ConfigurationError(f'File cannot be None')
+        if not os.path.exists(path):
+            raise allen_checks.ConfigurationError(f'Could not find the file at path: `{path}`')
+
+
+def check_size_match(size_1: torch.Size, size_2: torch.Size, tensor_1_name: str, tensor_2_name: str):
+    if size_1 != size_2:
+        raise allen_checks.ConfigurationError(
+            f"{tensor_1_name} must match {tensor_2_name}, but got {size_1} "
+            f"and {size_2} instead"
+        )
diff --git a/combo/utils/metrics.py b/combo/utils/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc142ff244d081f63f7ae54fe203d6c8a785ffe0
--- /dev/null
+++ b/combo/utils/metrics.py
@@ -0,0 +1,246 @@
+from typing import Optional, List, Dict
+
+import torch
+from allennlp.training import metrics
+from overrides import overrides
+
+
+class SequenceBoolAccuracy(metrics.Metric):
+
+    def __init__(self, prod_last_dim: bool = False):
+        self._correct_count = 0.0
+        self._total_count = 0.0
+        self.prod_last_dim = prod_last_dim
+        self.correct_indices = torch.ones([])
+
+    @overrides
+    def __call__(self,
+                 predictions: torch.Tensor,
+                 gold_labels: torch.Tensor,
+                 mask: Optional[torch.BoolTensor] = None):
+        if gold_labels is None:
+            return
+        predictions, gold_labels, mask = self.detach_tensors(predictions,
+                                                             gold_labels,
+                                                             mask)
+
+        # Some sanity checks.
+        if gold_labels.size() != predictions.size():
+            raise ValueError(
+                f"gold_labels must have shape == predictions.size() but "
+                f"found tensor of shape: {gold_labels.size()}"
+            )
+        if mask is not None and mask.size() not in [predictions.size()[:-1], predictions.size()]:
+            raise ValueError(
+                f"mask must have shape in one of [predictions.size()[:-1], predictions.size()] but "
+                f"found tensor of shape: {mask.size()}"
+            )
+        if mask is None:
+            mask = predictions.new_ones(predictions.size()[:-1]).bool()
+        if mask.dim() < predictions.dim():
+            mask = mask.unsqueeze(-1)
+
+        correct = predictions.eq(gold_labels) * mask
+
+        if self.prod_last_dim:
+            correct = correct.prod(-1).unsqueeze(-1)
+
+        correct = correct.float()
+
+        self.correct_indices = correct.flatten().bool()
+        self._correct_count += correct.sum()
+        self._total_count += mask.sum()
+
+    @overrides
+    def get_metric(self, reset: bool) -> float:
+        if self._total_count > 0:
+            accuracy = float(self._correct_count) / float(self._total_count)
+        else:
+            accuracy = 0.0
+        if reset:
+            self.reset()
+        return accuracy
+
+    @overrides
+    def reset(self) -> None:
+        self._correct_count = 0.0
+        self._total_count = 0.0
+        self.correct_indices = torch.ones([])
+
+
+class AttachmentScores(metrics.Metric):
+    """
+    Computes labeled and unlabeled attachment scores for a
+    dependency parse, as well as sentence level exact match
+    for both labeled and unlabeled trees. Note that the input
+    to this metric is the sampled predictions, not the distribution
+    itself.
+
+    # Parameters
+
+    ignore_classes : `List[int]`, optional (default = None)
+        A list of label ids to ignore when computing metrics.
+    """
+
+    def __init__(self, ignore_classes: List[int] = None) -> None:
+        self._labeled_correct = 0.0
+        self._unlabeled_correct = 0.0
+        self._exact_labeled_correct = 0.0
+        self._exact_unlabeled_correct = 0.0
+        self._total_words = 0.0
+        self._total_sentences = 0.0
+        self.correct_indices = torch.ones([])
+
+        self._ignore_classes: List[int] = ignore_classes or []
+
+    def __call__(  # type: ignore
+            self,
+            predicted_indices: torch.Tensor,
+            predicted_labels: torch.Tensor,
+            gold_indices: torch.Tensor,
+            gold_labels: torch.Tensor,
+            mask: Optional[torch.BoolTensor] = None,
+    ):
+        """
+        # Parameters
+
+        predicted_indices : `torch.Tensor`, required.
+            A tensor of head index predictions of shape (batch_size, timesteps).
+        predicted_labels : `torch.Tensor`, required.
+            A tensor of arc label predictions of shape (batch_size, timesteps).
+        gold_indices : `torch.Tensor`, required.
+            A tensor of the same shape as `predicted_indices`.
+        gold_labels : `torch.Tensor`, required.
+            A tensor of the same shape as `predicted_labels`.
+        mask : `torch.BoolTensor`, optional (default = None).
+            A tensor of the same shape as `predicted_indices`.
+        """
+        detached = self.detach_tensors(
+            predicted_indices, predicted_labels, gold_indices, gold_labels, mask
+        )
+        predicted_indices, predicted_labels, gold_indices, gold_labels, mask = detached
+
+        if mask is None:
+            mask = torch.ones_like(predicted_indices).bool()
+
+        predicted_indices = predicted_indices.long()
+        predicted_labels = predicted_labels.long()
+        gold_indices = gold_indices.long()
+        gold_labels = gold_labels.long()
+
+        # Multiply by a mask denoting locations of
+        # gold labels which we should ignore.
+        for label in self._ignore_classes:
+            label_mask = gold_labels.eq(label)
+            mask = mask & ~label_mask
+
+        correct_indices = predicted_indices.eq(gold_indices).long() * mask
+        unlabeled_exact_match = (correct_indices + ~mask).prod(dim=-1)
+        correct_labels = predicted_labels.eq(gold_labels).long() * mask
+        correct_labels_and_indices = correct_indices * correct_labels
+        self.correct_indices = correct_labels_and_indices.flatten()
+        labeled_exact_match = (correct_labels_and_indices + ~mask).prod(dim=-1)
+
+        self._unlabeled_correct += correct_indices.sum()
+        self._exact_unlabeled_correct += unlabeled_exact_match.sum()
+        self._labeled_correct += correct_labels_and_indices.sum()
+        self._exact_labeled_correct += labeled_exact_match.sum()
+        self._total_sentences += correct_indices.size(0)
+        self._total_words += correct_indices.numel() - (~mask).sum()
+
+    def get_metric(self, reset: bool = False):
+        """
+        # Returns
+
+        The accumulated metrics as a dictionary.
+        """
+        unlabeled_attachment_score = 0.0
+        labeled_attachment_score = 0.0
+        unlabeled_exact_match = 0.0
+        labeled_exact_match = 0.0
+        if self._total_words > 0.0:
+            unlabeled_attachment_score = float(self._unlabeled_correct) / float(self._total_words)
+            labeled_attachment_score = float(self._labeled_correct) / float(self._total_words)
+        if self._total_sentences > 0:
+            unlabeled_exact_match = float(self._exact_unlabeled_correct) / float(
+                self._total_sentences
+            )
+            labeled_exact_match = float(self._exact_labeled_correct) / float(self._total_sentences)
+        if reset:
+            self.reset()
+        return {
+            "UAS": unlabeled_attachment_score,
+            "LAS": labeled_attachment_score,
+            "UEM": unlabeled_exact_match,
+            "LEM": labeled_exact_match,
+        }
+
+    @overrides
+    def reset(self):
+        self._labeled_correct = 0.0
+        self._unlabeled_correct = 0.0
+        self._exact_labeled_correct = 0.0
+        self._exact_unlabeled_correct = 0.0
+        self._total_words = 0.0
+        self._total_sentences = 0.0
+        self.correct_indices = torch.ones([])
+
+
+class SemanticMetrics(metrics.Metric):
+
+    def __init__(self) -> None:
+        self.upos_score = SequenceBoolAccuracy()
+        self.xpos_score = SequenceBoolAccuracy()
+        self.semrel_score = SequenceBoolAccuracy()
+        self.feats_score = SequenceBoolAccuracy(prod_last_dim=True)
+        self.lemma_score = SequenceBoolAccuracy(prod_last_dim=True)
+        self.attachment_scores = AttachmentScores()
+        self.em_score = 0.0
+
+    def __call__(  # type: ignore
+            self,
+            predictions: Dict[str, torch.Tensor],
+            gold_labels: Dict[str, torch.Tensor],
+            mask: torch.BoolTensor):
+        self.upos_score(predictions['upostag'], gold_labels['upostag'], mask)
+        self.xpos_score(predictions['xpostag'], gold_labels['xpostag'], mask)
+        self.semrel_score(predictions['semrel'], gold_labels['semrel'], mask)
+        self.feats_score(predictions['feats'], gold_labels['feats'], mask)
+        self.lemma_score(predictions['lemma'], gold_labels['lemma'], mask)
+        self.attachment_scores(predictions['head'],
+                               predictions['deprel'],
+                               gold_labels['head'],
+                               gold_labels['deprel'],
+                               mask)
+        total = mask.sum()
+        correct_indices = (self.upos_score.correct_indices *
+                           self.xpos_score.correct_indices *
+                           self.semrel_score.correct_indices *
+                           self.feats_score.correct_indices *
+                           self.lemma_score.correct_indices *
+                           self.attachment_scores.correct_indices
+                           )
+
+        total, correct_indices = self.detach_tensors(total, correct_indices)
+        self.em_score = (correct_indices.float().sum() / total).item()
+
+    def get_metric(self, reset: bool) -> Dict[str, float]:
+        metrics = {
+            "UPOS_ACC": self.upos_score.get_metric(reset),
+            "XPOS_ACC": self.xpos_score.get_metric(reset),
+            "SEMREL_ACC": self.semrel_score.get_metric(reset),
+            "LEMMA_ACC": self.lemma_score.get_metric(reset),
+            "FEATS_ACC": self.feats_score.get_metric(reset),
+            "EM": self.em_score
+        }
+        metrics.update(self.attachment_scores.get_metric(reset))
+        return metrics
+
+    def reset(self) -> None:
+        self.upos_score.reset()
+        self.xpos_score.reset()
+        self.semrel_score.reset()
+        self.lemma_score.reset()
+        self.feats_score.reset()
+        self.attachment_scores.reset()
+        self.em_score = 0.0
diff --git a/config.template.jsonnet b/config.template.jsonnet
new file mode 100644
index 0000000000000000000000000000000000000000..4ec3ec4304c8bceb02a19be702d8c2e6bbfa0087
--- /dev/null
+++ b/config.template.jsonnet
@@ -0,0 +1,369 @@
+########################################################################################
+#                                 BASIC configuration                                  #
+########################################################################################
+# Training data path, str
+# Must be in CONNLU format (or it's extended version with semantic relation field).
+# Can accepted multiple paths when concatenated with ':', "path1:path2"
+local training_data_path = std.extVar("training_data_path");
+# Validation data path, str
+# Can accepted multiple paths when concatenated with ':', "path1:path2"
+local validation_data_path = if std.length(std.extVar("validation_data_path")) > 0 then std.extVar("validation_data_path");
+# Path to pretrained tokens, str or null
+local pretrained_tokens = if std.length(std.extVar("pretrained_tokens")) > 0 then std.extVar("pretrained_tokens");
+# Name of pretrained transformer model, str or null
+local pretrained_transformer_name = if std.length(std.extVar("pretrained_transformer_name")) > 0 then std.extVar("pretrained_transformer_name");
+# Learning rate value, float
+local learning_rate = 0.002;
+# Number of epochs, int
+local num_epochs = 1;
+# Cuda device id, -1 for cpu, int
+local cuda_device = -1;
+# Minimum number of words in batch, int
+local word_batch_size = 1;
+# Features used as input, list of str
+# Choice "upostag", "xpostag", "lemma"
+# Required "token", "char"
+local features = std.split(std.extVar("features"), " ");
+# Targets of the model, list of str
+# Choice "feats", "lemma", "upostag", "xpostag", "semrel"
+# Required "deprel", "head"
+local targets = std.split(std.extVar("targets"), " ");
+# Path for tensorboard metrics, str
+local metrics_dir = "./runs";
+# Word embedding dimension, int
+# If pretrained_tokens is not null must much provided dimensionality
+local embedding_dim = std.parseInt(std.extVar("embedding_dim"));
+# Dropout rate on predictors, float
+# All of the models on top of the encoder uses this dropout
+local predictors_dropout = 0.25;
+# Xpostag embedding dimension, int
+# (discarded if xpostag not in features)
+local xpostag_dim = 100;
+# Upostag embedding dimension, int
+# (discarded if upostag not in features)
+local upostag_dim = 100;
+# Lemma embedding dimension, int
+# (discarded if lemma not in features)
+local lemma_char_dim = 64;
+# Character embedding dim, int
+local char_dim = 64;
+# Word embedding projection dim, int
+local projected_embedding_dim = 100;
+# Loss weights, dict[str, int]
+local loss_weights = {
+    xpostag: 0.05,
+    upostag: 0.05,
+    lemma: 0.05,
+    feats: 0.2,
+    deprel: 0.8,
+    head: 0.2,
+};
+# Encoder hidden size, int
+local hidden_size = 512;
+# Number of layers in the encoder, int
+local num_layers = 2;
+# Cycle loss iterations, int
+local cycle_loss_n = 0;
+# Maximum length of the word, int
+# Shorter words are padded, longer - truncated
+local word_length = 30;
+
+
+# Helper functions
+local in_features(name) = !(std.length(std.find(name, features)) == 0);
+local in_targets(name) = !(std.length(std.find(name, targets)) == 0);
+local use_transformer = pretrained_transformer_name != null;
+
+# Verify some configuration requirements
+assert in_features("token"): "Key 'token' must be in features!";
+assert in_features("char"): "Key 'char' must be in features!";
+
+assert in_targets("deprel"): "Key 'deprel' must be in targets!";
+assert in_targets("head"): "Key 'head' must be in targets!";
+
+assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't use pretrained tokens and pretrained transformer at the same time!";
+
+########################################################################################
+#                              ADVANCED configuration                                  #
+########################################################################################
+
+# Detailed dataset, training, vocabulary and model configuration.
+{
+    # Configuration type (default or finetuning), str
+    type: std.extVar('type'),
+    # Datasets used for vocab creation, list of str
+    # Choice "train", "valid"
+    datasets_for_vocab_creation: ['train'],
+    # Path to training data, str
+    train_data_path: training_data_path,
+    # Path to validation data, str
+    validation_data_path: validation_data_path,
+    # Dataset reader configuration (conllu format)
+    dataset_reader: {
+        type: "conllu",
+        features: features,
+        targets: targets,
+        # Whether data contains semantic relation field, bool
+        use_sem: if in_features("semrel") then true else false,
+        token_indexers: {
+            token: if use_transformer then {
+                type: "pretrained_transformer_mismatched",
+                model_name: pretrained_transformer_name,
+            } else {
+                # SingleIdTokenIndexer, token as single int
+                type: "single_id",
+            },
+            upostag: {
+                type: "single_id",
+                namespace: "upostag",
+                feature_name: "pos_",
+            },
+            xpostag: {
+                type: "single_id",
+                namespace: "xpostag",
+                feature_name: "tag_",
+            },
+            lemma: {
+                type: "characters_const_padding",
+                character_tokenizer: {
+                    start_tokens: ["__START__"],
+                    end_tokens: ["__END__"],
+                },
+                # +2 for start and end token
+                min_padding_length: word_length + 2,
+            },
+            char: {
+                type: "characters_const_padding",
+                character_tokenizer: {
+                    start_tokens: ["__START__"],
+                    end_tokens: ["__END__"],
+                },
+                # +2 for start and end token
+                min_padding_length: word_length + 2,
+            },
+        },
+        lemma_indexers: {
+            char: {
+                type: "characters_const_padding",
+                namespace: "lemma_characters",
+                character_tokenizer: {
+                    start_tokens: ["__START__"],
+                    end_tokens: ["__END__"],
+                },
+                # +2 for start and end token
+                min_padding_length: word_length + 2,
+            },
+        },
+    },
+    # Data loader configuration
+    data_loader: {
+        batch_sampler: {
+            type: "token_count",
+            word_batch_size: word_batch_size,
+        },
+    },
+    # Vocabulary configuration
+    vocabulary: std.prune({
+        type: "from_instances_extended",
+        only_include_pretrained_words: true,
+        pretrained_files: {
+            tokens: pretrained_tokens,
+        },
+        oov_token: "_",
+        padding_token: "__PAD__",
+        non_padded_namespaces: ["head_labels"],
+    }),
+    model: std.prune({
+        type: "semantic_multitask",
+        text_field_embedder: {
+            type: "basic",
+            token_embedders: {
+                xpostag: if in_features("xpostag") then {
+                    type: "embedding",
+                    padding_index: 0,
+                    embedding_dim: xpostag_dim,
+                    vocab_namespace: "xpostag",
+                },
+                upostag: if in_features("upostag") then {
+                    type: "embedding",
+                    padding_index: 0,
+                    embedding_dim: upostag_dim,
+                    vocab_namespace: "upostag",
+                },
+                token: if use_transformer then {
+                    type: "transformers_word_embeddings",
+                    model_name: pretrained_transformer_name,
+                    projection_dim: projected_embedding_dim,
+                } else {
+                    type: "embeddings_projected",
+                    embedding_dim: embedding_dim,
+                    projection_layer: {
+                        in_features: embedding_dim,
+                        out_features: projected_embedding_dim,
+                        dropout_rate: 0.25,
+                        activation: "tanh"
+                    },
+                    vocab_namespace: "tokens",
+                    pretrained_file: pretrained_tokens,
+                    trainable: if pretrained_tokens == null then true else false,
+                },
+                char: {
+                    type: "char_embeddings_from_config",
+                    embedding_dim: char_dim,
+                    dilated_cnn_encoder: {
+                        input_dim: char_dim,
+                        filters: [512, 256, char_dim],
+                        kernel_size: [3, 3, 3],
+                        stride: [1, 1, 1],
+                        padding: [1, 2, 4],
+                        dilation: [1, 2, 4],
+                        activations: ["relu", "relu", "linear"],
+                    },
+                },
+                lemma: if in_features("lemma") then {
+                    type: "char_embeddings_from_config",
+                    embedding_dim: lemma_char_dim,
+                    dilated_cnn_encoder: {
+                        input_dim: lemma_char_dim,
+                        filters: [512, 256, lemma_char_dim],
+                        kernel_size: [3, 3, 3],
+                        stride: [1, 1, 1],
+                        padding: [1, 2, 4],
+                        dilation: [1, 2, 4],
+                        activations: ["relu", "relu", "linear"],
+                    },
+                },
+            },
+        },
+        loss_weights: loss_weights,
+        seq_encoder: {
+            type: "combo_encoder",
+            layer_dropout_probability: 0.33,
+            stacked_bilstm: {
+                input_size:
+                char_dim + projected_embedding_dim +
+                if in_features('xpostag') then xpostag_dim else 0 +
+                if in_features('lemma') then lemma_char_dim else 0 +
+                if in_features('upostag') then upostag_dim else 0,
+                hidden_size: hidden_size,
+                num_layers: num_layers,
+                recurrent_dropout_probability: 0.33,
+                layer_dropout_probability: 0.33
+            },
+        },
+        dependency_relation: {
+            type: "combo_dependency_parsing_from_vocab",
+            vocab_namespace: 'deprel_labels',
+            head_predictor: {
+                local projection_dim = 512,
+                cycle_loss_n: cycle_loss_n,
+                head_projection_layer: {
+                    in_features: hidden_size * 2,
+                    out_features: projection_dim,
+                    activation: "tanh",
+                },
+                dependency_projection_layer: {
+                    in_features: hidden_size * 2,
+                    out_features: projection_dim,
+                    activation: "tanh",
+                },
+            },
+            local projection_dim = 128,
+            head_projection_layer: {
+                in_features: hidden_size * 2,
+                out_features: projection_dim,
+                dropout_rate: predictors_dropout,
+                activation: "tanh"
+            },
+            dependency_projection_layer: {
+                in_features: hidden_size * 2,
+                out_features: projection_dim,
+                dropout_rate: predictors_dropout,
+                activation: "tanh"
+            },
+        },
+        morphological_feat: if in_targets("feats") then {
+            type: "combo_morpho_from_vocab",
+            vocab_namespace: "feats_labels",
+            input_dim: hidden_size * 2,
+            hidden_dims: [128],
+            activations: ["tanh", "linear"],
+            dropout: [predictors_dropout, 0.0],
+            num_layers: 2,
+        },
+        lemmatizer: if in_targets("lemma") then {
+            type: "combo_lemma_predictor_from_vocab",
+            char_vocab_namespace: "token_characters",
+            lemma_vocab_namespace: "lemma_characters",
+            embedding_dim: 256,
+            input_projection_layer: {
+                in_features: hidden_size * 2,
+                out_features: 32,
+                dropout_rate: predictors_dropout,
+                activation: "tanh"
+            },
+            filters: [256, 256, 256],
+            kernel_size: [3, 3, 3, 1],
+            stride: [1, 1, 1, 1],
+            padding: [1, 2, 4, 0],
+            dilation: [1, 2, 4, 1],
+            activations: ["relu", "relu", "relu", "linear"],
+        },
+        upos_tagger: if in_targets("upostag") then {
+            input_dim: hidden_size * 2,
+            hidden_dims: [64],
+            activations: ["tanh", "linear"],
+            dropout: [predictors_dropout, 0.0],
+            num_layers: 2,
+            vocab_namespace: "upostag_labels"
+        },
+        xpos_tagger: if in_targets("xpostag") then {
+            input_dim: hidden_size * 2,
+            hidden_dims: [128],
+            activations: ["tanh", "linear"],
+            dropout: [predictors_dropout, 0.0],
+            num_layers: 2,
+            vocab_namespace: "xpostag_labels"
+        },
+        semantic_relation: if in_targets("semrel") then {
+            input_dim: hidden_size * 2,
+            hidden_dims: [64],
+            activations: ["tanh", "linear"],
+            dropout: [predictors_dropout, 0.0],
+            num_layers: 2,
+            vocab_namespace: "semrel_labels"
+        },
+        regularizer: {
+            regexes: [
+                [".*conv1d.*", {type: "l2", alpha: 1e-6}],
+                [".*forward.*", {type: "l2", alpha: 1e-6}],
+                [".*backward.*", {type: "l2", alpha: 1e-6}],
+                [".*char_embed.*", {type: "l2", alpha: 1e-5}],
+            ],
+        },
+    }),
+    trainer: {
+        type: "gradient_descent_validate_n",
+        cuda_device: cuda_device,
+        grad_clipping: 5.0,
+        num_epochs: num_epochs,
+        optimizer: {
+            type: "adam",
+            lr: learning_rate,
+            betas: [0.9, 0.9],
+        },
+        patience: 1, # it will  be overwriten by callback
+        epoch_callbacks: [
+            { type: "transfer_patience" },
+        ],
+        learning_rate_scheduler: {
+            type: "combo_scheduler",
+        },
+        tensorboard_writer: {
+            serialization_dir: metrics_dir,
+            should_log_learning_rate: true,
+            summary_interval: 2,
+        },
+        validation_metric: "+EM",
+    },
+}
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..68c7453553d2fd81d2a044b229a2ece7f0f2a45b
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,3 @@
+pylint==2.4.4
+pylint-quotes==0.2.1
+pytest==5.3.5
diff --git a/setup.cfg b/setup.cfg
new file mode 100644
index 0000000000000000000000000000000000000000..b7e478982ccf9ab1963c74e1084dfccb6e42c583
--- /dev/null
+++ b/setup.cfg
@@ -0,0 +1,2 @@
+[aliases]
+test=pytest
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7caeb3b98e89a09d0737a809e2c8af865f65a94
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,25 @@
+"""Setup."""
+from setuptools import find_packages, setup
+
+REQUIREMENTS = [
+    'absl-py==0.9.0',
+    'allennlp==1.0.0rc4',
+    'conllu==2.3.2',
+    'joblib==0.14.1',
+    'jsonnet==0.15.0',
+    'overrides==2.8.0',
+    'tensorboard==2.1.0',
+    'torch==1.5.0',
+    'torchvision==0.6.0',
+    'transformers==2.9.1',
+]
+
+setup(
+    name='COMBO',
+    version='0.0.1',
+    install_requires=REQUIREMENTS,
+    packages=find_packages(exclude=['tests']),
+    setup_requires=['pytest-runner', 'pytest-pylint'],
+    tests_require=['pytest', 'pylint'],
+    entry_points={'console_scripts': ['combo = combo.main:main']},
+)
diff --git a/tests/data/fields/test_samplers.py b/tests/data/fields/test_samplers.py
new file mode 100644
index 0000000000000000000000000000000000000000..81689fc6ac573687708e3667d4ec734cb5f9b971
--- /dev/null
+++ b/tests/data/fields/test_samplers.py
@@ -0,0 +1,32 @@
+"""Sampler tests."""
+import unittest
+
+from allennlp import data
+from allennlp.data import fields
+
+from combo.data import TokenCountBatchSampler
+
+
+class TokenCountBatchSamplerTest(unittest.TestCase):
+
+    def setUp(self) -> None:
+        self.dataset = []
+        self.sentences = ['First sentence makes full batch.', 'Short', 'This ends first batch']
+        for sentence in self.sentences:
+            tokens = [data.Token(t)
+                      for t in sentence.split()]
+            text_field = fields.TextField(tokens, {})
+            self.dataset.append(data.Instance({'sentence': text_field}))
+
+    def test_batches(self):
+        # given
+        sampler = TokenCountBatchSampler(self.dataset, word_batch_size=2, shuffle_dataset=False)
+
+        # when
+        length = len(sampler)
+        values = list(sampler)
+
+        # then
+        self.assertEqual(2, length)
+        # sort by lengths + word_batch_size makes 1, 2 first batch
+        self.assertListEqual([[1, 2], [0]], values)
diff --git a/tests/data/fields/test_sequence_multilabel_field.py b/tests/data/fields/test_sequence_multilabel_field.py
new file mode 100644
index 0000000000000000000000000000000000000000..85d28baf44bf5d05b9d209031280da10cade834f
--- /dev/null
+++ b/tests/data/fields/test_sequence_multilabel_field.py
@@ -0,0 +1,96 @@
+"""Sequence multilabel field tests."""
+import unittest
+from typing import List
+
+import torch
+from allennlp import data as allen_data
+from allennlp.data import fields as allen_fields
+
+from combo.data import fields
+
+
+class IndexingSequenceMultiLabelFieldTest(unittest.TestCase):
+
+    def setUp(self) -> None:
+        self.namespace = 'test_labels'
+        self.vocab = allen_data.Vocabulary()
+        self.vocab.add_tokens_to_namespace(
+            tokens=['t' + str(idx) for idx in range(3)],
+            namespace=self.namespace
+        )
+
+        def _indexer(vocab: allen_data.Vocabulary):
+            vocab_size = vocab.get_vocab_size(self.namespace)
+
+            def _mapper(multi_label: List[str]) -> List[int]:
+                one_hot = [0] * vocab_size
+                for label in multi_label:
+                    index = vocab.get_token_index(label, self.namespace)
+                    one_hot[index] = 1
+                return one_hot
+
+            return _mapper
+
+        self.indexer = _indexer
+        self.sequence_field = _SequenceFieldTestWrapper(self.vocab.get_vocab_size(self.namespace))
+
+    def test_indexing(self):
+        # given
+        field = fields.SequenceMultiLabelField(
+            multi_labels=[['t1', 't2'], [], ['t0']],
+            multi_label_indexer=self.indexer,
+            sequence_field=self.sequence_field,
+            label_namespace=self.namespace
+        )
+        expected = [[0, 1, 1], [0, 0, 0], [1, 0, 0]]
+
+        # when
+        field.index(self.vocab)
+
+        # then
+        self.assertEqual(expected, field._indexed_multi_labels)
+
+    def test_mapping_to_tensor(self):
+        # given
+        field = fields.SequenceMultiLabelField(
+            multi_labels=[['t1', 't2'], [], ['t0']],
+            multi_label_indexer=self.indexer,
+            sequence_field=self.sequence_field,
+            label_namespace=self.namespace
+        )
+        field.index(self.vocab)
+        expected = torch.tensor([[0, 1, 1], [0, 0, 0], [1, 0, 0]])
+
+        # when
+        actual = field.as_tensor(field.get_padding_lengths())
+
+        # then
+        self.assertTrue(torch.all(expected.eq(actual)))
+
+    def test_sequence_method(self):
+        # given
+        field = fields.SequenceMultiLabelField(
+            multi_labels=[['t1', 't2'], [], ['t0']],
+            multi_label_indexer=self.indexer,
+            sequence_field=self.sequence_field,
+            label_namespace=self.namespace
+        )
+
+        # when
+        length = len(field)
+        iter_length = len(list(iter(field)))
+        middle_value = field[1]
+
+        # then
+        self.assertEqual(3, length)
+        self.assertEqual(3, iter_length)
+        self.assertEqual([], middle_value)
+
+
+class _SequenceFieldTestWrapper(allen_fields.SequenceField):
+
+    def __init__(self, length: int):
+        self.length = length
+
+    def sequence_length(self) -> int:
+        return self.length
diff --git a/tests/fixtures/example.conllu b/tests/fixtures/example.conllu
new file mode 100644
index 0000000000000000000000000000000000000000..0cf30f04351db8904513b7363a6e8414b27cd347
--- /dev/null
+++ b/tests/fixtures/example.conllu
@@ -0,0 +1,5 @@
+# sent_id = test-s1
+# text = Easy sentence.
+1	Verylongwordwhichmustbetruncatedbythesystemto30	easy	ADJ	adj	AdpType=Prep	1	amod	_	_
+2	Sentence	verylonglemmawhichmustbetruncatedbythesystemto30	NOUN	nom	Number=Sing	0	root	_	_
+3	.	.	PUNCT	.	_	1	punct	_	_
diff --git a/tests/fixtures/example.vec b/tests/fixtures/example.vec
new file mode 100644
index 0000000000000000000000000000000000000000..99d498691f6e086f954417062d7cff8d21b14b5f
--- /dev/null
+++ b/tests/fixtures/example.vec
@@ -0,0 +1,4 @@
+2 300
+Sentence 0.1271 0.0356 -0.0134 -0.0100 0.1353 0.0263 -0.0597 0.0231 0.0074 -0.0126 0.0209 -0.0309 0.2831 -0.0016 -0.0338 0.1002 0.0494 -0.0874 -0.0929 -0.0134 -0.1597 -0.0016 0.0273 0.0035 -0.1760 -0.0756 -0.0853 0.0095 0.0008 -0.0660 0.1774 0.0660 -0.0071 -0.0210 0.0229 -0.0173 -0.0379 -0.0682 0.0198 0.0254 -0.0108 -0.0778 -0.0094 0.0625 -0.0221 0.0631 0.0354 -0.0296 0.0617 0.0314 0.1582 0.1129 0.0281 0.0021 0.4368 -0.1094 0.0555 -0.0240 -0.1006 -0.0608 -0.0936 0.0184 -0.0265 0.0042 -0.1166 -0.0007 -0.0332 -0.0832 0.0454 -0.0446 -0.0499 -0.0852 -0.0684 -0.0598 -0.0065 -0.0321 0.0578 -0.0149 -0.0353 0.0319 0.0691 0.0843 0.8444 0.0066 -0.0465 -0.0164 0.0589 0.0603 -0.0589 -0.0497 -0.0184 0.0095 -0.1072 0.0254 -0.0094 -0.0271 0.0378 0.0048 -0.0015 -0.1009 -0.0182 -0.0513 0.0573 -0.0002 0.0518 -0.0759 0.0564 -0.0137 -0.0261 0.0656 0.0061 -0.0295 -0.0349 -0.1159 -0.0101 -0.0236 -0.0483 0.0343 0.0057 -0.1205 -0.0833 0.0393 0.0236 0.0129 -0.0899 -0.0360 -0.0380 -0.0386 0.0778 -0.0329 -0.0757 0.0701 -0.0246 -0.3195 -0.1065 -0.0229 0.2152 0.0918 -0.0843 -0.0342 -0.0281 0.0780 -0.0004 0.1064 0.1484 0.0727 -0.0202 -0.0514 -0.0393 0.0165 -0.0911 -0.0117 0.0395 -0.1155 -0.2148 -0.0061 0.0328 -0.0558 -0.1027 0.0284 -0.0193 -0.0006 -0.0815 0.0062 0.0224 0.0079 -0.1508 -0.2233 0.0603 -0.0200 0.0484 -0.0298 0.2781 0.0728 -0.0166 -0.0532 -0.1248 0.1290 -0.0349 0.0578 -0.0680 -0.0327 0.1507 -0.0220 -0.0769 -0.1918 0.0003 0.0308 -0.0105 0.0436 0.0709 -0.0622 0.0654 0.1635 -0.0525 0.0198 0.0719 0.0586 0.0509 -0.0101 -0.2630 0.0578 0.0433 -0.0168 0.1403 -0.0320 0.0609 0.0251 -0.0038 -0.0138 -0.0424 0.0864 -0.1150 -0.0703 -0.1763 0.0424 0.0629 -0.0049 -0.3170 0.1916 -0.0464 0.0751 0.0237 0.0864 -0.1301 0.0279 -0.2121 -0.0120 0.1025 0.0207 -0.0091 -0.0104 0.0493 -0.0151 -0.0463 -0.0199 0.0247 0.0271 0.0203 0.0210 0.0967 0.0239 0.0313 0.0170 -0.1281 0.0023 -0.0771 -0.0273 -0.0128 -0.0623 0.0169 0.0569 0.0365 0.0747 -0.0060 0.0538 0.0408 -0.0028 -0.0357 0.0390 0.0170 -0.1010 -0.0654 -0.0552 0.0706 -0.0406 0.0005 -0.0368 -0.0214 -0.0939 -0.0396 0.0535 0.0406 -0.0293 0.0240 0.0699 -0.0381 -0.0011 -0.0323 0.0977 -0.0687 0.1055 0.0692 -0.0002 -0.0174 -0.1106 0.0914 -0.0234 0.2364 0.0245 -0.0698 0.1212 0.0032 -0.0884 -0.1106 -0.0209 0.0748 -0.0175 -0.0111 0.0092
+. -0.1028 0.0208 -0.0235 0.0794 0.0284 0.0097 0.0328 -0.0054 -0.0338 -0.0541 -0.0293 -0.0616 0.3285 -0.1766 0.0214 0.3172 0.1139 0.0546 0.0106 -0.0492 0.1317 0.0441 -0.0442 -0.0822 -0.2058 -0.3436 0.0357 0.0860 0.0631 -0.0236 0.0253 0.1204 -0.0011 0.0833 0.0457 0.0463 -0.0049 -0.0002 0.0356 -0.0221 -0.0241 0.1429 -0.0498 0.1081 -0.0569 0.1883 -0.0259 0.0110 0.0795 0.0906 0.0159 -0.1774 0.0466 0.1067 0.4010 0.0221 0.0473 0.0630 -0.0191 0.0152 -0.0745 0.0615 -0.0565 0.0168 0.0042 -0.0890 -0.0333 0.0493 -0.0694 0.0076 -0.1073 0.0070 0.0190 -0.0456 -0.1193 0.0096 -0.0013 -0.0110 -0.0602 0.0485 -0.0192 0.1404 1.0454 -0.0198 0.1328 0.0495 -0.0621 -0.0135 -0.0507 -0.0276 0.0010 0.0499 0.1893 -0.0062 0.0283 -0.0601 0.0179 0.0620 0.0073 -0.0410 -0.0922 -0.0439 -0.0612 0.0197 0.0297 0.0109 0.0535 0.0724 -0.0540 0.0939 0.0062 -0.0886 -0.0044 -0.0572 -0.0058 -0.0100 -0.0138 0.0796 -0.0190 0.0539 0.1392 0.2092 -0.0483 0.0948 0.0293 -0.1015 -0.0377 -0.0052 -0.1322 0.0353 -0.0000 0.0968 -0.0553 -0.1574 -0.2598 -0.0186 0.2303 -0.0370 0.0001 -0.0023 0.0240 0.0100 -0.0249 0.2826 -0.1605 -0.0762 -0.0917 -0.0243 -0.0795 -0.0109 -0.0247 -0.1053 0.1111 -0.0671 -0.2930 0.0796 0.0465 -0.0236 0.0477 0.0900 -0.0841 -0.0175 -0.0492 -0.0029 0.0029 -0.0543 0.3960 -0.0695 0.0380 0.0090 0.0191 -0.0685 0.3998 -0.2886 -0.0274 -0.0080 -0.0081 0.6646 -0.0150 -0.0635 -0.0385 -0.0017 0.0057 -0.0346 0.0091 -0.5501 0.0295 -0.0726 0.0424 -0.0347 0.0247 -0.0223 -0.0820 0.5079 0.0582 0.0771 0.0283 0.0328 0.0002 -0.0144 0.2214 0.0470 0.0793 0.0079 -0.0763 0.0095 -0.0511 -0.0454 -0.0095 0.0097 -0.0196 0.1779 0.0348 0.0653 0.0193 0.0236 0.0269 -0.0042 -0.0604 -0.1879 0.0400 0.0415 0.0368 -0.0349 0.1267 0.0685 0.2566 -0.3385 0.0788 -0.0458 0.0736 0.3825 0.0496 0.0342 0.0906 -0.0084 -0.0540 0.0018 -0.0042 0.0416 0.0251 -0.0075 0.0347 -0.0899 -0.0071 0.0186 0.0034 -0.0192 -0.0017 -0.1191 0.0037 -0.0363 0.0364 0.0664 -0.2065 -0.1264 0.0571 -0.0173 -0.1454 -0.0173 0.0830 -0.0438 0.0029 0.1918 0.2079 0.0393 0.0081 0.0228 -0.0086 -0.1322 -0.0006 0.1892 0.0536 0.0155 -0.0371 -0.0056 0.0188 0.0864 -0.1048 0.2564 -0.1761 0.0433 0.0597 -0.0005 0.0409 -0.0803 0.0188 0.0211 -0.0255 -0.0026 -0.0438 -0.2800 -0.0682 -0.0544 -0.0386 -0.0479 0.1007 0.0435 0.0229 0.0632
+
diff --git a/tests/test_main.py b/tests/test_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7628b1258a4eff70c58616e310736795e4ac3e2
--- /dev/null
+++ b/tests/test_main.py
@@ -0,0 +1,52 @@
+import logging
+import os
+import pathlib
+import shutil
+import tempfile
+import unittest
+from allennlp.commands import train
+from allennlp.common import Params, util
+
+
+class TrainingEndToEndTest(unittest.TestCase):
+    PROJECT_ROOT = (pathlib.Path(__file__).parent / '..').resolve()
+    MODULE_ROOT = PROJECT_ROOT / 'combo'
+    TESTS_ROOT = PROJECT_ROOT / 'tests'
+    FIXTURES_ROOT = TESTS_ROOT / 'fixtures'
+    TEST_DIR = pathlib.Path(tempfile.mkdtemp(prefix='allennlp_tests'))
+
+    def setUp(self) -> None:
+        logging.getLogger('allennlp.common.util').disabled = True
+        logging.getLogger('allennlp.training.tensorboard_writer').disabled = True
+        logging.getLogger('allennlp.common.params').disabled = True
+        logging.getLogger('allennlp.nn.initializers').disabled = True
+
+    def test_training_produces_model(self):
+        # given
+        util.import_module_and_submodules('combo.models')
+        util.import_module_and_submodules('combo.training')
+        ext_vars = {
+            'training_data_path': os.path.join(self.FIXTURES_ROOT, 'example.conllu'),
+            'validation_data_path': os.path.join(self.FIXTURES_ROOT, 'example.conllu'),
+            'features': 'token char',
+            'targets': 'deprel head lemma feats upostag xpostag',
+            'type': 'default',
+            'pretrained_tokens': os.path.join(self.FIXTURES_ROOT, 'example.vec'),
+            'pretrained_transformer_name': '',
+            'embedding_dim': '300',
+
+        }
+        params = Params.from_file(os.path.join(self.PROJECT_ROOT, 'config.template.jsonnet'),
+                                  ext_vars=ext_vars)
+        params['trainer']['tensorboard_writer']['serialization_dir'] = os.path.join(self.TEST_DIR, 'metrics')
+        params['trainer']['num_epochs'] = 1
+        params['data_loader']['batch_sampler']['word_batch_size'] = 1
+
+        # when
+        model = train.train_model(params, serialization_dir=self.TEST_DIR)
+
+        # then
+        self.assertIsNotNone(model)
+
+    def tearDown(self) -> None:
+        shutil.rmtree(self.TEST_DIR)
diff --git a/tests/utils/test_checks.py b/tests/utils/test_checks.py
new file mode 100644
index 0000000000000000000000000000000000000000..df9e624246d7a0ae243cae65ea5756f2933b5829
--- /dev/null
+++ b/tests/utils/test_checks.py
@@ -0,0 +1,38 @@
+"""Checks tests."""
+import unittest
+
+import torch
+from allennlp.common import checks as allen_checks
+
+from combo.utils import checks
+
+
+class SizeCheckTest(unittest.TestCase):
+
+    def test_equal_sizes(self):
+        # given
+        size = (10, 2)
+        tensor1 = torch.rand(size)
+        tensor2 = torch.rand(size)
+
+        # when
+        checks.check_size_match(size_1=tensor1.size(),
+                                size_2=tensor2.size(),
+                                tensor_1_name='', tensor_2_name='')
+
+        # then
+        # nothing happens
+        self.assertTrue(True)
+
+    def test_different_sizes(self):
+        # given
+        size1 = (10, 2)
+        size2 = (20, 1)
+        tensor1 = torch.rand(size1)
+        tensor2 = torch.rand(size2)
+
+        # when/then
+        with self.assertRaises(allen_checks.ConfigurationError):
+            checks.check_size_match(size_1=tensor1.size(),
+                                    size_2=tensor2.size(),
+                                    tensor_1_name='', tensor_2_name='')
diff --git a/tests/utils/test_metrics.py b/tests/utils/test_metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..305939c032d344a2e7d9a1d0e772a0bc8acf19f0
--- /dev/null
+++ b/tests/utils/test_metrics.py
@@ -0,0 +1,152 @@
+"""Metrics tests."""
+import unittest
+
+import torch
+
+from combo.utils import metrics
+
+
+class SemanticMetricsTest(unittest.TestCase):
+
+    def setUp(self) -> None:
+        self.mask: torch.BoolTensor = torch.tensor([
+            [True, True, True, True],
+            [True, True, True, False],
+            [True, True, True, False],
+        ])
+        pred = torch.tensor([
+            [0, 1, 2, 3],
+            [0, 1, 2, 3],
+            [0, 1, 2, 3],
+        ])
+        pred_seq = pred.reshape(3, 4, 1)
+        gold = pred.clone()
+        gold_seq = pred_seq.clone()
+        self.upostag, self.upostag_l = (('upostag', x) for x in [pred, gold])
+        self.xpostag, self.xpostag_l = (('xpostag', x) for x in [pred, gold])
+        self.semrel, self.semrel_l = (('semrel', x) for x in [pred, gold])
+        self.head, self.head_l = (('head', x) for x in [pred, gold])
+        self.deprel, self.deprel_l = (('deprel', x) for x in [pred, gold])
+        self.feats, self.feats_l = (('feats', x) for x in [pred_seq, gold_seq])
+        self.lemma, self.lemma_l = (('lemma', x) for x in [pred_seq, gold_seq])
+        self.predictions = dict(
+            [self.upostag, self.xpostag, self.semrel, self.feats, self.lemma, self.head, self.deprel])
+        self.gold_labels = dict([self.upostag_l, self.xpostag_l, self.semrel_l, self.feats_l, self.lemma_l, self.head_l,
+                                 self.deprel_l])
+        self.eps = 1e-6
+
+    def test_every_prediction_correct(self):
+        # given
+        metric = metrics.SemanticMetrics()
+
+        # when
+        metric(self.predictions, self.gold_labels, self.mask)
+
+        # then
+        self.assertEqual(1.0, metric.em_score)
+
+    def test_missing_predictions_for_one_target(self):
+        # given
+        metric = metrics.SemanticMetrics()
+        self.predictions['upostag'] = None
+        self.gold_labels['upostag'] = None
+
+        # when
+        metric(self.predictions, self.gold_labels, self.mask)
+
+        # then
+        self.assertEqual(1.0, metric.em_score)
+
+    def test_missing_predictions_for_two_targets(self):
+        # given
+        metric = metrics.SemanticMetrics()
+        self.predictions['upostag'] = None
+        self.gold_labels['upostag'] = None
+        self.predictions['lemma'] = None
+        self.gold_labels['lemma'] = None
+
+        # when
+        metric(self.predictions, self.gold_labels, self.mask)
+
+        # then
+        self.assertEqual(1.0, metric.em_score)
+
+    def test_one_classification_in_one_target_is_wrong(self):
+        # given
+        metric = metrics.SemanticMetrics()
+        self.predictions['upostag'][0][0] = 100
+
+        # when
+        metric(self.predictions, self.gold_labels, self.mask)
+
+        # then
+        self.assertAlmostEqual(0.9, metric.em_score, delta=self.eps)
+
+    def test_classification_errors_and_target_without_predictions(self):
+        # given
+        metric = metrics.SemanticMetrics()
+        self.predictions['feats'] = None
+        self.gold_labels['feats'] = None
+        self.predictions['upostag'][0][0] = 100
+        self.predictions['upostag'][2][0] = 100
+        # should be ignored due to masking
+        self.predictions['upostag'][1][3] = 100
+
+        # when
+        metric(self.predictions, self.gold_labels, self.mask)
+
+        # then
+        self.assertAlmostEqual(0.8, metric.em_score, delta=self.eps)
+
+
+class SequenceBoolAccuracyTest(unittest.TestCase):
+
+    def setUp(self) -> None:
+        self.mask: torch.BoolTensor = torch.tensor([
+            [True, True, True, True],
+            [True, True, True, False],
+            [True, True, True, False],
+        ])
+
+    def test_regular_classification_accuracy(self):
+        # given
+        metric = metrics.SequenceBoolAccuracy()
+        predictions = torch.tensor([
+            [1, 1, 0, 8],
+            [1, 2, 3, 4],
+            [9, 4, 3, 9],
+        ])
+        gold_labels = torch.tensor([
+            [11, 1, 0, 8],
+            [14, 2, 3, 14],
+            [9, 4, 13, 9],
+        ])
+
+        # when
+        metric(predictions, gold_labels, self.mask)
+
+        # then
+        self.assertEqual(metric._correct_count.item(), 7)
+        self.assertEqual(metric._total_count.item(), 10)
+
+    def test_feats_classification_accuracy(self):
+        # given
+        metric = metrics.SequenceBoolAccuracy(prod_last_dim=True)
+        # batch_size, sequence_length, classes
+        predictions = torch.tensor([
+            [[1, 4], [0, 2], [0, 2], [0, 3]],
+            [[1, 4], [0, 2], [0, 2], [0, 3]],
+            [[1, 4], [0, 2], [0, 2], [0, 3]],
+        ])
+        gold_labels = torch.tensor([
+            [[1, 14], [0, 2], [0, 2], [0, 3]],
+            [[11, 4], [0, 2], [0, 2], [10, 3]],
+            [[1, 4], [0, 2], [10, 12], [0, 3]],
+        ])
+
+        # when
+        metric(predictions, gold_labels, self.mask)
+
+        # then
+        self.assertEqual(metric._correct_count.item(), 7)
+        self.assertEqual(metric._total_count.item(), 10)