diff --git a/combo/data/dataset.py b/combo/data/dataset.py
index 2877a65c8f64af347c96807dab9687a39c73b81b..fb770f6aee2ba4ad93763ae49bb030852b65aaf1 100644
--- a/combo/data/dataset.py
+++ b/combo/data/dataset.py
@@ -1,3 +1,4 @@
+import copy
 import logging
 from typing import Union, List, Dict, Iterable, Optional, Any, Tuple
 
@@ -115,29 +116,34 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
                         fields_[target_name] = allen_fields.SequenceLabelField(target_values, text_field,
                                                                                label_namespace=target_name + "_labels")
                     elif target_name == "deps":
-                        heads = [0 if t["head"] == "_" else int(t["head"]) for t in tree_tokens]
-                        deprels = [t["deprel"] for t in tree_tokens]
+                        # Graphs require adding ROOT (AdjacencyField uses sequence length from TextField).
+                        text_field_deps = copy.deepcopy(text_field)
+                        text_field_deps.tokens.insert(0, _Token("ROOT"))
                         enhanced_heads: List[Tuple[int, int]] = []
                         enhanced_deprels: List[str] = []
                         for idx, t in enumerate(tree_tokens):
-                            enhanced_heads.append((idx, heads[idx]))
-                            enhanced_deprels.append(deprels[idx])
                             t_deps = t["deps"]
                             if t_deps and t_deps != "_":
-                                t_deprels, t_heads = zip(*t_deps)
-                                enhanced_heads.extend([(idx, t) for t in t_heads])
-                                enhanced_deprels.extend(t_deprels)
+                                for rel, head in t_deps:
+                                    # EmoryNLP skips the first edge, if there are two edges between the same
+                                    # nodes. Thanks to that one is in a tree and another in a graph.
+                                    # This snippet follows that approach.
+                                    if enhanced_heads and enhanced_heads[-1] == (idx, head):
+                                        enhanced_heads.pop()
+                                        enhanced_deprels.pop()
+                                    enhanced_heads.append((idx, head))
+                                    enhanced_deprels.append(rel)
                         fields_["enhanced_heads"] = allen_fields.AdjacencyField(
                             indices=enhanced_heads,
-                            sequence_field=text_field,
+                            sequence_field=text_field_deps,
                             label_namespace="enhanced_heads_labels",
                             padding_value=0,
                         )
                         fields_["enhanced_deprels"] = allen_fields.AdjacencyField(
                             indices=enhanced_heads,
-                            sequence_field=text_field,
+                            sequence_field=text_field_deps,
                             labels=enhanced_deprels,
-                            # Label namespace should match regular tree parsing.
+                            # Label namespace matches regular tree parsing.
                             label_namespace="deprel_labels",
                             padding_value=0,
                         )
@@ -160,7 +166,9 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
                 token["feats"] = field
 
         # metadata
-        fields_["metadata"] = allen_fields.MetadataField({"input": tree, "field_names": self.fields})
+        fields_["metadata"] = allen_fields.MetadataField({"input": tree,
+                                                          "field_names": self.fields,
+                                                          "tokens": tokens})
 
         return allen_data.Instance(fields_)
 
