diff --git a/combo/combo_model.py b/combo/combo_model.py
index d2312319a930ab43b7db1c8cc3a510fdb9a1b775..2747ad2dd904e40f218dcb2bb316d5cfbbecaf04 100644
--- a/combo/combo_model.py
+++ b/combo/combo_model.py
@@ -155,6 +155,8 @@ class ComboModel(Model, FromParameters):
             "semrel_token_embedding": semrel_output["embedding"],
             "feats_token_embedding": morpho_output["embedding"],
             "deprel_token_embedding": parser_output["embedding"],
+            "deprel_tree_distribution": parser_output["deprel_tree_distribution"],
+            "deprel_label_distribution": parser_output["deprel_label_distribution"]
         }
 
         if "rel_probability" in enhanced_parser_output:
diff --git a/combo/data/api.py b/combo/data/api.py
index ae4ef35a0ee4e8958b8663b0721bb064adb170d7..b04192f04b2db84981ea34cfcfa1297769078f26 100644
--- a/combo/data/api.py
+++ b/combo/data/api.py
@@ -3,6 +3,9 @@ import dataclasses
 import json
 from dataclasses import dataclass, field
 from typing import Optional, List, Dict, Any, Union, Tuple
+
+from conllu.models import Metadata
+
 from combo.data.tokenizers import Token
 
 import conllu
@@ -12,13 +15,17 @@ from overrides import overrides
 class Sentence:
     tokens: List[Token] = field(default_factory=list)
     sentence_embedding: List[float] = field(default_factory=list, repr=False)
+    relation_distribution: List[float] = field(default_factory=list, repr=False)
+    relation_label_distribution: List[float] = field(default_factory=list, repr=False)
     metadata: Dict[str, Any] = field(default_factory=collections.OrderedDict)
 
     def to_json(self):
         return json.dumps({
             "tokens": [dataclasses.asdict(t) for t in self.tokens],
             "sentence_embedding": self.sentence_embedding,
-            "metadata": self.metadata,
+            "relation_distribution": self.relation_distribution,
+            "relation_label_distribution": self.relation_label_distribution,
+            "metadata": self.metadata
         })
 
     def __len__(self):
@@ -46,7 +53,7 @@ def sentence2conllu(sentence: Sentence, keep_semrel: bool = True) -> conllu.mode
                 if len(dep) > 1 and type(dep[1]) == list:
                     dep[1] = tuple(dep[1])
     return _TokenList(tokens=tokens,
-                      metadata=sentence.metadata)
+                      metadata=sentence.metadata if sentence.metadata is None else Metadata())
 
 
 def tokens2conllu(tokens: List[str]) -> conllu.models.TokenList:
@@ -60,13 +67,23 @@ def tokens2conllu(tokens: List[str]) -> conllu.models.TokenList:
 
 def conllu2sentence(tokens: List[Token],
                     sentence_embedding=None, embeddings=None,
-                    metadata=None) -> Sentence:
-    if embeddings is None:
-        embeddings = {}
+                    metadata=None,
+                    relation_distribution=None,
+                    relation_label_distribution=None) -> Sentence:
+    embeddings = embeddings or {}
+    if relation_distribution is None:
+        relation_distribution = []
+    if relation_label_distribution is None:
+        relation_label_distribution = []
     if sentence_embedding is None:
         sentence_embedding = []
+    if embeddings:
+        for token in tokens:
+            token.embeddings = embeddings[token["idx"]]
     return Sentence(
         tokens=tokens,
         sentence_embedding=sentence_embedding,
+        relation_distribution=relation_distribution,
+        relation_label_distribution=relation_label_distribution,
         metadata=metadata
     )
diff --git a/combo/data/dataset.py b/combo/data/dataset.py
index 3993a55c84b9d9486826f756144ef846bedde481..46d8aba42b6a0ca90d9fac80e74da60d18c0c8f9 100644
--- a/combo/data/dataset.py
+++ b/combo/data/dataset.py
@@ -17,7 +17,7 @@ from combo.data.fields.metadata_field import MetadataField
 from combo.data.fields.sequence_label_field import SequenceLabelField
 from combo.data.fields.text_field import TextField
 from combo.data.token_indexers import TokenIndexer
-from combo.models import parser
+from combo.modules import parser
 from combo.utils import checks, pad_sequence_to_length
 
 logger = logging.getLogger(__name__)
