diff --git a/combo/predict.py b/combo/predict.py
index 3e81eb539ec5737e345b108619afe5528da058c4..310f48fd5f8c78804344926b655df29b3a3d630b 100644
--- a/combo/predict.py
+++ b/combo/predict.py
@@ -1,7 +1,6 @@
 import collections
 import logging
 import os
-import time
 from typing import List, Union
 
 import conllu
@@ -24,68 +23,74 @@ class SemanticMultitaskPredictor(predictor.Predictor):
     def __init__(self,
                  model: models.Model,
                  dataset_reader: allen_data.DatasetReader,
-                 tokenizer: allen_data.Tokenizer = tokenizers.WhitespaceTokenizer()) -> None:
+                 tokenizer: allen_data.Tokenizer = tokenizers.WhitespaceTokenizer(),
+                 batch_size: int = 500) -> None:
         super().__init__(model, dataset_reader)
-        self.batch_size = 1000
+        self.batch_size = batch_size
         self.vocab = model.vocab
         self._dataset_reader.generate_labels = False
         self._tokenizer = tokenizer
 
-    @overrides
-    def _json_to_instance(self, json_dict: common.JsonDict) -> allen_data.Instance:
-        sentence = json_dict["sentence"]
-        if isinstance(sentence, str):
-            tokens = [t.text for t in self._tokenizer.tokenize(json_dict["sentence"])]
-        elif isinstance(sentence, list):
-            tokens = sentence
-        else:
-            raise ValueError("Input must be either string or list of strings.")
-        tree = self._sentence_to_tree(tokens)
-        return self._dataset_reader.text_to_instance(tree)
+    def __call__(self, sentence: Union[str, List[str], List[List[str]], List[data.Sentence]]):
+        """Depending on the input uses (or ignores) tokenizer.
+        When model isn't only text-based only List[data.Sentence] is possible input.
 
-    @overrides
-    def load_line(self, line: str) -> common.JsonDict:
-        return self._to_input_json(line.replace("\n", "").strip())
+        * str - tokenizer is used
+        * List[str] - tokenizer is used for each string (treated as list of raw sentences)
+        * List[List[str]] - tokenizer isn't used (treated as list of tokenized sentences)
+        * List[data.Sentence] - tokenizer isn't used (treated as list of tokenized sentences)
 
-    @overrides
-    def dump_line(self, outputs: common.JsonDict) -> str:
-        # Check whether serialized (str) tree or token's list
-        # Serialized tree has already separators between lines
-        if type(outputs["tree"]) == str:
-            return str(outputs["tree"])
-        else:
-            return str(outputs["tree"]) + "\n"
+        :param sentence: sentence(s) representation
+        :return: Sentence or List[Sentence] depending on the input
+        """
+        return self.predict(sentence)
 
-    def predict(self, sentence: Union[str, List[str]]):
+    def predict(self, sentence: Union[str, List[str], List[List[str]], List[data.Sentence]]):
         if isinstance(sentence, str):
             return data.Sentence.from_json(self.predict_json({"sentence": sentence}))
         elif isinstance(sentence, list):
-            sentences = []
-            for sentences_batch in util.lazy_groups_of(sentence, self.batch_size):
-                trees = self.predict_batch_json([self._to_input_json(s) for s in sentences_batch])
-                sentences.extend([data.Sentence.from_json(t) for t in trees])
-            return sentences
+            if len(sentence) == 0:
+                return []
+            example = sentence[0]
+            if isinstance(example, str) or isinstance(example, list):
+                sentences = []
+                for sentences_batch in util.lazy_groups_of(sentence, self.batch_size):
+                    trees = self.predict_batch_json([self._to_input_json(s) for s in sentences_batch])
+                    sentences.extend([data.Sentence.from_json(t) for t in trees])
+                return sentences
+            elif isinstance(example, data.Sentence):
+                sentences = []
+                for sentences_batch in util.lazy_groups_of(sentence, self.batch_size):
+                    trees = self.predict_batch_instance([self._to_input_instance(s) for s in sentences_batch],
+                                                        serialize=False)
+                    sentences.extend([data.Sentence.from_json(t) for t in trees])
+                return sentences
+            else:
+                raise ValueError("List must have either sentences as str, List[str] or Sentence object.")
         else:
             raise ValueError("Input must be either string or list of strings.")
 
