diff --git a/combo/predict.py b/combo/predict.py
index 6d657f6ae38d2716fccdedf489f2a30de4f78de7..42e8bed149595e06bee8979610e95fc38686ffbf 100644
--- a/combo/predict.py
+++ b/combo/predict.py
@@ -12,7 +12,7 @@ from overrides import overrides
 
 from combo import data
 from combo.data import sentence2conllu, tokens2conllu, conllu2sentence
-from combo.utils import download
+from combo.utils import download, graph
 
 logger = logging.getLogger(__name__)
 
@@ -195,12 +195,11 @@ class SemanticMultitaskPredictor(predictor.Predictor):
                         raise NotImplementedError(f"Unknown field name {field_name}!")
 
         if "enhanced_head" in predictions and predictions["enhanced_head"]:
-            import combo.utils.graph as graph
-            tree = graph.sdp_to_dag_deps(arc_scores=np.array(predictions["enhanced_head"]),
-                                         rel_scores=np.array(predictions["enhanced_deprel"]),
-                                         tree=tree,
-                                         root_label="ROOT",
-                                         vocab_index=self.vocab.get_index_to_token_vocabulary("deprel_labels"))
+            graph.sdp_to_dag_deps(arc_scores=np.array(predictions["enhanced_head"]),
+                                  rel_scores=np.array(predictions["enhanced_deprel"]),
+                                  tree_tokens=tree_tokens,
+                                  root_label="ROOT",
+                                  vocab_index=self.vocab.get_index_to_token_vocabulary("deprel_labels"))
 
         return tree, predictions["sentence_embedding"]
 
diff --git a/combo/utils/graph.py b/combo/utils/graph.py
index fd59e97b38ebaf247f74c801ac09609c7dee6582..814341b5fb473d1ec33dc93ca8b4cef5cff33b3d 100644
--- a/combo/utils/graph.py
+++ b/combo/utils/graph.py
@@ -1,11 +1,11 @@
 """Based on https://github.com/emorynlp/iwpt-shared-task-2020."""
+from typing import List
+
 import numpy as np
-from conllu import TokenList
 
 
-def sdp_to_dag_deps(arc_scores, rel_scores, tree: TokenList, root_label, vocab_index=None):
+def sdp_to_dag_deps(arc_scores, rel_scores, tree_tokens: List, root_label, vocab_index=None) -> None:
     # adding ROOT
-    tree_tokens = tree.tokens
     tree_heads = [0] + [t["head"] for t in tree_tokens]
     graph = adjust_root_score_then_add_secondary_arcs(arc_scores, rel_scores, tree_heads,
                                                       root_label)
@@ -25,7 +25,7 @@ def sdp_to_dag_deps(arc_scores, rel_scores, tree: TokenList, root_label, vocab_i
         deps = '|'.join(f'{h}:{r}' for h, r in zip(heads, rels))
         tree_tokens[i - 1]["deps"] = deps
         tree_tokens[i - 1]["deprel"] = deprel
-    return tree
+    return
 
 
 def adjust_root_score_then_add_secondary_arcs(arc_scores, rel_labels, tree_heads, root_idx):
diff --git a/tests/utils/test_graph.py b/tests/utils/test_graph.py
index 4a65182000d3595c71531d8835ff346fac41e451..4c5c7d33fbd2c11736f902892c72e83065866466 100644
--- a/tests/utils/test_graph.py
+++ b/tests/utils/test_graph.py
@@ -26,7 +26,7 @@ class GraphTest(unittest.TestCase):
         expected_deps = ["2:ROOT", "3:yes", "1:yes"]
 
         # when
-        tree = graph.sdp_to_dag_deps(empty_graph, graph_labels, tree, root_label)
+        graph.sdp_to_dag_deps(empty_graph, graph_labels, tree.tokens, root_label)
         actual_deps = [t["deps"] for t in tree.tokens]
 
         # then
@@ -51,7 +51,7 @@ class GraphTest(unittest.TestCase):
         expected_deps = ["2:ROOT", "3:graph_label", "1:graph_label"]
 
         # when
-        tree = graph.sdp_to_dag_deps(empty_graph, graph_labels, tree, root_label)
+        graph.sdp_to_dag_deps(empty_graph, graph_labels, tree.tokens, root_label)
         actual_deps = [t["deps"] for t in tree.tokens]
 
         # then
@@ -82,8 +82,8 @@ class GraphTest(unittest.TestCase):
         expected_deps = ["0:ROOT", "1:tree_label", "1:graph_label"]
 
         # when
-        tree = graph.sdp_to_dag_deps(arc_scores, graph_labels, tree, root_label)
+        graph.sdp_to_dag_deps(arc_scores, graph_labels, tree.tokens, root_label)
         actual_deps = [t["deps"] for t in tree.tokens]
 
         # then
-        self.assertEqual(actual_deps, expected_deps)
\ No newline at end of file
+        self.assertEqual(actual_deps, expected_deps)