diff --git a/combo/modules/parser.py b/combo/modules/parser.py
index ef885f28daf469d86735899ceaf4435c2a047deb..e034566f0fab7913ee348d0bcb519966385f681f 100644
--- a/combo/modules/parser.py
+++ b/combo/modules/parser.py
@@ -174,6 +174,8 @@ class DependencyRelationModel(Predictor):
         relation_prediction = self.relation_prediction_layer(dep_rel_pred)
         output = head_output
         output["embedding"] = dep_rel_pred
+        output["deprel_label_distribution"] = F.softmax(relation_prediction[:, 1:, 1:], dim=-1)
+        output["deprel_tree_distribution"] = head_pred_soft
 
         if self.training:
             output["prediction"] = (relation_prediction.argmax(-1)[:, 1:], head_output["prediction"])
diff --git a/combo/predict.py b/combo/predict.py
index c4218f58e5f2cb3bd88433fa934d0420f9027310..e30889d33eaa59956d08ea73a888b351e00f8103 100644
--- a/combo/predict.py
+++ b/combo/predict.py
@@ -105,8 +105,12 @@ class COMBO(PredictorModule):
         sentences = []
         predictions = super().predict_batch_instance(instances)
         for prediction, instance in zip(predictions, instances):
-            tree, sentence_embedding, embeddings = self._predictions_as_tree(prediction, instance)
-            sentence = conllu2sentence(tree, sentence_embedding, embeddings)
+            (tree, sentence_embedding, embeddings,
+             relation_distribution, relation_label_distribution) = self._predictions_as_tree(prediction, instance)
+            sentence = conllu2sentence(
+                tree, sentence_embedding, embeddings,
+                relation_distribution, relation_label_distribution
+            )
             sentences.append(sentence)
         return sentences
 
@@ -119,8 +123,9 @@ class COMBO(PredictorModule):
     @overrides
     def predict_instance(self, instance: Instance, serialize: bool = True) -> data.Sentence:
         predictions = super().predict_instance(instance)
-        tree, sentence_embedding, embeddings = self._predictions_as_tree(predictions, instance)
-        return conllu2sentence(tree, sentence_embedding, embeddings)
+        (tree, sentence_embedding, embeddings,
+         relation_distribution, relation_label_distribution) = self._predictions_as_tree(predictions, instance)
+        return conllu2sentence(tree, sentence_embedding, embeddings, relation_distribution, relation_label_distribution)
 
     @overrides
     def predict_json(self, inputs: JsonDict) -> data.Sentence:
@@ -166,9 +171,16 @@ class COMBO(PredictorModule):
         field_names = instance.fields["metadata"]["field_names"]
         tree_tokens = [t for t in tree if isinstance(t["idx"], int)]
         embeddings = {t["idx"]: {} for t in tree}
+        deprel_tree_distribution = None
+        deprel_label_distribution = None
         for field_name in field_names:
             if field_name not in predictions:
                 continue
+            if field_name == "deprel":
+                sentence_length = len(tree_tokens)
+                deprel_tree_distribution = np.matrix(predictions["deprel_tree_distribution"])[:sentence_length + 1,
+                                           :sentence_length + 1]
+                deprel_label_distribution = np.matrix(predictions["deprel_label_distribution"])[:sentence_length, :]
             field_predictions = predictions[field_name]
             for idx, token in enumerate(tree_tokens):
                 if field_name in {"xpostag", "upostag", "semrel", "deprel"}:
@@ -242,7 +254,9 @@ class COMBO(PredictorModule):
             empty_tokens = graph.restore_collapse_edges(tree_tokens)
             tree.tokens.extend(empty_tokens)
 
-        return tree, predictions["sentence_embedding"], embeddings
+        return tree, predictions["sentence_embedding"], embeddings, \
+               deprel_tree_distribution, deprel_label_distribution
+
 
     @classmethod
     def with_spacy_tokenizer(cls, model: Model,
@@ -250,7 +264,9 @@ class COMBO(PredictorModule):
         return cls(model, dataset_reader, tokenizers.SpacyTokenizer())
 
     @classmethod
-    def from_pretrained(cls, path: str, tokenizer=tokenizers.SpacyTokenizer(),
+    def from_pretrained(cls,
+                        path: str,
+                        tokenizer=tokenizers.SpacyTokenizer(),
                         batch_size: int = 1024,
                         cuda_device: int = -1):
         if os.path.exists(path):