diff --git a/combo/main.py b/combo/main.py
index 3f7fad44a829b8d514c5c64acd486b99b644a739..374af69bdbc4584d3425c2814ce4873e74233004 100644
--- a/combo/main.py
+++ b/combo/main.py
@@ -33,8 +33,10 @@ flags.DEFINE_string(name="output_file", default="output.log",
 # Training flags
 flags.DEFINE_list(name="training_data_path", default="./tests/fixtures/example.conllu",
                   help="Training data path(s)")
+flags.DEFINE_alias(name="training_data", original_name="training_data_path")
 flags.DEFINE_list(name="validation_data_path", default="",
                   help="Validation data path(s)")
+flags.DEFINE_alias(name="validation_data", original_name="validation_data_path")
 flags.DEFINE_string(name="pretrained_tokens", default="",
                     help="Pretrained tokens embeddings path")
 flags.DEFINE_integer(name="embedding_dim", default=300,
diff --git a/combo/models/__init__.py b/combo/models/__init__.py
index 5aa7b283c70af76a27d97f63783368bdbe4ffa3f..ba8d617510440fd6f0195e4f8ba6606ea8d68f5a 100644
--- a/combo/models/__init__.py
+++ b/combo/models/__init__.py
@@ -1,5 +1,6 @@
 """Models module."""
 from .base import FeedForwardPredictor
+from .graph_parser import GraphDependencyRelationModel
 from .parser import DependencyRelationModel
 from .embeddings import CharacterBasedWordEmbeddings
 from .encoder import ComboEncoder
diff --git a/combo/models/graph_parser.py b/combo/models/graph_parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..a31e6d052dc8a7969bb07bf92a5eb779b7aa24d4
--- /dev/null
+++ b/combo/models/graph_parser.py
@@ -0,0 +1,191 @@
+"""Enhanced dependency parsing models."""
+from typing import Tuple, Dict, Optional, Union, List
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from allennlp import data
+from allennlp.nn import chu_liu_edmonds
+
+from combo.models import base, utils
+
+
+class GraphHeadPredictionModel(base.Predictor):
+    """Head prediction model."""
+
+    def __init__(self,
+                 head_projection_layer: base.Linear,
+                 dependency_projection_layer: base.Linear,
+                 cycle_loss_n: int = 0,
+                 graph_weighting: float = 0.2):
+        super().__init__()
+        self.head_projection_layer = head_projection_layer
+        self.dependency_projection_layer = dependency_projection_layer
+        self.cycle_loss_n = cycle_loss_n
+        self.graph_weighting = graph_weighting
+
+    def forward(self,
+                x: Union[torch.Tensor, List[torch.Tensor]],
+                labels: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
+                mask: Optional[torch.BoolTensor] = None,
+                sample_weights: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]:
+        if mask is None:
+            mask = x.new_ones(x.size()[-1])
+        heads_labels = None
+        if labels is not None and labels[0] is not None:
+            heads_labels = labels
+
+        head_arc_emb = self.head_projection_layer(x)
+        dep_arc_emb = self.dependency_projection_layer(x)
+        x = dep_arc_emb.bmm(head_arc_emb.transpose(2, 1))
+        pred = x.sigmoid() > 0.5
+
+        output = {
+            "prediction": pred,
+            "probability": x
+        }
+
+        if heads_labels is not None:
+            if sample_weights is None:
+                sample_weights = heads_labels.new_ones([mask.size(0)])
+            output["loss"], output["cycle_loss"] = self._loss(x, heads_labels, mask, sample_weights)
+
+        return output
+
+    def _cycle_loss(self, pred: torch.Tensor):
+        BATCH_SIZE, _, _ = pred.size()
+        loss = pred.new_zeros(BATCH_SIZE)
+        # Index from 1: as using non __ROOT__ tokens
+        pred = pred.softmax(-1)[:, 1:, 1:]
+        x = pred
+        for i in range(self.cycle_loss_n):
+            loss += self._batch_trace(x)
+
+            # Don't multiple on last iteration
+            if i < self.cycle_loss_n - 1:
+                x = x.bmm(pred)
+
+        return loss
+
+    @staticmethod
+    def _batch_trace(x: torch.Tensor) -> torch.Tensor:
+        assert len(x.size()) == 3
+        BATCH_SIZE, N, M = x.size()
+        assert N == M
+        identity = x.new_tensor(torch.eye(N))
+        identity = identity.reshape((1, N, N))
+        batch_identity = identity.repeat(BATCH_SIZE, 1, 1)
+        return (x * batch_identity).sum((-1, -2))
+
+    def _loss(self, pred: torch.Tensor, labels: torch.Tensor,  mask: torch.BoolTensor,
+              sample_weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+        BATCH_SIZE, N, M = pred.size()
+        assert N == M
+        SENTENCE_LENGTH = N
+
+        valid_positions = mask.sum()
+
+        result = []
+        true = labels
+        # Ignore first pred dimension as it is ROOT token prediction
+        for i in range(SENTENCE_LENGTH - 1):
+            pred_i = pred[:, i + 1, 1:].reshape(-1)
+            true_i = true[:, i + 1, 1:].reshape(-1)
+            mask_i = mask[:, i]
+            bce_loss = F.binary_cross_entropy_with_logits(pred_i, true_i, reduction="none").mean(-1) * mask_i
+            result.append(bce_loss)
+        cycle_loss = self._cycle_loss(pred)
+        loss = torch.stack(result).transpose(1, 0) * sample_weights.unsqueeze(-1)
+        return loss.sum() / valid_positions + cycle_loss.mean(), cycle_loss.mean()
+
+
+@base.Predictor.register("combo_graph_dependency_parsing_from_vocab", constructor="from_vocab")
+class GraphDependencyRelationModel(base.Predictor):
+    """Dependency relation parsing model."""
+
+    def __init__(self,
+                 head_predictor: GraphHeadPredictionModel,
+                 head_projection_layer: base.Linear,
+                 dependency_projection_layer: base.Linear,
+                 relation_prediction_layer: base.Linear):
+        super().__init__()
+        self.head_predictor = head_predictor
+        self.head_projection_layer = head_projection_layer
+        self.dependency_projection_layer = dependency_projection_layer
+        self.relation_prediction_layer = relation_prediction_layer
+
+    def forward(self,
+                x: Union[torch.Tensor, List[torch.Tensor]],
+                mask: Optional[torch.BoolTensor] = None,
+                labels: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
+                sample_weights: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]:
+        # if mask is not None:
+        #     mask = mask[:, 1:]
+        relations_labels, head_labels, enhanced_heads_labels, enhanced_deprels_labels = None, None, None, None
+        if labels is not None and labels[0] is not None:
+            relations_labels, head_labels, enhanced_heads_labels = labels
+            # if mask is None:
+            #     mask = head_labels.new_ones(head_labels.size())
+
+        head_output = self.head_predictor(x, enhanced_heads_labels, mask, sample_weights)
+        head_pred = head_output["probability"]
+        BATCH_SIZE, LENGTH, _ = head_pred.size()
+
+        head_rel_emb = self.head_projection_layer(x)
+
+        dep_rel_emb = self.dependency_projection_layer(x)
+
+        # All possible edges combinations for each batch
+        # Repeat interleave to have [emb1, emb1 ... (length times) ... emb1, emb2 ... ]
+        head_rel_pred = head_rel_emb.repeat_interleave(LENGTH, -2)
+        # Regular repeat to have all combinations [deprel1, deprel2, ... deprelL, deprel1 ...]
+        dep_rel_pred = dep_rel_emb.repeat(1, LENGTH, 1)
+
+        # All possible edges combinations for each batch
+        dep_rel_pred = torch.cat((head_rel_pred, dep_rel_pred), dim=-1)
+
+        relation_prediction = self.relation_prediction_layer(dep_rel_pred).reshape(BATCH_SIZE, LENGTH, LENGTH, -1)
+        output = head_output
+
+        output["prediction"] = (relation_prediction.argmax(-1), head_output["prediction"])
+
+        if labels is not None and labels[0] is not None:
+            if sample_weights is None:
+                sample_weights = labels.new_ones([mask.size(0)])
+            loss = self._loss(relation_prediction, relations_labels, enhanced_heads_labels, mask, sample_weights)
+            output["loss"] = (loss, head_output["loss"])
+
+        return output
+
+    @staticmethod
+    def _loss(pred: torch.Tensor,
+              true: torch.Tensor,
+              heads_true: torch.Tensor,
+              mask: torch.BoolTensor,
+              sample_weights: torch.Tensor) -> torch.Tensor:
+
+        true = true[true.long() > 0]
+        pred = pred[heads_true.long() == 1]
+        loss = F.cross_entropy(pred, true.long())
+        return loss.sum() / pred.size(0)
+
+    @classmethod
+    def from_vocab(cls,
+                   vocab: data.Vocabulary,
+                   vocab_namespace: str,
+                   head_predictor: GraphHeadPredictionModel,
+                   head_projection_layer: base.Linear,
+                   dependency_projection_layer: base.Linear
+                   ):
+        """Creates parser combining model configuration and vocabulary data."""
+        assert vocab_namespace in vocab.get_namespaces()
+        relation_prediction_layer = base.Linear(
+            in_features=head_projection_layer.get_output_dim() + dependency_projection_layer.get_output_dim(),
+            out_features=vocab.get_vocab_size(vocab_namespace)
+        )
+        return cls(
+            head_predictor=head_predictor,
+            head_projection_layer=head_projection_layer,
+            dependency_projection_layer=dependency_projection_layer,
+            relation_prediction_layer=relation_prediction_layer
+        )
diff --git a/combo/models/model.py b/combo/models/model.py
index ec0f113be9a3c00868f5c9cce3d8f75f789bcb81..124f49a19145a32647b1698637115e238ee7bdd9 100644
--- a/combo/models/model.py
+++ b/combo/models/model.py
@@ -27,6 +27,7 @@ class SemanticMultitaskModel(allen_models.Model):
                  semantic_relation: Optional[base.Predictor] = None,
                  morphological_feat: Optional[base.Predictor] = None,
                  dependency_relation: Optional[base.Predictor] = None,
+                 enhanced_dependency_relation: Optional[base.Predictor] = None,
                  regularizer: allen_nn.RegularizerApplicator = None) -> None:
         super().__init__(vocab, regularizer)
         self.text_field_embedder = text_field_embedder
@@ -39,6 +40,7 @@ class SemanticMultitaskModel(allen_models.Model):
         self.semantic_relation = semantic_relation
         self.morphological_feat = morphological_feat
         self.dependency_relation = dependency_relation
+        self.enhanced_dependency_relation = enhanced_dependency_relation
         self._head_sentinel = torch.nn.Parameter(torch.randn([1, 1, self.seq_encoder.get_output_dim()]))
         self.scores = metrics.SemanticMetrics()
         self._partial_losses = None
@@ -96,7 +98,7 @@ class SemanticMultitaskModel(allen_models.Model):
                                        sample_weights=sample_weights)
         lemma_output = self._optional(self.lemmatizer,
                                       (encoder_emb[:, 1:], sentence.get("char").get("token_characters")
-                                       if sentence.get("char") else None),
+                                      if sentence.get("char") else None),
                                       mask=word_mask[:, 1:],
                                       labels=lemma.get("char").get("token_characters") if lemma else None,
                                       sample_weights=sample_weights)
