diff --git a/combo/common/util.py b/combo/common/util.py
index 5238f0b35c6841a06de2c272d801a16f4118335d..dfb2e7c9bd67ea79a5bb79d7aacb4fe746f3af41 100644
--- a/combo/common/util.py
+++ b/combo/common/util.py
@@ -1,8 +1,10 @@
-from typing import Union, Iterable, TypeVar, List
+from typing import Union, Iterable, TypeVar, List, Iterator, Any
+from itertools import islice
 
+import numpy
+import spacy
 import torch
 
-
 A = TypeVar("A")
 
 
@@ -23,3 +25,61 @@ def ensure_list(iterable: Iterable[A]) -> List[A]:
         return iterable
     else:
         return list(iterable)
+
+
+def lazy_groups_of(iterable: Iterable[A], group_size: int) -> Iterator[List[A]]:
+    """
+    Takes an iterable and batches the individual instances into lists of the
+    specified size. The last list may be smaller if there are instances left over.
+    """
+    iterator = iter(iterable)
+    while True:
+        s = list(islice(iterator, group_size))
+        if len(s) > 0:
+            yield s
+        else:
+            break
+
+
+def sanitize(x: Any) -> Any:
+    """
+    Sanitize turns PyTorch and Numpy types into basic Python types so they
+    can be serialized into JSON.
+    """
+    # Import here to avoid circular references
+    from combo.data.tokenizers import TokenizerToken
+
+    if isinstance(x, (str, float, int, bool)):
+        # x is already serializable
+        return x
+    elif isinstance(x, torch.Tensor):
+        # tensor needs to be converted to a list (and moved to cpu if necessary)
+        return x.cpu().tolist()
+    elif isinstance(x, numpy.ndarray):
+        # array needs to be converted to a list
+        return x.tolist()
+    elif isinstance(x, numpy.number):
+        # NumPy numbers need to be converted to Python numbers
+        return x.item()
+    elif isinstance(x, dict):
+        # Dicts need their values sanitized
+        return {key: sanitize(value) for key, value in x.items()}
+    elif isinstance(x, numpy.bool_):
+        # Numpy bool_ need to be converted to python bool.
+        return bool(x)
+    elif isinstance(x, (spacy.tokens.Token, TokenizerToken)):
+        # Tokens get sanitized to just their text.
+        return x.text
+    elif isinstance(x, (list, tuple, set)):
+        # Lists, tuples, and sets need their values sanitized
+        return [sanitize(x_i) for x_i in x]
+    elif x is None:
+        return "None"
+    elif hasattr(x, "to_json"):
+        return x.to_json()
+    else:
+        raise ValueError(
+            f"Cannot sanitize {x} of type {type(x)}. "
+            "If this is your own custom class, add a `to_json(self)` method "
+            "that returns a JSON-like object."
+        )
diff --git a/combo/data/__init__.py b/combo/data/__init__.py
index 8a64dafc5fd338069dc66102a2347c9333800c47..7708f3899ee3388adb9c7045d41da4e0d1bb5f7f 100644
--- a/combo/data/__init__.py
+++ b/combo/data/__init__.py
@@ -1,4 +1,6 @@
-from .api import Token
+from .api import (Token, Sentence, sentence2conllu, tokens2conllu, conllu2sentence)
 from .vocabulary import Vocabulary
 from .samplers import TokenCountBatchSampler
-from .instance import Instance
\ No newline at end of file
+from .instance import Instance
+from .tokenizers import (Tokenizer, TokenizerToken, CharacterTokenizer, PretrainedTransformerTokenizer,
+                         SpacyTokenizer, WhitespaceTokenizer)
diff --git a/combo/data/dataset.py b/combo/data/dataset.py
index de4805aee7ba1373e6fcb2a892605eb5a95ef553..3b16c149b7abd9e2aba6ac50a4b0b71740d4f5d7 100644
--- a/combo/data/dataset.py
+++ b/combo/data/dataset.py
@@ -1,16 +1,252 @@
+import copy
 import logging
+import pathlib
+from dataclasses import dataclass
+from typing import List, Any, Dict, Iterable, Optional, Tuple
+
+import conllu
+import torch
+from overrides import overrides
 
 from combo import data
+from combo.data import Vocabulary, fields, Instance, Token, TokenizerToken
+from combo.data.dataset_readers.dataset_reader import DatasetReader
+from combo.data.fields import Field
+from combo.data.fields.adjacency_field import AdjacencyField
+from combo.data.fields.metadata_field import MetadataField
+from combo.data.fields.sequence_label_field import SequenceLabelField
+from combo.data.fields.text_field import TextField
+from combo.data.token_indexers import TokenIndexer
+from combo.models import parser
+from combo.utils import checks, pad_sequence_to_length
 
 logger = logging.getLogger(__name__)
 
 
-class DatasetReader:
-    pass
+@dataclass(init=False, repr=False)
+class _Token(TokenizerToken):
+    __slots__ = TokenizerToken.__slots__ + ['feats_']
+
+    feats_: Optional[str]
+
+    def __init__(self, text: str = None, idx: int = None, idx_end: int = None, lemma_: str = None, pos_: str = None,
+                 tag_: str = None, dep_: str = None, ent_type_: str = None, text_id: int = None, type_id: int = None,
+                 feats_: str = None) -> None:
+        super().__init__(text, idx, idx_end, lemma_, pos_, tag_, dep_, ent_type_, text_id, type_id)
+        self.feats_ = feats_
 
 
 class UniversalDependenciesDatasetReader(DatasetReader):
-    pass
+    def __init__(
+            self,
+            token_indexers: Dict[str, TokenIndexer] = None,
+            lemma_indexers: Dict[str, TokenIndexer] = None,
+            features: List[str] = None,
+            targets: List[str] = None,
+            use_sem: bool = False,
+            **kwargs,
+    ) -> None:
+        super().__init__(**kwargs)
+        if features is None:
+            features = ["token", "char"]
+        if targets is None:
+            targets = ["head", "deprel", "upostag", "xpostag", "lemma", "feats"]
+
+        if "token" not in features and "char" not in features:
+            raise checks.ConfigurationError("There must be at least one ('char' or 'token') text-based feature!")
+
+        if "deps" in targets and not ("head" in targets and "deprel" in targets):
+            raise checks.ConfigurationError("Add 'head' and 'deprel' to targets when using 'deps'!")
+
+        intersection = set(features).intersection(set(targets))
+        if len(intersection) != 0:
+            raise checks.ConfigurationError(
+                "Features and targets cannot share elements! "
+                "Remove {} from either features or targets.".format(intersection)
+            )
+        self.use_sem = use_sem
+
+        # *.conllu readers configuration
+        fields = list(parser.DEFAULT_FIELDS)
+        fields[1] = "token"  # use 'token' instead of 'form'
+        field_parsers = parser.DEFAULT_FIELD_PARSERS
+        # Do not make it nullable
+        field_parsers.pop("xpostag", None)
+        # Ignore parsing misc
+        field_parsers.pop("misc", None)
+        if self.use_sem:
+            fields = list(fields)
+            fields.append("semrel")
+            field_parsers["semrel"] = lambda line, i: line[i]
+        self.field_parsers = field_parsers
+        self.fields = tuple(fields)
+
+        self._token_indexers = token_indexers
+        self._lemma_indexers = lemma_indexers
+        self._targets = targets
+        self._features = features
+        self.generate_labels = True
+        # Filter out not required token indexers to avoid
+        # Mismatched token keys ConfigurationError
+        for indexer_name in list(self._token_indexers.keys()):
+            if indexer_name not in self._features:
+                del self._token_indexers[indexer_name]
+
+    @overrides
+    def _read(self, file_path: str) -> Iterable[Instance]:
+        file_path = [file_path] if len(file_path.split(",")) == 0 else file_path.split(",")
+
+        for conllu_file in file_path:
+            file = pathlib.Path(conllu_file)
+            assert conllu_file and file.exists(), f"File with path '{conllu_file}' does not exists!"
+            with file.open("r", encoding="utf-8") as f:
+                for annotation in conllu.parse_incr(f, fields=self.fields, field_parsers=self.field_parsers):
+                    yield self.text_to_instance(annotation)
+
+    # why is there an error? TypeError: UniversalDependenciesDatasetReader.text_to_instance: `inputs` must be present
+    #@overrides
+    def text_to_instance(self, tree: conllu.TokenList) -> Instance:
+        fields_: Dict[str, Field] = {}
+        tree_tokens = [t for t in tree if isinstance(t["id"], int)]
+        tokens = [_Token(t["token"],
+                         pos_=t.get("upostag"),
+                         tag_=t.get("xpostag"),
+                         lemma_=t.get("lemma"),
+                         feats_=t.get("feats"))
+                  for t in tree_tokens]
+
+        # features
+        text_field = TextField(tokens, self._token_indexers)
+        fields_["sentence"] = text_field
+
+        # targets
+        if self.generate_labels:
+            for target_name in self._targets:
+                if target_name != "sent":
+                    target_values = [t[target_name] for t in tree_tokens]
+                    if target_name == "lemma":
+                        target_values = [TokenizerToken(v) for v in target_values]
+                        fields_[target_name] = TextField(target_values, self._lemma_indexers)
+                    elif target_name == "feats":
+                        target_values = self._feat_values(tree_tokens)
+                        fields_[target_name] = fields.SequenceMultiLabelField(target_values,
+                                                                              self._feats_indexer,
+                                                                              self._feats_as_tensor_wrapper,
+                                                                              text_field,
+                                                                              label_namespace="feats_labels")
+                    elif target_name == "head":
+                        target_values = [0 if v == "_" else int(v) for v in target_values]
+                        fields_[target_name] = SequenceLabelField(target_values, text_field,
+                                                                  label_namespace=target_name + "_labels")
+                    elif target_name == "deps":
+                        # Graphs require adding ROOT (AdjacencyField uses sequence length from TextField).
+                        text_field_deps = TextField([_Token("ROOT")] + copy.deepcopy(tokens), self._token_indexers)
+                        enhanced_heads: List[Tuple[int, int]] = []
+                        enhanced_deprels: List[str] = []
+                        for idx, t in enumerate(tree_tokens):
+                            t_deps = t["deps"]
+                            if t_deps and t_deps != "_":
+                                for rel, head in t_deps:
+                                    # EmoryNLP skips the first edge, if there are two edges between the same
+                                    # nodes. Thanks to that one is in a tree and another in a graph.
+                                    # This snippet follows that approach.
+                                    if enhanced_heads and enhanced_heads[-1] == (idx, head):
+                                        enhanced_heads.pop()
+                                        enhanced_deprels.pop()
+                                    enhanced_heads.append((idx, head))
+                                    enhanced_deprels.append(rel)
+                        fields_["enhanced_heads"] = AdjacencyField(
+                            indices=enhanced_heads,
+                            sequence_field=text_field_deps,
+                            label_namespace="enhanced_heads_labels",
+                            padding_value=0,
+                        )
+                        fields_["enhanced_deprels"] = AdjacencyField(
+                            indices=enhanced_heads,
+                            sequence_field=text_field_deps,
+                            labels=enhanced_deprels,
+                            # Label namespace matches regular tree parsing.
+                            label_namespace="enhanced_deprel_labels",
+                            padding_value=0,
+                        )
+                    else:
+                        fields_[target_name] = SequenceLabelField(target_values, text_field,
+                                                                  label_namespace=target_name + "_labels")
+
+        # Restore feats fields to string representation
+        # parser.serialize_field doesn't handle key without value
+        for token in tree.tokens:
+            if "feats" in token:
+                feats = token["feats"]
+                if feats:
+                    feats_values = []
+                    for k, v in feats.items():
+                        feats_values.append('='.join((k, v)) if v else k)
+                    field = "|".join(feats_values)
+                else:
+                    field = "_"
+                token["feats"] = field
+
+        # metadata
+        fields_["metadata"] = MetadataField({"input": tree,
+                                             "field_names": self.fields,
+                                             "tokens": tokens})
+
+        return Instance(fields_)
+
+    @staticmethod
+    def _feat_values(tree: List[Dict[str, Any]]):
+        features = []
+        for token in tree:
+            token_features = []
+            if token["feats"] is not None:
+                for feat, value in token["feats"].items():
+                    if feat in ["_", "__ROOT__"]:
+                        pass
+                    else:
+                        # Handle case where feature is binary (doesn't have associated value)
+                        if value:
+                            token_features.append(feat + "=" + value)
+                        else:
+                            token_features.append(feat)
+            features.append(token_features)
+        return features
+
+    @staticmethod
+    def _feats_as_tensor_wrapper(field: fields.SequenceMultiLabelField):
+        def as_tensor(padding_lengths):
+            desired_num_tokens = padding_lengths["num_tokens"]
+            assert len(field._indexed_multi_labels) > 0
+            classes_count = len(field._indexed_multi_labels[0])
+            default_value = [0.0] * classes_count
+            padded_tags = pad_sequence_to_length(field._indexed_multi_labels, desired_num_tokens,
+                                                 lambda: default_value)
+            tensor = torch.tensor(padded_tags, dtype=torch.long)
+            return tensor
+
+        return as_tensor
+
+    @staticmethod
+    def _feats_indexer(vocab: Vocabulary):
+        label_namespace = "feats_labels"
+        vocab_size = vocab.get_vocab_size(label_namespace)
+        slices = get_slices_if_not_provided(vocab)
+
+        def _m_from_n_ones_encoding(multi_label: List[str], sentence_length: int) -> List[int]:
+            one_hot_encoding = [0] * vocab_size
+            for cat, cat_indices in slices.items():
+                if cat not in ["__PAD__", "_"]:
+                    label_from_cat = [label for label in multi_label if cat == label.split("=")[0]]
+                    if label_from_cat:
+                        label_from_cat = label_from_cat[0]
+                        index = vocab.get_token_index(label_from_cat, label_namespace)
+                    else:
+                        # Get Cat=None index
+                        index = vocab.get_token_index(cat + "=None", label_namespace)
+                    one_hot_encoding[index] = 1
+            return one_hot_encoding
+
+        return _m_from_n_ones_encoding
 
 
 def get_slices_if_not_provided(vocab: data.Vocabulary):
