diff --git a/combo/predict.py b/combo/predict.py
index 35dbbee3af5720b110c37cee4c4e12f72243c8c4..2e5693f75a635915d4e7ed4f7c063bcda21fa8e2 100644
--- a/combo/predict.py
+++ b/combo/predict.py
@@ -79,7 +79,7 @@ class SemanticMultitaskPredictor(predictor.Predictor):
         for prediction, instance in zip(predictions, instances):
             tree = self._predictions_as_tree(prediction, instance)
             if serialize:
-                tree = tree.serialize()
+                tree = self._serialize(tree)
             tree_json = util.sanitize(tree)
             trees.append(collections.OrderedDict([
                 ("tree", tree_json),
@@ -89,11 +89,13 @@ class SemanticMultitaskPredictor(predictor.Predictor):
     @overrides
     def predict_instance(self, instance: allen_data.Instance, serialize: bool = True) -> common.JsonDict:
         tree = self.predict_instance_as_tree(instance)
+        sentence_embedding = tree.metadata.get("sentence_embedding", [])
         if serialize:
-            tree = tree.serialize()
+            tree = self._serialize(tree)
         tree_json = util.sanitize(tree)
         result = collections.OrderedDict([
             ("tree", tree_json),
+            ("sentence_embedding", sentence_embedding)
         ])
         return result
 
@@ -118,6 +120,7 @@ class SemanticMultitaskPredictor(predictor.Predictor):
         result = collections.OrderedDict([
             ("tree", tree_json),
         ])
+        result["sentence_embedding"] = tree.metadata.get("sentence_embedding", [])
         return result
 
     def predict_instance_as_tree(self, instance: allen_data.Instance) -> conllu.TokenList:
@@ -217,9 +220,15 @@ class SemanticMultitaskPredictor(predictor.Predictor):
                         raise NotImplementedError(f"Unknown field name {field_name}!")
 
         if self._dataset_reader and "sent" in self._dataset_reader._targets:
-            tree.metadata["sentence_embedding"] = str(predictions["sentence_embedding"])
+            tree.metadata["sentence_embedding"] = predictions["sentence_embedding"]
         return tree
 
+    @staticmethod
+    def _serialize(tree: conllu.TokenList):
+        if "sentence_embedding" in tree.metadata:
+            tree.metadata["sentence_embedding"] = str(tree.metadata["sentence_embedding"])
+        return tree.serialize()
+
     @classmethod
     def with_spacy_tokenizer(cls, model: models.Model,
                              dataset_reader: allen_data.DatasetReader):