@@ -106,7 +108,14 @@ class SemanticMultitaskModel(allen_models.Model):
                                        mask=word_mask,
                                        labels=(deprel, head),
                                        sample_weights=sample_weights)
+        enhanced_parser_output = self._optional(self.enhanced_dependency_relation,
+                                                encoder_emb,
+                                                returns_tuple=True,
+                                                mask=word_mask,
+                                                labels=(enhanced_deprels, head, enhanced_heads),
+                                                sample_weights=sample_weights)
         relations_pred, head_pred = parser_output["prediction"]
+        enhanced_relations_pred, enhanced_head_pred = enhanced_parser_output["prediction"]
         output = {
             "upostag": upos_output["prediction"],
             "xpostag": xpos_output["prediction"],
@@ -115,6 +124,8 @@ class SemanticMultitaskModel(allen_models.Model):
             "lemma": lemma_output["prediction"],
             "head": head_pred,
             "deprel": relations_pred,
+            "enhanced_head": enhanced_head_pred,
+            "enhanced_deprel": enhanced_relations_pred,
             "sentence_embedding": torch.max(encoder_emb[:, 1:], dim=1)[0],
         }
 
@@ -136,9 +147,12 @@ class SemanticMultitaskModel(allen_models.Model):
                 "lemma": lemma.get("char").get("token_characters") if lemma else None,
                 "head": head,
                 "deprel": deprel,
+                "enhanced_head": enhanced_heads,
+                "enhanced_deprel": enhanced_deprels,
             }
             self.scores(output, labels, word_mask[:, 1:])
             relations_loss, head_loss = parser_output["loss"]
+            enhanced_relations_loss, enhanced_head_loss = enhanced_parser_output["loss"]
             losses = {
                 "upostag_loss": upos_output["loss"],
                 "xpostag_loss": xpos_output["loss"],
@@ -147,6 +161,8 @@ class SemanticMultitaskModel(allen_models.Model):
                 "lemma_loss": lemma_output["loss"],
                 "head_loss": head_loss,
                 "deprel_loss": relations_loss,
+                "enhanced_head_loss": enhanced_head_loss,
+                "enhanced_deprel_loss": enhanced_relations_loss,
                 # Cycle loss is only for the metrics purposes.
                 "cycle_loss": parser_output.get("cycle_loss")
             }
diff --git a/combo/predict.py b/combo/predict.py
index b6c7172c2477efdea4c5e11e0ea1575450602db0..55f78c3beb4266018b222d3f635d43f806ad576a 100644
--- a/combo/predict.py
+++ b/combo/predict.py
@@ -154,6 +154,10 @@ class SemanticMultitaskPredictor(predictor.Predictor):
                         token[field_name] = value
                     elif field_name in ["head"]:
                         token[field_name] = int(predictions[field_name][idx])
+                    elif field_name == "deps":
+                        # Handled after every other decoding
+                        continue
+
                     elif field_name in ["feats"]:
                         slices = self._model.morphological_feat.slices
                         features = []
@@ -171,8 +175,6 @@ class SemanticMultitaskPredictor(predictor.Predictor):
                             field_value = "|".join(sorted(features))
 
                         token[field_name] = field_value
