diff --git a/combo/models/graph_parser.py b/combo/models/graph_parser.py
index a31e6d052dc8a7969bb07bf92a5eb779b7aa24d4..6799d4994222802af3cd1c6b869bd8a760523ffe 100644
--- a/combo/models/graph_parser.py
+++ b/combo/models/graph_parser.py
@@ -148,6 +148,7 @@ class GraphDependencyRelationModel(base.Predictor):
         output = head_output
 
         output["prediction"] = (relation_prediction.argmax(-1), head_output["prediction"])
+        output["rel_probability"] = relation_prediction
 
         if labels is not None and labels[0] is not None:
             if sample_weights is None:
diff --git a/combo/models/model.py b/combo/models/model.py
index 124f49a19145a32647b1698637115e238ee7bdd9..710f72cf8dae932fc8f2b5c92abbaf7639a52ec2 100644
--- a/combo/models/model.py
+++ b/combo/models/model.py
@@ -126,6 +126,7 @@ class SemanticMultitaskModel(allen_models.Model):
             "deprel": relations_pred,
             "enhanced_head": enhanced_head_pred,
             "enhanced_deprel": enhanced_relations_pred,
+            "enhanced_deprel_prob": enhanced_parser_output["rel_probability"],
             "sentence_embedding": torch.max(encoder_emb[:, 1:], dim=1)[0],
         }
 
diff --git a/combo/predict.py b/combo/predict.py
index e262b7061e8ac9dcc1f970bd456d3a3923055ebd..070975f3ed3c7e3e7a75be91d311683bf91ff5f4 100644
--- a/combo/predict.py
+++ b/combo/predict.py
@@ -198,9 +198,9 @@ class SemanticMultitaskPredictor(predictor.Predictor):
 
         if "enhanced_head" in predictions and predictions["enhanced_head"]:
             graph.sdp_to_dag_deps(arc_scores=np.array(predictions["enhanced_head"]),
-                                  rel_scores=np.array(predictions["enhanced_deprel"]),
+                                  rel_scores=np.array(predictions["enhanced_deprel_prob"]),
                                   tree_tokens=tree_tokens,
-                                  root_label="ROOT",
+                                  root_idx=self.vocab.get_token_index("root", "deprel_labels"),
                                   vocab_index=self.vocab.get_index_to_token_vocabulary("deprel_labels"))
 
         return tree, predictions["sentence_embedding"]
diff --git a/combo/utils/graph.py b/combo/utils/graph.py
index 814341b5fb473d1ec33dc93ca8b4cef5cff33b3d..8a55cb9ae37a5c8985ece5dd2b11758859ce751e 100644
--- a/combo/utils/graph.py
+++ b/combo/utils/graph.py
@@ -4,15 +4,15 @@ from typing import List
 import numpy as np
 
 
-def sdp_to_dag_deps(arc_scores, rel_scores, tree_tokens: List, root_label, vocab_index=None) -> None:
+def sdp_to_dag_deps(arc_scores, rel_scores, tree_tokens: List, root_idx=0, vocab_index=None) -> None:
     # adding ROOT
     tree_heads = [0] + [t["head"] for t in tree_tokens]
     graph = adjust_root_score_then_add_secondary_arcs(arc_scores, rel_scores, tree_heads,
-                                                      root_label)
+                                                      root_idx)
     for i, (t, g) in enumerate(zip(tree_heads, graph)):
         if not i:
             continue
-        rels = [vocab_index.get(x[1], "ROOT") if vocab_index else x[1] for x in g]
+        rels = [vocab_index.get(x[1], "root") if vocab_index else x[1] for x in g]
         heads = [x[0] for x in g]
         head = tree_tokens[i - 1]["head"]
         index = heads.index(head)
@@ -28,22 +28,23 @@ def sdp_to_dag_deps(arc_scores, rel_scores, tree_tokens: List, root_label, vocab
     return
 
 
-def adjust_root_score_then_add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_idx):
+def adjust_root_score_then_add_secondary_arcs(arc_scores, rel_scores, tree_heads, root_idx):
     if len(arc_scores) != tree_heads:
         arc_scores = arc_scores[:len(tree_heads)][:len(tree_heads)]
-        rel_labels = rel_labels[:len(tree_heads)][:len(tree_heads)]
+        rel_scores = rel_scores[:len(tree_heads)][:len(tree_heads)]
     parse_preds = np.array(arc_scores) > 0
     parse_preds[:, 0] = False  # set heads to False
-    # rel_labels[:, :, root_idx] = -float('inf')
-    return add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_idx, parse_preds)
+    rel_scores[:, :, root_idx] = -float('inf')
+    return add_secondary_arcs(arc_scores, rel_scores, tree_heads, root_idx, parse_preds)
 
 
-def add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_label, parse_preds):
+def add_secondary_arcs(arc_scores, rel_scores, tree_heads, root_idx, parse_preds):
     if not isinstance(tree_heads, np.ndarray):
         tree_heads = np.array(tree_heads)
     dh = np.argwhere(parse_preds)
     sdh = sorted([(arc_scores[x[0]][x[1]], list(x)) for x in dh], reverse=True)
     graph = [[] for _ in range(len(tree_heads))]
+    rel_pred = np.argmax(rel_scores, axis=-1)
     for d, h in enumerate(tree_heads):
         if d:
             graph[h].append(d)
