diff --git a/README.md b/README.md
index 732262ef48f2420dd2a9dde66de5ffab71ef37b0..76e436b91c7a402e1bb15f140c9ca9ba3a63a200 100644
--- a/README.md
+++ b/README.md
@@ -13,7 +13,7 @@
 Clone this repository and install COMBO (we suggest creating a virtualenv/conda environment with Python 3.6+, as a bundle of required packages will be installed):
 ```bash
 pip install -U pip setuptools wheel
-pip install --index-url https://pypi.clarin-pl.eu/simple combo==1.0.3
+pip install --index-url https://pypi.clarin-pl.eu/simple combo==1.0.4
 ```
 Run the following commands in your Python console to make predictions with a pre-trained model:
 ```python
@@ -39,3 +39,37 @@ We encourage you to use the [beginner's tutorial](https://colab.research.google.
 - [**Training**](docs/training.md)
 - [**Prediction**](docs/prediction.md)
 - [**Model performance**](docs/performance.md)
+
+## Citing
+
+### Accepted at EMNLP'21 demo session :tada: :fire:
+
+If you use COMBO in your research, please cite [COMBO: State-of-the-Art Morphosyntactic Analysis](https://arxiv.org/abs/2109.05361)
+```bibtex
+@misc{klimaszewski2021combo,
+      title={COMBO: State-of-the-Art Morphosyntactic Analysis}, 
+      author={Mateusz Klimaszewski and Alina Wróblewska},
+      year={2021},
+      eprint={2109.05361},
+      archivePrefix={arXiv},
+      primaryClass={cs.CL}
+}
+```
+
+If you use an EUD module in your research, please cite [COMBO: A New Module for EUD Parsing](https://aclanthology.org/2021.iwpt-1.16/)
+```bibtex
+@inproceedings{klimaszewski-wroblewska-2021-combo,
+    title = "{COMBO}: A New Module for {EUD} Parsing",
+    author = "Klimaszewski, Mateusz  and
+      Wr{\'o}blewska, Alina",
+    booktitle = "Proceedings of the 17th International Conference on Parsing Technologies and the IWPT 2021 Shared Task on Parsing into Enhanced Universal Dependencies (IWPT 2021)",
+    month = aug,
+    year = "2021",
+    address = "Online",
+    publisher = "Association for Computational Linguistics",
+    url = "https://aclanthology.org/2021.iwpt-1.16",
+    doi = "10.18653/v1/2021.iwpt-1.16",
+    pages = "158--166",
+    abstract = "We introduce the COMBO-based approach for EUD parsing and its implementation, which took part in the IWPT 2021 EUD shared task. The goal of this task is to parse raw texts in 17 languages into Enhanced Universal Dependencies (EUD). The proposed approach uses COMBO to predict UD trees and EUD graphs. These structures are then merged into the final EUD graphs. Some EUD edge labels are extended with case information using a single language-independent expansion rule. In the official evaluation, the solution ranked fourth, achieving an average ELAS of 83.79{\%}. The source code is available at https://gitlab.clarin-pl.eu/syntactic-tools/combo.",
+}
+```
diff --git a/combo/config.graph.template.jsonnet b/combo/config.graph.template.jsonnet
index c0c469674f4a74a6d46953e65d9715e9eabf0a2f..a4725606e3cbc1917c46966ee3bf38833de969bc 100644
--- a/combo/config.graph.template.jsonnet
+++ b/combo/config.graph.template.jsonnet
@@ -49,7 +49,7 @@ local lemma_char_dim = 64;
 # Character embedding dim, int
 local char_dim = 64;
 # Word embedding projection dim, int
-local projected_embedding_dim = 100;
+local projected_embedding_dim = 768;
 # Loss weights, dict[str, int]
 local loss_weights = {
     xpostag: 0.05,
@@ -112,10 +112,8 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
         use_sem: if in_targets("semrel") then true else false,
         token_indexers: {
             token: if use_transformer then {
-                type: "pretrained_transformer_mismatched_fixed",
-                model_name: pretrained_transformer_name,
-                tokenizer_kwargs: if std.startsWith(pretrained_transformer_name, "allegro/herbert")
-                                  then {use_fast: false} else {},
+                type: "pretrained_transformer_mismatched",
+                model_name: pretrained_transformer_name
             } else {
                 # SingleIdTokenIndexer, token as single int
                 type: "single_id",
@@ -202,10 +200,9 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
                 },
                 token: if use_transformer then {
                     type: "transformers_word_embeddings",
+                    last_layer_only: false,
                     model_name: pretrained_transformer_name,
-                    projection_dim: projected_embedding_dim,
-                    tokenizer_kwargs: if std.startsWith(pretrained_transformer_name, "allegro/herbert")
-                                      then {use_fast: false} else {},
+                    projection_dim: projected_embedding_dim
                 } else {
                     type: "embeddings_projected",
                     embedding_dim: embedding_dim,
@@ -303,7 +300,7 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
         },
         enhanced_dependency_relation: if in_targets("deps") then {
             type: "combo_graph_dependency_parsing_from_vocab",
-            vocab_namespace: 'deprel_labels',
+            vocab_namespace: 'enhanced_deprel_labels',
             head_predictor: {
                 local projection_dim = 512,
                 cycle_loss_n: cycle_loss_n,
diff --git a/combo/data/api.py b/combo/data/api.py
index 4ab7f1a33de77ccf4b17c284777200e39c668081..308e9e4a6e5daabe056f2e6a84ae17ed1dd8fa16 100644
--- a/combo/data/api.py
+++ b/combo/data/api.py
@@ -21,12 +21,13 @@ class Token:
     deps: Optional[str] = None
     misc: Optional[str] = None
     semrel: Optional[str] = None
+    embeddings: Dict[str, List[float]] = field(default_factory=list, repr=False)
 
 
 @dataclass
 class Sentence:
     tokens: List[Token] = field(default_factory=list)
-    sentence_embedding: List[float] = field(default_factory=list)
+    sentence_embedding: List[float] = field(default_factory=list, repr=False)
     metadata: Dict[str, Any] = field(default_factory=collections.OrderedDict)
 
     def to_json(self):
@@ -54,6 +55,7 @@ def sentence2conllu(sentence: Sentence, keep_semrel: bool = True) -> conllu.Toke
         # Remove semrel to have default conllu format.
         if not keep_semrel:
             del token_dict["semrel"]
+        del token_dict["embeddings"]
         tokens.append(token_dict)
     # Range tokens must be tuple not list, this is conllu library requirement
     for t in tokens:
@@ -77,14 +79,16 @@ def tokens2conllu(tokens: List[str]) -> conllu.TokenList:
 
 
 def conllu2sentence(conllu_sentence: conllu.TokenList,
-                    sentence_embedding=None) -> Sentence:
+                    sentence_embedding=None, embeddings=None) -> Sentence:
+    if embeddings is None:
+        embeddings = {}
     if sentence_embedding is None:
         sentence_embedding = []
     tokens = []
     for token in conllu_sentence.tokens:
         tokens.append(
             Token(
-                **token
+                **token, embeddings=embeddings[token["id"]]
             )
         )
     return Sentence(
diff --git a/combo/data/dataset.py b/combo/data/dataset.py
index 870659fe27857d157e5061f77690f61f166330a7..bdc8b20ea42ef9cd25d757cde2829d5e3327efab 100644
--- a/combo/data/dataset.py
+++ b/combo/data/dataset.py
@@ -149,7 +149,7 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
                             sequence_field=text_field_deps,
                             labels=enhanced_deprels,
                             # Label namespace matches regular tree parsing.
-                            label_namespace="deprel_labels",
+                            label_namespace="enhanced_deprel_labels",
                             padding_value=0,
                         )
                     else:
@@ -204,7 +204,7 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
             default_value = [0.0] * classes_count
             padded_tags = util.pad_sequence_to_length(field._indexed_multi_labels, desired_num_tokens,
                                                       lambda: default_value)
-            tensor = torch.LongTensor(padded_tags)
+            tensor = torch.tensor(padded_tags, dtype=torch.long)
             return tensor
 
         return as_tensor
diff --git a/combo/models/base.py b/combo/models/base.py
index a5cb5fe61f85a98f78d143a54695d01948aa8dda..234fbcaaf739b84fe33bb8633411a8aeb0276b5a 100644
--- a/combo/models/base.py
+++ b/combo/models/base.py
@@ -1,11 +1,10 @@
-from typing import Dict, Optional, List, Union
+from typing import Dict, Optional, List, Union, Tuple
 
 import torch
 import torch.nn as nn
 from allennlp import common, data
 from allennlp import nn as allen_nn
 from allennlp.common import checks
-from allennlp.modules import feedforward
 from allennlp.nn import Activation
 
 from combo.models import utils
@@ -51,7 +50,7 @@ class Linear(nn.Linear, common.FromParams):
 class FeedForwardPredictor(Predictor):
     """Feedforward predictor. Should be used on top of Seq2Seq encoder."""
 
-    def __init__(self, feedforward_network: feedforward.FeedForward):
+    def __init__(self, feedforward_network: "FeedForward"):
         super().__init__()
         self.feedforward_network = feedforward_network
 
@@ -63,10 +62,11 @@ class FeedForwardPredictor(Predictor):
         if mask is None:
             mask = x.new_ones(x.size()[:-1])
 
-        x = self.feedforward_network(x)
+        x, feature_maps = self.feedforward_network(x)
         output = {
             "prediction": x.argmax(-1),
-            "probability": x
+            "probability": x,
+            "embedding": feature_maps[-1],
         }
 
         if labels is not None:
@@ -109,9 +109,112 @@ class FeedForwardPredictor(Predictor):
             f"There is not {vocab_namespace} in created vocabs, check if this field has any values to predict!"
         hidden_dims = hidden_dims + [vocab.get_vocab_size(vocab_namespace)]
 
-        return cls(feedforward.FeedForward(
+        return cls(FeedForward(
             input_dim=input_dim,
             num_layers=num_layers,
             hidden_dims=hidden_dims,
             activations=activations,
             dropout=dropout))
+
+
+class FeedForward(torch.nn.Module, common.FromParams):
+    """
+    Modified copy of allennlp.modules.feedforward.FeedForward
+
+    This `Module` is a feed-forward neural network, just a sequence of `Linear` layers with
+    activation functions in between.
+
+    # Parameters
+
+    input_dim : `int`, required
+        The dimensionality of the input.  We assume the input has shape `(batch_size, input_dim)`.
+    num_layers : `int`, required
+        The number of `Linear` layers to apply to the input.
+    hidden_dims : `Union[int, List[int]]`, required
+        The output dimension of each of the `Linear` layers.  If this is a single `int`, we use
+        it for all `Linear` layers.  If it is a `List[int]`, `len(hidden_dims)` must be
+        `num_layers`.
+    activations : `Union[Activation, List[Activation]]`, required
+        The activation function to use after each `Linear` layer.  If this is a single function,
+        we use it after all `Linear` layers.  If it is a `List[Activation]`,
+        `len(activations)` must be `num_layers`. Activation must have torch.nn.Module type.
+    dropout : `Union[float, List[float]]`, optional (default = `0.0`)
+        If given, we will apply this amount of dropout after each layer.  Semantics of `float`
+        versus `List[float]` is the same as with other parameters.
+
+    # Examples
+
+    ```python
+    FeedForward(124, 2, [64, 32], torch.nn.ReLU(), 0.2)
+    #> FeedForward(
+    #>   (_activations): ModuleList(
+    #>     (0): ReLU()
+    #>     (1): ReLU()
+    #>   )
+    #>   (_linear_layers): ModuleList(
+    #>     (0): Linear(in_features=124, out_features=64, bias=True)
+    #>     (1): Linear(in_features=64, out_features=32, bias=True)
+    #>   )
+    #>   (_dropout): ModuleList(
+    #>     (0): Dropout(p=0.2, inplace=False)
+    #>     (1): Dropout(p=0.2, inplace=False)
+    #>   )
+    #> )
+    ```
+    """
+
+    def __init__(
+        self,
+        input_dim: int,
+        num_layers: int,
+        hidden_dims: Union[int, List[int]],
+        activations: Union[Activation, List[Activation]],
+        dropout: Union[float, List[float]] = 0.0,
+    ) -> None:
+
+        super().__init__()
+        if not isinstance(hidden_dims, list):
+            hidden_dims = [hidden_dims] * num_layers  # type: ignore
+        if not isinstance(activations, list):
+            activations = [activations] * num_layers  # type: ignore
+        if not isinstance(dropout, list):
+            dropout = [dropout] * num_layers  # type: ignore
+        if len(hidden_dims) != num_layers:
+            raise checks.ConfigurationError(
+                "len(hidden_dims) (%d) != num_layers (%d)" % (len(hidden_dims), num_layers)
+            )
+        if len(activations) != num_layers:
+            raise checks.ConfigurationError(
+                "len(activations) (%d) != num_layers (%d)" % (len(activations), num_layers)
+            )
+        if len(dropout) != num_layers:
+            raise checks.ConfigurationError(
+                "len(dropout) (%d) != num_layers (%d)" % (len(dropout), num_layers)
+            )
+        self._activations = torch.nn.ModuleList(activations)
+        input_dims = [input_dim] + hidden_dims[:-1]
+        linear_layers = []
+        for layer_input_dim, layer_output_dim in zip(input_dims, hidden_dims):
+            linear_layers.append(torch.nn.Linear(layer_input_dim, layer_output_dim))
+        self._linear_layers = torch.nn.ModuleList(linear_layers)
+        dropout_layers = [torch.nn.Dropout(p=value) for value in dropout]
+        self._dropout = torch.nn.ModuleList(dropout_layers)
+        self._output_dim = hidden_dims[-1]
+        self.input_dim = input_dim
+
+    def get_output_dim(self):
+        return self._output_dim
+
+    def get_input_dim(self):
+        return self.input_dim
+
+    def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
+
+        output = inputs
+        feature_maps = []
+        for layer, activation, dropout in zip(
+            self._linear_layers, self._activations, self._dropout
+        ):
+            feature_maps.append(output)
+            output = dropout(activation(layer(output)))
+        return output, feature_maps
diff --git a/combo/models/embeddings.py b/combo/models/embeddings.py
index d8e9d7a28a7fa36b60d108a30a3286026a327e51..c8499bafd6da01d063210431967d2599c079da3f 100644
--- a/combo/models/embeddings.py
+++ b/combo/models/embeddings.py
@@ -105,16 +105,20 @@ class TransformersWordEmbedder(token_embedders.PretrainedTransformerMismatchedEm
     Tested with Bert (but should work for other models as well).
     """
 
+    authorized_missing_keys = [r"position_ids$"]
+
     def __init__(self,
                  model_name: str,
                  projection_dim: int = 0,
                  projection_activation: Optional[allen_nn.Activation] = lambda x: x,
                  projection_dropout_rate: Optional[float] = 0.0,
                  freeze_transformer: bool = True,
+                 last_layer_only: bool = True,
                  tokenizer_kwargs: Optional[Dict[str, Any]] = None,
                  transformer_kwargs: Optional[Dict[str, Any]] = None):
         super().__init__(model_name,
                          train_parameters=not freeze_transformer,
+                         last_layer_only=last_layer_only,
                          tokenizer_kwargs=tokenizer_kwargs,
                          transformer_kwargs=transformer_kwargs)
         if projection_dim:
diff --git a/combo/models/model.py b/combo/models/model.py
index 9866bcb4fba41ed2506b2d33290e6cd0fe237d29..c648453db5ab951ed68573cf79094f1bf21286d2 100644
--- a/combo/models/model.py
+++ b/combo/models/model.py
@@ -129,6 +129,11 @@ class ComboModel(allen_models.Model):
             "enhanced_head": enhanced_head_pred,
             "enhanced_deprel": enhanced_relations_pred,
             "sentence_embedding": torch.max(encoder_emb, dim=1)[0],
+            "upostag_token_embedding": upos_output["embedding"],
+            "xpostag_token_embedding": xpos_output["embedding"],
+            "semrel_token_embedding": semrel_output["embedding"],
+            "feats_token_embedding": morpho_output["embedding"],
+            "deprel_token_embedding": parser_output["embedding"],
         }
 
         if "rel_probability" in enhanced_parser_output:
@@ -196,8 +201,8 @@ class ComboModel(allen_models.Model):
         if callable_model:
             return callable_model(*args, **kwargs)
         if returns_tuple:
-            return {"prediction": (None, None), "loss": (None, None)}
-        return {"prediction": None, "loss": None}
+            return {"prediction": (None, None), "loss": (None, None), "embedding": (None, None)}
+        return {"prediction": None, "loss": None, "embedding": None}
 
     @staticmethod
     def _clean(output):
diff --git a/combo/models/morpho.py b/combo/models/morpho.py
index ea3451dcd0a1e5656dff718e9988ef7ed4406500..b0d307932937da398e8e265b622190422b00a15a 100644
--- a/combo/models/morpho.py
+++ b/combo/models/morpho.py
@@ -4,7 +4,6 @@ from typing import Dict, List, Optional, Union
 import torch
 from allennlp import data
 from allennlp.common import checks
-from allennlp.modules import feedforward
 from allennlp.nn import Activation
 
 from combo.data import dataset
@@ -15,7 +14,7 @@ from combo.models import base, utils
 class MorphologicalFeatures(base.Predictor):
     """Morphological features predicting model."""
 
-    def __init__(self, feedforward_network: feedforward.FeedForward, slices: Dict[str, List[int]]):
+    def __init__(self, feedforward_network: base.FeedForward, slices: Dict[str, List[int]]):
         super().__init__()
         self.feedforward_network = feedforward_network
         self.slices = slices
@@ -28,7 +27,7 @@ class MorphologicalFeatures(base.Predictor):
         if mask is None:
             mask = x.new_ones(x.size()[:-1])
 
-        x = self.feedforward_network(x)
+        x, feature_maps = self.feedforward_network(x)
 
         prediction = []
         for _, cat_indices in self.slices.items():
@@ -36,7 +35,8 @@ class MorphologicalFeatures(base.Predictor):
 
         output = {
             "prediction": torch.stack(prediction, dim=-1),
-            "probability": x
+            "probability": x,
+            "embedding": feature_maps[-1],
         }
 
         if labels is not None:
@@ -92,7 +92,7 @@ class MorphologicalFeatures(base.Predictor):
         slices = dataset.get_slices_if_not_provided(vocab)
 
         return cls(
-            feedforward_network=feedforward.FeedForward(
+            feedforward_network=base.FeedForward(
                 input_dim=input_dim,
                 num_layers=num_layers,
                 hidden_dims=hidden_dims,
diff --git a/combo/models/parser.py b/combo/models/parser.py
index 511edffc2f8d17edbc3fd0702e6425a4ec645e4e..b16f0adcff066c39558cb8709122780d69ee8702 100644
--- a/combo/models/parser.py
+++ b/combo/models/parser.py
@@ -153,6 +153,7 @@ class DependencyRelationModel(base.Predictor):
         dep_rel_pred = torch.cat((dep_rel_pred, dep_rel_emb), dim=-1)
         relation_prediction = self.relation_prediction_layer(dep_rel_pred)
         output = head_output
+        output["embedding"] = dep_rel_pred
 
         if self.training:
             output["prediction"] = (relation_prediction.argmax(-1)[:, 1:], head_output["prediction"])
diff --git a/combo/predict.py b/combo/predict.py
index e528a186a287c5b9d0f971d6f091240a40231959..01a083727e7768953ba952c53038cc5156adf612 100644
--- a/combo/predict.py
+++ b/combo/predict.py
@@ -82,8 +82,8 @@ class COMBO(predictor.Predictor):
         sentences = []
         predictions = super().predict_batch_instance(instances)
         for prediction, instance in zip(predictions, instances):
-            tree, sentence_embedding = self._predictions_as_tree(prediction, instance)
-            sentence = conllu2sentence(tree, sentence_embedding)
+            tree, sentence_embedding, embeddings = self._predictions_as_tree(prediction, instance)
+            sentence = conllu2sentence(tree, sentence_embedding, embeddings)
             sentences.append(sentence)
         return sentences
 
@@ -96,8 +96,8 @@ class COMBO(predictor.Predictor):
     @overrides
     def predict_instance(self, instance: allen_data.Instance, serialize: bool = True) -> data.Sentence:
         predictions = super().predict_instance(instance)
-        tree, sentence_embedding = self._predictions_as_tree(predictions, instance)
-        return conllu2sentence(tree, sentence_embedding)
+        tree, sentence_embedding, embeddings = self._predictions_as_tree(predictions, instance, )
+        return conllu2sentence(tree, sentence_embedding, embeddings)
 
     @overrides
     def predict_json(self, inputs: common.JsonDict) -> data.Sentence:
@@ -141,6 +141,7 @@ class COMBO(predictor.Predictor):
         tree = instance.fields["metadata"]["input"]
         field_names = instance.fields["metadata"]["field_names"]
         tree_tokens = [t for t in tree if isinstance(t["id"], int)]
+        embeddings = {t["id"]: {} for t in tree}
         for field_name in field_names:
             if field_name not in predictions:
                 continue
@@ -149,6 +150,7 @@ class COMBO(predictor.Predictor):
                 if field_name in {"xpostag", "upostag", "semrel", "deprel"}:
                     value = self.vocab.get_token_from_index(field_predictions[idx], field_name + "_labels")
                     token[field_name] = value
+                    embeddings[token["id"]][field_name] = predictions[f"{field_name}_token_embedding"][idx]
                 elif field_name == "head":
                     token[field_name] = int(field_predictions[idx])
                 elif field_name == "deps":
@@ -174,6 +176,7 @@ class COMBO(predictor.Predictor):
                         field_value = "|".join(np.array(features)[arg_indices].tolist())
 
                     token[field_name] = field_value
+                    embeddings[token["id"]][field_name] = predictions[f"{field_name}_token_embedding"][idx]
                 elif field_name == "lemma":
                     prediction = field_predictions[idx]
                     word_chars = []
@@ -194,19 +197,28 @@ class COMBO(predictor.Predictor):
 
         if "enhanced_head" in predictions and predictions["enhanced_head"]:
             # TODO off-by-one hotfix, refactor
-            h = np.array(predictions["enhanced_head"])
+            sentence_length = len(tree_tokens)
+            h = np.array(predictions["enhanced_head"])[:sentence_length, :sentence_length]
             h = np.concatenate((h[-1:], h[:-1]))
-            r = np.array(predictions["enhanced_deprel_prob"])
+            r = np.array(predictions["enhanced_deprel_prob"])[:sentence_length, :sentence_length, :]
             r = np.concatenate((r[-1:], r[:-1]))
-            graph.sdp_to_dag_deps(arc_scores=h,
-                                  rel_scores=r,
-                                  tree_tokens=tree_tokens,
-                                  root_idx=self.vocab.get_token_index("root", "deprel_labels"),
-                                  vocab_index=self.vocab.get_index_to_token_vocabulary("deprel_labels"))
+
+            graph.graph_and_tree_merge(
+                tree_arc_scores=predictions["head"][:sentence_length],
+                tree_rel_scores=predictions["deprel"][:sentence_length],
+                graph_arc_scores=h,
+                graph_rel_scores=r,
+                idx2label=self.vocab.get_index_to_token_vocabulary("deprel_labels"),
+                label2idx=self.vocab.get_token_to_index_vocabulary("deprel_labels"),
+                graph_idx2label=self.vocab.get_index_to_token_vocabulary("enhanced_deprel_labels"),
+                graph_label2idx=self.vocab.get_token_to_index_vocabulary("enhanced_deprel_labels"),
+                tokens=tree_tokens
+            )
+
             empty_tokens = graph.restore_collapse_edges(tree_tokens)
             tree.tokens.extend(empty_tokens)
 
-        return tree, predictions["sentence_embedding"]
+        return tree, predictions["sentence_embedding"], embeddings
 
     @classmethod
     def with_spacy_tokenizer(cls, model: models.Model,
diff --git a/combo/utils/graph.py b/combo/utils/graph.py
index 3352625e6665ca1cd3196506ed5e50183fedfbb0..f61a68e5b835da0c2ce3dac438425c602b084240 100644
--- a/combo/utils/graph.py
+++ b/combo/utils/graph.py
@@ -1,77 +1,91 @@
 """Based on https://github.com/emorynlp/iwpt-shared-task-2020."""
-from typing import List
 
 import numpy as np
 
+_ACL_REL_CL = "acl:relcl"
 
-def sdp_to_dag_deps(arc_scores, rel_scores, tree_tokens: List, root_idx=0, vocab_index=None) -> None:
-    # adding ROOT
-    tree_heads = [0] + [t["head"] for t in tree_tokens]
-    graph = adjust_root_score_then_add_secondary_arcs(arc_scores, rel_scores, tree_heads,
-                                                      root_idx)
-    for i, (t, g) in enumerate(zip(tree_heads, graph)):
-        if not i:
-            continue
-        rels = [vocab_index.get(x[1], "root") if vocab_index else x[1] for x in g]
-        heads = [x[0] for x in g]
-        head = tree_tokens[i - 1]["head"]
-        index = heads.index(head)
-        deprel = tree_tokens[i - 1]["deprel"]
-        deprel = deprel.split('>')[-1]
-        # TODO - Consider if there should be a condition,
-        # It doesn't seem to make any sense as DEPS should contain DEPREL
-        # (although sometimes with different/more detailed label)
-        # if len(heads) >= 2:
-        #     heads.pop(index)
-        #     rels.pop(index)
-        deps = '|'.join(f'{h}:{r}' for h, r in zip(heads, rels))
-        tree_tokens[i - 1]["deps"] = deps
-        tree_tokens[i - 1]["deprel"] = deprel
-    return
-
-
-def adjust_root_score_then_add_secondary_arcs(arc_scores, rel_scores, tree_heads, root_idx):
-    if len(arc_scores) != tree_heads:
-        arc_scores = arc_scores[:len(tree_heads)][:len(tree_heads)]
-        rel_scores = rel_scores[:len(tree_heads)][:len(tree_heads)]
-    # Self-loops aren't allowed, mask with 0. This is an in-place operation.
-    np.fill_diagonal(arc_scores, 0)
-    parse_preds = np.array(arc_scores) > 0
-    parse_preds[:, 0] = False  # set heads to False
-    rel_scores[:, :, root_idx] = -float('inf')
-    return add_secondary_arcs(arc_scores, rel_scores, tree_heads, root_idx, parse_preds)
 
+def graph_and_tree_merge(tree_arc_scores,
+                         tree_rel_scores,
+                         graph_arc_scores,
+                         graph_rel_scores,
+                         label2idx,
+                         idx2label,
+                         graph_label2idx,
+                         graph_idx2label,
+                         tokens):
+    graph_arc_scores = np.copy(graph_arc_scores)
+    # Exclude self-loops, in-place operation.
+    np.fill_diagonal(graph_arc_scores, 0)
+    # Connection to root will be handled by tree.
+    graph_arc_scores[:, 0] = False
+    # The same with labels.
+    root_idx = graph_label2idx["root"]
+    graph_rel_scores[:, :, root_idx] = -float('inf')
+    graph_rel_pred = graph_rel_scores.argmax(-1)
 
-def add_secondary_arcs(arc_scores, rel_scores, tree_heads, root_idx, parse_preds):
-    if not isinstance(tree_heads, np.ndarray):
-        tree_heads = np.array(tree_heads)
-    dh = np.argwhere(parse_preds)
-    sdh = sorted([(arc_scores[x[0]][x[1]], list(x)) for x in dh], reverse=True)
+    # Add tree edges to graph
+    tree_heads = [0] + tree_arc_scores
     graph = [[] for _ in range(len(tree_heads))]
-    rel_pred = np.argmax(rel_scores, axis=-1)
+    labeled_graph = [[] for _ in range(len(tree_heads))]
     for d, h in enumerate(tree_heads):
-        if d:
+        if not d:
+            continue
+        label = idx2label[tree_rel_scores[d - 1]]
+        # graph_label = graph_idx2label[graph_rel_pred[d - 1][h - 1]]
+        # if ">" in graph_label and label in graph_label:
+        #     print("Using graph label instead of tree.")
+        #     label = graph_label
+        if label != _ACL_REL_CL:
             graph[h].append(d)
-    for s, (d, h) in sdh:
+            labeled_graph[h].append((d, label))
+
+    # Debug only
+    # Extract graph edges
+    graph_edges = np.argwhere(graph_arc_scores)
+
+    # Add graph edges which aren't creating a cycle
+    for (d, h) in graph_edges:
         if not d or not h or d in graph[h]:
             continue
         try:
             path = next(_dfs(graph, d, h))
         except StopIteration:
-            # no path from d to h
+            # There is not path from d to h
+            label = graph_idx2label[graph_rel_pred[d][h]]
+            if label != _ACL_REL_CL:
+                graph[h].append(d)
+                labeled_graph[h].append((d, label))
+
+    # Add 'acl:relcl' without checking for cycles.
+    for d, h in enumerate(tree_heads):
+        if not d:
+            continue
+        label = idx2label[tree_rel_scores[d - 1]]
+        if label == _ACL_REL_CL:
             graph[h].append(d)
+            labeled_graph[h].append((d, label))
+
+    assert len(labeled_graph[0]) == 1
+    d = graph[0][0]
+    graph[d].append(0)
+    labeled_graph[d].append((0, "root"))
+
     parse_graph = [[] for _ in range(len(tree_heads))]
-    num_root = 0
     for h in range(len(tree_heads)):
-        for d in graph[h]:
-            rel = rel_pred[d][h]
-            if h == 0:
-                rel = root_idx
-                assert num_root == 0
-                num_root += 1
-            parse_graph[d].append((h, rel))
+        for d, label in labeled_graph[h]:
+            parse_graph[d].append((h, label))
         parse_graph[d] = sorted(parse_graph[d])
-    return parse_graph
+
+    for i, g in enumerate(parse_graph):
+        heads = np.array([x[0] for x in g])
+        rels = np.array([x[1] for x in g])
+        indices = rels.argsort()
+        heads = heads[indices].tolist()
+        rels = rels[indices].tolist()
+        deps = '|'.join(f'{h}:{r}' for h, r in zip(heads, rels))
+        tokens[i - 1]["deps"] = deps
+    return
 
 
 def _dfs(graph, start, end):
@@ -104,13 +118,32 @@ def restore_collapse_edges(tree_tokens):
                 head, relation = d.split(':', 1)
                 ehead = f"{len(tree_tokens)}.{len(empty_tokens) + 1}"
                 empty_node_relation, current_node_relation = relation.split(">", 1)
-                deps[i] = f"{ehead}:{current_node_relation}"
-                empty_tokens.append(
-                    {
-                        "id": ehead,
-                        "deps": f"{head}:{empty_node_relation}"
-                    }
-                )
+                # Edge case, double >
+                if ">" in current_node_relation:
+                    second_empty_node_relation, current_node_relation = current_node_relation.split(">")
+                    deps[i] = f"{ehead}:{current_node_relation}"
+                    second_ehead = f"{len(tree_tokens)}.{len(empty_tokens) + 2}"
+                    empty_tokens.append(
+                        {
+                            "id": ehead,
+                            "deps": f"{second_ehead}:{empty_node_relation}"
+                        }
+                    )
+                    empty_tokens.append(
+                        {
+                            "id": second_ehead,
+                            "deps": f"{head}:{second_empty_node_relation}"
+                        }
+                    )
+
+                else:
+                    deps[i] = f"{ehead}:{current_node_relation}"
+                    empty_tokens.append(
+                        {
+                            "id": ehead,
+                            "deps": f"{head}:{empty_node_relation}"
+                        }
+                    )
         deps = sorted([d.split(":", 1) for d in deps], key=lambda x: float(x[0]))
         token["deps"] = "|".join([f"{k}:{v}" for k, v in deps])
     return empty_tokens
diff --git a/docs/installation.md b/docs/installation.md
index 6aba7f700f7c4647fb77c935a3e763d2d7aaa6ab..422bed22423873a334c3894e48b02ceed5e4e7b9 100644
--- a/docs/installation.md
+++ b/docs/installation.md
@@ -2,7 +2,7 @@
 Clone this repository and install COMBO (we suggest using virtualenv/conda with Python 3.6+):
 ```bash
 pip install -U pip setuptools wheel
-pip install --index-url https://pypi.clarin-pl.eu/simple combo==1.0.3
+pip install --index-url https://pypi.clarin-pl.eu/simple combo==1.0.4
 combo --helpfull
 ```
 
@@ -11,7 +11,7 @@ combo --helpfull
 python -m venv venv
 source venv/bin/activate
 pip install -U pip setuptools wheel
-pip install --index-url https://pypi.clarin-pl.eu/simple combo==1.0.3
+pip install --index-url https://pypi.clarin-pl.eu/simple combo==1.0.4
 ```
 
 ### Conda example:
diff --git a/docs/training.md b/docs/training.md
index 7d7b0a8256e0c7896bf1372f817ab4e263b4f45e..2cdf75d0803e6ea12e5ce6343d8471579e90bf8c 100644
--- a/docs/training.md
+++ b/docs/training.md
@@ -51,6 +51,8 @@ Enhanced Dependencies are described [here](https://universaldependencies.org/u/o
 ### Data pre-processing
 The organisers of [IWPT20 shared task](https://universaldependencies.org/iwpt20/data.html) distributed the data sets and a data pre-processing script `enhanced_collapse_empty_nodes.pl`. If you wish to train a model on IWPT20 data, apply this script to the training and validation data sets, before training the COMBO EUD model.
 
+The script is part of the [UD tools repository](https://github.com/UniversalDependencies/tools/). 
+
 ```bash
 perl enhanced_collapse_empty_nodes.pl training.conllu > training.fixed.conllu
 ``` 
diff --git a/scripts/evaluate_iwpt21.py b/scripts/evaluate_iwpt21.py
new file mode 100644
index 0000000000000000000000000000000000000000..bfd24eb908178dbb32c1027054b514b4055e327e
--- /dev/null
+++ b/scripts/evaluate_iwpt21.py
@@ -0,0 +1,87 @@
+import pathlib
+
+from absl import app
+from absl import flags
+
+from scripts import utils
+
+CODE2LANG = {
+    "ar": "Arabic",
+    "bg": "Bulgarian",
+    "cs": "Czech",
+    "nl": "Dutch",
+    "en": "English",
+    "et": "Estonian",
+    "fi": "Finnish",
+    "fr": "French",
+    "it": "Italian",
+    "lv": "Latvian",
+    "lt": "Lithuanian",
+    "pl": "Polish",
+    "ru": "Russian",
+    "sk": "Slovak",
+    "sv": "Swedish",
+    "ta": "Tamil",
+    "uk": "Ukrainian",
+}
+
+FLAGS = flags.FLAGS
+flags.DEFINE_string(name="data_dir", default="",
+                    help="Path to IWPT'21 data directory.")
+flags.DEFINE_string(name="models_dir", default="/tmp/",
+                    help="Model serialization dir.")
+flags.DEFINE_integer(name="cuda_device", default=-1,
+                     help="Cuda device id (-1 for cpu).")
+flags.DEFINE_string(name="evaluate_script_path", default="iwpt21_xud_eval.py",
+                    help="Path to 'iwpt21_xud_eval.py' eval script.")
+flags.DEFINE_boolean(name="expect_prefix", default=True,
+                     help="Whether to expect allennlp prefix.")
+
+
+def run(_):
+    models_dir = pathlib.Path(FLAGS.models_dir)
+    for model_dir in models_dir.iterdir():
+        if model_dir.name not in CODE2LANG:
+            print("Skipping unknown directory: ", model_dir.name)
+            continue
+
+        treebank_name = f"UD_{CODE2LANG[model_dir.name]}-IWPT"
+
+        if FLAGS.expect_prefix:
+            model_dir = list(model_dir.iterdir())
+            assert len(model_dir) == 1, f"There is incorrect count of models {model_dir}"
+            model_dir = model_dir[0]
+
+        treebank_dir = pathlib.Path(FLAGS.data_dir) / treebank_name
+        files = list(treebank_dir.iterdir())
+
+        test_file = [f for f in files if "dev" in f.name and ".conllu" in f.name]
+        assert len(test_file) == 1, f"Couldn't find test file."
+        test_file = test_file[0]
+
+        if not (model_dir / "results.txt").exists():
+            output_pred = model_dir / 'predictions.conllu'
+            command = f"""combo --mode predict --model_path {model_dir / 'model.tar.gz'}
+            --input_file {test_file}
+            --output_file {output_pred}
+            --cuda_device {FLAGS.cuda_device}
+            --silent
+            """
+            utils.execute_command(command)
+
+            output_collapsed = utils.path_to_str(output_pred).replace('.conllu', '.collapsed.conllu')
+            utils.collapse_nodes(pathlib.Path(FLAGS.data_dir) / 'tools', output_pred, output_collapsed)
+
+            command = f"""python {FLAGS.evaluate_script_path} -v 
+            {test_file}
+            {output_collapsed} 
+            """
+            utils.execute_command(command, output_file=model_dir / "results.txt")
+
+
+def main():
+    app.run(run)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/scripts/postprocessing.py b/scripts/postprocessing.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f4da16ef6124172493488322c0fc33cc348841d
--- /dev/null
+++ b/scripts/postprocessing.py
@@ -0,0 +1,454 @@
+# TODO lemma remove punctuation - ukrainian
+# TODO lemma remove punctuation - russian
+# TODO consider handling multiple 'case'
+import sys
+
+import conllu
+
+from re import *
+
+rus = compile(u'^из-за$')
+expand = compile('^\d+\.\d+$')
+
+'''
+A script correcting automatically predicted enhanced dependency graphs.
+Running the script: python postprocessing.py cs
+
+You have to modified the paths to the input CoNLL-U file and the output file.
+
+The last argument (e.g. cs) corresponds to the language symbol.
+All language symbols:
+ar (Arabic), bg (Bulgarian), cs (Czech), nl (Dutch), en (English), et (Estonian), fi (Finnish)
+fr (French), it (Italian), lv (Latvian), lt (Lithuanian), pl (Polish), ru (Russian)
+sk (Slovak), sv (Swedish), ta (Tamil), uk (Ukrainian)
+
+There are two main rules:
+1) the first one add case information to the following labels: nmod, obl, acl, advcl. 
+The case information comes from case/mark dependent of the current token and from the morphological feature Case.
+Depending on the language, not all information is added.
+In some languages ('en', 'it', 'nl', 'sv') the lemma of coordinating conjunction (cc) is appendend to the conjunct label (conj). 
+Functions: fix_mod_deps, fix_obj_deps, fix_acl_deps, fix_advcl_deps and fix_conj_deps
+
+2) the second rule correct enhanced edges comming into function words labelled ref, mark, punct, root, case, det, cc, cop, aux
+They should not be assinged other functions. For example, if a token, e.g. "and" is labelled cc (coordinating conjunction), 
+it cannot be simultaneously a subject (nsubj) and if this wrong enhanced edge exists, it should be removed from the graph.
+
+There is one additional rule for Estonian: 
+if the label is nsubj:cop or csubj:cop, the cop sublabel is removed and we have nsubj and csubj, respectively. 
+'''
+
+
+def fix_nmod_deps(dep, token, sentence, relation):
+    """
+    This function modifies enhanced edges labelled 'nmod'
+    """
+    label: str
+    label, head = dep
+
+    # All labels starting with 'relation' are checked
+    if not label.startswith(relation):
+        return dep
+
+    # case_lemma is a (complex) preposition labelled 'case' e.g. 'po' in nmod:po:loc
+    # or a (complex) subordinating conjunction labelled 'mark'
+    case_lemma = None
+    case_tokens = []
+    for t in sentence:
+        if t["deprel"] in ["case", "mark"] and t["head"] == token["id"]:
+            case_tokens.append(t)
+            break
+
+    if case_tokens:
+        fixed_tokens = []
+        for t in sentence:
+            for c in case_tokens:
+                if t["deprel"] == "fixed" and t["head"] == c["id"]:
+                    fixed_tokens.append(t)
+
+        if fixed_tokens:
+            case_lemma = "_".join(rus.sub('изза', f["lemma"]) for f in quicksort(case_tokens + fixed_tokens))
+        else:
+            case_lemma = "_".join(rus.sub('изза', f["lemma"]) for f in quicksort(case_tokens))
+
+    # case_val is a value of Case, e.g. 'gen' in nmod:gen and 'loc' in nmod:po:loc
+    case_val = None
+    if token['feats'] is not None:
+        if 'Case' in token["feats"]:
+            case_val = token["feats"]['Case'].lower()
+
+    #TODO: check for other languages
+    if language in ['fi'] and label not in ['nmod', 'nmod:poss']:
+        return dep
+    elif language not in ['fi'] and label not in ['nmod']:
+        return dep
+    else:
+        label_lst = [label]
+        if case_lemma:
+            label_lst.append(case_lemma)
+        if case_val:
+            #TODO: check for other languages
+            if language not in ['bg', 'en', 'nl', 'sv']:
+                label_lst.append(case_val)
+        label = ":".join(label_lst)
+
+    # print(label, sentence.metadata["sent_id"])
+    return label, head
+
+
+def fix_obl_deps(dep, token, sentence, relation):
+    """
+    This function modifies enhanced edges labelled 'obl', 'obl:arg', 'obl:rel'
+    """
+    label: str
+    label, head = dep
+
+    if not label.startswith(relation):
+        return dep
+
+    # case_lemma is a (complex) preposition labelled 'case' e.g. 'pod' in obl:pod:loc
+    # or a (complex) subordinating conjunction labelled 'mark'
+    case_lemma = None
+    case_tokens = []
+    for t in sentence:
+        if t["deprel"] in ["case", "mark"] and t["head"] == token["id"]:
+            case_tokens.append(t)
+            break
+
+    if case_tokens:
+        # fixed_token is the lemma of a complex preposition, e.g. 'przypadek' in obl:w_przypadku:gen
+        fixed_tokens = []
+        for t in sentence:
+            for c in case_tokens:
+                if t["deprel"] == "fixed" and t["head"] == c["id"]:
+                    fixed_tokens.append(t)
+
+        if fixed_tokens:
+            case_lemma = "_".join(rus.sub('изза', f["lemma"]) for f in quicksort(case_tokens + fixed_tokens))
+        else:
+            case_lemma = "_".join(rus.sub('изза', f["lemma"]) for f in quicksort(case_tokens))
+
+    # case_val is a value of Case feature, e.g. 'loc' in obl:pod:loc
+    case_val = None
+    if token['feats'] is not None:
+        if 'Case' in token["feats"]:
+            case_val = token["feats"]['Case'].lower()
+
+    if label not in ['obl', 'obl:arg', 'obl:agent']:
+        return dep
+    else:
+        label_lst = [label]
+        if case_lemma:
+            label_lst.append(case_lemma)
+            if case_val:
+                # TODO: check for other languages
+                if language not in ['bg', 'en', 'lv', 'nl', 'sv']:
+                    label_lst.append(case_val)
+        # TODO: check it for other languages
+        if language not in ['pl', 'sv']:
+            if case_val and not case_lemma:
+                if label == token['deprel']:
+                    label_lst.append(case_val)
+        label = ":".join(label_lst)
+
+    # print(label, sentence.metadata["sent_id"])
+    return label, head
+
+
+def fix_acl_deps(dep, acl_token, sentence, acl, lang):
+    """
+    This function modifies enhanced edges labelled 'acl'
+    """
+    label: str
+    label, head = dep
+
+    if not label.startswith(acl):
+        return dep
+
+    if label.startswith("acl:relcl"):
+        if lang not in ['uk']:
+            return dep
+
+    case_lemma = None
+    case_tokens = []
+    for token in sentence:
+        if token["deprel"] == "mark" and token["head"] == acl_token["id"]:
+            case_tokens.append(token)
+            break
+
+    if case_tokens:
+        fixed_tokens = []
+        for token in sentence:
+            if token["deprel"] == "fixed" and token["head"] == quicksort(case_tokens)[0]["id"]:
+                fixed_tokens.append(token)
+
+        if fixed_tokens:
+            case_lemma = "_".join([t["lemma"] for t in quicksort(case_tokens + fixed_tokens)])
+        else:
+            case_lemma = quicksort(case_tokens)[0]["lemma"]
+
+    if lang in ['uk']:
+        if label not in ['acl', 'acl:relcl']:
+            return dep
+        else:
+            label_lst = [label]
+            if case_lemma:
+                label_lst.append(case_lemma)
+            label = ":".join(label_lst)
+    else:
+        if label not in ['acl']:
+            return dep
+        else:
+            label_lst = [label]
+            if case_lemma:
+                label_lst.append(case_lemma)
+            label = ":".join(label_lst)
+
+    # print(label, sentence.metadata["sent_id"])
+    return label, head
+
+def fix_advcl_deps(dep, advcl_token, sentence, advcl):
+    """
+    This function modifies enhanced edges labelled 'advcl'
+    """
+    label: str
+    label, head = dep
+
+    if not label.startswith(advcl):
+        return dep
+
+    case_lemma = None
+    case_tokens = []
+    # TODO: check for other languages
+    if language in ['bg', 'lt']:
+        for token in sentence:
+            if token["deprel"] in ["mark", "case"] and token["head"] == advcl_token["id"]:
+                case_tokens.append(token)
+    else:
+        for token in sentence:
+            if token["deprel"] == "mark" and token["head"] == advcl_token["id"]:
+                case_tokens.append(token)
+
+    if case_tokens:
+        fixed_tokens = []
+        # TODO: check for other languages
+        if language not in ['bg', 'nl']:
+            for token in sentence:
+                for case in quicksort(case_tokens):
+                    if token["deprel"] == "fixed" and token["head"] == case["id"]:
+                        fixed_tokens.append(token)
+
+        if fixed_tokens:
+            case_lemma = "_".join([t["lemma"] for t in quicksort(case_tokens + fixed_tokens)])
+        else:
+            case_lemma = "_".join([t["lemma"] for t in quicksort(case_tokens)])
+
+    if label not in ['advcl']:
+        return dep
+    else:
+        label_lst = [label]
+        if case_lemma:
+            label_lst.append(case_lemma)
+        label = ":".join(label_lst)
+
+    # print(label, sentence.metadata["sent_id"])
+    return label, head
+
+
+def fix_conj_deps(dep, conj_token, sentence, conj):
+    """
+    This function modifies enhanced edges labelled 'conj' which should be assined the lemma of cc as sublabel
+    """
+    label: str
+    label, head = dep
+
+    if not label.startswith(conj):
+        return dep
+
+    case_lemma = None
+    case_tokens = []
+    for token in sentence:
+        if token["deprel"] == "cc" and token["head"] == conj_token["id"]:
+            case_tokens.append(token)
+
+    if case_tokens:
+        fixed_tokens = []
+        for token in sentence:
+            for case in quicksort(case_tokens):
+                if token["deprel"] == "fixed" and token["head"] == case["id"]:
+                    fixed_tokens.append(token)
+
+        if fixed_tokens:
+            case_lemma = "_".join([t["lemma"] for t in quicksort(case_tokens + fixed_tokens)])
+        else:
+            case_lemma = "_".join([t["lemma"] for t in quicksort(case_tokens)])
+
+    if label not in ['conj']:
+        return dep
+    else:
+        label_lst = [label]
+        if case_lemma:
+            label_lst.append(case_lemma)
+        label = ":".join(label_lst)
+
+    # print(label, sentence.metadata["sent_id"])
+    return label, head
+
+
+
+def quicksort(tokens):
+    if len(tokens) <= 1:
+        return tokens
+    else:
+        return quicksort([x for x in tokens[1:] if int(x["id"]) < int(tokens[0]["id"])]) \
+               + [tokens[0]] \
+               + quicksort([y for y in tokens[1:] if int(y["id"]) >= int(tokens[0]["id"])])
+
+
+language = sys.argv[1]
+errors = 0
+
+input_file = f"./token_test/{language}_pred.fixed.conllu"
+output_file = f"./token_test/{language}.nofixed.conllu"
+with open(input_file) as fh:
+    with open(output_file, "w") as oh:
+        for sentence in conllu.parse_incr(fh):
+            for token in sentence:
+                deps = token["deps"]
+                if deps:
+                    if language not in ['fr']:
+                        for idx, dep in enumerate(deps):
+                            assert len(dep) == 2, dep
+                            new_dep = fix_obl_deps(dep, token, sentence, "obl")
+                            token["deps"][idx] = new_dep
+                            if new_dep[0] != dep[0]:
+                                errors += 1
+                    if language not in ['fr']:
+                        for idx, dep in enumerate(deps):
+                            assert len(dep) == 2, dep
+                            new_dep = fix_nmod_deps(dep, token, sentence, "nmod")
+                            token["deps"][idx] = new_dep
+                            if new_dep[0] != dep[0]:
+                                errors += 1
+                    # TODO: check for other languages
+                    if language not in ['fr', 'lv']:
+                        for idx, dep in enumerate(deps):
+                            assert len(dep) == 2, dep
+                            new_dep = fix_acl_deps(dep, token, sentence, "acl", language)
+                            token["deps"][idx] = new_dep
+                            if new_dep[0] != dep[0]:
+                                errors += 1
+
+                    # TODO: check for other languages
+                    if language not in ['fr', 'lv']:
+                        for idx, dep in enumerate(deps):
+                            assert len(dep) == 2, dep
+                            new_dep = fix_advcl_deps(dep, token, sentence, "advcl")
+                            token["deps"][idx] = new_dep
+                            if new_dep[0] != dep[0]:
+                                errors += 1
+                    # TODO: check for other languages
+                    if language in ['en', 'it', 'nl', 'sv']:
+                        for idx, dep in enumerate(deps):
+                            assert len(dep) == 2, dep
+                            new_dep = fix_conj_deps(dep, token, sentence, "conj")
+                            token["deps"][idx] = new_dep
+                            if new_dep[0] != dep[0]:
+                                errors += 1
+                    # TODO: check for other languages
+                    if language in ['et']:
+                        for idx, dep in enumerate(deps):
+                            assert len(dep) == 2, dep
+                            if token['deprel'] == 'nsubj:cop' and dep[0] == 'nsubj:cop':
+                                new_dep = ('nsubj', dep[1])
+                                token["deps"][idx] = new_dep
+                                if new_dep[0] != dep[0]:
+                                    errors += 1
+                            if token['deprel'] == 'csubj:cop' and dep[0] == 'csubj:cop':
+                                new_dep = ('csubj', dep[1])
+                                token["deps"][idx] = new_dep
+                                if new_dep[0] != dep[0]:
+                                    errors += 1
+                    # BELOW ARE THE RULES FOR CORRECTION OF THE FUNCTION WORDS
+                    # labelled ref, mark, punct, root, case, det, cc, cop, aux
+                    # They should not be assinged other functions
+                    #TODO: to check for other languages
+                    if language in ['ar', 'bg', 'cs', 'en', 'et', 'fi', 'it', 'lt', 'lv', 'nl', 'pl', 'sk', 'sv', 'ru']:
+                        refs = [s for s in deps if s[0] == 'ref']
+                        if refs:
+                            token["deps"] = refs
+                    #TODO: to check for other languages
+                    if language in ['ar', 'bg', 'en', 'et', 'fi', 'it', 'lt', 'nl', 'pl', 'sk', 'sv', 'ta', 'uk', 'fr']:
+                        marks = [s for s in deps if s[0] == 'mark']
+                        if marks and token['deprel'] == 'mark':
+                            token["deps"] = marks
+                    #TODO: to check for other languages
+                    if language in ['ar', 'bg', 'cs', 'en', 'et', 'fi', 'lv', 'nl', 'pl', 'sk', 'sv', 'ta', 'uk', 'fr', 'ru']:
+                        puncts = [s for s in deps if s[0] == 'punct' and s[1] == token['head']]
+                        if puncts and token['deprel'] == 'punct':
+                            token["deps"] = puncts
+                    #TODO: to check for other languages
+                    if language in ['ar', 'lt', 'pl']:
+                        roots = [s for s in deps if s[0] == 'root']
+                        if roots and token['deprel'] == 'root':
+                            token["deps"] = roots
+                    #TODO: to check for other languages
+                    if language in ['en', 'ar', 'bg', 'et', 'fi', 'it', 'lt', 'lv', 'nl', 'pl', 'sk', 'sv', 'ta', 'uk', 'fr']:
+                        cases = [s for s in deps if s[0] == 'case']
+                        if cases and token['deprel'] == 'case':
+                            token["deps"] = cases
+                    #TODO: to check for other languages
+                    if language in ['en', 'ar', 'et', 'fi', 'it', 'lt', 'lv', 'nl', 'pl', 'sk', 'sv', 'ta', 'uk', 'fr', 'ru']:
+                        dets = [s for s in deps if s[0] == 'det']
+                        if dets and token['deprel'] == 'det':
+                            token["deps"] = dets
+                    #TODO: to check for other languages
+                    if language in ['et', 'fi', 'it', 'lv', 'nl', 'pl', 'sk', 'sv', 'uk', 'fr', 'ar', 'ru', 'ta']:
+                        ccs = [s for s in deps if s[0] == 'cc']
+                        if ccs and token['deprel'] == 'cc':
+                            token["deps"] = ccs
+                    #TODO: to check for other languages
+                    if language in ['bg', 'fi','et', 'it', 'sk', 'sv', 'uk', 'nl', 'fr', 'ru']:
+                        cops = [s for s in deps if s[0] == 'cop']
+                        if cops and token['deprel'] == 'cop':
+                            token["deps"] = cops
+                    #TODO: to check for other languages
+                    if language in ['bg', 'et', 'fi', 'it', 'lv', 'pl', 'sv']:
+                        auxs = [s for s in deps if s[0] == 'aux']
+                        if auxs and token['deprel'] == 'aux':
+                            token["deps"] = auxs
+
+                    #TODO: to check for other languages
+                    if language in ['ar', 'bg', 'cs', 'et', 'fi', 'fr', 'lt', 'lv', 'pl', 'sk', 'sv', 'uk', 'ru', 'ta']:
+                        conjs = [s for s in deps if s[0] == 'conj' and s[1] == token['head']]
+                        other = [s for s in deps if s[0] != 'conj']
+                        if conjs and token['deprel'] == 'conj':
+                            token["deps"] = conjs+other
+
+                    #TODO: to check for other languages
+                    # EXTRA rule 1
+                    if language in ['cs', 'et', 'fi', 'lv', 'pl', 'uk']: #ar nl ru
+                        # not use for: lt, bg, fr, sk, ta, sv, en
+                        deprel = [s for s in deps if s[0] == token['deprel'] and s[1] == token['head']]
+                        other_exp = [s for s in deps if type(s[1]) == tuple]
+                        other_noexp = [s for s in deps if s[1] != token['head'] and type(s[1]) != tuple]
+                        if other_exp:
+                            token["deps"] = other_exp+other_noexp
+
+                    # EXTRA rule 2
+                    if language in ['cs', 'lt', 'pl', 'sk', 'uk']: #ar nl ru
+                        conjs = [s for s in deps if s[0] == 'conj' and s[1] == token['head']]
+                        if conjs and len(deps) == 1 and len(conjs) == 1:
+                            for t in sentence:
+                                if t['id'] == conjs[0][1] and t['deprel'] == 'root':
+                                    conjs.append((t['deprel'], t['head']))
+                            token["deps"] = conjs
+
+                    if language in ['ta']:
+                        if token['deprel'] != 'conj':
+                            conjs = [s for s in deps if s[0] == 'conj']
+                            if conjs:
+                                new_dep = [s for s in deps if s[1] == token['head']]
+                                token["deps"] = new_dep
+
+            oh.write(sentence.serialize())
+print(errors)
diff --git a/scripts/predict_iwpt21.py b/scripts/predict_iwpt21.py
new file mode 100644
index 0000000000000000000000000000000000000000..61b9cf72a183c728a612507f21e459589b0005b6
--- /dev/null
+++ b/scripts/predict_iwpt21.py
@@ -0,0 +1,90 @@
+import pathlib
+
+from absl import app
+from absl import flags
+
+from scripts import utils
+
+CODE2LANG = {
+    "ar": "Arabic",
+    "bg": "Bulgarian",
+    "cs": "Czech",
+    "nl": "Dutch",
+    "en": "English",
+    "et": "Estonian",
+    "fi": "Finnish",
+    "fr": "French",
+    "it": "Italian",
+    "lv": "Latvian",
+    "lt": "Lithuanian",
+    "pl": "Polish",
+    "ru": "Russian",
+    "sk": "Slovak",
+    "sv": "Swedish",
+    "ta": "Tamil",
+    "uk": "Ukrainian",
+}
+
+FLAGS = flags.FLAGS
+flags.DEFINE_string(name="data_dir", default="",
+                    help="Path to data directory.")
+flags.DEFINE_string(name="models_dir", default="/tmp/",
+                    help="Model serialization dir.")
+flags.DEFINE_string(name="tools", default="",
+                    help="UD tools path.")
+flags.DEFINE_integer(name="cuda_device", default=-1,
+                     help="Cuda device id (-1 for cpu).")
+flags.DEFINE_boolean(name="expect_prefix", default=True,
+                     help="Whether to expect allennlp prefix.")
+flags.DEFINE_integer(name="batch_size", default=32,
+                     help="Batch size.")
+
+
+def run(_):
+    models_dir = pathlib.Path(FLAGS.models_dir)
+    for model_dir in models_dir.iterdir():
+        lang = model_dir.name
+        if lang not in CODE2LANG:
+            print("Skipping unknown directory: ", lang)
+            continue
+
+        if FLAGS.expect_prefix:
+            model_dir = list(model_dir.iterdir())
+            assert len(model_dir) == 1, f"There is incorrect count of models {model_dir}"
+            model_dir = model_dir[0]
+
+        data_dir = pathlib.Path(FLAGS.data_dir)
+        files = list(data_dir.iterdir())
+        test_file = [f for f in files if f"{lang}.mwt.conllu" == f.name]
+        # Try to use mwt file if it exists
+        if test_file:
+            assert len(test_file) == 1, f"Should be exactly one {lang}.mwt.conllu file."
+            test_file = test_file[0]
+        else:
+            test_file = [f for f in files if f"{lang}.conllu" == f.name]
+            assert len(test_file) == 1, f"Couldn't find test file."
+            test_file = test_file[0]
+
+        output_pred = data_dir / f'{lang}_pred.conllu'
+        command = f"""combo --mode predict --model_path {model_dir / 'model.tar.gz'}
+        --input_file {test_file}
+        --output_file {output_pred}
+        --cuda_device {FLAGS.cuda_device}
+        --batch_size {FLAGS.batch_size}
+        --silent
+        """
+        utils.execute_command(command)
+
+        output_fixed = utils.path_to_str(output_pred).replace('.conllu', '.fixed.conllu')
+        utils.quick_fix(pathlib.Path(FLAGS.tools), output_pred, output_fixed)
+
+        output_collapsed = output_fixed.replace('.fixed.conllu', '.collapsed.conllu')
+        utils.collapse_nodes(pathlib.Path(FLAGS.tools), pathlib.Path(output_fixed), output_collapsed)
+
+
+def main():
+    app.run(run)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/scripts/train.py b/scripts/train.py
index 950ee82184c4dbb4b3a4d9ce31b551ab81439375..b75bbedb258928c42dae898dce883bece4de286a 100644
--- a/scripts/train.py
+++ b/scripts/train.py
@@ -41,7 +41,6 @@ TREEBANKS = [
     "UD_Czech-PUD",
     "UD_Danish-DDT",
     "UD_Dutch-Alpino",
-    #END OF FIRST RUN
     "UD_English-EWT",
     # "UD_Erzya-JR", No training data
     "UD_Estonian-EWT",
@@ -104,7 +103,6 @@ TREEBANKS = [
     "UD_Latvian-LVTB",
     "UD_Lithuanian-ALKSNIS",
     "UD_Lithuanian-HSE",
-    # end batch 2
     "UD_Maltese-MUDT",
     # "UD_Manx-Cadhan", No training data
     "UD_Marathi-UFAL",
diff --git a/scripts/train_iwpt21.py b/scripts/train_iwpt21.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1e427fffcc2eaa34ba11989d365894c42c1f191
--- /dev/null
+++ b/scripts/train_iwpt21.py
@@ -0,0 +1,144 @@
+"""Script to train Enhanced Dependency Parsing models based on IWPT'21 Shared Task data.
+
+For possible requirements, see train_eud.py comments.
+"""
+
+import os
+import pathlib
+from typing import List
+
+from absl import app
+from absl import flags
+
+from scripts import utils
+
+LANG2TREEBANK = {
+    "ar": ["Arabic-PADT"],
+    "bg": ["Bulgarian-BTB"],
+    "cs": ["Czech-FicTree", "Czech-CAC", "Czech-PDT", "Czech-PUD"],
+    "nl": ["Dutch-Alpino", "Dutch-LassySmall"],
+    "en": ["English-EWT", "English-PUD", "English-GUM"],
+    "et": ["Estonian-EDT", "Estonian-EWT"],
+    "fi": ["Finnish-TDT", "Finnish-PUD"],
+    "fr": ["French-Sequoia", "French-FQB"],
+    "it": ["Italian-ISDT"],
+    "lv": ["Latvian-LVTB"],
+    "lt": ["Lithuanian-ALKSNIS"],
+    "pl": ["Polish-LFG", "Polish-PDB", "Polish-PUD"],
+    "ru": ["Russian-SynTagRus"],
+    "sk": ["Slovak-SNK"],
+    "sv": ["Swedish-Talbanken", "Swedish-PUD"],
+    "ta": ["Tamil-TTB"],
+    "uk": ["Ukrainian-IU"],
+}
+
+FLAGS = flags.FLAGS
+flags.DEFINE_list(name="lang", default=list(LANG2TREEBANK.keys()),
+                  help=f"Language of models to train. Possible values: {LANG2TREEBANK.keys()}.")
+flags.DEFINE_string(name="data_dir", default="",
+                    help="Path to IWPT'21 data directory.")
+flags.DEFINE_string(name="serialization_dir", default="/tmp/",
+                    help="Model serialization dir.")
+flags.DEFINE_integer(name="cuda_device", default=-1,
+                     help="Cuda device id (-1 for cpu).")
+
+
+def merge_files(files: List[str], output: pathlib.Path):
+    if not output.exists():
+        os.system(f"cat {' '.join(files)} > {output}")
+
+
+def run(_):
+    languages = FLAGS.lang
+    for lang in languages:
+        assert lang in LANG2TREEBANK, f"'{lang}' must be one of {list(LANG2TREEBANK.keys())}."
+        assert lang in utils.LANG2TRANSFORMER, f"Transformer for '{lang}' isn't defined. See 'LANG2TRANSFORMER' dict."
+        data_dir = pathlib.Path(FLAGS.data_dir)
+        assert data_dir.is_dir(), f"'{data_dir}' is not a directory!"
+
+        treebanks = LANG2TREEBANK[lang]
+        full_language = treebanks[0].split("-")[0]
+        train_paths = []
+        dev_paths = []
+        train_raw_paths = []
+        dev_raw_paths = []
+        # TODO Uncomment when IWPT'21 Shared Task ends.
+        # During shared task duration test data is not available.
+        test_paths = []
+        for treebank in treebanks:
+            treebank_dir = data_dir / f"UD_{treebank}"
+            assert treebank_dir.exists() and treebank_dir.is_dir(), f"'{treebank_dir}' directory doesn't exists."
+            for treebank_file in treebank_dir.iterdir():
+                name = treebank_file.name
+                if "conllu" in name and "fixed" not in name:
+                    output = utils.path_to_str(treebank_file).replace('.conllu', '.fixed.conllu')
+                    if "train" in name:
+                        utils.collapse_nodes(data_dir / 'tools', treebank_file, output)
+                        train_paths.append(output)
+                    elif "dev" in name:
+                        utils.collapse_nodes(data_dir / 'tools', treebank_file, output)
+                        dev_paths.append(output)
+                    # elif "test" in name:
+                    #     collapse_nodes(data_dir, treebank_file, output)
+                    #     test_paths.append(output)
+                if ".txt" in name:
+                    if "train" in name:
+                        train_raw_paths.append(utils.path_to_str(treebank_file))
+                    elif "dev" in name:
+                        dev_raw_paths.append(utils.path_to_str(treebank_file))
+
+        merged_dataset_name = "IWPT"
+        lang_data_dir = pathlib.Path(data_dir / f"UD_{full_language}-{merged_dataset_name}")
+        lang_data_dir.mkdir(exist_ok=True)
+
+        suffix = f"{lang}_{merged_dataset_name}-ud".lower()
+        train_path = lang_data_dir / f"{suffix}-train.conllu"
+        dev_path = lang_data_dir / f"{suffix}-dev.conllu"
+        test_path = lang_data_dir / f"{suffix}-test.conllu"
+        train_raw_path = lang_data_dir / f"{suffix}-train.txt"
+        dev_raw_path = lang_data_dir / f"{suffix}-dev.txt"
+        test_raw_path = lang_data_dir / f"{suffix}-test.txt"
+
+        merge_files(train_paths, output=train_path)
+        merge_files(dev_paths, output=dev_path)
+        # TODO Change to test_paths instead of dev_paths after IWPT'21
+        merge_files(dev_paths, output=test_path)
+
+        merge_files(train_raw_paths, output=train_raw_path)
+        merge_files(dev_raw_paths, output=dev_raw_path)
+        # TODO Change to test_raw_paths instead of dev_paths after IWPT'21
+        merge_files(dev_raw_paths, output=test_raw_path)
+
+        serialization_dir = pathlib.Path(FLAGS.serialization_dir) / lang
+        serialization_dir.mkdir(exist_ok=True, parents=True)
+
+        command = f"""combo --mode train
+        --training_data {train_path}
+        --validation_data {dev_path}
+        --targets feats,upostag,xpostag,head,deprel,lemma,deps
+        --pretrained_transformer_name {utils.LANG2TRANSFORMER[lang]}
+        --serialization_dir {serialization_dir}
+        --cuda_device {FLAGS.cuda_device}
+        --config_path {pathlib.Path.cwd() / 'combo' / 'config.graph.template.jsonnet'}
+        --notensorboard
+        """
+
+        # Datasets without XPOS
+        if lang in {"fr", "ru"}:
+            command = command + " --targets deprel,head,upostag,lemma,feats,deps"
+
+        # Smaller dataset
+        if lang in {"ta"}:
+            command = command + " --word_batch_size 500"
+        else:
+            command = command + " --word_batch_size 2500"
+
+        utils.execute_command("".join(command.splitlines()))
+
+
+def main():
+    app.run(run)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/scripts/utils.py b/scripts/utils.py
index 6ce5a8a5bbd2bff04219e6aa40b2c0c915d4a1c7..09f7591441327c25be5bf2c34931b4e59dbfeedb 100644
--- a/scripts/utils.py
+++ b/scripts/utils.py
@@ -1,16 +1,30 @@
 """Utils for scripts."""
+import pathlib
 import subprocess
 
 LANG2TRANSFORMER = {
     "en": "bert-base-cased",
-    "pl": "allegro/herbert-base-cased",
+    "pl": "allegro/herbert-large-cased",
     "zh": "bert-base-chinese",
     "fi": "TurkuNLP/bert-base-finnish-cased-v1",
     "ko": "kykim/bert-kor-base",
     "de": "dbmdz/bert-base-german-cased",
     "ar": "aubmindlab/bert-base-arabertv2",
     "eu": "ixa-ehu/berteus-base-cased",
-    "tr": "dbmdz/bert-base-turkish-cased"
+    "tr": "dbmdz/bert-base-turkish-cased",
+    "bg": "xlm-roberta-large",
+    "nl": "xlm-roberta-large",
+    "fr": "camembert-base",
+    "it": "xlm-roberta-large",
+    "ru": "xlm-roberta-large",
+    "sv": "xlm-roberta-large",
+    "uk": "xlm-roberta-large",
+    "ta": "xlm-roberta-large",
+    "sk": "xlm-roberta-large",
+    "lt": "xlm-roberta-large",
+    "lv": "xlm-roberta-large",
+    "cs": "xlm-roberta-large",
+    "et": "xlm-roberta-large",
 }
 
 
@@ -21,3 +35,21 @@ def execute_command(command, output_file=None):
             subprocess.run(command, check=True, stdout=f)
     else:
         subprocess.run(command, check=True)
+
+
+def path_to_str(path: pathlib.Path) -> str:
+    return str(path.resolve())
+
+
+def collapse_nodes(data_dir: pathlib.Path, treebank_file: pathlib.Path, output: str):
+    output_path = pathlib.Path(output)
+    if not output_path.exists():
+        execute_command(f"perl {path_to_str(data_dir / 'enhanced_collapse_empty_nodes.pl')} "
+                        f"{path_to_str(treebank_file)}", output)
+
+
+def quick_fix(data_dir: pathlib.Path, treebank_file: pathlib.Path, output: str):
+    output_path = pathlib.Path(output)
+    if not output_path.exists():
+        execute_command(f"perl {path_to_str(data_dir / 'conllu-quick-fix.pl')} "
+                        f"{path_to_str(treebank_file)}", output)
diff --git a/setup.py b/setup.py
index e1354b76cf47145c09f9da140068b9a22f68a4d2..876909dc196e631e2df7c83aaeae6ae53e2cfe17 100644
--- a/setup.py
+++ b/setup.py
@@ -15,7 +15,7 @@ REQUIREMENTS = [
     'scipy<1.6.0;python_version<"3.7"',  # SciPy 1.6.0 works for 3.7+
     'spacy==2.3.2',
     'scikit-learn<=0.23.2',
-    'torch==1.6.0',
+    'torch==1.7.0',
     'tqdm==4.43.0',
     'transformers==4.0.1',
     'urllib3==1.25.11',
@@ -23,7 +23,7 @@ REQUIREMENTS = [
 
 setup(
     name='combo',
-    version='1.0.3',
+    version='1.0.4',
     author='Mateusz Klimaszewski',
     author_email='M.Klimaszewski@ii.pw.edu.pl',
     install_requires=REQUIREMENTS,
diff --git a/tests/utils/test_graph.py b/tests/utils/test_graph.py
deleted file mode 100644
index 74e37446684f68c6d4ea4abe77c69ba9d3ae4c2b..0000000000000000000000000000000000000000
--- a/tests/utils/test_graph.py
+++ /dev/null
@@ -1,106 +0,0 @@
-import unittest
-import combo.utils.graph as graph
-
-import conllu
-import numpy as np
-
-
-class GraphTest(unittest.TestCase):
-
-    def test_adding_empty_graph_with_the_same_labels(self):
-        tree = conllu.TokenList(
-            tokens=[
-                {"head": 0, "deprel": "root", "form": "word1"},
-                {"head": 3, "deprel": "yes", "form": "word2"},
-                {"head": 1, "deprel": "yes", "form": "word3"},
-            ]
-        )
-        vocab_index = {0: "root", 1: "yes", 2: "yes", 3: "yes"}
-        empty_graph = np.zeros((4, 4))
-        graph_labels = np.zeros((4, 4, 4))
-        expected_deps = ["0:root", "3:yes", "1:yes"]
-
-        # when
-        graph.sdp_to_dag_deps(empty_graph, graph_labels, tree.tokens, root_idx=0, vocab_index=vocab_index)
-        actual_deps = [t["deps"] for t in tree.tokens]
-
-        # then
-        self.assertEqual(expected_deps, actual_deps)
-
-    def test_adding_empty_graph_with_different_labels(self):
-        tree = conllu.TokenList(
-            tokens=[
-                {"head": 0, "deprel": "root", "form": "word1"},
-                {"head": 3, "deprel": "tree_label", "form": "word2"},
-                {"head": 1, "deprel": "tree_label", "form": "word3"},
-            ]
-        )
-        vocab_index = {0: "root", 1: "tree_label", 2: "graph_label"}
-        empty_graph = np.zeros((4, 4))
-        graph_labels = np.zeros((4, 4, 3))
-        graph_labels[2][3][2] = 10e10
-        graph_labels[3][1][2] = 10e10
-        expected_deps = ["0:root", "3:graph_label", "1:graph_label"]
-
-        # when
-        graph.sdp_to_dag_deps(empty_graph, graph_labels, tree.tokens, root_idx=0, vocab_index=vocab_index)
-        actual_deps = [t["deps"] for t in tree.tokens]
-
-        # then
-        self.assertEqual(actual_deps, expected_deps)
-
-    def test_extending_tree_with_graph(self):
-        # given
-        tree = conllu.TokenList(
-            tokens=[
-                {"head": 0, "deprel": "root", "form": "word1"},
-                {"head": 1, "deprel": "tree_label", "form": "word2"},
-                {"head": 2, "deprel": "tree_label", "form": "word3"},
-            ]
-        )
-        vocab_index = {0: "root", 1: "tree_label", 2: "graph_label"}
-        arc_scores = np.array([
-            [0, 0, 0, 0],
-            [1, 0, 0, 0],
-            [0, 1, 0, 0],
-            [0, 1, 1, 0],
-        ])
-        graph_labels = np.zeros((4, 4, 3))
-        graph_labels[3][1][2] = 10e10
-        expected_deps = ["0:root", "1:tree_label", "1:graph_label|2:tree_label"]
-
-        # when
-        graph.sdp_to_dag_deps(arc_scores, graph_labels, tree.tokens,  root_idx=0, vocab_index=vocab_index)
-        actual_deps = [t["deps"] for t in tree.tokens]
-
-        # then
-        self.assertEqual(actual_deps, expected_deps)
-
-    def test_extending_tree_with_self_loop_edge_shouldnt_add_edge(self):
-        # given
-        tree = conllu.TokenList(
-            tokens=[
-                {"head": 0, "deprel": "root", "form": "word1"},
-                {"head": 1, "deprel": "tree_label", "form": "word2"},
-                {"head": 2, "deprel": "tree_label", "form": "word3"},
-            ]
-        )
-        vocab_index = {0: "root", 1: "tree_label", 2: "graph_label"}
-        arc_scores = np.array([
-            [0, 0, 0, 0],
-            [1, 0, 0, 0],
-            [0, 1, 0, 0],
-            [0, 0, 1, 1],
-        ])
-        graph_labels = np.zeros((4, 4, 3))
-        graph_labels[3][3][2] = 10e10
-        expected_deps = ["0:root", "1:tree_label", "2:tree_label"]
-        # TODO current actual, adds self-loop
-        # actual_deps = ["0:root", "1:tree_label", "2:tree_label|3:graph_label"]
-
-        # when
-        graph.sdp_to_dag_deps(arc_scores, graph_labels, tree.tokens,  root_idx=0, vocab_index=vocab_index)
-        actual_deps = [t["deps"] for t in tree.tokens]
-
-        # then
-        self.assertEqual(expected_deps, actual_deps)