-                    elif field_name == "head":
-                        pass
                     elif field_name == "lemma":
                         prediction = predictions[field_name][idx]
                         word_chars = []
@@ -191,6 +193,13 @@ class SemanticMultitaskPredictor(predictor.Predictor):
                     else:
                         raise NotImplementedError(f"Unknown field name {field_name}!")
 
+        if "enhanced_head" in predictions and predictions["enhanced_head"]:
+            import combo.utils.graph as graph
+            tree = graph.sdp_to_dag_deps(arc_scores=predictions["enhanced_head"],
+                                         rel_scores=predictions["enhanced_deprel"],
+                                         tree=tree,
+                                         root_label="ROOT")
+
         return tree, predictions["sentence_embedding"]
 
     @classmethod
diff --git a/combo/utils/graph.py b/combo/utils/graph.py
new file mode 100644
index 0000000000000000000000000000000000000000..5970b19a5e458f03046c5d6068ffa7eff71d77f9
--- /dev/null
+++ b/combo/utils/graph.py
@@ -0,0 +1,82 @@
+"""Based on https://github.com/emorynlp/iwpt-shared-task-2020."""
+import numpy as np
+from conllu import TokenList
+
+
+def sdp_to_dag_deps(arc_scores, rel_scores, tree: TokenList, root_label):
+    # adding ROOT
+    tree_tokens = tree.tokens
+    tree_heads = [0] + [t["head"] for t in tree_tokens]
+    graph = adjust_root_score_then_add_secondary_arcs(arc_scores, rel_scores, tree_heads,
+                                                      root_label)
+    for i, (t, g) in enumerate(zip(tree_heads, graph)):
+        if not i:
+            continue
+        rels = [x[1] for x in g]
+        heads = [x[0] for x in g]
+        head = tree_tokens[i - 1]["head"]
+        index = heads.index(head)
+        deprel = tree_tokens[i - 1]["deprel"]
+        deprel = deprel.split('>')[-1]
+        # TODO is this necessary?
+        if len(heads) >= 2:
+            heads.pop(index)
+            rels.pop(index)
+        deps = '|'.join(f'{h}:{r}' for h, r in zip(heads, rels))
+        tree_tokens[i - 1]["deps"] = deps
+        tree_tokens[i - 1]["deprel"] = deprel
+    return tree
+
+
+def adjust_root_score_then_add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_idx):
+    if len(arc_scores) != tree_heads:
+        arc_scores = arc_scores[:len(tree_heads), :len(tree_heads)]
+        rel_labels = rel_labels[:len(tree_heads), :len(tree_heads)]
+    parse_preds = arc_scores > 0
+    parse_preds[:, 0] = False  # set heads to False
+    # rel_labels[:, :, root_idx] = -float('inf')
+    return add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_idx, parse_preds)
+
+
+def add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_label, parse_preds):
+    if not isinstance(tree_heads, np.ndarray):
+        tree_heads = np.array(tree_heads)
+    dh = np.argwhere(parse_preds)
+    sdh = sorted([(arc_scores[x[0], x[1]], list(x)) for x in dh], reverse=True)
+    graph = [[] for _ in range(len(tree_heads))]
+    for d, h in enumerate(tree_heads):
+        if d:
+            graph[h].append(d)
+    for s, (d, h) in sdh:
+        if not d or not h or d in graph[h]:
+            continue
+        try:
+            path = next(_dfs(graph, d, h))
+        except StopIteration:
+            # no path from d to h
+            graph[h].append(d)
+    parse_graph = [[] for _ in range(len(tree_heads))]
+    num_root = 0
+    for h in range(len(tree_heads)):
+        for d in graph[h]:
+            rel = rel_labels[d, h]
+            if h == 0:
+                rel = root_label
+                assert num_root == 0
+                num_root += 1
+            parse_graph[d].append((h, rel))
+        parse_graph[d] = sorted(parse_graph[d])
+    return parse_graph
+
+
+def _dfs(graph, start, end):
+    fringe = [(start, [])]
+    while fringe:
+        state, path = fringe.pop()
+        if path and state == end:
+            yield path
+            continue
+        for next_state in graph[state]:
+            if next_state in path:
+                continue
+            fringe.append((next_state, path + [next_state]))
diff --git a/combo/utils/metrics.py b/combo/utils/metrics.py
index 28f8efa022c22cff7e6d7b2b0d2977ba50b39dc1..ae73db8ff4979addaf845e405ab14af2055246cd 100644
--- a/combo/utils/metrics.py
+++ b/combo/utils/metrics.py
@@ -117,6 +117,8 @@ class AttachmentScores(metrics.Metric):
         mask : `torch.BoolTensor`, optional (default = None).
             A tensor of the same shape as `predicted_indices`.
         """
+        if gold_labels is None or gold_indices is None:
+            return
         detached = self.detach_tensors(
             predicted_indices, predicted_labels, gold_indices, gold_labels, mask
         )
@@ -198,6 +200,7 @@ class SemanticMetrics(metrics.Metric):
         self.feats_score = SequenceBoolAccuracy(prod_last_dim=True)
         self.lemma_score = SequenceBoolAccuracy(prod_last_dim=True)
         self.attachment_scores = AttachmentScores()
+        self.enhanced_attachment_scores = AttachmentScores()
         self.em_score = 0.0
 
     def __call__(  # type: ignore
@@ -215,14 +218,25 @@ class SemanticMetrics(metrics.Metric):
                                gold_labels["head"],
                                gold_labels["deprel"],
                                mask)
+        self.enhanced_attachment_scores(predictions["enhanced_head"],
+                                        predictions["enhanced_deprel"],
+                                        gold_labels["enhanced_head"],
+                                        gold_labels["enhanced_deprel"],
+                                        mask=None)
+        enhanced_indices = (
+            self.enhanced_attachment_scores.correct_indices.reshape(mask.size(0), mask.size(1) + 1, -1)[:, 1:, 1:].sum(
+                -1).reshape(-1).bool()
+            if len(self.enhanced_attachment_scores.correct_indices.size()) > 0
+            else self.enhanced_attachment_scores.correct_indices
+        )
         total = mask.sum()
         correct_indices = (self.upos_score.correct_indices *
                            self.xpos_score.correct_indices *
                            self.semrel_score.correct_indices *
                            self.feats_score.correct_indices *
                            self.lemma_score.correct_indices *
-                           self.attachment_scores.correct_indices
-                           )
+                           self.attachment_scores.correct_indices *
+                           enhanced_indices)
 
         total, correct_indices = self.detach_tensors(total, correct_indices)
         self.em_score = (correct_indices.float().sum() / total).item()
@@ -237,6 +251,8 @@ class SemanticMetrics(metrics.Metric):
             "EM": self.em_score
         }
         metrics_dict.update(self.attachment_scores.get_metric(reset))
+        enhanced_metrics = {f"E{k}": v for k, v in self.enhanced_attachment_scores.get_metric(reset).items()}
+        metrics_dict.update(enhanced_metrics)
         return metrics_dict
 
     def reset(self) -> None:
@@ -246,4 +262,5 @@ class SemanticMetrics(metrics.Metric):
         self.lemma_score.reset()
         self.feats_score.reset()
         self.attachment_scores.reset()
+        self.enhanced_attachment_scores.reset()
         self.em_score = 0.0
diff --git a/config.graph.template.jsonnet b/config.graph.template.jsonnet
new file mode 100644
index 0000000000000000000000000000000000000000..bdb6f0bde07fe8ccffe568b5543259e94fe5142a
--- /dev/null
+++ b/config.graph.template.jsonnet
@@ -0,0 +1,421 @@
+########################################################################################
+#                                 BASIC configuration                                  #
+########################################################################################
+# Training data path, str
+# Must be in CONNLU format (or it's extended version with semantic relation field).
+# Can accepted multiple paths when concatenated with ',', "path1,path2"
+local training_data_path = std.extVar("training_data_path");
+# Validation data path, str
+# Can accepted multiple paths when concatenated with ',', "path1,path2"
+local validation_data_path = if std.length(std.extVar("validation_data_path")) > 0 then std.extVar("validation_data_path");
+# Path to pretrained tokens, str or null
+local pretrained_tokens = if std.length(std.extVar("pretrained_tokens")) > 0 then std.extVar("pretrained_tokens");
+# Name of pretrained transformer model, str or null
+local pretrained_transformer_name = if std.length(std.extVar("pretrained_transformer_name")) > 0 then std.extVar("pretrained_transformer_name");
+# Learning rate value, float
+local learning_rate = 0.002;
+# Number of epochs, int
+local num_epochs = std.parseInt(std.extVar("num_epochs"));
+# Cuda device id, -1 for cpu, int
+local cuda_device = std.parseInt(std.extVar("cuda_device"));
+# Minimum number of words in batch, int
+local word_batch_size = std.parseInt(std.extVar("word_batch_size"));
+# Features used as input, list of str
+# Choice "upostag", "xpostag", "lemma"
+# Required "token", "char"
+local features = std.split(std.extVar("features"), " ");
+# Targets of the model, list of str
+# Choice "feats", "lemma", "upostag", "xpostag", "semrel". "sent"
+# Required "deprel", "head"
+local targets = std.split(std.extVar("targets"), " ");
+# Word embedding dimension, int
+# If pretrained_tokens is not null must much provided dimensionality
+local embedding_dim = std.parseInt(std.extVar("embedding_dim"));
+# Dropout rate on predictors, float
+# All of the models on top of the encoder uses this dropout
+local predictors_dropout = 0.25;
+# Xpostag embedding dimension, int
+# (discarded if xpostag not in features)
+local xpostag_dim = 32;
+# Upostag embedding dimension, int
+# (discarded if upostag not in features)
+local upostag_dim = 32;
+# Feats embedding dimension, int
+# (discarded if feats not in featres)
+local feats_dim = 32;
+# Lemma embedding dimension, int
+# (discarded if lemma not in features)
+local lemma_char_dim = 64;
+# Character embedding dim, int
+local char_dim = 64;
+# Word embedding projection dim, int
+local projected_embedding_dim = 100;
+# Loss weights, dict[str, int]
+local loss_weights = {
+    xpostag: 0.05,
+    upostag: 0.05,
+    lemma: 0.05,
+    feats: 0.2,
+    deprel: 0.8,
+    head: 0.2,
+    semrel: 0.05,
+    enhanced_head: 0.2,
+    enhanced_deprel: 0.8,
+};
+# Encoder hidden size, int
+local hidden_size = 512;
+# Number of layers in the encoder, int
+local num_layers = 2;
+# Cycle loss iterations, int
+local cycle_loss_n = 0;
+# Maximum length of the word, int
+# Shorter words are padded, longer - truncated
+local word_length = 30;
+# Whether to use tensorboard, bool
+local use_tensorboard = if std.extVar("use_tensorboard") == "True" then true else false;
+# Path for tensorboard metrics, str
+local metrics_dir = "./runs";
+
+# Helper functions
+local in_features(name) = !(std.length(std.find(name, features)) == 0);
+local in_targets(name) = !(std.length(std.find(name, targets)) == 0);
+local use_transformer = pretrained_transformer_name != null;
+
+# Verify some configuration requirements
+assert in_features("token"): "Key 'token' must be in features!";
+assert in_features("char"): "Key 'char' must be in features!";
+
+assert in_targets("deprel"): "Key 'deprel' must be in targets!";
+assert in_targets("head"): "Key 'head' must be in targets!";
+
+assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't use pretrained tokens and pretrained transformer at the same time!";
+
+########################################################################################
+#                              ADVANCED configuration                                  #
+########################################################################################
+
+# Detailed dataset, training, vocabulary and model configuration.
+{
+    # Configuration type (default or finetuning), str
+    type: std.extVar('type'),
+    # Datasets used for vocab creation, list of str
+    # Choice "train", "valid"
+    datasets_for_vocab_creation: ['train'],
+    # Path to training data, str
+    train_data_path: training_data_path,
+    # Path to validation data, str
+    validation_data_path: validation_data_path,
+    # Dataset reader configuration (conllu format)
+    dataset_reader: {
+        type: "conllu",
+        features: features,
+        targets: targets,
+        # Whether data contains semantic relation field, bool
+        use_sem: if in_targets("semrel") then true else false,
+        token_indexers: {
+            token: if use_transformer then {
+                type: "pretrained_transformer_mismatched",
+                model_name: pretrained_transformer_name,
+            } else {
+                # SingleIdTokenIndexer, token as single int
+                type: "single_id",
+            },
+            upostag: {
+                type: "single_id",
+                namespace: "upostag",
+                feature_name: "pos_",
+            },
+            xpostag: {
+                type: "single_id",
+                namespace: "xpostag",
+                feature_name: "tag_",
+            },
+            lemma: {
+                type: "characters_const_padding",
+                character_tokenizer: {
+                    start_tokens: ["__START__"],
+                    end_tokens: ["__END__"],
+                },
+                # +2 for start and end token
+                min_padding_length: word_length + 2,
+            },
+            char: {
+                type: "characters_const_padding",
+                character_tokenizer: {
+                    start_tokens: ["__START__"],
+                    end_tokens: ["__END__"],
+                },
+                # +2 for start and end token
+                min_padding_length: word_length + 2,
+            },
+            feats: {
+                type: "feats_indexer",
+            },
+        },
+        lemma_indexers: {
+            char: {
+                type: "characters_const_padding",
+                namespace: "lemma_characters",
+                character_tokenizer: {
+                    start_tokens: ["__START__"],
+                    end_tokens: ["__END__"],
+                },
+                # +2 for start and end token
+                min_padding_length: word_length + 2,
+            },
+        },
+    },
+    # Data loader configuration
+    data_loader: {
+        batch_sampler: {
+            type: "token_count",
+            word_batch_size: word_batch_size,
+        },
+    },
+    # Vocabulary configuration
+    vocabulary: std.prune({
+        type: "from_instances_extended",
+        only_include_pretrained_words: true,
+        pretrained_files: {
+            tokens: pretrained_tokens,
+        },
+        oov_token: "_",
+        padding_token: "__PAD__",
+        non_padded_namespaces: ["head_labels"],
+    }),
+    model: std.prune({
+        type: "semantic_multitask",
+        text_field_embedder: {
+            type: "basic",
+            token_embedders: {
+                xpostag: if in_features("xpostag") then {
+                    type: "embedding",
+                    padding_index: 0,
+                    embedding_dim: xpostag_dim,
+                    vocab_namespace: "xpostag",
+                },
+                upostag: if in_features("upostag") then {
+                    type: "embedding",
+                    padding_index: 0,
+                    embedding_dim: upostag_dim,
+                    vocab_namespace: "upostag",
+                },
+                token: if use_transformer then {
+                    type: "transformers_word_embeddings",
+                    model_name: pretrained_transformer_name,
+                    projection_dim: projected_embedding_dim,
+                } else {
+                    type: "embeddings_projected",
+                    embedding_dim: embedding_dim,
+                    projection_layer: {
+                        in_features: embedding_dim,
+                        out_features: projected_embedding_dim,
+                        dropout_rate: 0.25,
+                        activation: "tanh"
+                    },
+                    vocab_namespace: "tokens",
+                    pretrained_file: pretrained_tokens,
+                    trainable: if pretrained_tokens == null then true else false,
+                },
+                char: {
+                    type: "char_embeddings_from_config",
+                    embedding_dim: char_dim,
+                    dilated_cnn_encoder: {
+                        input_dim: char_dim,
+                        filters: [512, 256, char_dim],
+                        kernel_size: [3, 3, 3],
+                        stride: [1, 1, 1],
+                        padding: [1, 2, 4],
+                        dilation: [1, 2, 4],
+                        activations: ["relu", "relu", "linear"],
+                    },
+                },
+                lemma: if in_features("lemma") then {
+                    type: "char_embeddings_from_config",
+                    embedding_dim: lemma_char_dim,
+                    dilated_cnn_encoder: {
+                        input_dim: lemma_char_dim,
+                        filters: [512, 256, lemma_char_dim],
+                        kernel_size: [3, 3, 3],
+                        stride: [1, 1, 1],
+                        padding: [1, 2, 4],
+                        dilation: [1, 2, 4],
+                        activations: ["relu", "relu", "linear"],
+                    },
+                },
+                feats: if in_features("feats") then {
+                    type: "feats_embedding",
+                    padding_index: 0,
+                    embedding_dim: feats_dim,
+                    vocab_namespace: "feats",
+                },
+            },
+        },
+        loss_weights: loss_weights,
+        seq_encoder: {
+            type: "combo_encoder",
+            layer_dropout_probability: 0.33,
+            stacked_bilstm: {
+                input_size:
+                (char_dim + projected_embedding_dim +
+                (if in_features('xpostag') then xpostag_dim else 0) +
+                (if in_features('lemma') then lemma_char_dim else 0) +
+                (if in_features('upostag') then upostag_dim else 0) +
+                (if in_features('feats') then feats_dim else 0)),
+                hidden_size: hidden_size,
+                num_layers: num_layers,
+                recurrent_dropout_probability: 0.33,
+                layer_dropout_probability: 0.33
+            },
+        },
+        dependency_relation: {
+            type: "combo_dependency_parsing_from_vocab",
+            vocab_namespace: 'deprel_labels',
+            head_predictor: {
+                local projection_dim = 512,
+                cycle_loss_n: cycle_loss_n,
+                head_projection_layer: {
+                    in_features: hidden_size * 2,
+                    out_features: projection_dim,
+                    activation: "tanh",
+                },
+                dependency_projection_layer: {
+                    in_features: hidden_size * 2,
+                    out_features: projection_dim,
+                    activation: "tanh",
+                },
+            },
+            local projection_dim = 128,
+            head_projection_layer: {
+                in_features: hidden_size * 2,
+                out_features: projection_dim,
+                dropout_rate: predictors_dropout,
+                activation: "tanh"
+            },
+            dependency_projection_layer: {
+                in_features: hidden_size * 2,
+                out_features: projection_dim,
+                dropout_rate: predictors_dropout,
+                activation: "tanh"
+            },
+        },
+        enhanced_dependency_relation: if in_targets("deps") then {
+            type: "combo_graph_dependency_parsing_from_vocab",
+            vocab_namespace: 'deprel_labels',
+            head_predictor: {
+                local projection_dim = 512,
+                cycle_loss_n: cycle_loss_n,
+                head_projection_layer: {
+                    in_features: hidden_size * 2,
+                    out_features: projection_dim,
+                    activation: "tanh",
+                },
+                dependency_projection_layer: {
+                    in_features: hidden_size * 2,
+                    out_features: projection_dim,
+                    activation: "tanh",
+                },
+            },
+            local projection_dim = 128,
+            head_projection_layer: {
+                in_features: hidden_size * 2,
+                out_features: projection_dim,
+                dropout_rate: predictors_dropout,
+                activation: "tanh"
+            },
+            dependency_projection_layer: {
+                in_features: hidden_size * 2,
+                out_features: projection_dim,
+                dropout_rate: predictors_dropout,
+                activation: "tanh"
+            },
+        },
+        morphological_feat: if in_targets("feats") then {
+            type: "combo_morpho_from_vocab",
+            vocab_namespace: "feats_labels",
+            input_dim: hidden_size * 2,
+            hidden_dims: [128],
+            activations: ["tanh", "linear"],
+            dropout: [predictors_dropout, 0.0],
+            num_layers: 2,
+        },
+        lemmatizer: if in_targets("lemma") then {
+            type: "combo_lemma_predictor_from_vocab",
+            char_vocab_namespace: "token_characters",
+            lemma_vocab_namespace: "lemma_characters",
+            embedding_dim: 256,
+            input_projection_layer: {
+                in_features: hidden_size * 2,
+                out_features: 32,
+                dropout_rate: predictors_dropout,
+                activation: "tanh"
+            },
+            filters: [256, 256, 256],
+            kernel_size: [3, 3, 3, 1],
+            stride: [1, 1, 1, 1],
+            padding: [1, 2, 4, 0],
+            dilation: [1, 2, 4, 1],
+            activations: ["relu", "relu", "relu", "linear"],
+        },
+        upos_tagger: if in_targets("upostag") then {
+            input_dim: hidden_size * 2,
+            hidden_dims: [64],
+            activations: ["tanh", "linear"],
+            dropout: [predictors_dropout, 0.0],
+            num_layers: 2,
+            vocab_namespace: "upostag_labels"
+        },
+        xpos_tagger: if in_targets("xpostag") then {
+            input_dim: hidden_size * 2,
+            hidden_dims: [128],
+            activations: ["tanh", "linear"],
+            dropout: [predictors_dropout, 0.0],
+            num_layers: 2,
+            vocab_namespace: "xpostag_labels"
+        },
+        semantic_relation: if in_targets("semrel") then {
+            input_dim: hidden_size * 2,
+            hidden_dims: [64],
+            activations: ["tanh", "linear"],
+            dropout: [predictors_dropout, 0.0],
+            num_layers: 2,
+            vocab_namespace: "semrel_labels"
+        },
+        regularizer: {
+            regexes: [
+                [".*conv1d.*", {type: "l2", alpha: 1e-6}],
+                [".*forward.*", {type: "l2", alpha: 1e-6}],
+                [".*backward.*", {type: "l2", alpha: 1e-6}],
+                [".*char_embed.*", {type: "l2", alpha: 1e-5}],
+            ],
+        },
+    }),
+    trainer: std.prune({
+        checkpointer: {
+            type: "finishing_only_checkpointer",
+        },
+        type: "gradient_descent_validate_n",
+        cuda_device: cuda_device,
+        grad_clipping: 5.0,
+        num_epochs: num_epochs,
+        optimizer: {
+            type: "adam",
+            lr: learning_rate,
+            betas: [0.9, 0.9],
+        },
+        patience: 1, # it will  be overwriten by callback
+        epoch_callbacks: [
+            { type: "transfer_patience" },
+        ],
+        learning_rate_scheduler: {
+            type: "combo_scheduler",
+        },
+        tensorboard_writer: if use_tensorboard then {
+            serialization_dir: metrics_dir,
+            should_log_learning_rate: false,
+            should_log_parameter_statistics: false,
+            summary_interval: 100,
+        },
+        validation_metric: "+EM",
+    }),
+}
diff --git a/setup.py b/setup.py
index 9529a0c4f5f2138cf913d94d8665a0810552cdd9..74540e259212f32d5407fea598e0b9e3560e1f6f 100644
--- a/setup.py
+++ b/setup.py
@@ -3,7 +3,7 @@ from setuptools import find_packages, setup
 
 REQUIREMENTS = [
     'absl-py==0.9.0',
-    'allennlp==1.2.0',
+    'allennlp==1.2.1',
     'conllu==2.3.2',
     'dataclasses;python_version<"3.7"',
     'dataclasses-json==0.5.2',
diff --git a/tests/fixtures/example.conllu b/tests/fixtures/example.conllu
index 1125392e17d71f9db09c5236e1fb27ac0968d410..32e0653525e1160135fbe6f05dc2fa07b6a7fd9d 100644
--- a/tests/fixtures/example.conllu
+++ b/tests/fixtures/example.conllu
@@ -4,3 +4,10 @@
 2	Sentence	verylonglemmawhichmustbetruncatedbythesystemto30	NOUN	nom	Number=Sing	0	root	_	_
 3	.	.	PUNCT	.	_	1	punct	_	_
 
+# sent_id = test-s1
+# text = Easy sentence.
+1	Verylongwordwhichmustbetruncatedbythesystemto30	easy	ADJ	adj	AdpType=Prep|Adp	2	amod	_	_
+2	Sentence	verylonglemmawhichmustbetruncatedbythesystemto30	NOUN	nom	Number=Sing	0	root	_	_
+3	.	.	PUNCT	.	_	1	punct	2:mod	_
+4	.	.	PUNCT	.	_	1	punct	2:xmod	_
+
diff --git a/tests/utils/test_graph.py b/tests/utils/test_graph.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a65182000d3595c71531d8835ff346fac41e451
--- /dev/null
+++ b/tests/utils/test_graph.py
@@ -0,0 +1,89 @@
+import unittest
+import combo.utils.graph as graph
+
+import conllu
+import numpy as np
+
+
+class GraphTest(unittest.TestCase):
+
+    def test_adding_empty_graph_with_the_same_labels(self):
+        tree = conllu.TokenList(
+            tokens=[
+                {"head": 2, "deprel": "ROOT", "form": "word1"},
+                {"head": 3, "deprel": "yes", "form": "word2"},
+                {"head": 1, "deprel": "yes", "form": "word3"},
+            ]
+        )
+        empty_graph = np.zeros((4, 4))
+        graph_labels = np.array([
+            ["no", "no", "no", "no"],
+            ["no", "no", "ROOT", "no"],
+            ["no", "no", "no", "yes"],
+            ["no", "yes", "no", "no"],
+        ])
+        root_label = "ROOT"
+        expected_deps = ["2:ROOT", "3:yes", "1:yes"]
+
+        # when
+        tree = graph.sdp_to_dag_deps(empty_graph, graph_labels, tree, root_label)
+        actual_deps = [t["deps"] for t in tree.tokens]
+
+        # then
+        self.assertEqual(actual_deps, expected_deps)
+
+    def test_adding_empty_graph_with_different_labels(self):
+        tree = conllu.TokenList(
+            tokens=[
+                {"head": 2, "deprel": "ROOT", "form": "word1"},
+                {"head": 3, "deprel": "tree_label", "form": "word2"},
+                {"head": 1, "deprel": "tree_label", "form": "word3"},
+            ]
+        )
+        empty_graph = np.zeros((4, 4))
+        graph_labels = np.array([
+            ["no", "no", "no", "no"],
+            ["no", "no", "ROOT", "no"],
+            ["no", "no", "no", "graph_label"],
+            ["no", "graph_label", "no", "no"],
+        ])
+        root_label = "ROOT"
+        expected_deps = ["2:ROOT", "3:graph_label", "1:graph_label"]
+
+        # when
+        tree = graph.sdp_to_dag_deps(empty_graph, graph_labels, tree, root_label)
+        actual_deps = [t["deps"] for t in tree.tokens]
+
+        # then
+        self.assertEqual(actual_deps, expected_deps)
+
+    def test_extending_tree_with_graph(self):
+        # given
+        tree = conllu.TokenList(
+            tokens=[
+                {"head": 0, "deprel": "ROOT", "form": "word1"},
+                {"head": 1, "deprel": "tree_label", "form": "word2"},
+                {"head": 2, "deprel": "tree_label", "form": "word3"},
+            ]
+        )
+        arc_scores = np.array([
+            [0, 0, 0, 0],
+            [1, 0, 0, 0],
+            [0, 1, 0, 0],
+            [0, 1, 1, 0],
+        ])
+        graph_labels = np.array([
+            ["no", "no", "no", "no"],
+            ["ROOT", "no", "no", "no"],
+            ["no", "tree_label", "no", "no"],
+            ["no", "graph_label", "tree_label", "no"],
+        ])
+        root_label = "ROOT"
+        expected_deps = ["0:ROOT", "1:tree_label", "1:graph_label"]
+
+        # when
+        tree = graph.sdp_to_dag_deps(arc_scores, graph_labels, tree, root_label)
+        actual_deps = [t["deps"] for t in tree.tokens]
+
+        # then
+        self.assertEqual(actual_deps, expected_deps)
\ No newline at end of file
diff --git a/tests/utils/test_metrics.py b/tests/utils/test_metrics.py
index 5b8411bb44c4df7baf7ee1c8d751ac06daa320ec..1d1ad3bdfbbf9921891783d834420f029b3e3f07 100644
--- a/tests/utils/test_metrics.py
+++ b/tests/utils/test_metrics.py
@@ -27,12 +27,16 @@ class SemanticMetricsTest(unittest.TestCase):
         self.semrel, self.semrel_l = (("semrel", x) for x in [pred, gold])
         self.head, self.head_l = (("head", x) for x in [pred, gold])
         self.deprel, self.deprel_l = (("deprel", x) for x in [pred, gold])
+        # TODO(mklimasz) Add examples with correct dimension (with ROOT token)
+        self.enhanced_head, self.enhanced_head_l = (("enhanced_head", x) for x in [None, None])
+        self.enhanced_deprel, self.enhanced_deprel_l = (("enhanced_deprel", x) for x in [None, None])
         self.feats, self.feats_l = (("feats", x) for x in [pred_seq, gold_seq])
         self.lemma, self.lemma_l = (("lemma", x) for x in [pred_seq, gold_seq])
         self.predictions = dict(
-            [self.upostag, self.xpostag, self.semrel, self.feats, self.lemma, self.head, self.deprel])
+            [self.upostag, self.xpostag, self.semrel, self.feats, self.lemma, self.head, self.deprel,
+             self.enhanced_head, self.enhanced_deprel])
         self.gold_labels = dict([self.upostag_l, self.xpostag_l, self.semrel_l, self.feats_l, self.lemma_l, self.head_l,
-                                 self.deprel_l])
+                                 self.deprel_l, self.enhanced_head_l, self.enhanced_deprel_l])
         self.eps = 1e-6
 
     def test_every_prediction_correct(self):