diff --git a/combo/models/parser.py b/combo/models/parser.py
index 486b2481b96bf17bb19fd8557916f21dcb6c4584..4b5b12606874c91e9303060d1c8bb68b3d3ac016 100644
--- a/combo/models/parser.py
+++ b/combo/models/parser.py
@@ -1,4 +1,5 @@
 """Dependency parsing models."""
+import math
 from typing import Tuple, Dict, Optional, Union, List
 
 import numpy as np
@@ -115,11 +116,13 @@ class DependencyRelationModel(base.Predictor):
     """Dependency relation parsing model."""
 
     def __init__(self,
+                 root_idx: int,
                  head_predictor: HeadPredictionModel,
                  head_projection_layer: base.Linear,
                  dependency_projection_layer: base.Linear,
                  relation_prediction_layer: base.Linear):
         super().__init__()
+        self.root_idx = root_idx
         self.head_predictor = head_predictor
         self.head_projection_layer = head_projection_layer
         self.dependency_projection_layer = dependency_projection_layer
@@ -130,6 +133,7 @@ class DependencyRelationModel(base.Predictor):
                 mask: Optional[torch.BoolTensor] = None,
                 labels: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
                 sample_weights: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]:
+        device = x.device
         if mask is not None:
             mask = mask[:, 1:]
         relations_labels, head_labels = None, None
@@ -151,7 +155,23 @@ class DependencyRelationModel(base.Predictor):
         relation_prediction = self.relation_prediction_layer(dep_rel_pred)
         output = head_output
 
-        output["prediction"] = (relation_prediction.argmax(-1)[:, 1:], head_output["prediction"])
+        if self.training:
+            output["prediction"] = (relation_prediction.argmax(-1)[:, 1:], head_output["prediction"])
+        else:
+            # Mask root label whenever head is not 0.
+            relation_prediction_output = relation_prediction[:, 1:]
+            mask = (head_output["prediction"] == 0)
+            vocab_size = relation_prediction_output.size(-1)
+            root_idx = torch.tensor([self.root_idx], device=device)
+            relation_prediction_output[mask] = (relation_prediction_output
+                                                .masked_select(mask.unsqueeze(-1))
+                                                .reshape(-1, vocab_size)
+                                                .index_fill(-1, root_idx, 10e10))
+            relation_prediction_output[~mask] = (relation_prediction_output
+                                                 .masked_select(~(mask.unsqueeze(-1)))
+                                                 .reshape(-1, vocab_size)
+                                                 .index_fill(-1, root_idx, -10e10))
+            output["prediction"] = (relation_prediction_output.argmax(-1), head_output["prediction"])
 
         if labels is not None and labels[0] is not None:
             if sample_weights is None:
@@ -195,5 +215,6 @@ class DependencyRelationModel(base.Predictor):
             head_predictor=head_predictor,
             head_projection_layer=head_projection_layer,
             dependency_projection_layer=dependency_projection_layer,
-            relation_prediction_layer=relation_prediction_layer
+            relation_prediction_layer=relation_prediction_layer,
+            root_idx=vocab.get_token_index("root", vocab_namespace)
         )
diff --git a/combo/utils/graph.py b/combo/utils/graph.py
index 32e7dd944999c9c2f7c8e800289eb2f35d14a4bc..651c14a7d79b7ea3c277b9466f5e050435a7a01b 100644
--- a/combo/utils/graph.py
+++ b/combo/utils/graph.py
@@ -110,5 +110,6 @@ def restore_collapse_edges(tree_tokens):
                         "deps": f"{head}:{empty_node_relation}"
                     }
                 )
-        token["deps"] = "|".join(deps)
+        deps = sorted([d.split(":", 1) for d in deps], key=lambda x: float(x[0]))
+        token["deps"] = "|".join([f"{k}:{v}" for k, v in deps])
     return empty_tokens