diff --git a/combo/models/graph_parser.py b/combo/models/graph_parser.py
index 82eefde3185ba776c861504d1554dc4fb7cda58d..6dffef52a943e4671aea7ac911b5c25981c3051f 100644
--- a/combo/models/graph_parser.py
+++ b/combo/models/graph_parser.py
@@ -1,9 +1,190 @@
+"""
+Adapted from COMBO.
+Author: Mateusz Klimaszewski
+"""
+
+from typing import List, Optional, Union, Tuple, Dict
+
+from combo import data
+from combo.models import base
 from combo.models.base import Predictor
 
+import torch
+import torch.nn.functional as F
+
 
 class GraphHeadPredictionModel(Predictor):
-    pass
+    """Head prediction model."""
+
+    def __init__(self,
+                 head_projection_layer: base.Linear,
+                 dependency_projection_layer: base.Linear,
+                 cycle_loss_n: int = 0,
+                 graph_weighting: float = 0.2):
+        super().__init__()
+        self.head_projection_layer = head_projection_layer
+        self.dependency_projection_layer = dependency_projection_layer
+        self.cycle_loss_n = cycle_loss_n
+        self.graph_weighting = graph_weighting
+
+    def forward(self,
+                x: Union[torch.Tensor, List[torch.Tensor]],
+                labels: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
+                mask: Optional[torch.BoolTensor] = None,
+                sample_weights: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]:
+        if mask is None:
+            mask = x.new_ones(x.size()[-1])
+        heads_labels = None
+        if labels is not None and labels[0] is not None:
+            heads_labels = labels
+
+        head_arc_emb = self.head_projection_layer(x)
+        dep_arc_emb = self.dependency_projection_layer(x)
+        x = dep_arc_emb.bmm(head_arc_emb.transpose(2, 1))
+        pred = x.sigmoid() > 0.5
+
+        output = {
+            "prediction": pred,
+            "probability": x
+        }
+
+        if heads_labels is not None:
+            if sample_weights is None:
+                sample_weights = heads_labels.new_ones([mask.size(0)])
+            output["loss"], output["cycle_loss"] = self._loss(x, heads_labels, mask, sample_weights)
+
+        return output
+
+    def _cycle_loss(self, pred: torch.Tensor):
+        BATCH_SIZE, _, _ = pred.size()
+        loss = pred.new_zeros(BATCH_SIZE)
+        # Index from 1: as using non __ROOT__ tokens
+        pred = pred.softmax(-1)[:, 1:, 1:]
+        x = pred
+        for i in range(self.cycle_loss_n):
+            loss += self._batch_trace(x)
+
+            # Don't multiple on last iteration
+            if i < self.cycle_loss_n - 1:
+                x = x.bmm(pred)
+
+        return loss
+
+    @staticmethod
+    def _batch_trace(x: torch.Tensor) -> torch.Tensor:
+        assert len(x.size()) == 3
+        BATCH_SIZE, N, M = x.size()
+        assert N == M
+        identity = x.new_tensor(torch.eye(N))
+        identity = identity.reshape((1, N, N))
+        batch_identity = identity.repeat(BATCH_SIZE, 1, 1)
+        return (x * batch_identity).sum((-1, -2))
+
+    def _loss(self, pred: torch.Tensor, labels: torch.Tensor, mask: torch.BoolTensor,
+              sample_weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+        BATCH_SIZE, N, M = pred.size()
+        assert N == M
+        SENTENCE_LENGTH = N
+
+        valid_positions = mask.sum()
+
+        result = []
+        true = labels
+        # Ignore first pred dimension as it is ROOT token prediction
+        for i in range(SENTENCE_LENGTH - 1):
+            pred_i = pred[:, i + 1, 1:].reshape(-1)
+            true_i = true[:, i + 1, 1:].reshape(-1)
+            mask_i = mask[:, i]
+            bce_loss = F.binary_cross_entropy_with_logits(pred_i, true_i, reduction="none").mean(-1) * mask_i
+            result.append(bce_loss)
+        cycle_loss = self._cycle_loss(pred)
+        loss = torch.stack(result).transpose(1, 0) * sample_weights.unsqueeze(-1)
+        return loss.sum() / valid_positions + cycle_loss.mean(), cycle_loss.mean()
 
 
 class GraphDependencyRelationModel(Predictor):
-    pass
+    """Dependency relation parsing model."""
+
+    def __init__(self,
+                 head_predictor: GraphHeadPredictionModel,
+                 head_projection_layer: base.Linear,
+                 dependency_projection_layer: base.Linear,
+                 relation_prediction_layer: base.Linear):
+        super().__init__()
+        self.head_predictor = head_predictor
+        self.head_projection_layer = head_projection_layer
+        self.dependency_projection_layer = dependency_projection_layer
+        self.relation_prediction_layer = relation_prediction_layer
+
+    def forward(self,
+                x: Union[torch.Tensor, List[torch.Tensor]],
+                mask: Optional[torch.BoolTensor] = None,
+                labels: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
+                sample_weights: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]:
+        relations_labels, head_labels, enhanced_heads_labels, enhanced_deprels_labels = None, None, None, None
+        if labels is not None and labels[0] is not None:
+            relations_labels, head_labels, enhanced_heads_labels = labels
+
+        head_output = self.head_predictor(x, enhanced_heads_labels, mask, sample_weights)
+        head_pred = head_output["probability"]
+        BATCH_SIZE, LENGTH, _ = head_pred.size()
+
+        head_rel_emb = self.head_projection_layer(x)
+
+        dep_rel_emb = self.dependency_projection_layer(x)
+
+        # All possible edges combinations for each batch
+        # Repeat interleave to have [emb1, emb1 ... (length times) ... emb1, emb2 ... ]
+        head_rel_pred = head_rel_emb.repeat_interleave(LENGTH, -2)
+        # Regular repeat to have all combinations [deprel1, deprel2, ... deprelL, deprel1 ...]
+        dep_rel_pred = dep_rel_emb.repeat(1, LENGTH, 1)
+
+        # All possible edges combinations for each batch
+        dep_rel_pred = torch.cat((head_rel_pred, dep_rel_pred), dim=-1)
+
+        relation_prediction = self.relation_prediction_layer(dep_rel_pred).reshape(BATCH_SIZE, LENGTH, LENGTH, -1)
+        output = head_output
+
+        output["prediction"] = (relation_prediction.argmax(-1), head_output["prediction"])
+        output["rel_probability"] = relation_prediction
+
+        if labels is not None and labels[0] is not None:
+            if sample_weights is None:
+                sample_weights = labels.new_ones([mask.size(0)])
+            loss = self._loss(relation_prediction, relations_labels, enhanced_heads_labels, mask, sample_weights)
+            output["loss"] = (loss, head_output["loss"])
+
+        return output
+
+    @staticmethod
+    def _loss(pred: torch.Tensor,
+              true: torch.Tensor,
+              heads_true: torch.Tensor,
+              mask: torch.BoolTensor,
+              sample_weights: torch.Tensor) -> torch.Tensor:
+        correct_heads_mask = heads_true.long() == 1
+        true = true[correct_heads_mask]
+        pred = pred[correct_heads_mask]
+        loss = F.cross_entropy(pred, true.long())
+        return loss.sum() / pred.size(0)
+
+    @classmethod
+    def from_vocab(cls,
+                   vocab: data.Vocabulary,
+                   vocab_namespace: str,
+                   head_predictor: GraphHeadPredictionModel,
+                   head_projection_layer: base.Linear,
+                   dependency_projection_layer: base.Linear
+                   ):
+        """Creates parser combining model configuration and vocabulary data."""
+        assert vocab_namespace in vocab.get_namespaces()
+        relation_prediction_layer = base.Linear(
+            in_features=head_projection_layer.get_output_dim() + dependency_projection_layer.get_output_dim(),
+            out_features=vocab.get_vocab_size(vocab_namespace)
+        )
+        return cls(
+            head_predictor=head_predictor,
+            head_projection_layer=head_projection_layer,
+            dependency_projection_layer=dependency_projection_layer,
+            relation_prediction_layer=relation_prediction_layer
+        )