-    def __call__(self, sentence: Union[str, List[str]]):
-        return self.predict(sentence)
-
     @overrides
-    def predict_batch_instance(self, instances: List[allen_data.Instance]) -> List[common.JsonDict]:
+    def predict_batch_instance(self, instances: List[allen_data.Instance], serialize: bool = True) -> List[
+        common.JsonDict]:
         trees = []
         predictions = super().predict_batch_instance(instances)
         for prediction, instance in zip(predictions, instances):
-            tree_json = util.sanitize(self._predictions_as_tree(prediction, instance).serialize())
+            tree = self._predictions_as_tree(prediction, instance)
+            if serialize:
+                tree = tree.serialize()
+            tree_json = util.sanitize(tree)
             trees.append(collections.OrderedDict([
                 ("tree", tree_json),
             ]))
         return trees
 
     @overrides
-    def predict_instance(self, instance: allen_data.Instance) -> common.JsonDict:
+    def predict_instance(self, instance: allen_data.Instance, serialize: bool = True) -> common.JsonDict:
         tree = self.predict_instance_as_tree(instance)
-        tree_json = util.sanitize(tree.serialize())
+        if serialize:
+            tree = tree.serialize()
+        tree_json = util.sanitize(tree)
         result = collections.OrderedDict([
             ("tree", tree_json),
         ])
@@ -95,9 +100,10 @@ class SemanticMultitaskPredictor(predictor.Predictor):
     def predict_batch_json(self, inputs: List[common.JsonDict]) -> List[common.JsonDict]:
         trees = []
         instances = self._batch_json_to_instances(inputs)
-        predictions = super().predict_batch_json(inputs)
+        predictions = self.predict_batch_instance(instances, serialize=False)
         for prediction, instance in zip(predictions, instances):
-            tree_json = util.sanitize(self._predictions_as_tree(prediction, instance))
+            tree = self._predictions_as_tree(prediction, instance)
+            tree_json = util.sanitize(tree)
             trees.append(collections.OrderedDict([
                 ("tree", tree_json),
             ]))
@@ -117,10 +123,35 @@ class SemanticMultitaskPredictor(predictor.Predictor):
         predictions = super().predict_instance(instance)
         return self._predictions_as_tree(predictions, instance)
 
+    @overrides
+    def _json_to_instance(self, json_dict: common.JsonDict) -> allen_data.Instance:
+        sentence = json_dict["sentence"]
+        if isinstance(sentence, str):
+            tokens = [t.text for t in self._tokenizer.tokenize(json_dict["sentence"])]
+        elif isinstance(sentence, list):
+            tokens = sentence
+        else:
+            raise ValueError("Input must be either string or list of strings.")
+        tree = self._sentence_to_tree(tokens)
+        return self._dataset_reader.text_to_instance(tree)
+
+    @overrides
+    def load_line(self, line: str) -> common.JsonDict:
+        return self._to_input_json(line.replace("\n", "").strip())
+
+    @overrides
+    def dump_line(self, outputs: common.JsonDict) -> str:
+        # Check whether serialized (str) tree or token's list
+        # Serialized tree has already separators between lines
+        if type(outputs["tree"]) == str:
+            return str(outputs["tree"])
+        else:
+            return str(outputs["tree"]) + "\n"
+
     @staticmethod
     def _sentence_to_tree(sentence: List[str]):
         d = collections.OrderedDict