@@ -59,9 +60,9 @@ def add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_label, parse_pre
     num_root = 0
     for h in range(len(tree_heads)):
         for d in graph[h]:
-            rel = rel_labels[d][h]
+            rel = rel_pred[d][h]
             if h == 0:
-                rel = root_label
+                rel = root_idx
                 assert num_root == 0
                 num_root += 1
             parse_graph[d].append((h, rel))
diff --git a/tests/utils/test_graph.py b/tests/utils/test_graph.py
index 4c5c7d33fbd2c11736f902892c72e83065866466..0a662122019f2f4ada380a685e207856d84a3646 100644
--- a/tests/utils/test_graph.py
+++ b/tests/utils/test_graph.py
@@ -10,48 +10,40 @@ class GraphTest(unittest.TestCase):
     def test_adding_empty_graph_with_the_same_labels(self):
         tree = conllu.TokenList(
             tokens=[
-                {"head": 2, "deprel": "ROOT", "form": "word1"},
+                {"head": 0, "deprel": "root", "form": "word1"},
                 {"head": 3, "deprel": "yes", "form": "word2"},
                 {"head": 1, "deprel": "yes", "form": "word3"},
             ]
         )
+        vocab_index = {0: "root", 1: "yes", 2: "yes", 3: "yes"}
         empty_graph = np.zeros((4, 4))
-        graph_labels = np.array([
-            ["no", "no", "no", "no"],
-            ["no", "no", "ROOT", "no"],
-            ["no", "no", "no", "yes"],
-            ["no", "yes", "no", "no"],
-        ])
-        root_label = "ROOT"
-        expected_deps = ["2:ROOT", "3:yes", "1:yes"]
+        graph_labels = np.zeros((4, 4, 4))
+        expected_deps = ["0:root", "3:yes", "1:yes"]
 
         # when
-        graph.sdp_to_dag_deps(empty_graph, graph_labels, tree.tokens, root_label)
+        graph.sdp_to_dag_deps(empty_graph, graph_labels, tree.tokens, root_idx=0, vocab_index=vocab_index)
         actual_deps = [t["deps"] for t in tree.tokens]
 
         # then
-        self.assertEqual(actual_deps, expected_deps)
+        self.assertEqual(expected_deps, actual_deps)
 
     def test_adding_empty_graph_with_different_labels(self):
         tree = conllu.TokenList(
             tokens=[
-                {"head": 2, "deprel": "ROOT", "form": "word1"},
+                {"head": 0, "deprel": "root", "form": "word1"},
                 {"head": 3, "deprel": "tree_label", "form": "word2"},
                 {"head": 1, "deprel": "tree_label", "form": "word3"},
             ]
         )
+        vocab_index = {0: "root", 1: "tree_label", 2: "graph_label"}
         empty_graph = np.zeros((4, 4))
-        graph_labels = np.array([
-            ["no", "no", "no", "no"],
-            ["no", "no", "ROOT", "no"],
-            ["no", "no", "no", "graph_label"],
-            ["no", "graph_label", "no", "no"],
-        ])
-        root_label = "ROOT"
-        expected_deps = ["2:ROOT", "3:graph_label", "1:graph_label"]
+        graph_labels = np.zeros((4, 4, 3))
+        graph_labels[2][3][2] = 10e10
+        graph_labels[3][1][2] = 10e10
+        expected_deps = ["0:root", "3:graph_label", "1:graph_label"]
 
         # when
-        graph.sdp_to_dag_deps(empty_graph, graph_labels, tree.tokens, root_label)
+        graph.sdp_to_dag_deps(empty_graph, graph_labels, tree.tokens, root_idx=0, vocab_index=vocab_index)
         actual_deps = [t["deps"] for t in tree.tokens]
 
         # then
@@ -61,28 +53,24 @@ class GraphTest(unittest.TestCase):
         # given
         tree = conllu.TokenList(
             tokens=[
-                {"head": 0, "deprel": "ROOT", "form": "word1"},
+                {"head": 0, "deprel": "root", "form": "word1"},
                 {"head": 1, "deprel": "tree_label", "form": "word2"},
                 {"head": 2, "deprel": "tree_label", "form": "word3"},
             ]
         )
+        vocab_index = {0: "root", 1: "tree_label", 2: "graph_label"}
         arc_scores = np.array([
             [0, 0, 0, 0],
             [1, 0, 0, 0],
             [0, 1, 0, 0],
             [0, 1, 1, 0],
         ])
-        graph_labels = np.array([
-            ["no", "no", "no", "no"],
-            ["ROOT", "no", "no", "no"],
-            ["no", "tree_label", "no", "no"],
-            ["no", "graph_label", "tree_label", "no"],
-        ])
-        root_label = "ROOT"
-        expected_deps = ["0:ROOT", "1:tree_label", "1:graph_label"]
+        graph_labels = np.zeros((4, 4, 3))
+        graph_labels[3][1][2] = 10e10
+        expected_deps = ["0:root", "1:tree_label", "1:graph_label"]
 
         # when
-        graph.sdp_to_dag_deps(arc_scores, graph_labels, tree.tokens, root_label)
+        graph.sdp_to_dag_deps(arc_scores, graph_labels, tree.tokens,  root_idx=0, vocab_index=vocab_index)
         actual_deps = [t["deps"] for t in tree.tokens]
 
         # then