diff --git a/README.md b/README.md
index 19847b8e5df935feb36bb9e5ec1e89cc3f1d35ed..a9c21135005d9abf50f2234bedca817b7d180327 100644
--- a/README.md
+++ b/README.md
@@ -10,19 +10,24 @@
 </p>
 
 ## Quick start
-Clone this repository and install COMBO (we suggest using virtualenv/conda with Python 3.6+):
+Clone this repository and install COMBO (we suggest creating a virtualenv/conda environment with Python 3.6+, as a bundle of required packages will be installed):
 ```bash
 git clone https://gitlab.clarin-pl.eu/syntactic-tools/clarinbiz/combo.git
 cd combo
 python setup.py develop
 ```
-Run the following lines in your Python console to make predictions with a pre-trained model:
+Run the following commands in your Python console to make predictions with a pre-trained model:
 ```python
-import combo.predict as predict
+from combo.predict import COMBO
 
-nlp = predict.SemanticMultitaskPredictor.from_pretrained("polish-herbert-base")
-sentence = nlp("Moje zdanie.")
-print(sentence.tokens)
+nlp = COMBO.from_pretrained("polish-herbert-base")
+sentence = nlp("COVID-19 to ostra choroba zakaźna układu oddechowego wywołana zakażeniem wirusem SARS-CoV-2.")
+```
+Predictions are accessible as a list of token attributes:
+```python
+print("{:5} {:15} {:15} {:10} {:10} {:10}".format('ID', 'TOKEN', 'LEMMA', 'UPOS', 'HEAD', 'DEPREL'))
+for token in sentence.tokens:
+    print("{:5} {:15} {:15} {:10} {:10} {:10}".format(str(token.id), token.token, token.lemma, token.upostag, str(token.head), token.deprel))
 ```
 
 ## Details
@@ -31,4 +36,3 @@ print(sentence.tokens)
 - [**Pre-trained models**](docs/models.md)
 - [**Training**](docs/training.md)
 - [**Prediction**](docs/prediction.md)
-
diff --git a/combo/data/api.py b/combo/data/api.py
index 10a3a727c9220601ebf243752a7e605e127a1774..7d44917ecc42a555be3c20e8500f595b8ee1edf1 100644
--- a/combo/data/api.py
+++ b/combo/data/api.py
@@ -1,13 +1,13 @@
 import collections
+import dataclasses
+import json
 from dataclasses import dataclass, field
 from typing import Optional, List, Dict, Any, Union, Tuple
 
 import conllu
-from dataclasses_json import dataclass_json
 from overrides import overrides
 
 
-@dataclass_json
 @dataclass
 class Token:
     id: Optional[Union[int, Tuple]] = None
@@ -23,13 +23,19 @@ class Token:
     semrel: Optional[str] = None
 
 
-@dataclass_json
 @dataclass
 class Sentence:
     tokens: List[Token] = field(default_factory=list)
     sentence_embedding: List[float] = field(default_factory=list)
     metadata: Dict[str, Any] = field(default_factory=collections.OrderedDict)
 
+    def to_json(self):
+        return json.dumps({
+            "tokens": [dataclasses.asdict(t) for t in self.tokens],
+            "sentence_embedding": self.sentence_embedding,
+            "metadata": self.metadata,
+        })
+
 
 class _TokenList(conllu.TokenList):
 
@@ -41,7 +47,7 @@ class _TokenList(conllu.TokenList):
 def sentence2conllu(sentence: Sentence, keep_semrel: bool = True) -> conllu.TokenList:
     tokens = []
     for token in sentence.tokens:
-        token_dict = collections.OrderedDict(token.to_dict())
+        token_dict = collections.OrderedDict(dataclasses.asdict(token))
         # Remove semrel to have default conllu format.
         if not keep_semrel:
             del token_dict["semrel"]
@@ -50,6 +56,10 @@ def sentence2conllu(sentence: Sentence, keep_semrel: bool = True) -> conllu.Toke
     for t in tokens:
         if type(t["id"]) == list:
             t["id"] = tuple(t["id"])
+        if t["deps"]:
+            for dep in t["deps"]:
+                if len(dep) > 1 and type(dep[1]) == list:
+                    dep[1] = tuple(dep[1])
     return _TokenList(tokens=tokens,
                       metadata=sentence.metadata)
 
@@ -64,9 +74,18 @@ def tokens2conllu(tokens: List[str]) -> conllu.TokenList:
 
 
 def conllu2sentence(conllu_sentence: conllu.TokenList,
-                    sentence_embedding: List[float]) -> Sentence:
+                    sentence_embedding=None) -> Sentence:
+    if sentence_embedding is None:
+        sentence_embedding = []
+    tokens = []
+    for token in conllu_sentence.tokens:
+        tokens.append(
+            Token(
+                **token
+            )
+        )
     return Sentence(
-        tokens=[Token.from_dict(t) for t in conllu_sentence.tokens],
+        tokens=tokens,
         sentence_embedding=sentence_embedding,
         metadata=conllu_sentence.metadata
     )
diff --git a/combo/data/dataset.py b/combo/data/dataset.py
index 459a755c7f71c40e449d0542bf9af21d05e1f2c9..48b68b14e592dc6f98e3f62e8b5c3bd23899cb4c 100644
--- a/combo/data/dataset.py
+++ b/combo/data/dataset.py
@@ -1,9 +1,11 @@
+import copy
 import logging
-from typing import Union, List, Dict, Iterable, Optional, Any
+from typing import Union, List, Dict, Iterable, Optional, Any, Tuple
 
 import conllu
+import torch
 from allennlp import data as allen_data
-from allennlp.common import checks
+from allennlp.common import checks, util
 from allennlp.data import fields as allen_fields, vocabulary
 from conllu import parser
 from dataclasses import dataclass
@@ -35,6 +37,9 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
         if "token" not in features and "char" not in features:
             raise checks.ConfigurationError("There must be at least one ('char' or 'token') text-based feature!")
 