-        return conllu.TokenList(
+        return _TokenList(
             [d({"id": idx, "token": token}) for
              idx, token
              in enumerate(sentence)],
@@ -131,6 +162,10 @@ class SemanticMultitaskPredictor(predictor.Predictor):
     def _to_input_json(sentence: str):
         return {"sentence": sentence}
 
+    def _to_input_instance(self, sentence: data.Sentence) -> allen_data.Instance:
+        tree = _TokenList([t.__dict__ for t in sentence.tokens])
+        return self._dataset_reader.text_to_instance(tree)
+
     def _predictions_as_tree(self, predictions, instance):
         tree = instance.fields["metadata"]["input"]
         field_names = instance.fields["metadata"]["field_names"]
@@ -190,7 +225,7 @@ class SemanticMultitaskPredictor(predictor.Predictor):
         return cls(model, dataset_reader, tokenizers.SpacyTokenizer())
 
     @classmethod
-    def from_pretrained(cls, path: str, tokenizer=tokenizers.SpacyTokenizer()):
+    def from_pretrained(cls, path: str, tokenizer=tokenizers.SpacyTokenizer(), batch_size: int = 500):
         util.import_module_and_submodules("combo.commands")
         util.import_module_and_submodules("combo.models")
         util.import_module_and_submodules("combo.training")
@@ -209,4 +244,11 @@ class SemanticMultitaskPredictor(predictor.Predictor):
         model = archive.model
         dataset_reader = allen_data.DatasetReader.from_params(
             archive.config["dataset_reader"])
-        return cls(model, dataset_reader, tokenizer)
+        return cls(model, dataset_reader, tokenizer, batch_size)
+
+
+class _TokenList(conllu.TokenList):
+
+    @overrides
+    def __repr__(self):
+        return 'TokenList<' + ', '.join(token['token'] for token in self) + '>'
diff --git a/tests/fixtures/example.conllu b/tests/fixtures/example.conllu
index 039312e5ce6036083fe4992c51312149bff10547..1125392e17d71f9db09c5236e1fb27ac0968d410 100644
--- a/tests/fixtures/example.conllu
+++ b/tests/fixtures/example.conllu
@@ -3,3 +3,4 @@
 1	Verylongwordwhichmustbetruncatedbythesystemto30	easy	ADJ	adj	AdpType=Prep|Adp	2	amod	_	_
 2	Sentence	verylonglemmawhichmustbetruncatedbythesystemto30	NOUN	nom	Number=Sing	0	root	_	_
 3	.	.	PUNCT	.	_	1	punct	_	_
+
diff --git a/tests/fixtures/model.tar.gz b/tests/fixtures/model.tar.gz
new file mode 100644
index 0000000000000000000000000000000000000000..ca7cf263d1a2e7b85e4a4810a8d993712d53ea6c
Binary files /dev/null and b/tests/fixtures/model.tar.gz differ
diff --git a/tests/test_predict.py b/tests/test_predict.py
new file mode 100644
index 0000000000000000000000000000000000000000..e536561f9cfe68e5afa3d339d491230a73c63429
--- /dev/null
+++ b/tests/test_predict.py
@@ -0,0 +1,47 @@
+import os
+import pathlib
+import shutil
+import unittest
+from unittest import mock
+
+import combo.data as data
+import combo.predict as predict
+
+
+class PredictionTest(unittest.TestCase):
+    PROJECT_ROOT = (pathlib.Path(__file__).parent / "..").resolve()
+    MODULE_ROOT = PROJECT_ROOT / "combo"
+    TESTS_ROOT = PROJECT_ROOT / "tests"
+    FIXTURES_ROOT = TESTS_ROOT / "fixtures"
+
+    def setUp(self) -> None:
+        def _cleanup_archive_dir_without_logging(path: str):
+            if os.path.exists(path):
+                shutil.rmtree(path)
+
+        self.patcher = mock.patch(
+            "allennlp.models.archival._cleanup_archive_dir", _cleanup_archive_dir_without_logging
+        )
+        self.mock_cleanup_archive_dir = self.patcher.start()
+
+    def test_prediction_are_equal_given_the_same_input_in_different_form(self):
+        # given
+        raw_sentence = "Test."
+        raw_sentence_collection = ["Test."]
+        tokenized_sentence_collection = [["Test", "."]]
+        wrapped_tokenized_sentence = [data.Sentence(tokens=[
+            data.Token(id=0, token="Test"),
+            data.Token(id=1, token=".")
+        ])]
+        nlp = predict.SemanticMultitaskPredictor.from_pretrained(os.path.join(self.FIXTURES_ROOT, "model.tar.gz"))
+
+        # when
+        results = [
+            nlp(raw_sentence),
+            nlp(raw_sentence_collection)[0],
+            nlp(tokenized_sentence_collection)[0],
+            nlp(wrapped_tokenized_sentence)[0]
+        ]
+
+        # then
+        self.assertTrue(all(x == results[0] for x in results))