diff --git a/combo/data/dataset_readers/dataset_reader.py b/combo/data/dataset_readers/dataset_reader.py
index 796496aa4029528e5627d3b0693c648ff3d2825d..6f0fbd16bd772d08b8dc7d724389bce4df6e38f4 100644
--- a/combo/data/dataset_readers/dataset_reader.py
+++ b/combo/data/dataset_readers/dataset_reader.py
@@ -169,7 +169,7 @@ class DatasetReader:
                 self.apply_token_indexers(instance)
             yield instance
 
-    def _read(self, file_path) -> Iterable[Instance]:
+    def _read(self, file_path: str) -> Iterable[Instance]:
         """
         Reads the instances from the given `file_path` and returns them as an
         `Iterable`.
diff --git a/combo/data/fields/adjacency_field.py b/combo/data/fields/adjacency_field.py
new file mode 100644
index 0000000000000000000000000000000000000000..492b5b9b2f28653f2f7d35e7217d9304edc3e243
--- /dev/null
+++ b/combo/data/fields/adjacency_field.py
@@ -0,0 +1,159 @@
+"""
+Adapted from AllenNLP
+https://github.com/allenai/allennlp/blob/main/allennlp/data/fields/adjacency_field.py
+"""
+from typing import Dict, List, Set, Tuple, Optional
+import logging
+import textwrap
+
+
+import torch
+
+from combo.data import Vocabulary
+from combo.data.fields import Field
+from combo.data.fields.sequence_field import SequenceField
+from combo.utils import ConfigurationError
+
+logger = logging.getLogger(__name__)
+
+
+class AdjacencyField(Field[torch.Tensor]):
+    """
+    A `AdjacencyField` defines directed adjacency relations between elements
+    in a :class:`~allennlp.data.fields.sequence_field.SequenceField`.
+    Because it's a labeling of some other field, we take that field as input here
+    and use it to determine our padding and other things.
+    This field will get converted into an array of shape (sequence_field_length, sequence_field_length),
+    where the (i, j)th array element is either a binary flag indicating there is an edge from i to j,
+    or an integer label k, indicating there is a label from i to j of type k.
+    # Parameters
+    indices : `List[Tuple[int, int]]`
+    sequence_field : `SequenceField`
+        A field containing the sequence that this `AdjacencyField` is labeling.  Most often,
+        this is a `TextField`, for tagging edge relations between tokens in a sentence.
+    labels : `List[str]`, optional, (default = `None`)
+        Optional labels for the edges of the adjacency matrix.
+    label_namespace : `str`, optional (default=`'labels'`)
+        The namespace to use for converting tag strings into integers.  We convert tag strings to
+        integers for you, and this parameter tells the `Vocabulary` object which mapping from
+        strings to integers to use (so that "O" as a tag doesn't get the same id as "O" as a word).
+    padding_value : `int`, optional (default = `-1`)
+        The value to use as padding.
+    """
+
+    __slots__ = [
+        "indices",
+        "labels",
+        "sequence_field",
+        "_label_namespace",
+        "_padding_value",
+        "_indexed_labels",
+    ]
+
+    # It is possible that users want to use this field with a namespace which uses OOV/PAD tokens.
+    # This warning will be repeated for every instantiation of this class (i.e for every data
+    # instance), spewing a lot of warnings so this class variable is used to only log a single
+    # warning per namespace.
+    _already_warned_namespaces: Set[str] = set()
+
+    def __init__(
+        self,
+        indices: List[Tuple[int, int]],
+        sequence_field: SequenceField,
+        labels: List[str] = None,
+        label_namespace: str = "labels",
+        padding_value: int = -1,
+    ) -> None:
+        self.indices = indices
+        self.labels = labels
+        self.sequence_field = sequence_field
+        self._label_namespace = label_namespace
+        self._padding_value = padding_value
+        self._indexed_labels: Optional[List[int]] = None
+
+        self._maybe_warn_for_namespace(label_namespace)
+        field_length = sequence_field.sequence_length()
+
+        if len(set(indices)) != len(indices):
+            raise ConfigurationError(f"Indices must be unique, but found {indices}")
+
+        if not all(
+            0 <= index[1] < field_length and 0 <= index[0] < field_length for index in indices
+        ):
+            raise ConfigurationError(
+                f"Label indices and sequence length "
+                f"are incompatible: {indices} and {field_length}"
+            )
+
+        if labels is not None and len(indices) != len(labels):
+            raise ConfigurationError(
+                f"Labelled indices were passed, but their lengths do not match: "
+                f" {labels}, {indices}"
+            )
+
+    def _maybe_warn_for_namespace(self, label_namespace: str) -> None:
+        if not (self._label_namespace.endswith("labels") or self._label_namespace.endswith("tags")):
+            if label_namespace not in self._already_warned_namespaces:
+                logger.warning(
+                    "Your label namespace was '%s'. We recommend you use a namespace "
+                    "ending with 'labels' or 'tags', so we don't add UNK and PAD tokens by "
+                    "default to your vocabulary.  See documentation for "
+                    "`non_padded_namespaces` parameter in Vocabulary.",
+                    self._label_namespace,
+                )
+                self._already_warned_namespaces.add(label_namespace)
+
+    def count_vocab_items(self, counter: Dict[str, Dict[str, int]]):
+        if self._indexed_labels is None and self.labels is not None:
+            for label in self.labels:
+                counter[self._label_namespace][label] += 1  # type: ignore
+
+    def index(self, vocab: Vocabulary):
+        if self.labels is not None:
+            self._indexed_labels = [
+                vocab.get_token_index(label, self._label_namespace) for label in self.labels
+            ]
+
+    def get_padding_lengths(self) -> Dict[str, int]:
+        return {"num_tokens": self.sequence_field.sequence_length()}
+
+    def as_tensor(self, padding_lengths: Dict[str, int]) -> torch.Tensor:
+        desired_num_tokens = padding_lengths["num_tokens"]
+        tensor = torch.ones(desired_num_tokens, desired_num_tokens) * self._padding_value
+        labels = self._indexed_labels or [1 for _ in range(len(self.indices))]
+
+        for index, label in zip(self.indices, labels):
+            tensor[index] = label
+        return tensor
+
+    def empty_field(self) -> "AdjacencyField":
+
+        # The empty_list here is needed for mypy
+        empty_list: List[Tuple[int, int]] = []
+        adjacency_field = AdjacencyField(
+            empty_list, self.sequence_field.empty_field(), padding_value=self._padding_value
+        )
+        return adjacency_field
+
+    def __str__(self) -> str:
+        length = self.sequence_field.sequence_length()
+        formatted_labels = "".join(
+            "\t\t" + labels + "\n" for labels in textwrap.wrap(repr(self.labels), 100)
+        )
+        formatted_indices = "".join(
+            "\t\t" + index + "\n" for index in textwrap.wrap(repr(self.indices), 100)
+        )
+        return (
+            f"AdjacencyField of length {length}\n"
+            f"\t\twith indices:\n {formatted_indices}\n"
+            f"\t\tand labels:\n {formatted_labels} \t\tin namespace: '{self._label_namespace}'."
+        )
+
+    def __len__(self):
+        return len(self.sequence_field)
+
+    def human_readable_repr(self):
+        ret = {"indices": self.indices}
+        if self.labels is not None:
+            ret["labels"] = self.labels
+        return ret
diff --git a/combo/data/fields/metadata_field.py b/combo/data/fields/metadata_field.py
new file mode 100644
index 0000000000000000000000000000000000000000..73730197b5ed28af8bb6c552a7f995ae94d58c13
--- /dev/null
+++ b/combo/data/fields/metadata_field.py
@@ -0,0 +1,69 @@
+"""
+Adapted from AllenNLP
+https://github.com/allenai/allennlp/blob/main/allennlp/data/fields/metadata_field.py
+"""
+from typing import Any, Dict, List, Mapping
+
+
+from combo.data.fields.field import DataArray, Field
+
+
+class MetadataField(Field[DataArray], Mapping[str, Any]):
+    """
+    A `MetadataField` is a `Field` that does not get converted into tensors.  It just carries
+    side information that might be needed later on, for computing some third-party metric, or
+    outputting debugging information, or whatever else you need.  We use this in the BiDAF model,
+    for instance, to keep track of question IDs and passage token offsets, so we can more easily
+    use the official evaluation script to compute metrics.
+    We don't try to do any kind of smart combination of this field for batched input - when you use
+    this `Field` in a model, you'll get a list of metadata objects, one for each instance in the
+    batch.
+    # Parameters
+    metadata : `Any`
+        Some object containing the metadata that you want to store.  It's likely that you'll want
+        this to be a dictionary, but it could be anything you want.
+    """
+
+    __slots__ = ["metadata"]
+
+    def __init__(self, metadata: Any) -> None:
+        self.metadata = metadata
+
+    def __getitem__(self, key: str) -> Any:
+        try:
+            return self.metadata[key]  # type: ignore
+        except TypeError:
+            raise TypeError("your metadata is not a dict")
+
+    def __iter__(self):
+        try:
+            return iter(self.metadata)
+        except TypeError:
+            raise TypeError("your metadata is not iterable")
+
+    def __len__(self):
+        try:
+            return len(self.metadata)
+        except TypeError:
+            raise TypeError("your metadata has no length")
+
+    def get_padding_lengths(self) -> Dict[str, int]:
+        return {}
+
+    def as_tensor(self, padding_lengths: Dict[str, int]) -> DataArray:
+
+        return self.metadata  # type: ignore
+
+    def empty_field(self) -> "MetadataField":
+        return MetadataField(None)
+
+    def batch_tensors(self, tensor_list: List[DataArray]) -> List[DataArray]:  # type: ignore
+        return tensor_list
+
+    def __str__(self) -> str:
+        return "MetadataField (print field.metadata to see specific information)."
+
+    def human_readable_repr(self):
+        if hasattr(self.metadata, "human_readable_repr"):
+            return self.metadata.human_readable_repr()
+        return self.metadata
diff --git a/combo/data/fields/sequence_label_field.py b/combo/data/fields/sequence_label_field.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d0299d95f30d2e8a45f8133ccc8bec70f9c58ec
--- /dev/null
+++ b/combo/data/fields/sequence_label_field.py
@@ -0,0 +1,151 @@
+"""
+Adapted from AllenNLP
+https://github.com/allenai/allennlp/blob/main/allennlp/data/fields/sequence_label_field.py
+"""
+
+from typing import Dict, List, Union, Set, Iterator
+import logging
+import textwrap
+
+
+import torch
+
+from combo.data import Vocabulary
+from combo.data.fields import Field
+from combo.data.fields.sequence_field import SequenceField
+from combo.utils import ConfigurationError, pad_sequence_to_length
+
+logger = logging.getLogger(__name__)
+
+
+class SequenceLabelField(Field[torch.Tensor]):
+    """
+    A `SequenceLabelField` assigns a categorical label to each element in a
+    :class:`~allennlp.data.fields.sequence_field.SequenceField`.
+    Because it's a labeling of some other field, we take that field as input here, and we use it to
+    determine our padding and other things.
+    This field will get converted into a list of integer class ids, representing the correct class
+    for each element in the sequence.
+    # Parameters
+    labels : `Union[List[str], List[int]]`
+        A sequence of categorical labels, encoded as strings or integers.  These could be POS tags like [NN,
+        JJ, ...], BIO tags like [B-PERS, I-PERS, O, O, ...], or any other categorical tag sequence. If the
+        labels are encoded as integers, they will not be indexed using a vocab.
+    sequence_field : `SequenceField`
+        A field containing the sequence that this `SequenceLabelField` is labeling.  Most often, this is a
+        `TextField`, for tagging individual tokens in a sentence.
+    label_namespace : `str`, optional (default=`'labels'`)
+        The namespace to use for converting tag strings into integers.  We convert tag strings to
+        integers for you, and this parameter tells the `Vocabulary` object which mapping from
+        strings to integers to use (so that "O" as a tag doesn't get the same id as "O" as a word).
+    """
+
+    __slots__ = [
+        "labels",
+        "sequence_field",
+        "_label_namespace",
+        "_indexed_labels",
+        "_skip_indexing",
+    ]
+
+    # It is possible that users want to use this field with a namespace which uses OOV/PAD tokens.
+    # This warning will be repeated for every instantiation of this class (i.e for every data
+    # instance), spewing a lot of warnings so this class variable is used to only log a single
+    # warning per namespace.
+    _already_warned_namespaces: Set[str] = set()
+
+    def __init__(
+        self,
+        labels: Union[List[str], List[int]],
+        sequence_field: SequenceField,
+        label_namespace: str = "labels",
+    ) -> None:
+        self.labels = labels
+        self.sequence_field = sequence_field
+        self._label_namespace = label_namespace
+        self._indexed_labels = None
+        self._maybe_warn_for_namespace(label_namespace)
+        if len(labels) != sequence_field.sequence_length():
+            raise ConfigurationError(
+                "Label length and sequence length "
+                "don't match: %d and %d" % (len(labels), sequence_field.sequence_length())
+            )
+
+        self._skip_indexing = False
+        if all(isinstance(x, int) for x in labels):
+            self._indexed_labels = labels
+            self._skip_indexing = True
+
+        elif not all(isinstance(x, str) for x in labels):
+            raise ConfigurationError(
+                "SequenceLabelFields must be passed either all "
+                "strings or all ints. Found labels {} with "
+                "types: {}.".format(labels, [type(x) for x in labels])
+            )
+
+    def _maybe_warn_for_namespace(self, label_namespace: str) -> None:
+        if not (self._label_namespace.endswith("labels") or self._label_namespace.endswith("tags")):
+            if label_namespace not in self._already_warned_namespaces:
+                logger.warning(
+                    "Your label namespace was '%s'. We recommend you use a namespace "
+                    "ending with 'labels' or 'tags', so we don't add UNK and PAD tokens by "
+                    "default to your vocabulary.  See documentation for "
+                    "`non_padded_namespaces` parameter in Vocabulary.",
+                    self._label_namespace,
+                )
+                self._already_warned_namespaces.add(label_namespace)
+
+    # Sequence methods
+    def __iter__(self) -> Iterator[Union[str, int]]:
+        return iter(self.labels)
+
+    def __getitem__(self, idx: int) -> Union[str, int]:
+        return self.labels[idx]
+
+    def __len__(self) -> int:
+        return len(self.labels)
+
+    def count_vocab_items(self, counter: Dict[str, Dict[str, int]]):
+        if self._indexed_labels is None:
+            for label in self.labels:
+                counter[self._label_namespace][label] += 1  # type: ignore
+
+    def index(self, vocab: Vocabulary):
+        if not self._skip_indexing:
+            self._indexed_labels = [
+                vocab.get_token_index(label, self._label_namespace)  # type: ignore
+                for label in self.labels
+            ]
+
+    def get_padding_lengths(self) -> Dict[str, int]:
+        return {"num_tokens": self.sequence_field.sequence_length()}
+
+    def as_tensor(self, padding_lengths: Dict[str, int]) -> torch.Tensor:
+        if self._indexed_labels is None:
+            raise ConfigurationError(
+                "You must call .index(vocabulary) on a field before calling .as_tensor()"
+            )
+        desired_num_tokens = padding_lengths["num_tokens"]
+        padded_tags = pad_sequence_to_length(self._indexed_labels, desired_num_tokens)
+        tensor = torch.LongTensor(padded_tags)
+        return tensor
+
+    def empty_field(self) -> "SequenceLabelField":
+        # The empty_list here is needed for mypy
+        empty_list: List[str] = []
+        sequence_label_field = SequenceLabelField(empty_list, self.sequence_field.empty_field())
+        sequence_label_field._indexed_labels = empty_list
+        return sequence_label_field
+
+    def __str__(self) -> str:
+        length = self.sequence_field.sequence_length()
+        formatted_labels = "".join(
+            "\t\t" + labels + "\n" for labels in textwrap.wrap(repr(self.labels), 100)
+        )
+        return (
+            f"SequenceLabelField of length {length} with "
+            f"labels:\n {formatted_labels} \t\tin namespace: '{self._label_namespace}'."
+        )
+
+    def human_readable_repr(self) -> Union[List[str], List[int]]:
+        return self.labels
\ No newline at end of file
diff --git a/combo/data/tokenizers/__init__.py b/combo/data/tokenizers/__init__.py
index 6cec93f70bd219dae01923eeb67a3604d036c9a2..04e3c6eb9e11f73721445addebe9586e7e626cc0 100644
--- a/combo/data/tokenizers/__init__.py
+++ b/combo/data/tokenizers/__init__.py
@@ -2,3 +2,4 @@ from .tokenizer import Tokenizer, TokenizerToken
 from .character_tokenizer import CharacterTokenizer
 from .pretrained_transformer_tokenizer import PretrainedTransformerTokenizer
 from .spacy_tokenizer import SpacyTokenizer
+from .whitespace_tokenizer import WhitespaceTokenizer
diff --git a/combo/data/tokenizers/whitespace_tokenizer.py b/combo/data/tokenizers/whitespace_tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfaff11220f6ebcb4cfdffefc4dd61a254322428
--- /dev/null
+++ b/combo/data/tokenizers/whitespace_tokenizer.py
@@ -0,0 +1,20 @@
+from typing import List, Dict, Any
+
+from combo.data.tokenizers.tokenizer import TokenizerToken
+
+
+class WhitespaceTokenizer(TokenizerToken):
+    """
+    A `Tokenizer` that assumes you've already done your own tokenization somehow and have
+    separated the tokens by spaces.  We just split the input string on whitespace and return the
+    resulting list.
+    Note that we use `text.split()`, which means that the amount of whitespace between the
+    tokens does not matter.  This will never result in spaces being included as tokens.
+    Registered as a `Tokenizer` with name "whitespace" and "just_spaces".
+    """
+
+    def tokenize(self, text: str) -> List[TokenizerToken]:
+        return [TokenizerToken(t) for t in text.split()]
+
+    def _to_params(self) -> Dict[str, Any]:
+        return {"type": "whitespace"}
diff --git a/combo/models/__init__.py b/combo/models/__init__.py
index 01e1c9e625ac8e423a3124fd570c2187679afdbe..66122cf32038cc71b431c46134df74cac80ac354 100644
--- a/combo/models/__init__.py
+++ b/combo/models/__init__.py
@@ -7,3 +7,4 @@ from .encoder import ComboEncoder
 from .lemma import LemmatizerModel
 from .combo_model import ComboModel
 from .morpho import MorphologicalFeatures
+from .model import Model
diff --git a/combo/models/base.py b/combo/models/base.py
index fbf4d45fb7199217964c9a1141f31a59d47f9481..45eae041affc66f1bd34e3224b893e22285d0028 100644
--- a/combo/models/base.py
+++ b/combo/models/base.py
@@ -4,26 +4,18 @@ import torch
 import torch.nn as nn
 from overrides import overrides
 
-import combo.models.utils as utils
-import combo.models.combo_nn as combo_nn
+from combo.models.combo_nn import Activation
 import combo.utils.checks as checks
-from combo import data
-
-
-class Predictor(nn.Module):
-    def forward(self,
-                x: Union[torch.Tensor, List[torch.Tensor]],
-                mask: Optional[torch.BoolTensor] = None,
-                labels: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
-                sample_weights: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]:
-        raise NotImplementedError()
+from combo.data.vocabulary import Vocabulary
+from combo.models.utils import masked_cross_entropy
+from combo.predictors.predictor import Predictor
 
 
 class Linear(nn.Linear):
     def __init__(self,
                  in_features: int,
                  out_features: int,
-                 activation: Optional[combo_nn.Activation] = None,
+                 activation: Optional[Activation] = None,
                  dropout_rate: Optional[float] = 0.0):
         super().__init__(in_features, out_features)
         self.activation = activation if activation else self.identity
@@ -93,7 +85,7 @@ class FeedForward(torch.nn.Module):
             input_dim: int,
             num_layers: int,
             hidden_dims: Union[int, List[int]],
-            activations: Union[combo_nn.Activation, List[combo_nn.Activation]],
+            activations: Union[Activation, List[Activation]],
             dropout: Union[float, List[float]] = 0.0,
     ) -> None:
 
@@ -184,18 +176,18 @@ class FeedForwardPredictor(Predictor):
         pred = pred.reshape(-1, CLASSES)
         true = true.reshape(-1)
         mask = mask.reshape(-1)
-        loss = utils.masked_cross_entropy(pred, true, mask)
+        loss = masked_cross_entropy(pred, true, mask)
         loss = loss.reshape(BATCH_SIZE, -1) * sample_weights.unsqueeze(-1)
         return loss.sum() / valid_positions
 
     @classmethod
     def from_vocab(cls,
-                   vocab: data.Vocabulary,
+                   vocab: Vocabulary,
                    vocab_namespace: str,
                    input_dim: int,
                    num_layers: int,
                    hidden_dims: List[int],
-                   activations: Union[combo_nn.Activation, List[combo_nn.Activation]],
+                   activations: Union[Activation, List[Activation]],
                    dropout: Union[float, List[float]] = 0.0,
                    ):
         if len(hidden_dims) + 1 != num_layers:
@@ -218,6 +210,8 @@ class FeedForwardPredictor(Predictor):
 """
 Adapted from AllenNLP
 """
+
+
 class TimeDistributed(torch.nn.Module):
     """
     Given an input shaped like `(batch_size, time_steps, [rest])` and a `Module` that takes
diff --git a/combo/models/combo_model.py b/combo/models/combo_model.py
index c6696333cd815652046705c3f799ff3ff1eab1d2..dc94077a12e2e4c26e3fcb636d43fed61a3ac4f8 100644
--- a/combo/models/combo_model.py
+++ b/combo/models/combo_model.py
@@ -47,7 +47,8 @@ class ComboModel(Model):
         self.scores = metrics.SemanticMetrics()
         self._partial_losses = None
 
-    @overrides
+    # Why does the @overrides give an error?
+    #@overrides
     def forward(self,
                 sentence: Dict[str, Dict[str, torch.Tensor]],
                 metadata: List[Dict[str, Any]],
diff --git a/combo/models/embeddings.py b/combo/models/embeddings.py
index 26cdca133bcd9012757b5140dade68a7bf928846..35b732ae5c4497893a28e67ecff36cbb383b3d79 100644
--- a/combo/models/embeddings.py
+++ b/combo/models/embeddings.py
@@ -212,7 +212,8 @@ class FeatsTokenEmbedder(_TorchEmbedder):
 
     @overrides
     def forward(self,
-                x: torch.Tensor) -> torch.Tensor:
+                x: torch.Tensor,
+                char_mask: Optional[torch.BoolTensor] = None) -> torch.Tensor:
         mask = x.gt(0)
         x = super().forward(x)
         return x.sum(dim=-2)/(
diff --git a/combo/models/encoder.py b/combo/models/encoder.py
index 97952a59d1a8ebc9c2ab86c626e73f3bda8b42c8..2d525fa1cc5be66a35ffc17c8a25afcc147212f7 100644
--- a/combo/models/encoder.py
+++ b/combo/models/encoder.py
@@ -1,6 +1,225 @@
-class StackedBidirectionalLstm:
-    pass
+"""
+Adapted parts from AllenNLP
+and COMBO (Author: Mateusz Klimaszewski)
+"""
+
+from typing import Optional, Tuple, List
+import torch
+import torch.nn.utils.rnn as rnn
+from overrides import overrides
+from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence
+
+from combo.modules import input_variational_dropout
+from combo.modules.augmented_lstm import AugmentedLstm
+from combo.modules.input_variational_dropout import InputVariationalDropout
+from combo.utils import ConfigurationError
+
+TensorPair = Tuple[torch.Tensor, torch.Tensor]
+
+
+class StackedBidirectionalLstm(torch.nn.Module):
+    """
+    A standard stacked Bidirectional LSTM where the LSTM layers
+    are concatenated between each layer. The only difference between
+    this and a regular bidirectional LSTM is the application of
+    variational dropout to the hidden states and outputs of each layer apart
+    from the last layer of the LSTM. Note that this will be slower, as it
+    doesn't use CUDNN.
+
+    [0]: https://arxiv.org/abs/1512.05287
+
+    # Parameters
+
+    input_size : `int`, required
+        The dimension of the inputs to the LSTM.
+    hidden_size : `int`, required
+        The dimension of the outputs of the LSTM.
+    num_layers : `int`, required
+        The number of stacked Bidirectional LSTMs to use.
+    recurrent_dropout_probability : `float`, optional (default = `0.0`)
+        The recurrent dropout probability to be used in a dropout scheme as
+        stated in [A Theoretically Grounded Application of Dropout in Recurrent
+        Neural Networks][0].
+    layer_dropout_probability : `float`, optional (default = `0.0`)
+        The layer wise dropout probability to be used in a dropout scheme as
+        stated in [A Theoretically Grounded Application of Dropout in Recurrent
+        Neural Networks][0].
+    use_highway : `bool`, optional (default = `True`)
+        Whether or not to use highway connections between layers. This effectively involves
+        reparameterising the normal output of an LSTM as::
+
+            gate = sigmoid(W_x1 * x_t + W_h * h_t)
+            output = gate * h_t  + (1 - gate) * (W_x2 * x_t)
+    """
+
+    def __init__(
+        self,
+        input_size: int,
+        hidden_size: int,
+        num_layers: int,
+        recurrent_dropout_probability: float = 0.0,
+        layer_dropout_probability: float = 0.0,
+        use_highway: bool = True,
+    ) -> None:
+        super().__init__()
+
+        # Required to be wrapped with a `PytorchSeq2SeqWrapper`.
+        self.input_size = input_size
+        self.hidden_size = hidden_size
+        self.num_layers = num_layers
+        self.bidirectional = True
+
+        layers = []
+        lstm_input_size = input_size
+        for layer_index in range(num_layers):
+
+            forward_layer = AugmentedLstm(
+                lstm_input_size,
+                hidden_size,
+                go_forward=True,
+                recurrent_dropout_probability=recurrent_dropout_probability,
+                use_highway=use_highway,
+                use_input_projection_bias=False,
+            )
+            backward_layer = AugmentedLstm(
+                lstm_input_size,
+                hidden_size,
+                go_forward=False,
+                recurrent_dropout_probability=recurrent_dropout_probability,
+                use_highway=use_highway,
+                use_input_projection_bias=False,
+            )
+
+            lstm_input_size = hidden_size * 2
+            self.add_module("forward_layer_{}".format(layer_index), forward_layer)
+            self.add_module("backward_layer_{}".format(layer_index), backward_layer)
+            layers.append([forward_layer, backward_layer])
+        self.lstm_layers = layers
+        self.layer_dropout = InputVariationalDropout(layer_dropout_probability)
+
+    def forward(
+        self, inputs: PackedSequence, initial_state: Optional[TensorPair] = None
+    ) -> Tuple[PackedSequence, TensorPair]:
+        """
+        # Parameters
+
+        inputs : `PackedSequence`, required.
+            A batch first `PackedSequence` to run the stacked LSTM over.
+        initial_state : `Tuple[torch.Tensor, torch.Tensor]`, optional, (default = `None`)
+            A tuple (state, memory) representing the initial hidden state and memory
+            of the LSTM. Each tensor has shape (num_layers, batch_size, output_dimension * 2).
+
+        # Returns
+
+        output_sequence : `PackedSequence`
+            The encoded sequence of shape (batch_size, sequence_length, hidden_size * 2)
+        final_states: `torch.Tensor`
+            The per-layer final (state, memory) states of the LSTM, each with shape
+            (num_layers * 2, batch_size, hidden_size * 2).
+        """
+        if initial_state is None:
+            hidden_states: List[Optional[TensorPair]] = [None] * len(self.lstm_layers)
+        elif initial_state[0].size()[0] != len(self.lstm_layers):
+            raise ConfigurationError(
+                "Initial states were passed to forward() but the number of "
+                "initial states does not match the number of layers."
+            )
+        else:
+            hidden_states = list(zip(initial_state[0].split(1, 0), initial_state[1].split(1, 0)))
+
+        output_sequence = inputs
+        final_h = []
+        final_c = []
+        for i, state in enumerate(hidden_states):
+            forward_layer = getattr(self, "forward_layer_{}".format(i))
+            backward_layer = getattr(self, "backward_layer_{}".format(i))
+            # The state is duplicated to mirror the Pytorch API for LSTMs.
+            forward_output, final_forward_state = forward_layer(output_sequence, state)
+            backward_output, final_backward_state = backward_layer(output_sequence, state)
+
+            forward_output, lengths = pad_packed_sequence(forward_output, batch_first=True)
+            backward_output, _ = pad_packed_sequence(backward_output, batch_first=True)
+
+            output_sequence = torch.cat([forward_output, backward_output], -1)
+            # Apply layer wise dropout on each output sequence apart from the
+            # first (input) and last
+            if i < (self.num_layers - 1):
+                output_sequence = self.layer_dropout(output_sequence)
+            output_sequence = pack_padded_sequence(output_sequence, lengths, batch_first=True)
+
+            final_h.extend([final_forward_state[0], final_backward_state[0]])
+            final_c.extend([final_forward_state[1], final_backward_state[1]])
+
+        final_h = torch.cat(final_h, dim=0)
+        final_c = torch.cat(final_c, dim=0)
+        final_state_tuple = (final_h, final_c)
+        return output_sequence, final_state_tuple
+
+
+# TODO: merge into one
+class ComboStackedBidirectionalLSTM(StackedBidirectionalLstm):
+
+    def __init__(self, input_size: int, hidden_size: int, num_layers: int, recurrent_dropout_probability: float,
+                 layer_dropout_probability: float, use_highway: bool = False):
+        super().__init__(input_size=input_size,
+                         hidden_size=hidden_size,
+                         num_layers=num_layers,
+                         recurrent_dropout_probability=recurrent_dropout_probability,
+                         layer_dropout_probability=layer_dropout_probability,
+                         use_highway=use_highway)
+
+    @overrides
+    def forward(self,
+                inputs: rnn.PackedSequence,
+                initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
+                ) -> Tuple[rnn.PackedSequence, Tuple[torch.Tensor, torch.Tensor]]:
+        """Changes when compared to stacked_bidirectional_lstm.StackedBidirectionalLstm
+        * dropout also on last layer
+        * accepts BxTxD tensor
+        * state from n-1 layer used as n layer initial state
+
+        :param inputs:
+        :param initial_state:
+        :return:
+        """
+        output_sequence = inputs
+        state_fwd = None
+        state_bwd = None
+        for i in range(self.num_layers):
+            forward_layer = getattr(self, f"forward_layer_{i}")
+            backward_layer = getattr(self, f"backward_layer_{i}")
+
+            forward_output, state_fwd = forward_layer(output_sequence, state_fwd)
+            backward_output, state_bwd = backward_layer(output_sequence, state_bwd)
+
+            forward_output, lengths = rnn.pad_packed_sequence(forward_output, batch_first=True)
+            backward_output, _ = rnn.pad_packed_sequence(backward_output, batch_first=True)
+
+            output_sequence = torch.cat([forward_output, backward_output], -1)
+
+            output_sequence = self.layer_dropout(output_sequence)
+            output_sequence = rnn.pack_padded_sequence(output_sequence, lengths, batch_first=True)
+
+        return output_sequence, (state_fwd, state_bwd)
 
 
 class ComboEncoder:
-    pass
+    """COMBO encoder (https://www.aclweb.org/anthology/K18-2004.pdf).
+
+    This implementation uses Variational Dropout on the input and then outputs of each BiLSTM layer
+    (instead of used Gaussian Dropout and Gaussian Noise).
+    """
+
+    def __init__(self,
+                 stacked_bilstm: ComboStackedBidirectionalLSTM,
+                 layer_dropout_probability: float):
+        super().__init__(stacked_bilstm, stateful=False)
+        self.layer_dropout = input_variational_dropout.InputVariationalDropout(p=layer_dropout_probability)
+
+    def forward(self,
+                inputs: torch.Tensor,
+                mask: torch.BoolTensor,
+                hidden_state: torch.Tensor = None) -> torch.Tensor:
+        x = self.layer_dropout(inputs)
+        x = super().forward(x, mask)
+        return self.layer_dropout(x)
diff --git a/combo/models/model.py b/combo/models/model.py
index 09a50ed48cfcecb2897660bca8e3219025bc76ae..9482ea546e69803d1c95815a955e259523bfd80f 100644
--- a/combo/models/model.py
+++ b/combo/models/model.py
@@ -452,11 +452,6 @@ class Model(torch.nn.Module):
         return model
 
 
-# We can't decorate `Model` with `Model.register()`, because `Model` hasn't been defined yet.  So we
-# put this down here.
-Model.register("from_archive", constructor="from_archive")(Model)
-
-
 def remove_weights_related_keys_from_params(
     params: Params, keys: List[str] = ["pretrained_file", "initializer"]
 ):
diff --git a/combo/models/parser.py b/combo/models/parser.py
index f28b07c9d79c6e48566ae8afd457182eab1576e6..42d2efb3944a3e81188d6cf8f673dcb2cfb75002 100644
--- a/combo/models/parser.py
+++ b/combo/models/parser.py
@@ -113,7 +113,7 @@ class HeadPredictionModel(base.Predictor):
         return loss.sum() / valid_positions + cycle_loss.mean(), cycle_loss.mean()
 
 
-@base.Predictor.register("combo_dependency_parsing_from_vocab", constructor="from_vocab")
+
 class DependencyRelationModel(base.Predictor):
     """Dependency relation parsing model."""
 
diff --git a/combo/modules/augmented_lstm.py b/combo/modules/augmented_lstm.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e33d869d5713f63209f8a14ca1f05fddaed18d2
--- /dev/null
+++ b/combo/modules/augmented_lstm.py
@@ -0,0 +1,291 @@
+"""
+Adapted from AllenNLP
+https://github.com/allenai/allennlp/blob/main/allennlp/modules/augmented_lstm.py
+"""
+from typing import Optional, Tuple
+
+import torch
+from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence
+
+from combo.nn.util import get_dropout_mask
+
+from combo.nn.initializers import block_orthogonal
+from combo.utils import ConfigurationError
+
+
+class AugmentedLSTMCell(torch.nn.Module):
+    """
+    `AugmentedLSTMCell` implements a AugmentedLSTM cell.
+
+    # Parameters
+
+    embed_dim : `int`
+        The number of expected features in the input.
+    lstm_dim : `int`
+        Number of features in the hidden state of the LSTM.
+    use_highway : `bool`, optional (default = `True`)
+        If `True` we append a highway network to the outputs of the LSTM.
+    use_bias : `bool`, optional (default = `True`)
+        If `True` we use a bias in our LSTM calculations, otherwise we don't.
+
+    # Attributes
+
+    input_linearity : `nn.Module`
+        Fused weight matrix which computes a linear function over the input.
+    state_linearity : `nn.Module`
+        Fused weight matrix which computes a linear function over the states.
+    """
+
+    def __init__(
+        self, embed_dim: int, lstm_dim: int, use_highway: bool = True, use_bias: bool = True
+    ):
+        super().__init__()
+        self.embed_dim = embed_dim
+        self.lstm_dim = lstm_dim
+        self.use_highway = use_highway
+        self.use_bias = use_bias
+
+        if use_highway:
+            self._highway_inp_proj_start = 5 * self.lstm_dim
+            self._highway_inp_proj_end = 6 * self.lstm_dim
+
+            # fused linearity of input to input_gate,
+            # forget_gate, memory_init, output_gate, highway_gate,
+            # and the actual highway value
+            self.input_linearity = torch.nn.Linear(
+                self.embed_dim, self._highway_inp_proj_end, bias=self.use_bias
+            )
+            # fused linearity of input to input_gate,
+            # forget_gate, memory_init, output_gate, highway_gate
+            self.state_linearity = torch.nn.Linear(
+                self.lstm_dim, self._highway_inp_proj_start, bias=True
+            )
+        else:
+            # If there's no highway layer then we have a standard
+            # LSTM. The 4 comes from fusing input, forget, memory, output
+            # gates/inputs.
+            self.input_linearity = torch.nn.Linear(
+                self.embed_dim, 4 * self.lstm_dim, bias=self.use_bias
+            )
+            self.state_linearity = torch.nn.Linear(self.lstm_dim, 4 * self.lstm_dim, bias=True)
+        self.reset_parameters()
+
+    def reset_parameters(self):
+        # Use sensible default initializations for parameters.
+        block_orthogonal(self.input_linearity.weight.data, [self.lstm_dim, self.embed_dim])
+        block_orthogonal(self.state_linearity.weight.data, [self.lstm_dim, self.lstm_dim])
+
+        self.state_linearity.bias.data.fill_(0.0)
+        # Initialize forget gate biases to 1.0 as per An Empirical
+        # Exploration of Recurrent Network Architectures, (Jozefowicz, 2015).
+        self.state_linearity.bias.data[self.lstm_dim : 2 * self.lstm_dim].fill_(1.0)
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        states=Tuple[torch.Tensor, torch.Tensor],
+        variational_dropout_mask: Optional[torch.BoolTensor] = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """
+        !!! Warning
+            DO NOT USE THIS LAYER DIRECTLY, instead use the AugmentedLSTM class
+
+        # Parameters
+
+        x : `torch.Tensor`
+            Input tensor of shape (bsize x input_dim).
+        states : `Tuple[torch.Tensor, torch.Tensor]`
+            Tuple of tensors containing
+            the hidden state and the cell state of each element in
+            the batch. Each of these tensors have a dimension of
+            (bsize x nhid). Defaults to `None`.
+
+        # Returns
+
+        `Tuple[torch.Tensor, torch.Tensor]`
+            Returned states. Shape of each state is (bsize x nhid).
+
+        """
+        hidden_state, memory_state = states
+
+        # In Pytext this was done as the last step of the cell.
+        # But the original AugmentedLSTM from AllenNLP this was done before the processing
+        if variational_dropout_mask is not None and self.training:
+            hidden_state = hidden_state * variational_dropout_mask
+
+        projected_input = self.input_linearity(x)
+        projected_state = self.state_linearity(hidden_state)
+
+        input_gate = forget_gate = memory_init = output_gate = highway_gate = None
+        if self.use_highway:
+            fused_op = projected_input[:, : 5 * self.lstm_dim] + projected_state
+            fused_chunked = torch.chunk(fused_op, 5, 1)
+            (input_gate, forget_gate, memory_init, output_gate, highway_gate) = fused_chunked
+            highway_gate = torch.sigmoid(highway_gate)
+        else:
+            fused_op = projected_input + projected_state
+            input_gate, forget_gate, memory_init, output_gate = torch.chunk(fused_op, 4, 1)
+        input_gate = torch.sigmoid(input_gate)
+        forget_gate = torch.sigmoid(forget_gate)
+        memory_init = torch.tanh(memory_init)
+        output_gate = torch.sigmoid(output_gate)
+        memory = input_gate * memory_init + forget_gate * memory_state
+        timestep_output: torch.Tensor = output_gate * torch.tanh(memory)
+
+        if self.use_highway:
+            highway_input_projection = projected_input[
+                :, self._highway_inp_proj_start : self._highway_inp_proj_end
+            ]
+            timestep_output = (
+                highway_gate * timestep_output
+                + (1 - highway_gate) * highway_input_projection  # type: ignore
+            )
+
+        return timestep_output, memory
+
+
+class AugmentedLstm(torch.nn.Module):
+    """
+    `AugmentedLstm` implements a one-layer single directional
+    AugmentedLSTM layer. AugmentedLSTM is an LSTM which optionally
+    appends an optional highway network to the output layer. Furthermore the
+    dropout controls the level of variational dropout done.
+
+    # Parameters
+
+    input_size : `int`
+        The number of expected features in the input.
+    hidden_size : `int`
+        Number of features in the hidden state of the LSTM.
+        Defaults to 32.
+    go_forward : `bool`
+        Whether to compute features left to right (forward)
+        or right to left (backward).
+    recurrent_dropout_probability : `float`
+        Variational dropout probability to use. Defaults to 0.0.
+    use_highway : `bool`
+        If `True` we append a highway network to the outputs of the LSTM.
+    use_input_projection_bias : `bool`
+        If `True` we use a bias in our LSTM calculations, otherwise we don't.
+
+    # Attributes
+
+    cell : `AugmentedLSTMCell`
+        `AugmentedLSTMCell` that is applied at every timestep.
+
+    """
+
+    def __init__(
+        self,
+        input_size: int,
+        hidden_size: int,
+        go_forward: bool = True,
+        recurrent_dropout_probability: float = 0.0,
+        use_highway: bool = True,
+        use_input_projection_bias: bool = True,
+    ):
+        super().__init__()
+
+        self.embed_dim = input_size
+        self.lstm_dim = hidden_size
+
+        self.go_forward = go_forward
+        self.use_highway = use_highway
+        self.recurrent_dropout_probability = recurrent_dropout_probability
+
+        self.cell = AugmentedLSTMCell(
+            self.embed_dim, self.lstm_dim, self.use_highway, use_input_projection_bias
+        )
+
+    def forward(
+        self, inputs: PackedSequence, states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
+    ) -> Tuple[PackedSequence, Tuple[torch.Tensor, torch.Tensor]]:
+        """
+        Warning: Would be better to use the BiAugmentedLstm class in a regular model
+
+        Given an input batch of sequential data such as word embeddings, produces a single layer unidirectional
+        AugmentedLSTM representation of the sequential input and new state tensors.
+
+        # Parameters
+
+        inputs : `PackedSequence`
+            `bsize` sequences of shape `(len, input_dim)` each, in PackedSequence format
+        states : `Tuple[torch.Tensor, torch.Tensor]`
+            Tuple of tensors containing the initial hidden state and
+            the cell state of each element in the batch. Each of these tensors have a dimension of
+            (1 x bsize x nhid). Defaults to `None`.
+
+        # Returns
+
+        `Tuple[PackedSequence, Tuple[torch.Tensor, torch.Tensor]]`
+            AugmentedLSTM representation of input and the state of the LSTM `t = seq_len`.
+            Shape of representation is (bsize x seq_len x representation_dim).
+            Shape of each state is (1 x bsize x nhid).
+
+        """
+        if not isinstance(inputs, PackedSequence):
+            raise ConfigurationError("inputs must be PackedSequence but got %s" % (type(inputs)))
+
+        sequence_tensor, batch_lengths = pad_packed_sequence(inputs, batch_first=True)
+        batch_size = sequence_tensor.size()[0]
+        total_timesteps = sequence_tensor.size()[1]
+        output_accumulator = sequence_tensor.new_zeros(batch_size, total_timesteps, self.lstm_dim)
+        if states is None:
+            full_batch_previous_memory = sequence_tensor.new_zeros(batch_size, self.lstm_dim)
+            full_batch_previous_state = sequence_tensor.data.new_zeros(batch_size, self.lstm_dim)
+        else:
+            full_batch_previous_state = states[0].squeeze(0)
+            full_batch_previous_memory = states[1].squeeze(0)
+        current_length_index = batch_size - 1 if self.go_forward else 0
+        if self.recurrent_dropout_probability > 0.0:
+            dropout_mask = get_dropout_mask(
+                self.recurrent_dropout_probability, full_batch_previous_memory
+            )
+        else:
+            dropout_mask = None
+
+        for timestep in range(total_timesteps):
+            index = timestep if self.go_forward else total_timesteps - timestep - 1
+
+            if self.go_forward:
+                while batch_lengths[current_length_index] <= index:
+                    current_length_index -= 1
+            # If we're going backwards, we are _picking up_ more indices.
+            else:
+                # First conditional: Are we already at the maximum
+                # number of elements in the batch?
+                # Second conditional: Does the next shortest
+                # sequence beyond the current batch
+                # index require computation use this timestep?
+                while (
+                    current_length_index < (len(batch_lengths) - 1)
+                    and batch_lengths[current_length_index + 1] > index
+                ):
+                    current_length_index += 1
+
+            previous_memory = full_batch_previous_memory[0 : current_length_index + 1].clone()
+            previous_state = full_batch_previous_state[0 : current_length_index + 1].clone()
+            timestep_input = sequence_tensor[0 : current_length_index + 1, index]
+            timestep_output, memory = self.cell(
+                timestep_input,
+                (previous_state, previous_memory),
+                dropout_mask[0 : current_length_index + 1] if dropout_mask is not None else None,
+            )
+            full_batch_previous_memory = full_batch_previous_memory.data.clone()
+            full_batch_previous_state = full_batch_previous_state.data.clone()
+            full_batch_previous_memory[0 : current_length_index + 1] = memory
+            full_batch_previous_state[0 : current_length_index + 1] = timestep_output
+            output_accumulator[0 : current_length_index + 1, index, :] = timestep_output
+
+        output_accumulator = pack_padded_sequence(
+            output_accumulator, batch_lengths, batch_first=True
+        )
+
+        # Mimic the pytorch API by returning state in the following shape:
+        # (num_layers * num_directions, batch_size, lstm_dim). As this
+        # LSTM cannot be stacked, the first dimension here is just 1.
+        final_state = (
+            full_batch_previous_state.unsqueeze(0),
+            full_batch_previous_memory.unsqueeze(0),
+        )
+        return output_accumulator, final_state
diff --git a/combo/modules/input_variational_dropout.py b/combo/modules/input_variational_dropout.py
new file mode 100644
index 0000000000000000000000000000000000000000..5744642be710cc136cf82eaff443b96cb4e66352
--- /dev/null
+++ b/combo/modules/input_variational_dropout.py
@@ -0,0 +1,36 @@
+"""
+Adapted from AllenNLP
+https://github.com/allenai/allennlp/blob/main/allennlp/modules/input_variational_dropout.py
+"""
+
+import torch
+
+
+class InputVariationalDropout(torch.nn.Dropout):
+    """
+    Apply the dropout technique in Gal and Ghahramani, [Dropout as a Bayesian Approximation:
+    Representing Model Uncertainty in Deep Learning](https://arxiv.org/abs/1506.02142) to a
+    3D tensor.
+    This module accepts a 3D tensor of shape `(batch_size, num_timesteps, embedding_dim)`
+    and samples a single dropout mask of shape `(batch_size, embedding_dim)` and applies
+    it to every time step.
+    """
+
+    def forward(self, input_tensor):
+
+        """
+        Apply dropout to input tensor.
+        # Parameters
+        input_tensor : `torch.FloatTensor`
+            A tensor of shape `(batch_size, num_timesteps, embedding_dim)`
+        # Returns
+        output : `torch.FloatTensor`
+            A tensor of shape `(batch_size, num_timesteps, embedding_dim)` with dropout applied.
+        """
+        ones = input_tensor.data.new_ones(input_tensor.shape[0], input_tensor.shape[-1])
+        dropout_mask = torch.nn.functional.dropout(ones, self.p, self.training, inplace=False)
+        if self.inplace:
+            input_tensor *= dropout_mask.unsqueeze(1)
+            return None
+        else:
+            return dropout_mask.unsqueeze(1) * input_tensor
diff --git a/combo/modules/seq2seq_encoder.py b/combo/modules/seq2seq_encoder.py
index 7f6d201c6b4adb177c3212bd8fb606f36784982d..71413f3c5f0539caf337e10e7c05c09e50dd5c19 100644
--- a/combo/modules/seq2seq_encoder.py
+++ b/combo/modules/seq2seq_encoder.py
@@ -1,7 +1,4 @@
-from combo.models.encoder import Encoder
-
-
-class Seq2SeqEncoder(Encoder):
+class Seq2SeqEncoder:
     """
     A `Seq2SeqEncoder` is a `Module` that takes as input a sequence of vectors and returns a
     modified sequence of vectors.  Input shape : `(batch_size, sequence_length, input_dim)`; output
diff --git a/combo/nn/initializers.py b/combo/nn/initializers.py
new file mode 100644
index 0000000000000000000000000000000000000000..fdf80a0f0aebaef9b9e857e863fbaf60b728071d
--- /dev/null
+++ b/combo/nn/initializers.py
@@ -0,0 +1,50 @@
+"""
+Adapted from AllenNLP
+https://github.com/allenai/allennlp/blob/main/allennlp/nn/initializers.py
+"""
+import itertools
+
+import torch
+from typing import List
+from combo.utils import ConfigurationError
+
+
+def block_orthogonal(tensor: torch.Tensor, split_sizes: List[int], gain: float = 1.0) -> None:
+    """
+    An initializer which allows initializing model parameters in "blocks". This is helpful
+    in the case of recurrent models which use multiple gates applied to linear projections,
+    which can be computed efficiently if they are concatenated together. However, they are
+    separate parameters which should be initialized independently.
+    # Parameters
+    tensor : `torch.Tensor`, required.
+        A tensor to initialize.
+    split_sizes : `List[int]`, required.
+        A list of length `tensor.ndim()` specifying the size of the
+        blocks along that particular dimension. E.g. `[10, 20]` would
+        result in the tensor being split into chunks of size 10 along the
+        first dimension and 20 along the second.
+    gain : `float`, optional (default = `1.0`)
+        The gain (scaling) applied to the orthogonal initialization.
+    """
+    data = tensor.data
+    sizes = list(tensor.size())
+    if any(a % b != 0 for a, b in zip(sizes, split_sizes)):
+        raise ConfigurationError(
+            "tensor dimensions must be divisible by their respective "
+            "split_sizes. Found size: {} and split_sizes: {}".format(sizes, split_sizes)
+        )
+    indexes = [list(range(0, max_size, split)) for max_size, split in zip(sizes, split_sizes)]
+    # Iterate over all possible blocks within the tensor.
+    for block_start_indices in itertools.product(*indexes):
+        # A list of tuples containing the index to start at for this block
+        # and the appropriate step size (i.e split_size[i] for dimension i).
+        index_and_step_tuples = zip(block_start_indices, split_sizes)
+        # This is a tuple of slices corresponding to:
+        # tensor[index: index + step_size, ...]. This is
+        # required because we could have an arbitrary number
+        # of dimensions. The actual slices we need are the
+        # start_index: start_index + step for each dimension in the tensor.
+        block_slice = tuple(
+            slice(start_index, start_index + step) for start_index, step in index_and_step_tuples
+        )
+        data[block_slice] = torch.nn.init.orthogonal_(tensor[block_slice].contiguous(), gain=gain)
\ No newline at end of file
diff --git a/combo/nn/util.py b/combo/nn/util.py
index 7fb485aa18f6382f5f6d9655195123088e5c98f5..69c8d017760606cfdbf4b09999d813e1932b71f7 100644
--- a/combo/nn/util.py
+++ b/combo/nn/util.py
@@ -2,7 +2,7 @@
 Adapted from AllenNLP
 https://github.com/allenai/allennlp/blob/80fb6061e568cb9d6ab5d45b661e86eb61b92c82/allennlp/nn/util.py
 """
-from typing import Union
+from typing import Union, Dict, Optional, List, Any
 
 import torch
 
@@ -159,3 +159,101 @@ def get_text_field_mask(
         return (character_tensor != padding_id).any(dim=-1)
     else:
         raise ValueError("Expected a tensor with dimension 2 or 3, found {}".format(smallest_dim))
+
+
+def get_dropout_mask(dropout_probability: float, tensor_for_masking: torch.Tensor):
+    """
+    Computes and returns an element-wise dropout mask for a given tensor, where
+    each element in the mask is dropped out with probability dropout_probability.
+    Note that the mask is NOT applied to the tensor - the tensor is passed to retain
+    the correct CUDA tensor type for the mask.
+    # Parameters
+    dropout_probability : `float`, required.
+        Probability of dropping a dimension of the input.
+    tensor_for_masking : `torch.Tensor`, required.
+    # Returns
+    `torch.FloatTensor`
+        A torch.FloatTensor consisting of the binary mask scaled by 1/ (1 - dropout_probability).
+        This scaling ensures expected values and variances of the output of applying this mask
+        and the original tensor are the same.
+    """
+    binary_mask = (torch.rand(tensor_for_masking.size()) > dropout_probability).to(
+        tensor_for_masking.device
+    )
+    # Scale mask by 1/keep_prob to preserve output statistics.
+    dropout_mask = binary_mask.float().div(1.0 - dropout_probability)
+    return dropout_mask
+
+
+def find_embedding_layer(model: torch.nn.Module) -> torch.nn.Module:
+    """
+    Takes a model (typically an AllenNLP `Model`, but this works for any `torch.nn.Module`) and
+    makes a best guess about which module is the embedding layer.  For typical AllenNLP models,
+    this often is the `TextFieldEmbedder`, but if you're using a pre-trained contextualizer, we
+    really want layer 0 of that contextualizer, not the output.  So there are a bunch of hacks in
+    here for specific pre-trained contextualizers.
+    """
+    # We'll look for a few special cases in a first pass, then fall back to just finding a
+    # TextFieldEmbedder in a second pass if we didn't find a special case.
+    from transformers.models.gpt2.modeling_gpt2 import GPT2Model
+    from transformers.models.bert.modeling_bert import BertEmbeddings
+    from transformers.models.albert.modeling_albert import AlbertEmbeddings
+    from transformers.models.roberta.modeling_roberta import RobertaEmbeddings
+
+    for module in model.modules():
+        if isinstance(module, BertEmbeddings):
+            return module.word_embeddings
+        if isinstance(module, RobertaEmbeddings):
+            return module.word_embeddings
+        if isinstance(module, AlbertEmbeddings):
+            return module.word_embeddings
+        if isinstance(module, GPT2Model):
+            return module.wte
+
+    return None
+
+    # for module in model.modules():
+    #     if isinstance(module, TextFieldEmbedder):
+    #
+    #         if isinstance(module, BasicTextFieldEmbedder):
+    #             # We'll have a check for single Embedding cases, because we can be more efficient
+    #             # in cases like this.  If this check fails, then for something like hotflip we need
+    #             # to actually run the text field embedder and construct a vector for each token.
+    #             if len(module._token_embedders) == 1:
+    #                 embedder = list(module._token_embedders.values())[0]
+    #                 if isinstance(embedder, Embedding):
+    #                     if embedder._projection is None:
+    #                         # If there's a projection inside the Embedding, then we need to return
+    #                         # the whole TextFieldEmbedder, because there's more computation that
+    #                         # needs to be done than just multiply by an embedding matrix.
+    #                         return embedder
+    #         return module
+    raise RuntimeError("No embedding module found!")
+
+
+
+def get_token_offsets_from_text_field_inputs(
+    text_field_inputs: List[Any],
+) -> Optional[torch.Tensor]:
+    """
+    Given a list of inputs to a TextFieldEmbedder, tries to find token offsets from those inputs, if
+    there are any.  You will have token offsets if you are using a mismatched token embedder; if
+    you're not, the return value from this function should be None.  This function is intended to be
+    called from a `forward_hook` attached to a `TextFieldEmbedder`, so the inputs are formatted just
+    as a list.
+    It's possible in theory that you could have multiple offsets as inputs to a single call to a
+    `TextFieldEmbedder`, but that's an extremely rare use case (I can't really imagine anyone
+    wanting to do that).  In that case, we'll only return the first one.  If you need different
+    behavior for your model, open an issue on github describing what you're doing.
+    """
+    for input_index, text_field_input in enumerate(text_field_inputs):
+        if not isinstance(text_field_input, dict):
+            continue
+        for input_value in text_field_input.values():
+            if not isinstance(input_value, dict):
+                continue
+            for embedder_arg_name, embedder_arg_value in input_value.items():
+                if embedder_arg_name == "offsets":
+                    return embedder_arg_value
+    return None
+
diff --git a/combo/predict.py b/combo/predict.py
index bdf979763e349e4524b6e2ee7f8b8729275a2273..8672ff1b2bb2a11abbbf120fb79e8a8261b6fcec 100644
--- a/combo/predict.py
+++ b/combo/predict.py
@@ -3,13 +3,253 @@ import os
 import sys
 from typing import List, Union, Dict, Any
 
-from combo import data
-from combo.data import sentence2conllu, tokens2conllu, conllu2sentence
-from combo.models.base import Predictor
+import numpy as np
+from overrides import overrides
+
+from combo import data, models
+from combo.common import util
+from combo.data import sentence2conllu, tokens2conllu, conllu2sentence, tokenizers, Instance
+from combo.data.dataset_readers.dataset_reader import DatasetReader
+from combo.data.instance import JsonDict
+from combo.predictors.predictor import Predictor
 from combo.utils import download, graph
 
 logger = logging.getLogger(__name__)
 
 
 class COMBO(Predictor):
-    pass
+
+    def __init__(self,
+                 model: models.Model,
+                 dataset_reader: DatasetReader,
+                 tokenizer: data.Tokenizer = tokenizers.WhitespaceTokenizer(),
+                 batch_size: int = 1024,
+                 line_to_conllu: bool = True) -> None:
+        super().__init__(model, dataset_reader)
+        self.batch_size = batch_size
+        self.vocab = model.vocab
+        self.dataset_reader = self._dataset_reader
+        self.dataset_reader.generate_labels = False
+        self.dataset_reader.lazy = True
+        self._tokenizer = tokenizer
+        self.without_sentence_embedding = False
+        self.line_to_conllu = line_to_conllu
+
+    def __call__(self, sentence: Union[str, List[str], List[List[str]], List[data.Sentence]]):
+        """Depending on the input uses (or ignores) tokenizer.
+        When model isn't only text-based only List[data.Sentence] is possible input.
+
+        * str - tokenizer is used
+        * List[str] - tokenizer is used for each string (treated as list of raw sentences)
+        * List[List[str]] - tokenizer isn't used (treated as list of tokenized sentences)
+        * List[data.Sentence] - tokenizer isn't used (treated as list of tokenized sentences)
+
+        :param sentence: sentence(s) representation
+        :return: Sentence or List[Sentence] depending on the input
+        """
+        try:
+            return self.predict(sentence)
+        except Exception as e:
+            logger.error(e)
+            logger.error('Exiting.')
+            sys.exit(1)
+
+    def predict(self, sentence: Union[str, List[str], List[List[str]], List[data.Sentence]]):
+        if isinstance(sentence, str):
+            return self.predict_json({"sentence": sentence})
+        elif isinstance(sentence, list):
+            if len(sentence) == 0:
+                return []
+            example = sentence[0]
+            sentences = sentence
+            if isinstance(example, str) or isinstance(example, list):
+                result = []
+                sentences = [self._to_input_json(s) for s in sentences]
+                for sentences_batch in util.lazy_groups_of(sentences, self.batch_size):
+                    sentences_batch = self.predict_batch_json(sentences_batch)
+                    result.extend(sentences_batch)
+                return result
+            elif isinstance(example, data.Sentence):
+                result = []
+                sentences = [self._to_input_instance(s) for s in sentences]
+                for sentences_batch in util.lazy_groups_of(sentences, self.batch_size):
+                    sentences_batch = self.predict_batch_instance(sentences_batch)
+                    result.extend(sentences_batch)
+                return result
+            else:
+                raise ValueError("List must have either sentences as str, List[str] or Sentence object.")
+        else:
+            raise ValueError("Input must be either string or list of strings.")
+
+    @overrides
+    def predict_batch_instance(self, instances: List[Instance]) -> List[data.Sentence]:
+        sentences = []
+        predictions = super().predict_batch_instance(instances)
+        for prediction, instance in zip(predictions, instances):
+            tree, sentence_embedding, embeddings = self._predictions_as_tree(prediction, instance)
+            sentence = conllu2sentence(tree, sentence_embedding, embeddings)
+            sentences.append(sentence)
+        return sentences
+
+    @overrides
+    def predict_batch_json(self, inputs: List[JsonDict]) -> List[data.Sentence]:
+        instances = self._batch_json_to_instances(inputs)
+        sentences = self.predict_batch_instance(instances)
+        return sentences
+
+    @overrides
+    def predict_instance(self, instance: Instance, serialize: bool = True) -> data.Sentence:
+        predictions = super().predict_instance(instance)
+        tree, sentence_embedding, embeddings = self._predictions_as_tree(predictions, instance, )
+        return conllu2sentence(tree, sentence_embedding, embeddings)
+
+    @overrides
+    def predict_json(self, inputs: JsonDict) -> data.Sentence:
+        instance = self._json_to_instance(inputs)
+        return self.predict_instance(instance)
+
+    @overrides
+    def _json_to_instance(self, json_dict: JsonDict) -> Instance:
+        sentence = json_dict["sentence"]
+        if isinstance(sentence, str):
+            tokens = [t.text for t in self._tokenizer.tokenize(json_dict["sentence"])]
+        elif isinstance(sentence, list):
+            tokens = sentence
+        else:
+            raise ValueError("Input must be either string or list of strings.")
+        return self.dataset_reader.text_to_instance(tokens2conllu(tokens))
+
+    @overrides
+    def load_line(self, line: str) -> JsonDict:
+        return self._to_input_json(line.replace("\n", "").strip())
+
+    # outputs should be api.Sentence
+    @overrides
+    def dump_line(self, outputs: Any) -> str:
+        # Check whether serialized (str) tree or token's list
+        # Serialized tree has already separators between lines
+        if self.without_sentence_embedding:
+            outputs.sentence_embedding = []
+        if self.line_to_conllu:
+            return sentence2conllu(outputs, keep_semrel=self.dataset_reader.use_sem).serialize()
+        else:
+            return outputs.to_json()
+
+    @staticmethod
+    def _to_input_json(sentence: str):
+        return {"sentence": sentence}
+
+    def _to_input_instance(self, sentence: data.Sentence) -> Instance:
+        return self.dataset_reader.text_to_instance(sentence2conllu(sentence))
+
+    def _predictions_as_tree(self, predictions: Dict[str, Any], instance: Instance):
+        tree = instance.fields["metadata"]["input"]
+        field_names = instance.fields["metadata"]["field_names"]
+        tree_tokens = [t for t in tree if isinstance(t["id"], int)]
+        embeddings = {t["id"]: {} for t in tree}
+        for field_name in field_names:
+            if field_name not in predictions:
+                continue
+            field_predictions = predictions[field_name]
+            for idx, token in enumerate(tree_tokens):
+                if field_name in {"xpostag", "upostag", "semrel", "deprel"}:
+                    value = self.vocab.get_token_from_index(field_predictions[idx], field_name + "_labels")
+                    token[field_name] = value
+                    embeddings[token["id"]][field_name] = predictions[f"{field_name}_token_embedding"][idx]
+                elif field_name == "head":
+                    token[field_name] = int(field_predictions[idx])
+                elif field_name == "deps":
+                    # Handled after every other decoding
+                    continue
+
+                elif field_name == "feats":
+                    slices = self._model.morphological_feat.slices
+                    features = []
+                    prediction = field_predictions[idx]
+                    for (cat, cat_indices), pred_idx in zip(slices.items(), prediction):
+                        if cat not in ["__PAD__", "_"]:
+                            value = self.vocab.get_token_from_index(cat_indices[pred_idx],
+                                                                    field_name + "_labels")
+                            # Exclude auxiliary values
+                            if "=None" not in value:
+                                features.append(value)
+                    if len(features) == 0:
+                        field_value = "_"
+                    else:
+                        lowercase_features = [f.lower() for f in features]
+                        arg_indices = sorted(range(len(lowercase_features)), key=lowercase_features.__getitem__)
+                        field_value = "|".join(np.array(features)[arg_indices].tolist())
+
+                    token[field_name] = field_value
+                    embeddings[token["id"]][field_name] = predictions[f"{field_name}_token_embedding"][idx]
+                elif field_name == "lemma":
+                    prediction = field_predictions[idx]
+                    word_chars = []
+                    for char_idx in prediction[1:-1]:
+                        pred_char = self.vocab.get_token_from_index(char_idx, "lemma_characters")
+
+                        if pred_char == "__END__":
+                            break
+                        elif pred_char == "__PAD__":
+                            continue
+                        elif "_" in pred_char:
+                            pred_char = "?"
+
+                        word_chars.append(pred_char)
+                    token[field_name] = "".join(word_chars)
+                else:
+                    raise NotImplementedError(f"Unknown field name {field_name}!")
+
+        if "enhanced_head" in predictions and predictions["enhanced_head"]:
+            # TODO off-by-one hotfix, refactor
+            sentence_length = len(tree_tokens)
+            h = np.array(predictions["enhanced_head"])[:sentence_length, :sentence_length]
+            h = np.concatenate((h[-1:], h[:-1]))
+            r = np.array(predictions["enhanced_deprel_prob"])[:sentence_length, :sentence_length, :]
+            r = np.concatenate((r[-1:], r[:-1]))
+
+            graph.graph_and_tree_merge(
+                tree_arc_scores=predictions["head"][:sentence_length],
+                tree_rel_scores=predictions["deprel"][:sentence_length],
+                graph_arc_scores=h,
+                graph_rel_scores=r,
+                idx2label=self.vocab.get_index_to_token_vocabulary("deprel_labels"),
+                label2idx=self.vocab.get_token_to_index_vocabulary("deprel_labels"),
+                graph_idx2label=self.vocab.get_index_to_token_vocabulary("enhanced_deprel_labels"),
+                graph_label2idx=self.vocab.get_token_to_index_vocabulary("enhanced_deprel_labels"),
+                tokens=tree_tokens
+            )
+
+            empty_tokens = graph.restore_collapse_edges(tree_tokens)
+            tree.tokens.extend(empty_tokens)
+
+        return tree, predictions["sentence_embedding"], embeddings
+
+    @classmethod
+    def with_spacy_tokenizer(cls, model: models.Model,
+                             dataset_reader: DatasetReader):
+        return cls(model, dataset_reader, tokenizers.SpacyTokenizer())
+
+    # @classmethod
+    # def from_pretrained(cls, path: str, tokenizer=tokenizers.SpacyTokenizer(),
+    #                     batch_size: int = 1024,
+    #                     cuda_device: int = -1):
+    #     util.import_module_and_submodules("combo.commands")
+    #     util.import_module_and_submodules("combo.models")
+    #     util.import_module_and_submodules("combo.training")
+    #
+    #     if os.path.exists(path):
+    #         model_path = path
+    #     else:
+    #         try:
+    #             logger.debug("Downloading model.")
+    #             model_path = download.download_file(path)
+    #         except Exception as e:
+    #             logger.error(e)
+    #             raise e
+    #
+    #     archive = models.load_archive(model_path, cuda_device=cuda_device)
+    #     model = archive.model
+    #     dataset_reader = DatasetReader.from_params(
+    #         archive.config["dataset_reader"])
+    #     return cls(model, dataset_reader, tokenizer, batch_size)
diff --git a/combo/predictors/__init__.py b/combo/predictors/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/combo/predictors/predictor.py b/combo/predictors/predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4917a7c8466d4fda3214043f4109a1703d3e639
--- /dev/null
+++ b/combo/predictors/predictor.py
@@ -0,0 +1,401 @@
+"""
+Adapted from AllenNLP
+https://github.com/allenai/allennlp/blob/main/allennlp/predictors/predictor.py
+"""
+
+from typing import List, Iterator, Dict, Tuple, Any, Type, Union
+import logging
+import json
+import re
+from contextlib import contextmanager
+from pathlib import Path
+
+import numpy
+import torch
+from torch.utils.hooks import RemovableHandle
+from torch import Tensor
+from torch import backends
+
+from combo.common.util import sanitize
+from combo.data.batch import Batch
+from combo.data.dataset_readers.dataset_reader import DatasetReader
+from combo.data.instance import JsonDict, Instance
+from combo.models.model import Model
+from combo.nn import util
+
+logger = logging.getLogger(__name__)
+
+
+class Predictor:
+    """
+    a `Predictor` is a thin wrapper around an AllenNLP model that handles JSON -> JSON predictions
+    that can be used for serving models through the web API or making predictions in bulk.
+    """
+
+    def __init__(self, model: Model, dataset_reader: DatasetReader, frozen: bool = True) -> None:
+        if frozen:
+            model.eval()
+        self._model = model
+        self._dataset_reader = dataset_reader
+        self.cuda_device = next(self._model.named_parameters())[1].get_device()
+        self._token_offsets: List[Tensor] = []
+
+    def load_line(self, line: str) -> JsonDict:
+        """
+        If your inputs are not in JSON-lines format (e.g. you have a CSV)
+        you can override this function to parse them correctly.
+        """
+        return json.loads(line)
+
+    def dump_line(self, outputs: Any) -> Any:
+        """
+        If you don't want your outputs in JSON-lines format
+        you can override this function to output them differently.
+        """
+        return json.dumps(outputs) + "\n"
+
+    def predict_json(self, inputs: JsonDict) -> Any:
+        instance = self._json_to_instance(inputs)
+        return self.predict_instance(instance)
+
+    def json_to_labeled_instances(self, inputs: JsonDict) -> List[Instance]:
+        """
+        Converts incoming json to a [`Instance`](../data/instance.md),
+        runs the model on the newly created instance, and adds labels to the
+        `Instance`s given by the model's output.
+
+        # Returns
+
+        `List[instance]`
+            A list of `Instance`'s.
+        """
+
+        instance = self._json_to_instance(inputs)
+        outputs = self._model.forward_on_instance(instance)
+        new_instances = self.predictions_to_labeled_instances(instance, outputs)
+        return new_instances
+
+    def get_gradients(self, instances: List[Instance]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
+        """
+        Gets the gradients of the loss with respect to the model inputs.
+
+        # Parameters
+
+        instances : `List[Instance]`
+
+        # Returns
+
+        `Tuple[Dict[str, Any], Dict[str, Any]]`
+            The first item is a Dict of gradient entries for each input.
+            The keys have the form  `{grad_input_1: ..., grad_input_2: ... }`
+            up to the number of inputs given. The second item is the model's output.
+
+        # Notes
+
+        Takes a `JsonDict` representing the inputs of the model and converts
+        them to [`Instances`](../data/instance.md)), sends these through
+        the model [`forward`](../models/model.md#forward) function after registering hooks on the embedding
+        layer of the model. Calls `backward` on the loss and then removes the
+        hooks.
+        """
+        # set requires_grad to true for all parameters, but save original values to
+        # restore them later
+        original_param_name_to_requires_grad_dict = {}
+        for param_name, param in self._model.named_parameters():
+            original_param_name_to_requires_grad_dict[param_name] = param.requires_grad
+            param.requires_grad = True
+
+        embedding_gradients: List[Tensor] = []
+        hooks: List[RemovableHandle] = self._register_embedding_gradient_hooks(embedding_gradients)
+
+        dataset = Batch(instances)
+        dataset.index_instances(self._model.vocab)
+        dataset_tensor_dict = util.move_to_device(dataset.as_tensor_dict(), self.cuda_device)
+        # To bypass "RuntimeError: cudnn RNN backward can only be called in training mode"
+        with backends.cudnn.flags(enabled=False):
+            outputs = self._model.make_output_human_readable(
+                self._model.forward(**dataset_tensor_dict)  # type: ignore
+            )
+
+            loss = outputs["loss"]
+            # Zero gradients.
+            # NOTE: this is actually more efficient than calling `self._model.zero_grad()`
+            # because it avoids a read op when the gradients are first updated below.
+            for p in self._model.parameters():
+                p.grad = None
+            loss.backward()
+
+        for hook in hooks:
+            hook.remove()
+
+        grad_dict = dict()
+        for idx, grad in enumerate(embedding_gradients):
+            key = "grad_input_" + str(idx + 1)
+            grad_dict[key] = grad.detach().cpu().numpy()
+
+        # restore the original requires_grad values of the parameters
+        for param_name, param in self._model.named_parameters():
+            param.requires_grad = original_param_name_to_requires_grad_dict[param_name]
+
+        return grad_dict, outputs
+
+    def get_interpretable_layer(self) -> torch.nn.Module:
+        """
+        Returns the input/embedding layer of the model.
+        If the predictor wraps around a non-AllenNLP model,
+        this function should be overridden to specify the correct input/embedding layer.
+        For the cases where the input layer _is_ an embedding layer, this should be the
+        layer 0 of the embedder.
+        """
+        try:
+            return util.find_embedding_layer(self._model)
+        except RuntimeError:
+            raise RuntimeError(
+                "If the model does not use `TextFieldEmbedder`, please override "
+                "`get_interpretable_layer` in your predictor to specify the embedding layer."
+            )
+
+    def get_interpretable_text_field_embedder(self) -> torch.nn.Module:
+        """
+        Returns the first `TextFieldEmbedder` of the model.
+        If the predictor wraps around a non-AllenNLP model,
+        this function should be overridden to specify the correct embedder.
+        """
+        try:
+            return util.find_text_field_embedder(self._model)
+        except RuntimeError:
+            raise RuntimeError(
+                "If the model does not use `TextFieldEmbedder`, please override "
+                "`get_interpretable_text_field_embedder` in your predictor to specify "
+                "the embedding layer."
+            )
+
+    def _register_embedding_gradient_hooks(self, embedding_gradients):
+        """
+        Registers a backward hook on the embedding layer of the model.  Used to save the gradients
+        of the embeddings for use in get_gradients()
+
+        When there are multiple inputs (e.g., a passage and question), the hook
+        will be called multiple times. We append all the embeddings gradients
+        to a list.
+
+        We additionally add a hook on the _forward_ pass of the model's `TextFieldEmbedder` to save
+        token offsets, if there are any.  Having token offsets means that you're using a mismatched
+        token indexer, so we need to aggregate the gradients across wordpieces in a token.  We do
+        that with a simple sum.
+        """
+
+        def hook_layers(module, grad_in, grad_out):
+            grads = grad_out[0]
+            if self._token_offsets:
+                # If you have a mismatched indexer with multiple TextFields, it's quite possible
+                # that the order we deal with the gradients is wrong.  We'll just take items from
+                # the list one at a time, and try to aggregate the gradients.  If we got the order
+                # wrong, we should crash, so you'll know about it.  If you get an error because of
+                # that, open an issue on github, and we'll see what we can do.  The intersection of
+                # multiple TextFields and mismatched indexers is pretty small (currently empty, that
+                # I know of), so we'll ignore this corner case until it's needed.
+                offsets = self._token_offsets.pop(0)
+                span_grads, span_mask = util.batched_span_select(grads.contiguous(), offsets)
+                span_mask = span_mask.unsqueeze(-1)
+                span_grads *= span_mask  # zero out paddings
+
+                span_grads_sum = span_grads.sum(2)
+                span_grads_len = span_mask.sum(2)
+                # Shape: (batch_size, num_orig_tokens, embedding_size)
+                grads = span_grads_sum / torch.clamp_min(span_grads_len, 1)
+
+                # All the places where the span length is zero, write in zeros.
+                grads[(span_grads_len == 0).expand(grads.shape)] = 0
+
+            embedding_gradients.append(grads)
+
+        def get_token_offsets(module, inputs, outputs):
+            offsets = util.get_token_offsets_from_text_field_inputs(inputs)
+            if offsets is not None:
+                self._token_offsets.append(offsets)
+
+        hooks = []
+        text_field_embedder = self.get_interpretable_text_field_embedder()
+        hooks.append(text_field_embedder.register_forward_hook(get_token_offsets))
+        embedding_layer = self.get_interpretable_layer()
+        hooks.append(embedding_layer.register_backward_hook(hook_layers))
+        return hooks
+
+    @contextmanager
+    def capture_model_internals(self, module_regex: str = ".*") -> Iterator[dict]:
+        """
+        Context manager that captures the internal-module outputs of
+        this predictor's model. The idea is that you could use it as follows:
+
+        ```
+            with predictor.capture_model_internals() as internals:
+                outputs = predictor.predict_json(inputs)
+
+            return {**outputs, "model_internals": internals}
+        ```
+        """
+        results = {}
+        hooks = []
+
+        # First we'll register hooks to add the outputs of each module to the results dict.
+        def add_output(idx: int):
+            def _add_output(mod, _, outputs):
+                results[idx] = {"name": str(mod), "output": sanitize(outputs)}
+
+            return _add_output
+
+        regex = re.compile(module_regex)
+        for idx, (name, module) in enumerate(self._model.named_modules()):
+            if regex.fullmatch(name) and module != self._model:
+                hook = module.register_forward_hook(add_output(idx))
+                hooks.append(hook)
+
+        # If you capture the return value of the context manager, you get the results dict.
+        yield results
+
+        # And then when you exit the context we remove all the hooks.
+        for hook in hooks:
+            hook.remove()
+
+    def predict_instance(self, instance: Instance) -> Any:
+        outputs = self._model.forward_on_instance(instance)
+        return sanitize(outputs)
+
+    def predictions_to_labeled_instances(
+        self, instance: Instance, outputs: Dict[str, numpy.ndarray]
+    ) -> List[Instance]:
+        """
+        This function takes a model's outputs for an Instance, and it labels that instance according
+        to the output. For example, in classification this function labels the instance according
+        to the class with the highest probability. This function is used to to compute gradients
+        of what the model predicted. The return type is a list because in some tasks there are
+        multiple predictions in the output (e.g., in NER a model predicts multiple spans). In this
+        case, each instance in the returned list of Instances contains an individual
+        entity prediction as the label.
+        """
+
+        raise RuntimeError("implement this method for model interpretations or attacks")
+
+    def _json_to_instance(self, json_dict: JsonDict) -> Instance:
+        """
+        Converts a JSON object into an [`Instance`](../data/instance.md)
+        and a `JsonDict` of information which the `Predictor` should pass through,
+        such as tokenised inputs.
+        """
+        raise NotImplementedError
+
+    def predict_batch_json(self, inputs: List[JsonDict]) -> List[Any]:
+        instances = self._batch_json_to_instances(inputs)
+        return self.predict_batch_instance(instances)
+
+    def predict_batch_instance(self, instances: List[Instance]) -> List[Any]:
+        outputs = self._model.forward_on_instances(instances)
+        return sanitize(outputs)
+
+    def _batch_json_to_instances(self, json_dicts: List[JsonDict]) -> List[Instance]:
+        """
+        Converts a list of JSON objects into a list of `Instance`s.
+        By default, this expects that a "batch" consists of a list of JSON blobs which would
+        individually be predicted by `predict_json`. In order to use this method for
+        batch prediction, `_json_to_instance` should be implemented by the subclass, or
+        if the instances have some dependency on each other, this method should be overridden
+        directly.
+        """
+        instances = []
+        for json_dict in json_dicts:
+            instances.append(self._json_to_instance(json_dict))
+        return instances
+    #
+    # @classmethod
+    # def from_path(
+    #     cls,
+    #     archive_path: Union[str, Path],
+    #     predictor_name: str = None,
+    #     cuda_device: int = -1,
+    #     dataset_reader_to_load: str = "validation",
+    #     frozen: bool = True,
+    #     import_plugins: bool = True,
+    #     overrides: Union[str, Dict[str, Any]] = "",
+    # ) -> "Predictor":
+    #     """
+    #     Instantiate a `Predictor` from an archive path.
+    #
+    #     If you need more detailed configuration options, such as overrides,
+    #     please use `from_archive`.
+    #
+    #     # Parameters
+    #
+    #     archive_path : `Union[str, Path]`
+    #         The path to the archive.
+    #     predictor_name : `str`, optional (default=`None`)
+    #         Name that the predictor is registered as, or None to use the
+    #         predictor associated with the model.
+    #     cuda_device : `int`, optional (default=`-1`)
+    #         If `cuda_device` is >= 0, the model will be loaded onto the
+    #         corresponding GPU. Otherwise it will be loaded onto the CPU.
+    #     dataset_reader_to_load : `str`, optional (default=`"validation"`)
+    #         Which dataset reader to load from the archive, either "train" or
+    #         "validation".
+    #     frozen : `bool`, optional (default=`True`)
+    #         If we should call `model.eval()` when building the predictor.
+    #     import_plugins : `bool`, optional (default=`True`)
+    #         If `True`, we attempt to import plugins before loading the predictor.
+    #         This comes with additional overhead, but means you don't need to explicitly
+    #         import the modules that your predictor depends on as long as those modules
+    #         can be found by `allennlp.common.plugins.import_plugins()`.
+    #     overrides : `Union[str, Dict[str, Any]]`, optional (default = `""`)
+    #         JSON overrides to apply to the unarchived `Params` object.
+    #
+    #     # Returns
+    #
+    #     `Predictor`
+    #         A Predictor instance.
+    #     """
+    #     if import_plugins:
+    #         plugins.import_plugins()
+    #     return Predictor.from_archive(
+    #         load_archive(archive_path, cuda_device=cuda_device, overrides=overrides),
+    #         predictor_name,
+    #         dataset_reader_to_load=dataset_reader_to_load,
+    #         frozen=frozen,
+    #     )
+    #
+    # @classmethod
+    # def from_archive(
+    #     cls,
+    #     archive: Archive,
+    #     predictor_name: str = None,
+    #     dataset_reader_to_load: str = "validation",
+    #     frozen: bool = True,
+    # ) -> "Predictor":
+    #     """
+    #     Instantiate a `Predictor` from an [`Archive`](../models/archival.md);
+    #     that is, from the result of training a model. Optionally specify which `Predictor`
+    #     subclass; otherwise, we try to find a corresponding predictor in `DEFAULT_PREDICTORS`, or if
+    #     one is not found, the base class (i.e. `Predictor`) will be used. Optionally specify
+    #     which [`DatasetReader`](../data/dataset_readers/dataset_reader.md) should be loaded;
+    #     otherwise, the validation one will be used if it exists followed by the training dataset reader.
+    #     Optionally specify if the loaded model should be frozen, meaning `model.eval()` will be called.
+    #     """
+    #     # Duplicate the config so that the config inside the archive doesn't get consumed
+    #     config = archive.config.duplicate()
+    #
+    #     if not predictor_name:
+    #         model_type = config.get("model").get("type")
+    #         model_class, _ = Model.resolve_class_name(model_type)
+    #         predictor_name = model_class.default_predictor
+    #     predictor_class: Type[Predictor] = (
+    #         Predictor.by_name(predictor_name) if predictor_name is not None else cls  # type: ignore
+    #     )
+    #
+    #     if dataset_reader_to_load == "validation":
+    #         dataset_reader = archive.validation_dataset_reader
+    #     else:
+    #         dataset_reader = archive.dataset_reader
+    #
+    #     model = archive.model
+    #     if frozen:
+    #         model.eval()
+    #
+    #     return predictor_class(model, dataset_reader)