+        if "deps" in targets and not ("head" in targets and "deprel" in targets):
+            raise checks.ConfigurationError("Add 'head' and 'deprel' to targets when using 'deps'!")
+
         intersection = set(features).intersection(set(targets))
         if len(intersection) != 0:
             raise checks.ConfigurationError(
@@ -49,6 +54,8 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
         field_parsers = parser.DEFAULT_FIELD_PARSERS
         # Do not make it nullable
         field_parsers.pop("xpostag", None)
+        # Ignore parsing misc
+        field_parsers.pop("misc", None)
         if self.use_sem:
             fields = list(fields)
             fields.append("semrel")
@@ -102,13 +109,46 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
                     elif target_name == "feats":
                         target_values = self._feat_values(tree_tokens)
                         fields_[target_name] = fields.SequenceMultiLabelField(target_values,
-                                                                              self._feats_to_index_multi_label,
+                                                                              self._feats_indexer,
+                                                                              self._feats_as_tensor_wrapper,
                                                                               text_field,
                                                                               label_namespace="feats_labels")
                     elif target_name == "head":
                         target_values = [0 if v == "_" else int(v) for v in target_values]
                         fields_[target_name] = allen_fields.SequenceLabelField(target_values, text_field,
                                                                                label_namespace=target_name + "_labels")
+                    elif target_name == "deps":
+                        # Graphs require adding ROOT (AdjacencyField uses sequence length from TextField).
+                        text_field_deps = allen_fields.TextField([_Token("ROOT")] + copy.deepcopy(tokens),
+                                                                 self._token_indexers)
+                        enhanced_heads: List[Tuple[int, int]] = []
+                        enhanced_deprels: List[str] = []
+                        for idx, t in enumerate(tree_tokens):
+                            t_deps = t["deps"]
+                            if t_deps and t_deps != "_":
+                                for rel, head in t_deps:
+                                    # EmoryNLP skips the first edge, if there are two edges between the same
+                                    # nodes. Thanks to that one is in a tree and another in a graph.
+                                    # This snippet follows that approach.
+                                    if enhanced_heads and enhanced_heads[-1] == (idx, head):
+                                        enhanced_heads.pop()
+                                        enhanced_deprels.pop()
+                                    enhanced_heads.append((idx, head))
+                                    enhanced_deprels.append(rel)
+                        fields_["enhanced_heads"] = allen_fields.AdjacencyField(
+                            indices=enhanced_heads,
+                            sequence_field=text_field_deps,
+                            label_namespace="enhanced_heads_labels",
+                            padding_value=0,
+                        )
+                        fields_["enhanced_deprels"] = allen_fields.AdjacencyField(
+                            indices=enhanced_heads,
+                            sequence_field=text_field_deps,
+                            labels=enhanced_deprels,
+                            # Label namespace matches regular tree parsing.
+                            label_namespace="deprel_labels",
+                            padding_value=0,
+                        )
                     else:
                         fields_[target_name] = allen_fields.SequenceLabelField(target_values, text_field,
                                                                                label_namespace=target_name + "_labels")
@@ -128,7 +168,9 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
                 token["feats"] = field
 
         # metadata
-        fields_["metadata"] = allen_fields.MetadataField({"input": tree, "field_names": self.fields})
+        fields_["metadata"] = allen_fields.MetadataField({"input": tree,
+                                                          "field_names": self.fields,
+                                                          "tokens": tokens})
 
         return allen_data.Instance(fields_)
 
@@ -151,12 +193,26 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
         return features
 
     @staticmethod
-    def _feats_to_index_multi_label(vocab: allen_data.Vocabulary):
+    def _feats_as_tensor_wrapper(field: fields.SequenceMultiLabelField):
+        def as_tensor(padding_lengths):
+            desired_num_tokens = padding_lengths["num_tokens"]
+            assert len(field._indexed_multi_labels) > 0
+            classes_count = len(field._indexed_multi_labels[0])
+            default_value = [0.0] * classes_count
+            padded_tags = util.pad_sequence_to_length(field._indexed_multi_labels, desired_num_tokens,
+                                                      lambda: default_value)
+            tensor = torch.LongTensor(padded_tags)
+            return tensor
+
+        return as_tensor
+
+    @staticmethod
+    def _feats_indexer(vocab: allen_data.Vocabulary):
         label_namespace = "feats_labels"
         vocab_size = vocab.get_vocab_size(label_namespace)
         slices = get_slices_if_not_provided(vocab)
 
-        def _m_from_n_ones_encoding(multi_label: List[str]) -> List[int]:
+        def _m_from_n_ones_encoding(multi_label: List[str], sentence_length: int) -> List[int]:
             one_hot_encoding = [0] * vocab_size
             for cat, cat_indices in slices.items():
                 if cat not in ["__PAD__", "_"]:
diff --git a/combo/data/fields/sequence_multilabel_field.py b/combo/data/fields/sequence_multilabel_field.py
index 4e98a148aee35e42af0b4828a031368fe0eafc12..b200580cf1edfba1e710b42bc19c4c4efdb0db4f 100644
--- a/combo/data/fields/sequence_multilabel_field.py
+++ b/combo/data/fields/sequence_multilabel_field.py
@@ -5,7 +5,7 @@ from typing import Set, List, Callable, Iterator, Union, Dict
 
 import torch
 from allennlp import data
-from allennlp.common import checks, util
+from allennlp.common import checks
 from allennlp.data import fields
 from overrides import overrides
 
@@ -17,15 +17,16 @@ class SequenceMultiLabelField(data.Field[torch.Tensor]):
     A `SequenceMultiLabelField` is an extension of the :class:`MultiLabelField` that allows for multiple labels
     while keeping sequence dimension.
 
-    This field will get converted into a sequence of vectors of length equal to the vocabulary size with
-    M from N encoding for the labels (all zeros, and ones for the labels).
+    To allow configuration to different circumstances, class takes few delegates functions.
 
     # Parameters
 
     multi_labels : `List[List[str]]`
     multi_label_indexer : `Callable[[data.Vocabulary], Callable[[List[str]], List[int]]]`
-        Nested callable which based on vocab creates mapper for multilabel field in the sequence from strings
-        to indexed, int values.
+        Nested callable which based on vocab and sequence length maps values of the fields in the sequence
+        from strings to indexed, int values.
+    as_tensor: `Callable[["SequenceMultiLabelField"], Callable[[Dict[str, int]], torch.Tensor]]`
+        Nested callable which based on the field itself, maps indexed data to a tensor.
     sequence_field : `SequenceField`
         A field containing the sequence that this `SequenceMultiLabelField` is labeling.  Most often, this is a
         `TextField`, for tagging individual tokens in a sentence.
@@ -43,7 +44,8 @@ class SequenceMultiLabelField(data.Field[torch.Tensor]):
     def __init__(
             self,
             multi_labels: List[List[str]],
-            multi_label_indexer: Callable[[data.Vocabulary], Callable[[List[str]], List[int]]],
+            multi_label_indexer: Callable[[data.Vocabulary], Callable[[List[str], int], List[int]]],
+            as_tensor: Callable[["SequenceMultiLabelField"], Callable[[Dict[str, int]], torch.Tensor]],
             sequence_field: fields.SequenceField,
             label_namespace: str = "labels",
     ) -> None:
@@ -53,6 +55,7 @@ class SequenceMultiLabelField(data.Field[torch.Tensor]):
         self._label_namespace = label_namespace
         self._indexed_multi_labels = None
         self._maybe_warn_for_namespace(label_namespace)
+        self.as_tensor_wrapper = as_tensor(self)
         if len(multi_labels) != sequence_field.sequence_length():
             raise checks.ConfigurationError(
                 "Label length and sequence length "
@@ -101,7 +104,7 @@ class SequenceMultiLabelField(data.Field[torch.Tensor]):
 
         indexed = []
         for multi_label in self.multi_labels:
-            indexed.append(indexer(multi_label))
+            indexed.append(indexer(multi_label, len(self.multi_labels)))
         self._indexed_multi_labels = indexed
 
     @overrides
@@ -110,19 +113,13 @@ class SequenceMultiLabelField(data.Field[torch.Tensor]):
 
     @overrides
     def as_tensor(self, padding_lengths: Dict[str, int]) -> torch.Tensor:
-        desired_num_tokens = padding_lengths["num_tokens"]
-        assert len(self._indexed_multi_labels) > 0
-        classes_count = len(self._indexed_multi_labels[0])
-        default_value = [0.0] * classes_count
-        padded_tags = util.pad_sequence_to_length(self._indexed_multi_labels, desired_num_tokens, lambda: default_value)
-        tensor = torch.LongTensor(padded_tags)
-        return tensor
+        return self.as_tensor_wrapper(padding_lengths)
 
     @overrides
     def empty_field(self) -> "SequenceMultiLabelField":
-        # The empty_list here is needed for mypy
         empty_list: List[List[str]] = [[]]
         sequence_label_field = SequenceMultiLabelField(empty_list, lambda x: lambda y: y,
+                                                       lambda x: lambda y: y,
                                                        self.sequence_field.empty_field())
         sequence_label_field._indexed_labels = empty_list
         return sequence_label_field
diff --git a/combo/main.py b/combo/main.py
index 44ad091f8c5004bf9f4e70323ec828bf44288447..17c960ac7caa84513692841abb989955c7925721 100644
--- a/combo/main.py
+++ b/combo/main.py
@@ -18,7 +18,7 @@ from combo.utils import checks
 
 logger = logging.getLogger(__name__)
 _FEATURES = ["token", "char", "upostag", "xpostag", "lemma", "feats"]
-_TARGETS = ["deprel", "feats", "head", "lemma", "upostag", "xpostag", "semrel", "sent"]
+_TARGETS = ["deprel", "feats", "head", "lemma", "upostag", "xpostag", "semrel", "sent", "deps"]
 
 FLAGS = flags.FLAGS
 flags.DEFINE_enum(name="mode", default=None, enum_values=["train", "predict"],
@@ -33,8 +33,10 @@ flags.DEFINE_string(name="output_file", default="output.log",
 # Training flags
 flags.DEFINE_list(name="training_data_path", default="./tests/fixtures/example.conllu",
                   help="Training data path(s)")
+flags.DEFINE_alias(name="training_data", original_name="training_data_path")
 flags.DEFINE_list(name="validation_data_path", default="",
                   help="Validation data path(s)")
+flags.DEFINE_alias(name="validation_data", original_name="validation_data_path")
 flags.DEFINE_string(name="pretrained_tokens", default="",
                     help="Pretrained tokens embeddings path")
 flags.DEFINE_integer(name="embedding_dim", default=300,
@@ -134,7 +136,7 @@ def run(_):
             params = common.Params.from_file(FLAGS.config_path, ext_vars=_get_ext_vars())["dataset_reader"]
             params.pop("type")
             dataset_reader = dataset.UniversalDependenciesDatasetReader.from_params(params)
-            predictor = predict.SemanticMultitaskPredictor(
+            predictor = predict.COMBO(
                 model=model,
                 dataset_reader=dataset_reader
             )
diff --git a/combo/models/__init__.py b/combo/models/__init__.py
index 5aa7b283c70af76a27d97f63783368bdbe4ffa3f..ec7a1380e1cfc80b0302806e46cca4e5fc2d3568 100644
--- a/combo/models/__init__.py
+++ b/combo/models/__init__.py
@@ -1,8 +1,9 @@
 """Models module."""
 from .base import FeedForwardPredictor
+from .graph_parser import GraphDependencyRelationModel
 from .parser import DependencyRelationModel
 from .embeddings import CharacterBasedWordEmbeddings
 from .encoder import ComboEncoder
 from .lemma import LemmatizerModel
-from .model import SemanticMultitaskModel
+from .model import ComboModel
 from .morpho import MorphologicalFeatures
diff --git a/combo/models/base.py b/combo/models/base.py
index 10e9d371a1cb1665358819898817e4b454b9244c..a5cb5fe61f85a98f78d143a54695d01948aa8dda 100644
--- a/combo/models/base.py
+++ b/combo/models/base.py
@@ -27,11 +27,11 @@ class Linear(nn.Linear, common.FromParams):
     def __init__(self,
                  in_features: int,
                  out_features: int,
-                 activation: Optional[allen_nn.Activation] = lambda x: x,
+                 activation: Optional[allen_nn.Activation] = None,
                  dropout_rate: Optional[float] = 0.0):
         super().__init__(in_features, out_features)
-        self.activation = activation
-        self.dropout = nn.Dropout(p=dropout_rate) if dropout_rate else lambda x: x
+        self.activation = activation if activation else self.identity
+        self.dropout = nn.Dropout(p=dropout_rate) if dropout_rate else self.identity
 
     def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
         x = super().forward(x)
@@ -41,6 +41,10 @@ class Linear(nn.Linear, common.FromParams):
     def get_output_dim(self) -> int:
         return self.out_features
 
+    @staticmethod
+    def identity(x):
+        return x
+
 
 @Predictor.register("feedforward_predictor")
 @Predictor.register("feedforward_predictor_from_vocab", constructor="from_vocab")
diff --git a/combo/models/embeddings.py b/combo/models/embeddings.py
index 5cad95928dab03d8e5046eb1a281c07e9ffe33ff..6ad25590e3f29bcde42266b8ee9cc720787b4388 100644
--- a/combo/models/embeddings.py
+++ b/combo/models/embeddings.py
@@ -196,10 +196,10 @@ class FeatsTokenEmbedder(token_embedders.Embedding):
 
     def forward(self, tokens: torch.Tensor) -> torch.Tensor:
         # (batch_size, sentence_length, features_vocab_length)
-        mask = (tokens > 0).float()
+        mask = tokens.gt(0)
         # (batch_size, sentence_length, features_vocab_length, embedding_dim)
         x = super().forward(tokens)
         # (batch_size, sentence_length, embedding_dim)
         return x.sum(dim=-2) / (
-            (mask.sum(dim=-1) + util.tiny_value_of_dtype(mask.dtype)).unsqueeze(dim=-1)
+            (mask.sum(dim=-1) + util.tiny_value_of_dtype(torch.float)).unsqueeze(dim=-1)
         )
diff --git a/combo/models/graph_parser.py b/combo/models/graph_parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..edcdc2d0785dd73d91b3e79249d196ba55ec148c
--- /dev/null
+++ b/combo/models/graph_parser.py
@@ -0,0 +1,188 @@
+"""Enhanced dependency parsing models."""
+from typing import Tuple, Dict, Optional, Union, List
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from allennlp import data
+from allennlp.nn import chu_liu_edmonds
+
+from combo.models import base, utils
+
+
+class GraphHeadPredictionModel(base.Predictor):
+    """Head prediction model."""
+
+    def __init__(self,
+                 head_projection_layer: base.Linear,
+                 dependency_projection_layer: base.Linear,
+                 cycle_loss_n: int = 0,
+                 graph_weighting: float = 0.2):
+        super().__init__()
+        self.head_projection_layer = head_projection_layer
+        self.dependency_projection_layer = dependency_projection_layer
+        self.cycle_loss_n = cycle_loss_n
+        self.graph_weighting = graph_weighting
+
+    def forward(self,
+                x: Union[torch.Tensor, List[torch.Tensor]],
+                labels: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
+                mask: Optional[torch.BoolTensor] = None,
+                sample_weights: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]:
+        if mask is None:
+            mask = x.new_ones(x.size()[-1])
+        heads_labels = None
+        if labels is not None and labels[0] is not None:
+            heads_labels = labels
+
+        head_arc_emb = self.head_projection_layer(x)
+        dep_arc_emb = self.dependency_projection_layer(x)
+        x = dep_arc_emb.bmm(head_arc_emb.transpose(2, 1))
+        pred = x.sigmoid() > 0.5
+
+        output = {
+            "prediction": pred,
+            "probability": x
+        }
+
+        if heads_labels is not None:
+            if sample_weights is None:
+                sample_weights = heads_labels.new_ones([mask.size(0)])
+            output["loss"], output["cycle_loss"] = self._loss(x, heads_labels, mask, sample_weights)
+
+        return output
+
+    def _cycle_loss(self, pred: torch.Tensor):
+        BATCH_SIZE, _, _ = pred.size()
+        loss = pred.new_zeros(BATCH_SIZE)
+        # Index from 1: as using non __ROOT__ tokens
+        pred = pred.softmax(-1)[:, 1:, 1:]
+        x = pred
+        for i in range(self.cycle_loss_n):
+            loss += self._batch_trace(x)
+
+            # Don't multiple on last iteration
+            if i < self.cycle_loss_n - 1:
+                x = x.bmm(pred)
+
+        return loss
+
+    @staticmethod
+    def _batch_trace(x: torch.Tensor) -> torch.Tensor:
+        assert len(x.size()) == 3
+        BATCH_SIZE, N, M = x.size()
+        assert N == M
+        identity = x.new_tensor(torch.eye(N))
+        identity = identity.reshape((1, N, N))
+        batch_identity = identity.repeat(BATCH_SIZE, 1, 1)
+        return (x * batch_identity).sum((-1, -2))
+
+    def _loss(self, pred: torch.Tensor, labels: torch.Tensor,  mask: torch.BoolTensor,
+              sample_weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+        BATCH_SIZE, N, M = pred.size()
+        assert N == M
+        SENTENCE_LENGTH = N
+
+        valid_positions = mask.sum()
+
+        result = []
+        true = labels
+        # Ignore first pred dimension as it is ROOT token prediction
+        for i in range(SENTENCE_LENGTH - 1):
+            pred_i = pred[:, i + 1, 1:].reshape(-1)
+            true_i = true[:, i + 1, 1:].reshape(-1)
+            mask_i = mask[:, i]
+            bce_loss = F.binary_cross_entropy_with_logits(pred_i, true_i, reduction="none").mean(-1) * mask_i
+            result.append(bce_loss)
+        cycle_loss = self._cycle_loss(pred)
+        loss = torch.stack(result).transpose(1, 0) * sample_weights.unsqueeze(-1)
+        return loss.sum() / valid_positions + cycle_loss.mean(), cycle_loss.mean()
+
+
+@base.Predictor.register("combo_graph_dependency_parsing_from_vocab", constructor="from_vocab")
+class GraphDependencyRelationModel(base.Predictor):
+    """Dependency relation parsing model."""
+
+    def __init__(self,
+                 head_predictor: GraphHeadPredictionModel,
+                 head_projection_layer: base.Linear,
+                 dependency_projection_layer: base.Linear,
+                 relation_prediction_layer: base.Linear):
+        super().__init__()
+        self.head_predictor = head_predictor
+        self.head_projection_layer = head_projection_layer
+        self.dependency_projection_layer = dependency_projection_layer
+        self.relation_prediction_layer = relation_prediction_layer
+
+    def forward(self,
+                x: Union[torch.Tensor, List[torch.Tensor]],
+                mask: Optional[torch.BoolTensor] = None,
+                labels: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
+                sample_weights: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]:
+        relations_labels, head_labels, enhanced_heads_labels, enhanced_deprels_labels = None, None, None, None
+        if labels is not None and labels[0] is not None:
+            relations_labels, head_labels, enhanced_heads_labels = labels
+
+        head_output = self.head_predictor(x, enhanced_heads_labels, mask, sample_weights)
+        head_pred = head_output["probability"]
+        BATCH_SIZE, LENGTH, _ = head_pred.size()
+
+        head_rel_emb = self.head_projection_layer(x)
+
+        dep_rel_emb = self.dependency_projection_layer(x)
+
+        # All possible edges combinations for each batch
+        # Repeat interleave to have [emb1, emb1 ... (length times) ... emb1, emb2 ... ]
+        head_rel_pred = head_rel_emb.repeat_interleave(LENGTH, -2)
+        # Regular repeat to have all combinations [deprel1, deprel2, ... deprelL, deprel1 ...]
+        dep_rel_pred = dep_rel_emb.repeat(1, LENGTH, 1)
+
+        # All possible edges combinations for each batch
+        dep_rel_pred = torch.cat((head_rel_pred, dep_rel_pred), dim=-1)
+
+        relation_prediction = self.relation_prediction_layer(dep_rel_pred).reshape(BATCH_SIZE, LENGTH, LENGTH, -1)
+        output = head_output
+
+        output["prediction"] = (relation_prediction.argmax(-1), head_output["prediction"])
+        output["rel_probability"] = relation_prediction
+
+        if labels is not None and labels[0] is not None:
+            if sample_weights is None:
+                sample_weights = labels.new_ones([mask.size(0)])
+            loss = self._loss(relation_prediction, relations_labels, enhanced_heads_labels, mask, sample_weights)
+            output["loss"] = (loss, head_output["loss"])
+
+        return output
+
+    @staticmethod
+    def _loss(pred: torch.Tensor,
+              true: torch.Tensor,
+              heads_true: torch.Tensor,
+              mask: torch.BoolTensor,
+              sample_weights: torch.Tensor) -> torch.Tensor:
+        correct_heads_mask = heads_true.long() == 1
+        true = true[correct_heads_mask]
+        pred = pred[correct_heads_mask]
+        loss = F.cross_entropy(pred, true.long())
+        return loss.sum() / pred.size(0)
+
+    @classmethod
+    def from_vocab(cls,
+                   vocab: data.Vocabulary,
+                   vocab_namespace: str,
+                   head_predictor: GraphHeadPredictionModel,
+                   head_projection_layer: base.Linear,
+                   dependency_projection_layer: base.Linear
+                   ):
+        """Creates parser combining model configuration and vocabulary data."""
+        assert vocab_namespace in vocab.get_namespaces()
+        relation_prediction_layer = base.Linear(
+            in_features=head_projection_layer.get_output_dim() + dependency_projection_layer.get_output_dim(),
+            out_features=vocab.get_vocab_size(vocab_namespace)
+        )
+        return cls(
+            head_predictor=head_predictor,
+            head_projection_layer=head_projection_layer,
+            dependency_projection_layer=dependency_projection_layer,
+            relation_prediction_layer=relation_prediction_layer
+        )
diff --git a/combo/models/model.py b/combo/models/model.py
index 77b43e3c1a95e09b15c310409af0090f097d47fa..9866bcb4fba41ed2506b2d33290e6cd0fe237d29 100644
--- a/combo/models/model.py
+++ b/combo/models/model.py
@@ -12,7 +12,7 @@ from combo.utils import metrics
 
 
 @allen_models.Model.register("semantic_multitask")
-class SemanticMultitaskModel(allen_models.Model):
+class ComboModel(allen_models.Model):
     """Main COMBO model."""
 
     def __init__(self,
@@ -27,6 +27,7 @@ class SemanticMultitaskModel(allen_models.Model):
                  semantic_relation: Optional[base.Predictor] = None,
                  morphological_feat: Optional[base.Predictor] = None,
                  dependency_relation: Optional[base.Predictor] = None,
+                 enhanced_dependency_relation: Optional[base.Predictor] = None,
                  regularizer: allen_nn.RegularizerApplicator = None) -> None:
         super().__init__(vocab, regularizer)
         self.text_field_embedder = text_field_embedder
@@ -39,6 +40,7 @@ class SemanticMultitaskModel(allen_models.Model):
         self.semantic_relation = semantic_relation
         self.morphological_feat = morphological_feat
         self.dependency_relation = dependency_relation
+        self.enhanced_dependency_relation = enhanced_dependency_relation
         self._head_sentinel = torch.nn.Parameter(torch.randn([1, 1, self.seq_encoder.get_output_dim()]))
         self.scores = metrics.SemanticMetrics()
         self._partial_losses = None
@@ -53,12 +55,16 @@ class SemanticMultitaskModel(allen_models.Model):
                 feats: torch.Tensor = None,
                 head: torch.Tensor = None,
                 deprel: torch.Tensor = None,
-                semrel: torch.Tensor = None, ) -> Dict[str, torch.Tensor]:
+                semrel: torch.Tensor = None,
+                enhanced_heads: torch.Tensor = None,
+                enhanced_deprels: torch.Tensor = None) -> Dict[str, torch.Tensor]:
 
         # Prepare masks
-        char_mask: torch.BoolTensor = sentence["char"]["token_characters"] > 0
+        char_mask = sentence["char"]["token_characters"].gt(0)
         word_mask = util.get_text_field_mask(sentence)
 
+        device = word_mask.device
+
         # If enabled weight samples loss by log(sentence_length)
         sample_weights = word_mask.sum(-1).float().log() if self.use_sample_weight else None
 
@@ -69,42 +75,49 @@ class SemanticMultitaskModel(allen_models.Model):
 
         # Concatenate the head sentinel (ROOT) onto the sentence representation.
         head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim)
-        encoder_emb = torch.cat([head_sentinel, encoder_emb], 1)
-        word_mask = torch.cat([word_mask.new_ones((batch_size, 1)), word_mask], 1)
+        encoder_emb_with_root = torch.cat([head_sentinel, encoder_emb], 1)
+        word_mask_with_root = torch.cat([torch.ones((batch_size, 1), device=device), word_mask], 1)
 
         upos_output = self._optional(self.upos_tagger,
-                                     encoder_emb[:, 1:],
-                                     mask=word_mask[:, 1:],
+                                     encoder_emb,
+                                     mask=word_mask,
                                      labels=upostag,
                                      sample_weights=sample_weights)
         xpos_output = self._optional(self.xpos_tagger,
-                                     encoder_emb[:, 1:],
-                                     mask=word_mask[:, 1:],
+                                     encoder_emb,
+                                     mask=word_mask,
                                      labels=xpostag,
                                      sample_weights=sample_weights)
         semrel_output = self._optional(self.semantic_relation,
-                                       encoder_emb[:, 1:],
-                                       mask=word_mask[:, 1:],
+                                       encoder_emb,
+                                       mask=word_mask,
                                        labels=semrel,
                                        sample_weights=sample_weights)
         morpho_output = self._optional(self.morphological_feat,
-                                       encoder_emb[:, 1:],
-                                       mask=word_mask[:, 1:],
+                                       encoder_emb,
+                                       mask=word_mask,
                                        labels=feats,
                                        sample_weights=sample_weights)
         lemma_output = self._optional(self.lemmatizer,
-                                      (encoder_emb[:, 1:], sentence.get("char").get("token_characters")
-                                       if sentence.get("char") else None),
-                                      mask=word_mask[:, 1:],
+                                      (encoder_emb, sentence.get("char").get("token_characters")
+                                      if sentence.get("char") else None),
+                                      mask=word_mask,
                                       labels=lemma.get("char").get("token_characters") if lemma else None,
                                       sample_weights=sample_weights)
         parser_output = self._optional(self.dependency_relation,
-                                       encoder_emb,
+                                       encoder_emb_with_root,
                                        returns_tuple=True,
-                                       mask=word_mask,
+                                       mask=word_mask_with_root,
                                        labels=(deprel, head),
                                        sample_weights=sample_weights)
+        enhanced_parser_output = self._optional(self.enhanced_dependency_relation,
+                                                encoder_emb_with_root,
+                                                returns_tuple=True,
+                                                mask=word_mask_with_root,
+                                                labels=(enhanced_deprels, head, enhanced_heads),
+                                                sample_weights=sample_weights)
         relations_pred, head_pred = parser_output["prediction"]
+        enhanced_relations_pred, enhanced_head_pred = enhanced_parser_output["prediction"]
         output = {
             "upostag": upos_output["prediction"],
             "xpostag": xpos_output["prediction"],
@@ -113,9 +126,14 @@ class SemanticMultitaskModel(allen_models.Model):
             "lemma": lemma_output["prediction"],
             "head": head_pred,
             "deprel": relations_pred,
-            "sentence_embedding": torch.max(encoder_emb[:, 1:], dim=1)[0],
+            "enhanced_head": enhanced_head_pred,
+            "enhanced_deprel": enhanced_relations_pred,
+            "sentence_embedding": torch.max(encoder_emb, dim=1)[0],
         }
 
+        if "rel_probability" in enhanced_parser_output:
+            output["enhanced_deprel_prob"] = enhanced_parser_output["rel_probability"]
+
         if self._has_labels([upostag, xpostag, lemma, feats, head, deprel, semrel]):
 
             # Feats mapping
@@ -134,9 +152,12 @@ class SemanticMultitaskModel(allen_models.Model):
                 "lemma": lemma.get("char").get("token_characters") if lemma else None,
                 "head": head,
                 "deprel": deprel,
+                "enhanced_head": enhanced_heads,
+                "enhanced_deprel": enhanced_deprels,
             }
-            self.scores(output, labels, word_mask[:, 1:])
+            self.scores(output, labels, word_mask)
             relations_loss, head_loss = parser_output["loss"]
+            enhanced_relations_loss, enhanced_head_loss = enhanced_parser_output["loss"]
             losses = {
                 "upostag_loss": upos_output["loss"],
                 "xpostag_loss": xpos_output["loss"],
@@ -145,6 +166,8 @@ class SemanticMultitaskModel(allen_models.Model):
                 "lemma_loss": lemma_output["loss"],
                 "head_loss": head_loss,
                 "deprel_loss": relations_loss,
+                "enhanced_head_loss": enhanced_head_loss,
+                "enhanced_deprel_loss": enhanced_relations_loss,
                 # Cycle loss is only for the metrics purposes.
                 "cycle_loss": parser_output.get("cycle_loss")
             }
diff --git a/combo/models/parser.py b/combo/models/parser.py
index 486b2481b96bf17bb19fd8557916f21dcb6c4584..dfb53ab8ded369b01eae4851dd1d7a9936c05bbe 100644
--- a/combo/models/parser.py
+++ b/combo/models/parser.py
@@ -115,11 +115,13 @@ class DependencyRelationModel(base.Predictor):
     """Dependency relation parsing model."""
 
     def __init__(self,
+                 root_idx: int,
                  head_predictor: HeadPredictionModel,
                  head_projection_layer: base.Linear,
                  dependency_projection_layer: base.Linear,
                  relation_prediction_layer: base.Linear):
         super().__init__()
+        self.root_idx = root_idx
         self.head_predictor = head_predictor
         self.head_projection_layer = head_projection_layer
         self.dependency_projection_layer = dependency_projection_layer
@@ -130,6 +132,7 @@ class DependencyRelationModel(base.Predictor):
                 mask: Optional[torch.BoolTensor] = None,
                 labels: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
                 sample_weights: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]:
+        device = x.device
         if mask is not None:
             mask = mask[:, 1:]
         relations_labels, head_labels = None, None
@@ -151,7 +154,23 @@ class DependencyRelationModel(base.Predictor):
         relation_prediction = self.relation_prediction_layer(dep_rel_pred)
         output = head_output
 
-        output["prediction"] = (relation_prediction.argmax(-1)[:, 1:], head_output["prediction"])
+        if self.training:
+            output["prediction"] = (relation_prediction.argmax(-1)[:, 1:], head_output["prediction"])
+        else:
+            # Mask root label whenever head is not 0.
+            relation_prediction_output = relation_prediction[:, 1:]
+            mask = (head_output["prediction"] == 0)
+            vocab_size = relation_prediction_output.size(-1)
+            root_idx = torch.tensor([self.root_idx], device=device)
+            relation_prediction_output[mask] = (relation_prediction_output
+                                                .masked_select(mask.unsqueeze(-1))
+                                                .reshape(-1, vocab_size)
+                                                .index_fill(-1, root_idx, 10e10))
+            relation_prediction_output[~mask] = (relation_prediction_output
+                                                 .masked_select(~(mask.unsqueeze(-1)))
+                                                 .reshape(-1, vocab_size)
+                                                 .index_fill(-1, root_idx, -10e10))
+            output["prediction"] = (relation_prediction_output.argmax(-1), head_output["prediction"])
 
         if labels is not None and labels[0] is not None:
             if sample_weights is None:
@@ -195,5 +214,6 @@ class DependencyRelationModel(base.Predictor):
             head_predictor=head_predictor,
             head_projection_layer=head_projection_layer,
             dependency_projection_layer=dependency_projection_layer,
-            relation_prediction_layer=relation_prediction_layer
+            relation_prediction_layer=relation_prediction_layer,
+            root_idx=vocab.get_token_index("root", vocab_namespace)
         )
diff --git a/combo/predict.py b/combo/predict.py
index b6c7172c2477efdea4c5e11e0ea1575450602db0..21941d91d56170e7c552af2a3ac1af229816f76d 100644
--- a/combo/predict.py
+++ b/combo/predict.py
@@ -3,6 +3,7 @@ import os
 from typing import List, Union, Tuple
 
 import conllu
+import numpy as np
 from allennlp import data as allen_data, common, models
 from allennlp.common import util
 from allennlp.data import tokenizers
@@ -11,20 +12,20 @@ from overrides import overrides
 
 from combo import data
 from combo.data import sentence2conllu, tokens2conllu, conllu2sentence
-from combo.utils import download
+from combo.utils import download, graph
 
 logger = logging.getLogger(__name__)
 
 
 @predictor.Predictor.register("semantic-multitask-predictor")
 @predictor.Predictor.register("semantic-multitask-predictor-spacy", constructor="with_spacy_tokenizer")
-class SemanticMultitaskPredictor(predictor.Predictor):
+class COMBO(predictor.Predictor):
 
     def __init__(self,
                  model: models.Model,
                  dataset_reader: allen_data.DatasetReader,
                  tokenizer: allen_data.Tokenizer = tokenizers.WhitespaceTokenizer(),
-                 batch_size: int = 500,
+                 batch_size: int = 32,
                  line_to_conllu: bool = False) -> None:
         super().__init__(model, dataset_reader)
         self.batch_size = batch_size
@@ -51,7 +52,7 @@ class SemanticMultitaskPredictor(predictor.Predictor):
 
     def predict(self, sentence: Union[str, List[str], List[List[str]], List[data.Sentence]]):
         if isinstance(sentence, str):
-            return data.Sentence.from_dict(self.predict_json({"sentence": sentence}))
+            return self.predict_json({"sentence": sentence})
         elif isinstance(sentence, list):
             if len(sentence) == 0:
                 return []
@@ -154,6 +155,10 @@ class SemanticMultitaskPredictor(predictor.Predictor):
                         token[field_name] = value
                     elif field_name in ["head"]:
                         token[field_name] = int(predictions[field_name][idx])
+                    elif field_name == "deps":
+                        # Handled after every other decoding
+                        continue
+
                     elif field_name in ["feats"]:
                         slices = self._model.morphological_feat.slices
                         features = []
@@ -168,11 +173,11 @@ class SemanticMultitaskPredictor(predictor.Predictor):
                         if len(features) == 0:
                             field_value = "_"
                         else:
-                            field_value = "|".join(sorted(features))
+                            lowercase_features = [f.lower() for f in features]
+                            arg_indices = sorted(range(len(lowercase_features)), key=lowercase_features.__getitem__)
+                            field_value = "|".join(np.array(features)[arg_indices].tolist())
 
                         token[field_name] = field_value
-                    elif field_name == "head":
-                        pass
                     elif field_name == "lemma":
                         prediction = predictions[field_name][idx]
                         word_chars = []
@@ -191,6 +196,20 @@ class SemanticMultitaskPredictor(predictor.Predictor):
                     else:
                         raise NotImplementedError(f"Unknown field name {field_name}!")
 
+        if "enhanced_head" in predictions and predictions["enhanced_head"]:
+            # TODO off-by-one hotfix, refactor
+            h = np.array(predictions["enhanced_head"])
+            h = np.concatenate((h[-1:], h[:-1]))
+            r = np.array(predictions["enhanced_deprel_prob"])
+            r = np.concatenate((r[-1:], r[:-1]))
+            graph.sdp_to_dag_deps(arc_scores=h,
+                                  rel_scores=r,
+                                  tree_tokens=tree_tokens,
+                                  root_idx=self.vocab.get_token_index("root", "deprel_labels"),
+                                  vocab_index=self.vocab.get_index_to_token_vocabulary("deprel_labels"))
+            empty_tokens = graph.restore_collapse_edges(tree_tokens)
+            tree.tokens.extend(empty_tokens)
+
         return tree, predictions["sentence_embedding"]
 
     @classmethod
@@ -200,7 +219,7 @@ class SemanticMultitaskPredictor(predictor.Predictor):
 
     @classmethod
     def from_pretrained(cls, path: str, tokenizer=tokenizers.SpacyTokenizer(),
-                        batch_size: int = 500,
+                        batch_size: int = 32,
                         cuda_device: int = -1):
         util.import_module_and_submodules("combo.commands")
         util.import_module_and_submodules("combo.models")
diff --git a/combo/utils/graph.py b/combo/utils/graph.py
new file mode 100644
index 0000000000000000000000000000000000000000..651c14a7d79b7ea3c277b9466f5e050435a7a01b
--- /dev/null
+++ b/combo/utils/graph.py
@@ -0,0 +1,115 @@
+"""Based on https://github.com/emorynlp/iwpt-shared-task-2020."""
+from typing import List
+
+import numpy as np
+
+
+def sdp_to_dag_deps(arc_scores, rel_scores, tree_tokens: List, root_idx=0, vocab_index=None) -> None:
+    # adding ROOT
+    tree_heads = [0] + [t["head"] for t in tree_tokens]
+    graph = adjust_root_score_then_add_secondary_arcs(arc_scores, rel_scores, tree_heads,
+                                                      root_idx)
+    for i, (t, g) in enumerate(zip(tree_heads, graph)):
+        if not i:
+            continue
+        rels = [vocab_index.get(x[1], "root") if vocab_index else x[1] for x in g]
+        heads = [x[0] for x in g]
+        head = tree_tokens[i - 1]["head"]
+        index = heads.index(head)
+        deprel = tree_tokens[i - 1]["deprel"]
+        deprel = deprel.split('>')[-1]
+        # TODO - Consider if there should be a condition,
+        # It doesn't seem to make any sense as DEPS should contain DEPREL
+        # (although sometimes with different/more detailed label)
+        # if len(heads) >= 2:
+        #     heads.pop(index)
+        #     rels.pop(index)
+        deps = '|'.join(f'{h}:{r}' for h, r in zip(heads, rels))
+        tree_tokens[i - 1]["deps"] = deps
+        tree_tokens[i - 1]["deprel"] = deprel
+    return
+
+
+def adjust_root_score_then_add_secondary_arcs(arc_scores, rel_scores, tree_heads, root_idx):
+    if len(arc_scores) != tree_heads:
+        arc_scores = arc_scores[:len(tree_heads)][:len(tree_heads)]
+        rel_scores = rel_scores[:len(tree_heads)][:len(tree_heads)]
+    # Self-loops aren't allowed, mask with 0. This is an in-place operation.
+    np.fill_diagonal(arc_scores, 0)
+    parse_preds = np.array(arc_scores) > 0
+    parse_preds[:, 0] = False  # set heads to False
+    rel_scores[:, :, root_idx] = -float('inf')
+    return add_secondary_arcs(arc_scores, rel_scores, tree_heads, root_idx, parse_preds)
+
+
+def add_secondary_arcs(arc_scores, rel_scores, tree_heads, root_idx, parse_preds):
+    if not isinstance(tree_heads, np.ndarray):
+        tree_heads = np.array(tree_heads)
+    dh = np.argwhere(parse_preds)
+    sdh = sorted([(arc_scores[x[0]][x[1]], list(x)) for x in dh], reverse=True)
+    graph = [[] for _ in range(len(tree_heads))]
+    rel_pred = np.argmax(rel_scores, axis=-1)
+    for d, h in enumerate(tree_heads):
+        if d:
+            graph[h].append(d)
+    for s, (d, h) in sdh:
+        if not d or not h or d in graph[h]:
+            continue
+        try:
+            path = next(_dfs(graph, d, h))
+        except StopIteration:
+            # no path from d to h
+            graph[h].append(d)
+    parse_graph = [[] for _ in range(len(tree_heads))]
+    num_root = 0
+    for h in range(len(tree_heads)):
+        for d in graph[h]:
+            rel = rel_pred[d][h]
+            if h == 0:
+                rel = root_idx
+                assert num_root == 0
+                num_root += 1
+            parse_graph[d].append((h, rel))
+        parse_graph[d] = sorted(parse_graph[d])
+    return parse_graph
+
+
+def _dfs(graph, start, end):
+    fringe = [(start, [])]
+    while fringe:
+        state, path = fringe.pop()
+        if path and state == end:
+            yield path
+            continue
+        for next_state in graph[state]:
+            if next_state in path:
+                continue
+            fringe.append((next_state, path + [next_state]))
+
+
+def restore_collapse_edges(tree_tokens):
+    empty_tokens = []
+    for token in tree_tokens:
+        deps = token["deps"].split("|")
+        for i, d in enumerate(deps):
+            if ">" in d:
+                # {head}:{empty_node_relation}>{current_node_relation}
+                # should map to
+                # For new, empty node:
+                # {head}:{empty_node_relation}
+                # For current node:
+                # {new_empty_node_id}:{current_node_relation}
+                # TODO consider where to put new_empty_node_id (currently at the end)
+                head, relation = d.split(':', 1)
+                ehead = f"{len(tree_tokens)}.{len(empty_tokens) + 1}"
+                empty_node_relation, current_node_relation = relation.split(">", 1)
+                deps[i] = f"{ehead}:{current_node_relation}"
+                empty_tokens.append(
+                    {
+                        "id": ehead,
+                        "deps": f"{head}:{empty_node_relation}"
+                    }
+                )
+        deps = sorted([d.split(":", 1) for d in deps], key=lambda x: float(x[0]))
+        token["deps"] = "|".join([f"{k}:{v}" for k, v in deps])
+    return empty_tokens
diff --git a/combo/utils/metrics.py b/combo/utils/metrics.py
index 28f8efa022c22cff7e6d7b2b0d2977ba50b39dc1..682e8859264a3414bf86d3a1e408ce5b3588a6f3 100644
--- a/combo/utils/metrics.py
+++ b/combo/utils/metrics.py
@@ -117,6 +117,8 @@ class AttachmentScores(metrics.Metric):
         mask : `torch.BoolTensor`, optional (default = None).
             A tensor of the same shape as `predicted_indices`.
         """
+        if gold_labels is None or gold_indices is None:
+            return
         detached = self.detach_tensors(
             predicted_indices, predicted_labels, gold_indices, gold_labels, mask
         )
@@ -138,10 +140,14 @@ class AttachmentScores(metrics.Metric):
 
         correct_indices = predicted_indices.eq(gold_indices).long() * mask
         unlabeled_exact_match = (correct_indices + ~mask).prod(dim=-1)
+        if len(correct_indices.size()) > 2:
+            unlabeled_exact_match = unlabeled_exact_match.prod(dim=-1)
         correct_labels = predicted_labels.eq(gold_labels).long() * mask
         correct_labels_and_indices = correct_indices * correct_labels
         self.correct_indices = correct_labels_and_indices.flatten()
         labeled_exact_match = (correct_labels_and_indices + ~mask).prod(dim=-1)
+        if len(correct_indices.size()) > 2:
+            labeled_exact_match = labeled_exact_match.prod(dim=-1)
 
         self._unlabeled_correct += correct_indices.sum()
         self._exact_unlabeled_correct += unlabeled_exact_match.sum()
@@ -198,6 +204,8 @@ class SemanticMetrics(metrics.Metric):
         self.feats_score = SequenceBoolAccuracy(prod_last_dim=True)
         self.lemma_score = SequenceBoolAccuracy(prod_last_dim=True)
         self.attachment_scores = AttachmentScores()
+        # Ignore PADDING and OOV
+        self.enhanced_attachment_scores = AttachmentScores(ignore_classes=[0, 1])
         self.em_score = 0.0
 
     def __call__(  # type: ignore
@@ -215,14 +223,25 @@ class SemanticMetrics(metrics.Metric):
                                gold_labels["head"],
                                gold_labels["deprel"],
                                mask)
+        self.enhanced_attachment_scores(predictions["enhanced_head"],
+                                        predictions["enhanced_deprel"],
+                                        gold_labels["enhanced_head"],
+                                        gold_labels["enhanced_deprel"],
+                                        mask=None)
+        enhanced_indices = (
+            self.enhanced_attachment_scores.correct_indices.reshape(mask.size(0), mask.size(1) + 1, -1)[:, 1:, 1:].sum(
+                -1).reshape(-1).bool()
+            if len(self.enhanced_attachment_scores.correct_indices.size()) > 0
+            else self.enhanced_attachment_scores.correct_indices
+        )
         total = mask.sum()
         correct_indices = (self.upos_score.correct_indices *
                            self.xpos_score.correct_indices *
                            self.semrel_score.correct_indices *
                            self.feats_score.correct_indices *
                            self.lemma_score.correct_indices *
-                           self.attachment_scores.correct_indices
-                           )
+                           self.attachment_scores.correct_indices *
+                           enhanced_indices)
 
         total, correct_indices = self.detach_tensors(total, correct_indices)
         self.em_score = (correct_indices.float().sum() / total).item()
@@ -237,6 +256,8 @@ class SemanticMetrics(metrics.Metric):
             "EM": self.em_score
         }
         metrics_dict.update(self.attachment_scores.get_metric(reset))
+        enhanced_metrics = {f"E{k}": v for k, v in self.enhanced_attachment_scores.get_metric(reset).items()}
+        metrics_dict.update(enhanced_metrics)
         return metrics_dict
 
     def reset(self) -> None:
@@ -246,4 +267,5 @@ class SemanticMetrics(metrics.Metric):
         self.lemma_score.reset()
         self.feats_score.reset()
         self.attachment_scores.reset()
+        self.enhanced_attachment_scores.reset()
         self.em_score = 0.0
diff --git a/config.graph.template.jsonnet b/config.graph.template.jsonnet
new file mode 100644
index 0000000000000000000000000000000000000000..bc8c46580f17f22924a9f68628d64ce7f1060d55
--- /dev/null
+++ b/config.graph.template.jsonnet
@@ -0,0 +1,422 @@
+########################################################################################
+#                                 BASIC configuration                                  #
+########################################################################################
+# Training data path, str
+# Must be in CONNLU format (or it's extended version with semantic relation field).
+# Can accepted multiple paths when concatenated with ',', "path1,path2"
+local training_data_path = std.extVar("training_data_path");
+# Validation data path, str
+# Can accepted multiple paths when concatenated with ',', "path1,path2"
+local validation_data_path = if std.length(std.extVar("validation_data_path")) > 0 then std.extVar("validation_data_path");
+# Path to pretrained tokens, str or null
+local pretrained_tokens = if std.length(std.extVar("pretrained_tokens")) > 0 then std.extVar("pretrained_tokens");
+# Name of pretrained transformer model, str or null
+local pretrained_transformer_name = if std.length(std.extVar("pretrained_transformer_name")) > 0 then std.extVar("pretrained_transformer_name");
+# Learning rate value, float
+local learning_rate = 0.002;
+# Number of epochs, int
+local num_epochs = std.parseInt(std.extVar("num_epochs"));
+# Cuda device id, -1 for cpu, int
+local cuda_device = std.parseInt(std.extVar("cuda_device"));
+# Minimum number of words in batch, int
+local word_batch_size = std.parseInt(std.extVar("word_batch_size"));
+# Features used as input, list of str
+# Choice "upostag", "xpostag", "lemma"
+# Required "token", "char"
+local features = std.split(std.extVar("features"), " ");
+# Targets of the model, list of str
+# Choice "feats", "lemma", "upostag", "xpostag", "semrel". "sent"
+# Required "deprel", "head"
+local targets = std.split(std.extVar("targets"), " ");
+# Word embedding dimension, int
+# If pretrained_tokens is not null must much provided dimensionality
+local embedding_dim = std.parseInt(std.extVar("embedding_dim"));
+# Dropout rate on predictors, float
+# All of the models on top of the encoder uses this dropout
+local predictors_dropout = 0.25;
+# Xpostag embedding dimension, int
+# (discarded if xpostag not in features)
+local xpostag_dim = 32;
+# Upostag embedding dimension, int
+# (discarded if upostag not in features)
+local upostag_dim = 32;
+# Feats embedding dimension, int
+# (discarded if feats not in featres)
+local feats_dim = 32;
+# Lemma embedding dimension, int
+# (discarded if lemma not in features)
+local lemma_char_dim = 64;
+# Character embedding dim, int
+local char_dim = 64;
+# Word embedding projection dim, int
+local projected_embedding_dim = 100;
+# Loss weights, dict[str, int]
+local loss_weights = {
+    xpostag: 0.05,
+    upostag: 0.05,
+    lemma: 0.05,
+    feats: 0.2,
+    deprel: 0.8,
+    head: 0.2,
+    semrel: 0.05,
+    enhanced_head: 0.2,
+    enhanced_deprel: 0.8,
+};
+# Encoder hidden size, int
+local hidden_size = 512;
+# Number of layers in the encoder, int
+local num_layers = 2;
+# Cycle loss iterations, int
+local cycle_loss_n = 0;
+# Maximum length of the word, int
+# Shorter words are padded, longer - truncated
+local word_length = 30;
+# Whether to use tensorboard, bool
+local use_tensorboard = if std.extVar("use_tensorboard") == "True" then true else false;
+
+# Helper functions
+local in_features(name) = !(std.length(std.find(name, features)) == 0);
+local in_targets(name) = !(std.length(std.find(name, targets)) == 0);
+local use_transformer = pretrained_transformer_name != null;
+
+# Verify some configuration requirements
+assert in_features("token"): "Key 'token' must be in features!";
+assert in_features("char"): "Key 'char' must be in features!";
+
+assert in_targets("deprel"): "Key 'deprel' must be in targets!";
+assert in_targets("head"): "Key 'head' must be in targets!";
+
+assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't use pretrained tokens and pretrained transformer at the same time!";
+
+########################################################################################
+#                              ADVANCED configuration                                  #
+########################################################################################
+
+# Detailed dataset, training, vocabulary and model configuration.
+{
+    # Configuration type (default or finetuning), str
+    type: std.extVar('type'),
+    # Datasets used for vocab creation, list of str
+    # Choice "train", "valid"
+    datasets_for_vocab_creation: ['train'],
+    # Path to training data, str
+    train_data_path: training_data_path,
+    # Path to validation data, str
+    validation_data_path: validation_data_path,
+    # Dataset reader configuration (conllu format)
+    dataset_reader: {
+        type: "conllu",
+        features: features,
+        targets: targets,
+        # Whether data contains semantic relation field, bool
+        use_sem: if in_targets("semrel") then true else false,
+        token_indexers: {
+            token: if use_transformer then {
+                type: "pretrained_transformer_mismatched_fixed",
+                model_name: pretrained_transformer_name,
+                tokenizer_kwargs: if std.startsWith(pretrained_transformer_name, "allegro/herbert")
+                                  then {use_fast: false} else {},
+            } else {
+                # SingleIdTokenIndexer, token as single int
+                type: "single_id",
+            },
+            upostag: {
+                type: "single_id",
+                namespace: "upostag",
+                feature_name: "pos_",
+            },
+            xpostag: {
+                type: "single_id",
+                namespace: "xpostag",
+                feature_name: "tag_",
+            },
+            lemma: {
+                type: "characters_const_padding",
+                character_tokenizer: {
+                    start_tokens: ["__START__"],
+                    end_tokens: ["__END__"],
+                },
+                # +2 for start and end token
+                min_padding_length: word_length + 2,
+            },
+            char: {
+                type: "characters_const_padding",
+                character_tokenizer: {
+                    start_tokens: ["__START__"],
+                    end_tokens: ["__END__"],
+                },
+                # +2 for start and end token
+                min_padding_length: word_length + 2,
+            },
+            feats: {
+                type: "feats_indexer",
+            },
+        },
+        lemma_indexers: {
+            char: {
+                type: "characters_const_padding",
+                namespace: "lemma_characters",
+                character_tokenizer: {
+                    start_tokens: ["__START__"],
+                    end_tokens: ["__END__"],
+                },
+                # +2 for start and end token
+                min_padding_length: word_length + 2,
+            },
+        },
+    },
+    # Data loader configuration
+    data_loader: {
+        batch_sampler: {
+            type: "token_count",
+            word_batch_size: word_batch_size,
+        },
+    },
+    # Vocabulary configuration
+    vocabulary: std.prune({
+        type: "from_instances_extended",
+        only_include_pretrained_words: true,
+        pretrained_files: {
+            tokens: pretrained_tokens,
+        },
+        oov_token: "_",
+        padding_token: "__PAD__",
+        non_padded_namespaces: ["head_labels"],
+    }),
+    model: std.prune({
+        type: "semantic_multitask",
+        text_field_embedder: {
+            type: "basic",
+            token_embedders: {
+                xpostag: if in_features("xpostag") then {
+                    type: "embedding",
+                    padding_index: 0,
+                    embedding_dim: xpostag_dim,
+                    vocab_namespace: "xpostag",
+                },
+                upostag: if in_features("upostag") then {
+                    type: "embedding",
+                    padding_index: 0,
+                    embedding_dim: upostag_dim,
+                    vocab_namespace: "upostag",
+                },
+                token: if use_transformer then {
+                    type: "transformers_word_embeddings",
+                    model_name: pretrained_transformer_name,
+                    projection_dim: projected_embedding_dim,
+                    tokenizer_kwargs: if std.startsWith(pretrained_transformer_name, "allegro/herbert")
+                                      then {use_fast: false} else {},
+                } else {
+                    type: "embeddings_projected",
+                    embedding_dim: embedding_dim,
+                    projection_layer: {
+                        in_features: embedding_dim,
+                        out_features: projected_embedding_dim,
+                        dropout_rate: 0.25,
+                        activation: "tanh"
+                    },
+                    vocab_namespace: "tokens",
+                    pretrained_file: pretrained_tokens,
+                    trainable: if pretrained_tokens == null then true else false,
+                },
+                char: {
+                    type: "char_embeddings_from_config",
+                    embedding_dim: char_dim,
+                    dilated_cnn_encoder: {
+                        input_dim: char_dim,
+                        filters: [512, 256, char_dim],
+                        kernel_size: [3, 3, 3],
+                        stride: [1, 1, 1],
+                        padding: [1, 2, 4],
+                        dilation: [1, 2, 4],
+                        activations: ["relu", "relu", "linear"],
+                    },
+                },
+                lemma: if in_features("lemma") then {
+                    type: "char_embeddings_from_config",
+                    embedding_dim: lemma_char_dim,
+                    dilated_cnn_encoder: {
+                        input_dim: lemma_char_dim,
+                        filters: [512, 256, lemma_char_dim],
+                        kernel_size: [3, 3, 3],
+                        stride: [1, 1, 1],
+                        padding: [1, 2, 4],
+                        dilation: [1, 2, 4],
+                        activations: ["relu", "relu", "linear"],
+                    },
+                },
+                feats: if in_features("feats") then {
+                    type: "feats_embedding",
+                    padding_index: 0,
+                    embedding_dim: feats_dim,
+                    vocab_namespace: "feats",
+                },
+            },
+        },
+        loss_weights: loss_weights,
+        seq_encoder: {
+            type: "combo_encoder",
+            layer_dropout_probability: 0.33,
+            stacked_bilstm: {
+                input_size:
+                (char_dim + projected_embedding_dim +
+                (if in_features('xpostag') then xpostag_dim else 0) +
+                (if in_features('lemma') then lemma_char_dim else 0) +
+                (if in_features('upostag') then upostag_dim else 0) +
+                (if in_features('feats') then feats_dim else 0)),
+                hidden_size: hidden_size,
+                num_layers: num_layers,
+                recurrent_dropout_probability: 0.33,
+                layer_dropout_probability: 0.33
+            },
+        },
+        dependency_relation: {
+            type: "combo_dependency_parsing_from_vocab",
+            vocab_namespace: 'deprel_labels',
+            head_predictor: {
+                local projection_dim = 512,
+                cycle_loss_n: cycle_loss_n,
+                head_projection_layer: {
+                    in_features: hidden_size * 2,
+                    out_features: projection_dim,
+                    activation: "tanh",
+                },
+                dependency_projection_layer: {
+                    in_features: hidden_size * 2,
+                    out_features: projection_dim,
+                    activation: "tanh",
+                },
+            },
+            local projection_dim = 128,
+            head_projection_layer: {
+                in_features: hidden_size * 2,
+                out_features: projection_dim,
+                dropout_rate: predictors_dropout,
+                activation: "tanh"
+            },
+            dependency_projection_layer: {
+                in_features: hidden_size * 2,
+                out_features: projection_dim,
+                dropout_rate: predictors_dropout,
+                activation: "tanh"
+            },
+        },
+        enhanced_dependency_relation: if in_targets("deps") then {
+            type: "combo_graph_dependency_parsing_from_vocab",
+            vocab_namespace: 'deprel_labels',
+            head_predictor: {
+                local projection_dim = 512,
+                cycle_loss_n: cycle_loss_n,
+                head_projection_layer: {
+                    in_features: hidden_size * 2,
+                    out_features: projection_dim,
+                    activation: "tanh",
+                },
+                dependency_projection_layer: {
+                    in_features: hidden_size * 2,
+                    out_features: projection_dim,
+                    activation: "tanh",
+                },
+            },
+            local projection_dim = 128,
+            head_projection_layer: {
+                in_features: hidden_size * 2,
+                out_features: projection_dim,
+                dropout_rate: predictors_dropout,
+                activation: "tanh"
+            },
+            dependency_projection_layer: {
+                in_features: hidden_size * 2,
+                out_features: projection_dim,
+                dropout_rate: predictors_dropout,
+                activation: "tanh"
+            },
+        },
+        morphological_feat: if in_targets("feats") then {
+            type: "combo_morpho_from_vocab",
+            vocab_namespace: "feats_labels",
+            input_dim: hidden_size * 2,
+            hidden_dims: [128],
+            activations: ["tanh", "linear"],
+            dropout: [predictors_dropout, 0.0],
+            num_layers: 2,
+        },
+        lemmatizer: if in_targets("lemma") then {
+            type: "combo_lemma_predictor_from_vocab",
+            char_vocab_namespace: "token_characters",
+            lemma_vocab_namespace: "lemma_characters",
+            embedding_dim: 256,
+            input_projection_layer: {
+                in_features: hidden_size * 2,
+                out_features: 32,
+                dropout_rate: predictors_dropout,
+                activation: "tanh"
+            },
+            filters: [256, 256, 256],
+            kernel_size: [3, 3, 3, 1],
+            stride: [1, 1, 1, 1],
+            padding: [1, 2, 4, 0],
+            dilation: [1, 2, 4, 1],
+            activations: ["relu", "relu", "relu", "linear"],
+        },
+        upos_tagger: if in_targets("upostag") then {
+            input_dim: hidden_size * 2,
+            hidden_dims: [64],
+            activations: ["tanh", "linear"],
+            dropout: [predictors_dropout, 0.0],
+            num_layers: 2,
+            vocab_namespace: "upostag_labels"
+        },
+        xpos_tagger: if in_targets("xpostag") then {
+            input_dim: hidden_size * 2,
+            hidden_dims: [128],
+            activations: ["tanh", "linear"],
+            dropout: [predictors_dropout, 0.0],
+            num_layers: 2,
+            vocab_namespace: "xpostag_labels"
+        },
+        semantic_relation: if in_targets("semrel") then {
+            input_dim: hidden_size * 2,
+            hidden_dims: [64],
+            activations: ["tanh", "linear"],
+            dropout: [predictors_dropout, 0.0],
+            num_layers: 2,
+            vocab_namespace: "semrel_labels"
+        },
+        regularizer: {
+            regexes: [
+                [".*conv1d.*", {type: "l2", alpha: 1e-6}],
+                [".*forward.*", {type: "l2", alpha: 1e-6}],
+                [".*backward.*", {type: "l2", alpha: 1e-6}],
+                [".*char_embed.*", {type: "l2", alpha: 1e-5}],
+            ],
+        },
+    }),
+    trainer: std.prune({
+        checkpointer: {
+            type: "finishing_only_checkpointer",
+        },
+        type: "gradient_descent_validate_n",
+        cuda_device: cuda_device,
+        grad_clipping: 5.0,
+        num_epochs: num_epochs,
+        optimizer: {
+            type: "adam",
+            lr: learning_rate,
+            betas: [0.9, 0.9],
+        },
+        patience: 1, # it will  be overwriten by callback
+        epoch_callbacks: [
+            { type: "transfer_patience" },
+        ],
+        learning_rate_scheduler: {
+            type: "combo_scheduler",
+        },
+        tensorboard_writer: if use_tensorboard then {
+            should_log_learning_rate: false,
+            should_log_parameter_statistics: false,
+            summary_interval: 100,
+        },
+        validation_metric: "+EM",
+    }),
+}
diff --git a/config.template.jsonnet b/config.template.jsonnet
index 8e5ddc9f3d120156a4d00b1d231a54d90a66b631..f41ba62672eb4f93e130261ac85a5abc00e1efee 100644
--- a/config.template.jsonnet
+++ b/config.template.jsonnet
@@ -71,8 +71,6 @@ local cycle_loss_n = 0;
 local word_length = 30;
 # Whether to use tensorboard, bool
 local use_tensorboard = if std.extVar("use_tensorboard") == "True" then true else false;
-# Path for tensorboard metrics, str
-local metrics_dir = "./runs";
 
 # Helper functions
 local in_features(name) = !(std.length(std.find(name, features)) == 0);
@@ -382,7 +380,6 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
             type: "combo_scheduler",
         },
         tensorboard_writer: if use_tensorboard then {
-            serialization_dir: metrics_dir,
             should_log_learning_rate: false,
             should_log_parameter_statistics: false,
             summary_interval: 100,
diff --git a/docs/models.md b/docs/models.md
index 485f7614cc0b65c6413f88f0139a7dd7cd8a1711..25a7f7092ef295a0cf1ff2b3ba13e0adb05a5bc6 100644
--- a/docs/models.md
+++ b/docs/models.md
@@ -1,19 +1,26 @@
 # Models
 
-Pre-trained models are available [here](http://mozart.ipipan.waw.pl/~mklimaszewski/models/).
+COMBO provides pre-trained models for:
+- morphosyntactic prediction (i.e. part-of-speech tagging, morphosyntactic analysis, lemmatisation and dependency parsing) trained on the treebanks from [Universal Dependencies repository](https://universaldependencies.org),
+- enhanced dependency parsing trained on IWPT 2020 shared task [data](https://universaldependencies.org/iwpt20/data.html).
 
-## Automatic download
-Python `from_pretrained` method will download the pre-trained model if the provided name (without the extension .tar.gz) matches one of the names in [here](http://mozart.ipipan.waw.pl/~mklimaszewski/models/).
-```python
-import combo.predict as predict
+## Manual download
 
-nlp = predict.SemanticMultitaskPredictor.from_pretrained("polish-herbert-base")
-```
-Otherwise it looks for a model in local env.
+The pre-trained models can be downloaded from [here](http://mozart.ipipan.waw.pl/~mklimaszewski/models/).
 
-## Console prediction/Local model
-If you want to use the console version of COMBO, you need to download a pre-trained model manually
+
+If you want to use the console version of COMBO, you need to download a pre-trained model manually:
 ```bash
 wget http://mozart.ipipan.waw.pl/~mklimaszewski/models/polish-herbert-base.tar.gz
 ```
-and pass it as a parameter (see [prediction doc](prediction.md)).
+
+The downloaded model should be passed as a parameter for COMBO (see [prediction doc](prediction.md)).
+
+## Automatic download
+The pre-trained models can be downloaded automatically with the Python `from_pretrained` method. Select a model name (without the extension .tar.gz) from the list of [pre-trained models](http://mozart.ipipan.waw.pl/~mklimaszewski/models/) and pass the name as the attribute to `from_pretrained` method:
+```python
+from combo.predict import COMBO
+
+nlp = COMBO.from_pretrained("polish-herbert-base")
+```
+If the model name doesn't match any model on the list of [pre-trained models](http://mozart.ipipan.waw.pl/~mklimaszewski/models/), COMBO looks for a model in local env.
diff --git a/docs/prediction.md b/docs/prediction.md
index 89cc74c27e8de8e4fafb44c60aea8ed260b67a3d..6de5d0e1892389ba5cd18c25b88947db3f717074 100644
--- a/docs/prediction.md
+++ b/docs/prediction.md
@@ -32,9 +32,19 @@ Use either `--predictor_name semantic-multitask-predictor` or `--predictor_name
 
 ## Python
 ```python
-import combo.predict as predict
+from combo.predict import COMBO
 
 model_path = "your_model.tar.gz"
-nlp = predict.SemanticMultitaskPredictor.from_pretrained(model_path)
+nlp = COMBO.from_pretrained(model_path)
 sentence = nlp("Sentence to parse.")
 ```
+
+Using your own tokenization:
+```python
+from combo.predict import COMBO
+
+model_path = "your_model.tar.gz"
+nlp = COMBO.from_pretrained(model_path)
+tokenized_sentence = ["Sentence", "to", "parse", "."]
+sentence = nlp([tokenized_sentence])
+```
diff --git a/docs/training.md b/docs/training.md
index 9dc430a782baacb7344e95d29a7bf9066b1df613..d3f69e0913c59681279b1fd966be0f4901ade11e 100644
--- a/docs/training.md
+++ b/docs/training.md
@@ -1,6 +1,6 @@
 # Training
 
-Command:
+Basic command:
 ```bash
 combo --mode train \
       --training_data_path your_training_path \
@@ -32,18 +32,38 @@ Examples (for clarity without training/validation data paths):
     combo --mode train --pretrained_transformer_name your_choosen_pretrained_transformer
     ```
 
-* predict only dependency tree:
+* train only a dependency parser:
 
     ```bash
     combo --mode train --targets head,deprel
     ```
 
-* use part-of-speech tags for predicting only dependency tree
+* use additional features (e.g. part-of-speech tags) for training a dependency parser (`token` and `char` are default features)
 
     ```bash
     combo --mode train --targets head,deprel --features token,char,upostag
     ```
-  
+
+## Enhanced UD
+
+Training a model with Enhanced UD prediction **requires** data pre-processing.
+
+```bash
+combo --mode train \
+      --training_data_path your_preprocessed_training_path \
+      --validation_data_path your_preprocessed_validation_path \
+      --targets feats,upostag,xpostag,head,deprel,lemma,deps \
+      --config_path config.graph.template.jsonnet
+```
+### Data pre-processing
+Download data from [IWPT20 Shared Task](https://universaldependencies.org/iwpt20/data.html).
+It contains `enhanced_collapse_empty_nodes.pl` script which is required as pre-processing step.
+Apply this script to training and validation data.
+
+```bash
+perl enhanced_collapse_empty_nodes.pl training.conllu > training.fixed.conllu
+``` 
+
 ## Configuration
 
 ### Advanced
diff --git a/scripts/train.py b/scripts/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..939088800f772c113693eb6d0858304ed82f766d
--- /dev/null
+++ b/scripts/train.py
@@ -0,0 +1,172 @@
+"""Script to train Dependency Parsing models based on UD 2.x data."""
+import pathlib
+
+from absl import app
+from absl import flags
+
+from scripts import utils
+
+TREEBANKS = [
+    "UD_Afrikaans-AfriBooms",
+    "UD_Arabic-NYUAD",
+    "UD_Arabic-PADT",
+    "UD_Armenian-ArmTDP",
+    "UD_Basque-BDT",
+    "UD_Belarusian-HSE",
+    "UD_Breton-KEB",
+    "UD_Bulgarian-BTB",
+    "UD_Catalan-AnCora",
+    "UD_Croatian-SET",
+    "UD_Czech-CAC",
+    "UD_Czech-CLTT",
+    "UD_Czech-FicTree",
+    "UD_Czech-PDT",
+    "UD_Danish-DDT",
+    "UD_Dutch-Alpino",
+    "UD_Dutch-LassySmall",
+    "UD_English-ESL",
+    "UD_English-EWT",
+    "UD_English-GUM",
+    "UD_English-LinES",
+    "UD_English-ParTUT",
+    "UD_English-Pronouns",
+    "UD_Estonian-EDT",
+    "UD_Estonian-EWT",
+    "UD_Finnish-FTB",
+    "UD_Finnish-TDT",
+    "UD_French-FQB",
+    "UD_French-FTB",
+    "UD_French-GSD",
+    "UD_French-ParTUT",
+    "UD_French-Sequoia",
+    "UD_French-Spoken",
+    "UD_Galician-CTG",
+    "UD_Galician-TreeGal",
+    "UD_German-GSD",
+    "UD_German-HDT",
+    "UD_German-LIT",
+    "UD_Greek-GDT",
+    "UD_Hebrew-HTB",
+    "UD_Hindi_English-HIENCS",
+    "UD_Hindi-HDTB",
+    "UD_Hungarian-Szeged",
+    "UD_Indonesian-GSD",
+    "UD_Irish-IDT",
+    "UD_Italian-ISDT",
+    "UD_Italian-ParTUT",
+    "UD_Italian-PoSTWITA",
+    "UD_Italian-TWITTIRO",
+    "UD_Italian-VIT",
+    "UD_Japanese-BCCWJ",
+    "UD_Japanese-GSD",
+    "UD_Japanese-Modern",
+    "UD_Kazakh-KTB",
+    "UD_Korean-GSD",
+    "UD_Korean-Kaist",
+    "UD_Latin-ITTB",
+    "UD_Latin-Perseus",
+    "UD_Latin-PROIEL",
+    "UD_Latvian-LVTB",
+    "UD_Lithuanian-ALKSNIS",
+    "UD_Lithuanian-HSE",
+    "UD_Maltese-MUDT",
+    "UD_Marathi-UFAL",
+    "UD_Persian-Seraji",
+    "UD_Polish-LFG",
+    "UD_Polish-PDB",
+    "UD_Portuguese-Bosque",
+    "UD_Portuguese-GSD",
+    "UD_Romanian-Nonstandard",
+    "UD_Romanian-RRT",
+    "UD_Romanian-SiMoNERo",
+    "UD_Russian-GSD",
+    "UD_Russian-SynTagRus",
+    "UD_Russian-Taiga",
+    "UD_Serbian-SET",
+    "UD_Slovak-SNK",
+    "UD_Slovenian-SSJ",
+    "UD_Slovenian-SST",
+    "UD_Spanish-AnCora",
+    "UD_Spanish-GSD",
+    "UD_Swedish-LinES",
+    "UD_Swedish_Sign_Language-SSLC",
+    "UD_Swedish-Talbanken",
+    "UD_Tamil-TTB",
+    "UD_Telugu-MTG",
+    "UD_Turkish-GB",
+    "UD_Turkish-IMST",
+    "UD_Ukrainian-IU",
+    "UD_Urdu-UDTB",
+    "UD_Uyghur-UDT",
+    "UD_Vietnamese-VTB",
+]
+
+FLAGS = flags.FLAGS
+flags.DEFINE_list(name="treebanks", default=TREEBANKS,
+                  help=f"Treebanks to train. Possible values: {TREEBANKS}.")
+flags.DEFINE_string(name="data_dir", default="",
+                    help="Path to UD data directory.")
+flags.DEFINE_string(name="serialization_dir", default="/tmp/",
+                    help="Model serialization directory.")
+flags.DEFINE_string(name="embeddings_dir", default="",
+                    help="Path to embeddings directory (with languages as subdirectories).")
+flags.DEFINE_integer(name="cuda_device", default=-1,
+                     help="Cuda device id (-1 for cpu).")
+
+
+def run(_):
+    treebanks_dir = pathlib.Path(FLAGS.data_dir)
+    for treebank in FLAGS.treebanks:
+        assert treebank in TREEBANKS, f"Unknown treebank {treebank}."
+        treebank_dir = treebanks_dir / treebank
+        treebank_parts = treebank.split("_")[1].split("-")
+        language = treebank_parts[0]
+
+        files = list(treebank_dir.iterdir())
+
+        training_file = [f for f in files if "train" in f.name and ".conllu" in f.name]
+        assert len(training_file) == 1, f"Couldn't find training file."
+        training_file_path = training_file[0]
+
+        valid_file = [f for f in files if "dev" in f.name and ".conllu" in f.name]
+        assert len(valid_file) == 1, f"Couldn't find validation file."
+        valid_file_path = valid_file[0]
+
+        embeddings_dir = FLAGS.embeddings_dir
+        embeddings_file = None
+        if embeddings_dir:
+            embeddings_dir = pathlib.Path(embeddings_dir) / language
+            embeddings_file = [f for f in embeddings_dir.iterdir() if "vectors" in f.name and ".vec.gz" in f.name]
+            assert len(embeddings_file) == 1, f"Couldn't find embeddings file."
+            embeddings_file = embeddings_file[0]
+
+        language = training_file_path.name.split("_")[0]
+
+        serialization_dir = pathlib.Path(FLAGS.serialization_dir) / treebank
+        serialization_dir.mkdir(exist_ok=True, parents=True)
+
+        command = f"""time combo --mode train
+        --cuda_device {FLAGS.cuda_device}
+        --training_data_path {training_file_path}
+        --validation_data_path {valid_file_path}
+        {f"--pretrained_tokens {embeddings_file}" if embeddings_dir
+        else f"--pretrained_transformer_name {utils.LANG2TRANSFORMER[language]}"}
+        --serialization_dir {serialization_dir}
+        --config_path {pathlib.Path.cwd() / 'config.template.jsonnet'}
+        --word_batch_size 2500
+        --notensorboard
+        """
+
+        # no XPOS datasets
+        if treebank in ["UD_Hungarian-Szeged", "UD_Armenian-ArmTDP"]:
+            command = command + " --targets deprel,head,upostag,lemma,feats"
+
+        utils.execute_command(command)
+
+
+def main():
+    app.run(run)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/scripts/train_eud.py b/scripts/train_eud.py
new file mode 100644
index 0000000000000000000000000000000000000000..4904e0bff6a9d78a7d0a56bff2ed0357992b615b
--- /dev/null
+++ b/scripts/train_eud.py
@@ -0,0 +1,126 @@
+"""Script to train Enhanced Dependency Parsing models based on IWPT'20 Shared Task data.
+
+Might require:
+conda install -c bioconda perl-list-moreutils
+conda install -c bioconda perl-namespace-autoclean
+conda install -c bioconda perl-moose
+conda install -c dan_blanchard perl-moosex-semiaffordanceaccessor
+"""
+
+import os
+import pathlib
+from typing import List
+
+from absl import app
+from absl import flags
+
+from scripts import utils
+
+LANG2TREEBANK = {
+    "ar": ["Arabic-PADT"],
+    "bg": ["Bulgarian-BTB"],
+    "cs": ["Czech-FicTree", "Czech-CAC", "Czech-PDT", "Czech-PUD"],
+    "nl": ["Dutch-Alpino", "Dutch-LassySmall"],
+    "en": ["English-EWT", "English-PUD"],
+    "et": ["Estonian-EDT", "Estonian-EWT"],
+    "fi": ["Finnish-TDT", "Finnish-PUD"],
+    "fr": ["French-Sequoia", "French-FQB"],
+    "it": ["Italian-ISDT"],
+    "lv": ["Latvian-LVTB"],
+    "lt": ["Lithuanian-ALKSNIS"],
+    "pl": ["Polish-LFG", "Polish-PDB", "Polish-PUD"],
+    "ru": ["Russian-SynTagRus"],
+    "sk": ["Slovak-SNK"],
+    "sv": ["Swedish-Talbanken", "Swedish-PUD"],
+    "ta": ["Tamil-TTB"],
+    "uk": ["Ukrainian-IU"],
+}
+
+FLAGS = flags.FLAGS
+flags.DEFINE_list(name="lang", default=list(LANG2TREEBANK.keys()),
+                  help=f"Language of models to train. Possible values: {LANG2TREEBANK.keys()}.")
+flags.DEFINE_string(name="data_dir", default="",
+                    help="Path to 'iwpt2020stdata' directory.")
+flags.DEFINE_string(name="serialization_dir", default="/tmp/",
+                    help="Model serialization dir.")
+flags.DEFINE_integer(name="cuda_device", default=-1,
+                     help="Cuda device id (-1 for cpu).")
+
+
+def path_to_str(path: pathlib.Path) -> str:
+    return str(path.resolve())
+
+
+def merge_files(files: List[str], output: pathlib.Path):
+    if not output.exists():
+        os.system(f"cat {' '.join(files)} > {output}")
+
+
+def collapse_nodes(data_dir: pathlib.Path, treebank_file: pathlib.Path, output: str):
+    output_path = pathlib.Path(output)
+    if not output_path.exists():
+        utils.execute_command(f"perl {path_to_str(data_dir / 'tools' / 'enhanced_collapse_empty_nodes.pl')} "
+                              f"{path_to_str(treebank_file)}", output)
+
+
+def run(_):
+    languages = FLAGS.lang
+    for lang in languages:
+        assert lang in LANG2TREEBANK, f"'{lang}' must be one of {list(LANG2TREEBANK.keys())}."
+        assert lang in utils.LANG2TRANSFORMER, f"Transformer for '{lang}' isn't defined. See 'LANG2TRANSFORMER' dict."
+        data_dir = pathlib.Path(FLAGS.data_dir)
+        assert data_dir.is_dir(), f"'{data_dir}' is not a directory!"
+
+        treebanks = LANG2TREEBANK[lang]
+        train_paths = []
+        dev_paths = []
+        test_paths = []
+        for treebank in treebanks:
+            treebank_dir = data_dir / f"UD_{treebank}"
+            assert treebank_dir.exists() and treebank_dir.is_dir(), f"'{treebank_dir}' directory doesn't exists."
+            for treebank_file in treebank_dir.iterdir():
+                name = treebank_file.name
+                if "conllu" in name and "fixed" not in name:
+                    output = path_to_str(treebank_file).replace('.conllu', '.fixed.conllu')
+                    if "train" in name:
+                        collapse_nodes(data_dir, treebank_file, output)
+                        train_paths.append(output)
+                    elif "dev" in name:
+                        collapse_nodes(data_dir, treebank_file, output)
+                        dev_paths.append(output)
+                    elif "test" in name:
+                        collapse_nodes(data_dir, treebank_file, output)
+                        test_paths.append(output)
+
+        lang_data_dir = pathlib.Path(data_dir / lang)
+        lang_data_dir.mkdir(exist_ok=True)
+
+        train_path = lang_data_dir / "train.conllu"
+        dev_path = lang_data_dir / "dev.conllu"
+        test_path = lang_data_dir / "test.conllu"
+
+        merge_files(train_paths, output=train_path)
+        merge_files(dev_paths, output=dev_path)
+        merge_files(test_paths, output=test_path)
+
+        serialization_dir = pathlib.Path(FLAGS.serialization_dir) / lang
+        serialization_dir.mkdir(exist_ok=True, parents=True)
+        utils.execute_command("".join(f"""combo --mode train
+        --training_data {train_path}
+        --validation_data {dev_path}
+        --targets feats,upostag,xpostag,head,deprel,lemma,deps
+        --pretrained_transformer_name {utils.LANG2TRANSFORMER[lang]}
+        --serialization_dir {serialization_dir}
+        --cuda_device {FLAGS.cuda_device}
+        --word_batch_size 2500
+        --config_path {pathlib.Path.cwd() / 'config.graph.template.jsonnet'}
+        --notensorboard
+        """.splitlines()))
+
+
+def main():
+    app.run(run)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/scripts/utils.py b/scripts/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5dda2b89693fc1431b810709d9e7002d9f5f8071
--- /dev/null
+++ b/scripts/utils.py
@@ -0,0 +1,16 @@
+"""Utils for scripts."""
+import subprocess
+
+LANG2TRANSFORMER = {
+    "en": "bert-base-cased",
+    "pl": "allegro/herbert-base-cased",
+}
+
+
+def execute_command(command, output_file=None):
+    command = [c for c in command.split() if c.strip()]
+    if output_file:
+        with open(output_file, "w") as f:
+            subprocess.run(command, check=True, stdout=f)
+    else:
+        subprocess.run(command, check=True)
diff --git a/setup.cfg b/setup.cfg
index b7e478982ccf9ab1963c74e1084dfccb6e42c583..6876d0d7447015400e616dbd7479de01d19c2948 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -1,2 +1,5 @@
 [aliases]
 test=pytest
+
+[metadata]
+description-file = README.md
diff --git a/setup.py b/setup.py
index 9529a0c4f5f2138cf913d94d8665a0810552cdd9..fdaa2be51e5098e9aa2d2ddc3b691188f11c7b10 100644
--- a/setup.py
+++ b/setup.py
@@ -3,10 +3,9 @@ from setuptools import find_packages, setup
 
 REQUIREMENTS = [
     'absl-py==0.9.0',
-    'allennlp==1.2.0',
+    'allennlp==1.2.1',
     'conllu==2.3.2',
     'dataclasses;python_version<"3.7"',
-    'dataclasses-json==0.5.2',
     'joblib==0.14.1',
     'jsonnet==0.15.0',
     'requests==2.23.0',
@@ -20,11 +19,23 @@ REQUIREMENTS = [
 
 setup(
     name='COMBO',
-    version='0.0.1',
+    version='1.0.0b1',
     install_requires=REQUIREMENTS,
     packages=find_packages(exclude=['tests']),
+    license="GPL-3.0",
+    url='https://gitlab.clarin-pl.eu/syntactic-tools/combo',
+    keywords="nlp natural-language-processing dependency-parsing",
     setup_requires=['pytest-runner', 'pytest-pylint'],
     tests_require=['pytest', 'pylint'],
     python_requires='>=3.6',
     entry_points={'console_scripts': ['combo = combo.main:main']},
+    classifiers=[
+        'Development Status :: 4 - Beta',
+        'Intended Audience :: Science/Research',
+        'License :: OSI Approved :: GNU General Public License v3 (GPLv3)',
+        'Topic :: Scientific/Engineering :: Artificial Intelligence'
+        'Programming Language :: Python :: 3.6',
+        'Programming Language :: Python :: 3.7',
+        'Programming Language :: Python :: 3.8',
+    ]
 )
diff --git a/tests/data/fields/test_sequence_multilabel_field.py b/tests/data/fields/test_sequence_multilabel_field.py
index d2a1f8bc6b4d853e427d6a7214592f1b881db52c..fff8ff4ee285215aa4015d1c55609f3c2d28a3a6 100644
--- a/tests/data/fields/test_sequence_multilabel_field.py
+++ b/tests/data/fields/test_sequence_multilabel_field.py
@@ -4,6 +4,7 @@ from typing import List
 
 import torch
 from allennlp import data as allen_data
+from allennlp.common import util
 from allennlp.data import fields as allen_fields
 
 from combo.data import fields
@@ -22,7 +23,7 @@ class IndexingSequenceMultiLabelFieldTest(unittest.TestCase):
         def _indexer(vocab: allen_data.Vocabulary):
             vocab_size = vocab.get_vocab_size(self.namespace)
 
-            def _mapper(multi_label: List[str]) -> List[int]:
+            def _mapper(multi_label: List[str], _: int) -> List[int]:
                 one_hot = [0] * vocab_size
                 for label in multi_label:
                     index = vocab.get_token_index(label, self.namespace)
@@ -31,7 +32,21 @@ class IndexingSequenceMultiLabelFieldTest(unittest.TestCase):
 
             return _mapper
 
+        def _as_tensor(field: fields.SequenceMultiLabelField):
+
+            def _wrapped(padding_lengths):
+                desired_num_tokens = padding_lengths["num_tokens"]
+                classes_count = len(field._indexed_multi_labels[0])
+                default_value = [0.0] * classes_count
+                padded_tags = util.pad_sequence_to_length(field._indexed_multi_labels, desired_num_tokens,
+                                                          lambda: default_value)
+                tensor = torch.LongTensor(padded_tags)
+                return tensor
+
+            return _wrapped
+
         self.indexer = _indexer
+        self.as_tensor = _as_tensor
         self.sequence_field = _SequenceFieldTestWrapper(self.vocab.get_vocab_size(self.namespace))
 
     def test_indexing(self):
@@ -39,6 +54,7 @@ class IndexingSequenceMultiLabelFieldTest(unittest.TestCase):
         field = fields.SequenceMultiLabelField(
             multi_labels=[["t1", "t2"], [], ["t0"]],
             multi_label_indexer=self.indexer,
+            as_tensor=self.as_tensor,
             sequence_field=self.sequence_field,
             label_namespace=self.namespace
         )
@@ -55,6 +71,7 @@ class IndexingSequenceMultiLabelFieldTest(unittest.TestCase):
         field = fields.SequenceMultiLabelField(
             multi_labels=[["t1", "t2"], [], ["t0"]],
             multi_label_indexer=self.indexer,
+            as_tensor=self.as_tensor,
             sequence_field=self.sequence_field,
             label_namespace=self.namespace
         )
@@ -72,6 +89,7 @@ class IndexingSequenceMultiLabelFieldTest(unittest.TestCase):
         field = fields.SequenceMultiLabelField(
             multi_labels=[["t1", "t2"], [], ["t0"]],
             multi_label_indexer=self.indexer,
+            as_tensor=self.as_tensor,
             sequence_field=self.sequence_field,
             label_namespace=self.namespace
         )
diff --git a/tests/fixtures/example.conllu b/tests/fixtures/example.conllu
index 1125392e17d71f9db09c5236e1fb27ac0968d410..32e0653525e1160135fbe6f05dc2fa07b6a7fd9d 100644
--- a/tests/fixtures/example.conllu
+++ b/tests/fixtures/example.conllu
@@ -4,3 +4,10 @@
 2	Sentence	verylonglemmawhichmustbetruncatedbythesystemto30	NOUN	nom	Number=Sing	0	root	_	_
 3	.	.	PUNCT	.	_	1	punct	_	_
 
+# sent_id = test-s1
+# text = Easy sentence.
+1	Verylongwordwhichmustbetruncatedbythesystemto30	easy	ADJ	adj	AdpType=Prep|Adp	2	amod	_	_
+2	Sentence	verylonglemmawhichmustbetruncatedbythesystemto30	NOUN	nom	Number=Sing	0	root	_	_
+3	.	.	PUNCT	.	_	1	punct	2:mod	_
+4	.	.	PUNCT	.	_	1	punct	2:xmod	_
+
diff --git a/tests/test_predict.py b/tests/test_predict.py
index 2a56bd9baff30a6d34d3ad6bb16ce4ccdea71792..332ced3cfa010723fa51b77ce1742b8d976e1025 100644
--- a/tests/test_predict.py
+++ b/tests/test_predict.py
@@ -22,7 +22,7 @@ class PredictionTest(unittest.TestCase):
             data.Token(id=2, token=".")
         ])]
         api_wrapped_tokenized_sentence = [data.conllu2sentence(data.tokens2conllu(["Test", "."]), [])]
-        nlp = predict.SemanticMultitaskPredictor.from_pretrained(os.path.join(self.FIXTURES_ROOT, "model.tar.gz"))
+        nlp = predict.COMBO.from_pretrained(os.path.join(self.FIXTURES_ROOT, "model.tar.gz"))
 
         # when
         results = [
diff --git a/tests/utils/test_graph.py b/tests/utils/test_graph.py
new file mode 100644
index 0000000000000000000000000000000000000000..74e37446684f68c6d4ea4abe77c69ba9d3ae4c2b
--- /dev/null
+++ b/tests/utils/test_graph.py
@@ -0,0 +1,106 @@
+import unittest
+import combo.utils.graph as graph
+
+import conllu
+import numpy as np
+
+
+class GraphTest(unittest.TestCase):
+
+    def test_adding_empty_graph_with_the_same_labels(self):
+        tree = conllu.TokenList(
+            tokens=[
+                {"head": 0, "deprel": "root", "form": "word1"},
+                {"head": 3, "deprel": "yes", "form": "word2"},
+                {"head": 1, "deprel": "yes", "form": "word3"},
+            ]
+        )
+        vocab_index = {0: "root", 1: "yes", 2: "yes", 3: "yes"}
+        empty_graph = np.zeros((4, 4))
+        graph_labels = np.zeros((4, 4, 4))
+        expected_deps = ["0:root", "3:yes", "1:yes"]
+
+        # when
+        graph.sdp_to_dag_deps(empty_graph, graph_labels, tree.tokens, root_idx=0, vocab_index=vocab_index)
+        actual_deps = [t["deps"] for t in tree.tokens]
+
+        # then
+        self.assertEqual(expected_deps, actual_deps)
+
+    def test_adding_empty_graph_with_different_labels(self):
+        tree = conllu.TokenList(
+            tokens=[
+                {"head": 0, "deprel": "root", "form": "word1"},
+                {"head": 3, "deprel": "tree_label", "form": "word2"},
+                {"head": 1, "deprel": "tree_label", "form": "word3"},
+            ]
+        )
+        vocab_index = {0: "root", 1: "tree_label", 2: "graph_label"}
+        empty_graph = np.zeros((4, 4))
+        graph_labels = np.zeros((4, 4, 3))
+        graph_labels[2][3][2] = 10e10
+        graph_labels[3][1][2] = 10e10
+        expected_deps = ["0:root", "3:graph_label", "1:graph_label"]
+
+        # when
+        graph.sdp_to_dag_deps(empty_graph, graph_labels, tree.tokens, root_idx=0, vocab_index=vocab_index)
+        actual_deps = [t["deps"] for t in tree.tokens]
+
+        # then
+        self.assertEqual(actual_deps, expected_deps)
+
+    def test_extending_tree_with_graph(self):
+        # given
+        tree = conllu.TokenList(
+            tokens=[
+                {"head": 0, "deprel": "root", "form": "word1"},
+                {"head": 1, "deprel": "tree_label", "form": "word2"},
+                {"head": 2, "deprel": "tree_label", "form": "word3"},
+            ]
+        )
+        vocab_index = {0: "root", 1: "tree_label", 2: "graph_label"}
+        arc_scores = np.array([
+            [0, 0, 0, 0],
+            [1, 0, 0, 0],
+            [0, 1, 0, 0],
+            [0, 1, 1, 0],
+        ])
+        graph_labels = np.zeros((4, 4, 3))
+        graph_labels[3][1][2] = 10e10
+        expected_deps = ["0:root", "1:tree_label", "1:graph_label|2:tree_label"]
+
+        # when
+        graph.sdp_to_dag_deps(arc_scores, graph_labels, tree.tokens,  root_idx=0, vocab_index=vocab_index)
+        actual_deps = [t["deps"] for t in tree.tokens]
+
+        # then
+        self.assertEqual(actual_deps, expected_deps)
+
+    def test_extending_tree_with_self_loop_edge_shouldnt_add_edge(self):
+        # given
+        tree = conllu.TokenList(
+            tokens=[
+                {"head": 0, "deprel": "root", "form": "word1"},
+                {"head": 1, "deprel": "tree_label", "form": "word2"},
+                {"head": 2, "deprel": "tree_label", "form": "word3"},
+            ]
+        )
+        vocab_index = {0: "root", 1: "tree_label", 2: "graph_label"}
+        arc_scores = np.array([
+            [0, 0, 0, 0],
+            [1, 0, 0, 0],
+            [0, 1, 0, 0],
+            [0, 0, 1, 1],
+        ])
+        graph_labels = np.zeros((4, 4, 3))
+        graph_labels[3][3][2] = 10e10
+        expected_deps = ["0:root", "1:tree_label", "2:tree_label"]
+        # TODO current actual, adds self-loop
+        # actual_deps = ["0:root", "1:tree_label", "2:tree_label|3:graph_label"]
+
+        # when
+        graph.sdp_to_dag_deps(arc_scores, graph_labels, tree.tokens,  root_idx=0, vocab_index=vocab_index)
+        actual_deps = [t["deps"] for t in tree.tokens]
+
+        # then
+        self.assertEqual(expected_deps, actual_deps)
diff --git a/tests/utils/test_metrics.py b/tests/utils/test_metrics.py
index 5b8411bb44c4df7baf7ee1c8d751ac06daa320ec..242eaa3dcffaf452c19a52e2f56625de00cd0433 100644
--- a/tests/utils/test_metrics.py
+++ b/tests/utils/test_metrics.py
@@ -27,12 +27,16 @@ class SemanticMetricsTest(unittest.TestCase):
         self.semrel, self.semrel_l = (("semrel", x) for x in [pred, gold])
         self.head, self.head_l = (("head", x) for x in [pred, gold])
         self.deprel, self.deprel_l = (("deprel", x) for x in [pred, gold])
+        # TODO(mklimasz) Set up an example with size 3x5x5
+        self.enhanced_head, self.enhanced_head_l = (("enhanced_head", x) for x in [None, None])
+        self.enhanced_deprel, self.enhanced_deprel_l = (("enhanced_deprel", x) for x in [None, None])
         self.feats, self.feats_l = (("feats", x) for x in [pred_seq, gold_seq])
         self.lemma, self.lemma_l = (("lemma", x) for x in [pred_seq, gold_seq])
         self.predictions = dict(
-            [self.upostag, self.xpostag, self.semrel, self.feats, self.lemma, self.head, self.deprel])
+            [self.upostag, self.xpostag, self.semrel, self.feats, self.lemma, self.head, self.deprel,
+             self.enhanced_head, self.enhanced_deprel])
         self.gold_labels = dict([self.upostag_l, self.xpostag_l, self.semrel_l, self.feats_l, self.lemma_l, self.head_l,
-                                 self.deprel_l])
+                                 self.deprel_l, self.enhanced_head_l, self.enhanced_deprel_l])
         self.eps = 1e-6
 
     def test_every_prediction_correct(self):