From 60d5f10703e1aa61999269d6aada5dd10a3a3dfc Mon Sep 17 00:00:00 2001
From: Maja Jablonska <majajjablonska@gmail.com>
Date: Tue, 19 Sep 2023 20:31:37 +1000
Subject: [PATCH] Modify custom dependency injection

---
 .idea/combolightning.iml                      |   3 +
 combo/commands/train.py                       |   2 +-
 combo/common/cached_transformers.py           | 187 ++++-
 combo/config/__init__.py                      |   2 +
 combo/config/exceptions.py                    |   3 +
 combo/config/from_parameters.py               | 139 ++++
 combo/config/registry.py                      |  61 ++
 combo/data/__init__.py                        |   2 +-
 combo/data/dataset_readers/__init__.py        |   2 +-
 .../dataset_readers/{conll.py => conllu.py}   |  11 +-
 combo/data/dataset_readers/dataset_reader.py  |  26 +-
 .../text_classification_json_reader.py        |   3 +
 .../universal_dependencies_dataset_reader.py  |   4 +-
 combo/data/fields/sequence_label_field.py     |   2 +-
 .../data/fields/sequence_multilabel_field.py  |   2 +-
 combo/data/fields/text_field.py               |   3 +-
 combo/data/instance.py                        |   2 +-
 combo/data/token_indexers/__init__.py         |   1 +
 ...ed_transformer_fixed_mismatched_indexer.py |   5 +-
 .../pretrained_transformer_indexer.py         |  10 +-
 ...etrained_transformer_mismatched_indexer.py |   7 +
 .../token_indexers/single_id_token_indexer.py |   9 +-
 .../token_characters_indexer.py               |  25 +-
 .../token_const_padding_characters_indexer.py |   8 +-
 .../token_indexers/token_features_indexer.py  |  16 +-
 combo/data/token_indexers/token_indexer.py    |  10 +-
 combo/data/tokenizers/character_tokenizer.py  |   2 +
 combo/data/tokenizers/lambo_tokenizer.py      |   2 +
 .../pretrained_transformer_tokenizer.py       |   2 +
 combo/data/tokenizers/spacy_tokenizer.py      |   4 +-
 combo/data/tokenizers/tokenizer.py            |   4 +-
 combo/data/tokenizers/whitespace_tokenizer.py |   2 +
 combo/data/vocabulary.py                      | 486 ++++++++++---
 combo/example.ipynb                           |  79 +-
 combo/main.py                                 |   6 +-
 combo/models/__init__.py                      |  12 +-
 combo/models/archival.py                      |  11 +-
 combo/models/combo_model.py                   |  47 +-
 combo/models/combo_nn.py                      |  14 -
 combo/models/encoder.py                       |  57 +-
 combo/modules/__init__.py                     |   2 +
 combo/modules/augmented_lstm.py               |   8 +-
 combo/{models => modules}/dilated_cnn.py      |  11 +-
 combo/modules/encoder.py                      |   1 +
 combo/{models => modules}/graph_parser.py     |  39 +-
 combo/modules/input_variational_dropout.py    |   5 +-
 combo/{models => modules}/lemma.py            |  82 +--
 combo/{models => modules}/model.py            |  45 +-
 combo/modules/module.py                       |  49 ++
 combo/{models => modules}/morpho.py           |  66 +-
 combo/{models => modules}/parser.py           |  53 +-
 combo/modules/scalar_mix.py                   | 101 +++
 combo/modules/seq2seq_encoder.py              |  33 -
 combo/modules/seq2seq_encoders/__init__.py    |   0
 .../seq2seq_encoders/seq2seq_encoder.py       | 103 +++
 .../modules/text_field_embedders/__init__.py  |   2 +
 .../basic_text_field_embedder.py              | 117 +++
 .../text_field_embedder.py                    |  55 ++
 combo/modules/time_distributed.py             |  73 ++
 combo/modules/token_embedders/__init__.py     |   7 +
 .../character_token_embedder.py               |  53 ++
 combo/modules/token_embedders/embedding.py    | 675 ++++++++++++++++++
 .../modules/token_embedders/empty_embedder.py |  28 +
 .../pretrained_transformer_embedder.py        | 415 +++++++++++
 ...trained_transformer_mismatched_embedder.py | 183 +++++
 .../projected_words_embedder.py               |  50 ++
 .../token_embedders/token_embedder.py}        | 136 ++--
 .../transformers_words_embeddings.py          |  67 ++
 combo/nn/__init__.py                          |   3 +-
 combo/nn/activations.py                       | 100 +++
 combo/{models => nn}/base.py                  |  84 +--
 combo/nn/util.py                              | 139 +++-
 combo/predict.py                              |  14 +-
 combo/predictors/predictor.py                 |  10 +-
 tests/config/__init__.py                      |   0
 tests/config/test_configuration.py            |  31 +
 76 files changed, 3571 insertions(+), 542 deletions(-)
 create mode 100644 combo/config/exceptions.py
 create mode 100644 combo/config/from_parameters.py
 create mode 100644 combo/config/registry.py
 rename combo/data/dataset_readers/{conll.py => conllu.py} (96%)
 delete mode 100644 combo/models/combo_nn.py
 rename combo/{models => modules}/dilated_cnn.py (77%)
 rename combo/{models => modules}/graph_parser.py (89%)
 rename combo/{models => modules}/lemma.py (72%)
 rename combo/{models => modules}/model.py (93%)
 create mode 100644 combo/modules/module.py
 rename combo/{models => modules}/morpho.py (74%)
 rename combo/{models => modules}/parser.py (89%)
 create mode 100644 combo/modules/scalar_mix.py
 delete mode 100644 combo/modules/seq2seq_encoder.py
 create mode 100644 combo/modules/seq2seq_encoders/__init__.py
 create mode 100644 combo/modules/seq2seq_encoders/seq2seq_encoder.py
 create mode 100644 combo/modules/text_field_embedders/__init__.py
 create mode 100644 combo/modules/text_field_embedders/basic_text_field_embedder.py
 create mode 100644 combo/modules/text_field_embedders/text_field_embedder.py
 create mode 100644 combo/modules/time_distributed.py
 create mode 100644 combo/modules/token_embedders/__init__.py
 create mode 100644 combo/modules/token_embedders/character_token_embedder.py
 create mode 100644 combo/modules/token_embedders/embedding.py
 create mode 100644 combo/modules/token_embedders/empty_embedder.py
 create mode 100644 combo/modules/token_embedders/pretrained_transformer_embedder.py
 create mode 100644 combo/modules/token_embedders/pretrained_transformer_mismatched_embedder.py
 create mode 100644 combo/modules/token_embedders/projected_words_embedder.py
 rename combo/{models/embeddings.py => modules/token_embedders/token_embedder.py} (62%)
 create mode 100644 combo/modules/token_embedders/transformers_words_embeddings.py
 create mode 100644 combo/nn/activations.py
 rename combo/{models => nn}/base.py (73%)
 create mode 100644 tests/config/__init__.py
 create mode 100644 tests/config/test_configuration.py

diff --git a/.idea/combolightning.iml b/.idea/combolightning.iml
index 332f523..d4c73d1 100644
--- a/.idea/combolightning.iml
+++ b/.idea/combolightning.iml
@@ -5,4 +5,7 @@
     <orderEntry type="jdk" jdkName="combo" jdkType="Python SDK" />
     <orderEntry type="sourceFolder" forTests="false" />
   </component>
+  <component name="TestRunnerService">
+    <option name="PROJECT_TEST_RUNNER" value="Unittests" />
+  </component>
 </module>
\ No newline at end of file
diff --git a/combo/commands/train.py b/combo/commands/train.py
index f4bb807..79b7f72 100644
--- a/combo/commands/train.py
+++ b/combo/commands/train.py
@@ -4,7 +4,7 @@ from pytorch_lightning import Trainer
 class FinetuningTrainModel(Trainer):
     """
     Class made only for finetuning,
-    the only difference is saving vocab from concatenated
+    the only difference is saving vocabulary from concatenated
     (archive and current) datasets
     """
     pass
\ No newline at end of file
diff --git a/combo/common/cached_transformers.py b/combo/common/cached_transformers.py
index 627879c..24db221 100644
--- a/combo/common/cached_transformers.py
+++ b/combo/common/cached_transformers.py
@@ -3,15 +3,198 @@ Adapted from AllenNLP
 https://github.com/allenai/allennlp/blob/main/allennlp/common/cached_transformers.py
 """
 import logging
+import re
+import warnings
+
 import transformers
-from typing import Dict, Tuple
+from typing import Dict, Optional, Union, Tuple, NamedTuple
 
-from combo.common.util import hash_object
+from transformers import AutoModel, AutoConfig
 
+from combo.common.util import hash_object
+from combo.utils import ConfigurationError, cast
 
 logger = logging.getLogger(__name__)
 
+
+class TransformerSpec(NamedTuple):
+    model_name: str
+    override_weights_file: Optional[str] = None
+    override_weights_strip_prefix: Optional[str] = None
+    reinit_modules: Optional[Union[int, Tuple[int, ...], Tuple[str, ...]]] = None
+
+
 _tokenizer_cache: Dict[Tuple[str, str], transformers.PreTrainedTokenizer] = {}
+_model_cache: Dict[TransformerSpec, transformers.PreTrainedModel] = {}
+
+def get(
+    model_name: str,
+    make_copy: bool,
+    override_weights_file: Optional[str] = None,
+    override_weights_strip_prefix: Optional[str] = None,
+    reinit_modules: Optional[Union[int, Tuple[int, ...], Tuple[str, ...]]] = None,
+    load_weights: bool = True,
+    **kwargs,
+) -> transformers.PreTrainedModel:
+    """
+    Returns a transformer model from the cache.
+
+    # Parameters
+
+    model_name : `str`
+        The name of the transformer, for example `"bert-base-cased"`
+    make_copy : `bool`
+        If this is `True`, return a copy of the model instead of the cached model itself. If you want to modify the
+        parameters of the model, set this to `True`. If you want only part of the model, set this to `False`, but
+        make sure to `copy.deepcopy()` the bits you are keeping.
+    override_weights_file : `str`, optional (default = `None`)
+        If set, this specifies a file from which to load alternate weights that override the
+        weights from huggingface. The file is expected to contain a PyTorch `state_dict`, created
+        with `torch.save()`.
+    override_weights_strip_prefix : `str`, optional (default = `None`)
+        If set, strip the given prefix from the state dict when loading it.
+    reinit_modules: `Optional[Union[int, Tuple[int, ...], Tuple[str, ...]]]`, optional (default = `None`)
+        If this is an integer, the last `reinit_modules` layers of the transformer will be
+        re-initialized. If this is a tuple of integers, the layers indexed by `reinit_modules` will
+        be re-initialized. Note, because the module structure of the transformer `model_name` can
+        differ, we cannot guarantee that providing an integer or tuple of integers will work. If
+        this fails, you can instead provide a tuple of strings, which will be treated as regexes and
+        any module with a name matching the regex will be re-initialized. Re-initializing the last
+        few layers of a pretrained transformer can reduce the instability of fine-tuning on small
+        datasets and may improve performance (https://arxiv.org/abs/2006.05987v3). Has no effect
+        if `load_weights` is `False` or `override_weights_file` is not `None`.
+    load_weights : `bool`, optional (default = `True`)
+        If set to `False`, no weights will be loaded. This is helpful when you only
+        want to initialize the architecture, like when you've already fine-tuned a model
+        and are going to load the weights from a state dict elsewhere.
+    """
+    global _model_cache
+    spec = TransformerSpec(
+        model_name,
+        override_weights_file,
+        override_weights_strip_prefix,
+        reinit_modules,
+    )
+    transformer = _model_cache.get(spec, None)
+    if transformer is None:
+        if not load_weights:
+            if override_weights_file is not None:
+                warnings.warn(
+                    "You specified an 'override_weights_file' in allennlp.common.cached_transformers.get(), "
+                    "but 'load_weights' is set to False, so 'override_weights_file' will be ignored.",
+                    UserWarning,
+                )
+            if reinit_modules is not None:
+                warnings.warn(
+                    "You specified 'reinit_modules' in allennlp.common.cached_transformers.get(), "
+                    "but 'load_weights' is set to False, so 'reinit_modules' will be ignored.",
+                    UserWarning,
+                )
+            transformer = AutoModel.from_config(
+                AutoConfig.from_pretrained(
+                    model_name,
+                    **kwargs,
+                )
+            )
+        elif override_weights_file is not None:
+            if reinit_modules is not None:
+                warnings.warn(
+                    "You specified 'reinit_modules' in allennlp.common.cached_transformers.get(), "
+                    "but 'override_weights_file' is not None, so 'reinit_modules' will be ignored.",
+                    UserWarning,
+                )
+            import torch
+            from combo.utils.file_utils import cached_path
+
+            override_weights_file = cached_path(override_weights_file)
+            override_weights = torch.load(override_weights_file)
+            if override_weights_strip_prefix is not None:
+
+                def strip_prefix(s):
+                    if s.startswith(override_weights_strip_prefix):
+                        return s[len(override_weights_strip_prefix) :]
+                    else:
+                        return s
+
+                valid_keys = {
+                    k
+                    for k in override_weights.keys()
+                    if k.startswith(override_weights_strip_prefix)
+                }
+                if len(valid_keys) > 0:
+                    logger.info(
+                        "Loading %d tensors from %s", len(valid_keys), override_weights_file
+                    )
+                else:
+                    raise ValueError(
+                        f"Specified prefix of '{override_weights_strip_prefix}' means no tensors "
+                        f"will be loaded from {override_weights_file}."
+                    )
+                override_weights = {strip_prefix(k): override_weights[k] for k in valid_keys}
+
+            transformer = AutoModel.from_config(
+                AutoConfig.from_pretrained(
+                    model_name,
+                    **kwargs,
+                )
+            )
+            # When DistributedDataParallel or DataParallel is used, the state dict of the
+            # DistributedDataParallel/DataParallel wrapper prepends "module." to all parameters
+            # of the actual model, since the actual model is stored within the module field.
+            # This accounts for if a pretained model was saved without removing the
+            # DistributedDataParallel/DataParallel wrapper.
+            if hasattr(transformer, "module"):
+                transformer.module.load_state_dict(override_weights)
+            else:
+                transformer.load_state_dict(override_weights)
+        elif reinit_modules is not None:
+            transformer = AutoModel.from_pretrained(
+                model_name,
+                **kwargs,
+            )
+            num_layers = transformer.config.num_hidden_layers
+            if isinstance(reinit_modules, int):
+                reinit_modules = tuple(range(num_layers - reinit_modules, num_layers))
+            if all(isinstance(x, int) for x in reinit_modules):
+                # This type cast is neccessary to avoid a mypy error.
+                reinit_modules = cast(Tuple[int], reinit_modules)
+                if any(layer_idx < 0 or layer_idx > num_layers for layer_idx in reinit_modules):
+                    raise ValueError(
+                        f"A layer index in reinit_modules ({reinit_modules}) is invalid."
+                        f" Must be between 0 and the maximum layer index ({num_layers - 1}.)"
+                    )
+                # Some transformer models organize their modules differently, so if this fails,
+                # raise an error with a helpful message.
+                try:
+                    for layer_idx in reinit_modules:
+                        transformer.encoder.layer[layer_idx].apply(transformer._init_weights)
+                except AttributeError:
+                    raise ConfigurationError(
+                        f"Unable to re-initialize the layers of transformer model"
+                        f" {model_name} using layer indices. Please provide a tuple of"
+                        " strings corresponding to the names of the layers to re-initialize."
+                    )
+            elif all(isinstance(x, str) for x in reinit_modules):
+                for regex in reinit_modules:
+                    for name, module in transformer.named_modules():
+                        if re.search(str(regex), name):
+                            module.apply(transformer._init_weights)
+            else:
+                raise ValueError(
+                    "reinit_modules must be either an integer, a tuple of strings, or a tuple of integers."
+                )
+        else:
+            transformer = AutoModel.from_pretrained(
+                model_name,
+                **kwargs,
+            )
+        _model_cache[spec] = transformer
+    if make_copy:
+        import copy
+
+        return copy.deepcopy(transformer)
+    else:
+        return transformer
 
 def get_tokenizer(model_name: str, **kwargs) -> transformers.PreTrainedTokenizer:
 
diff --git a/combo/config/__init__.py b/combo/config/__init__.py
index e69de29..778e873 100644
--- a/combo/config/__init__.py
+++ b/combo/config/__init__.py
@@ -0,0 +1,2 @@
+from .from_parameters import FromParameters
+from .registry import Registry
diff --git a/combo/config/exceptions.py b/combo/config/exceptions.py
new file mode 100644
index 0000000..4e0222f
--- /dev/null
+++ b/combo/config/exceptions.py
@@ -0,0 +1,3 @@
+class RegistryException(Exception):
+    def __init__(self, message):
+        super().__init__(message)
diff --git a/combo/config/from_parameters.py b/combo/config/from_parameters.py
new file mode 100644
index 0000000..7d7d507
--- /dev/null
+++ b/combo/config/from_parameters.py
@@ -0,0 +1,139 @@
+import inspect
+import warnings
+from typing import Any, Dict, List, Optional, Type, Union
+import typing
+
+import torch
+
+from combo.common.params import Params
+from combo.config.exceptions import RegistryException
+from combo.config.registry import (Registry)
+from combo.nn import Activation
+from combo.utils import ConfigurationError
+
+
+def resolve_optional(type_annotation):
+    arg1, arg2 = typing.get_args(type_annotation)
+    # Optional
+    if (arg1 is None and arg2 is not None) or (arg1 is not None and arg2 is None):
+        type_annotation = arg1 or arg2
+    return type_annotation
+
+
+def resolve_union(type_annotation, v):
+    arg1, arg2 = typing.get_args(type_annotation)
+    if arg1 is type(v) or typing.get_origin(arg1) is type(v):
+        type_annotation = arg1
+    else:
+        type_annotation = arg2
+    return type_annotation
+
+
+def _resolve(type: Type[object],
+             values: typing.Union[Dict[str, Any], str],
+             pass_to_subclasses: Dict[str, Any],
+             default_name: str = 'base'):
+    if isinstance(values, Params):
+        values = Params.as_dict()
+    if isinstance(values, dict):
+        try:
+            if 'type' in values.keys():
+                return Registry.resolve(
+                    type,
+                    values['type']
+                ).from_parameters({**values, **pass_to_subclasses}, pass_to_subclasses)
+            else:
+                return Registry.get_default(type).from_parameters({**values, **pass_to_subclasses}, pass_to_subclasses)
+        except RegistryException:
+            if type is torch.nn.Module:
+                return torch.nn.Module(**values)
+            if type.__base__ is not object:
+                return _resolve(type.__base__, values, pass_to_subclasses, default_name)
+            warnings.warn(f'No classes of type {type} (values: {values}) in Registry!')
+            return values
+    elif type is Activation and isinstance(values, str):
+        try:
+            return Registry.resolve(type, values)()
+        except RegistryException:
+            warnings.warn(f'No Activation of class name {values} in Registry!')
+            return values
+    else:
+        return values
+
+
+def _init_from_list(init_func_type_annotation: Type,
+                    value: List[Any],
+                    pass_to_subclasses: Dict[str, Any] = None,
+                    default_name: str = 'base') -> List[Any]:
+    # Check the typing of the arguments
+    list_argument_anotation = typing.get_args(init_func_type_annotation)[0]
+    if isinstance(value, Params):
+        value = Params.as_dict()
+    if typing.get_origin(list_argument_anotation) is list:
+        return [_init_from_list(list_argument_anotation, v, pass_to_subclasses, default_name) for v in value]
+    elif typing.get_origin(list_argument_anotation) is dict:
+        return [_init_from_dict(list_argument_anotation, v, pass_to_subclasses, default_name) for v in value]
+    else:
+        # A single object:
+        return [_resolve(list_argument_anotation, v, pass_to_subclasses, default_name) for v in value]
+
+
+def _init_from_dict(init_func_type_annotation: Type,
+                    value: Dict[Any, Any],
+                    pass_to_subclasses: Dict[str, Any] = None,
+                    default_name: str = 'base') -> Dict[Any, Any]:
+    key_annotation, value_annotation = typing.get_args(init_func_type_annotation)
+    if isinstance(value, Params):
+        value = Params.as_dict()
+    if typing.get_origin(value_annotation) is list:
+        return {k: _init_from_list(value_annotation, v, pass_to_subclasses, default_name) for k, v in value.items()}
+    elif typing.get_origin(value_annotation) is dict:
+        return {k: _init_from_dict(value_annotation, v, pass_to_subclasses, default_name) for k, v in value.items()}
+    else:
+        return {k: _resolve(value_annotation, v, pass_to_subclasses, default_name) for k, v in value.items()}
+
+
+class FromParameters:
+    @classmethod
+    def from_parameters(cls,
+                        parameters: Dict[str, Any],
+                        pass_to_subclasses: Dict[str, Any] = None,
+                        default_name: str = 'base'):
+        parameters_to_pass = {}
+        pass_to_subclasses = pass_to_subclasses or {}
+        arguments = inspect.signature(cls.__init__).parameters
+        argument_names = list(arguments.keys())
+        for k, v in parameters.items():
+            if isinstance(v, Params):
+                v = v.as_dict()
+            if k not in argument_names:
+                # Only initialize the configuration items that have names matching with the
+                # init arguments functions
+                continue
+            # Check the type
+            type_annotation = arguments[k].annotation
+
+            if typing.get_origin(type_annotation) is Union:
+                type_annotation = resolve_union(resolve_optional(type_annotation), v)
+
+            # Check if list
+            if typing.get_origin(type_annotation) is list:
+                parameters_to_pass[k] = _init_from_list(type_annotation,
+                                                        v, pass_to_subclasses, default_name)
+            elif typing.get_origin(type_annotation) is dict:
+                parameters_to_pass[k] = _init_from_dict(type_annotation,
+                                                        v, pass_to_subclasses, default_name)
+            elif typing.get_origin(type_annotation) is None and type_annotation.__base__ is not object:
+                try:
+                    parameters_to_pass[k] = _resolve(type_annotation,
+                                                     v, pass_to_subclasses, default_name)
+                except ConfigurationError:
+                    warnings.warn(f'An object of type {type_annotation} is not in the registry!')
+                    parameters_to_pass[k] = v
+            else:
+                parameters_to_pass[k] = v
+
+        return cls(**parameters_to_pass)
+
+    def _to_params(self) -> Dict[str, str]:
+        return {}
diff --git a/combo/config/registry.py b/combo/config/registry.py
new file mode 100644
index 0000000..c7bfce8
--- /dev/null
+++ b/combo/config/registry.py
@@ -0,0 +1,61 @@
+from collections import defaultdict
+from typing import Any, Optional, Type, Union, Dict, List
+
+from combo.config.exceptions import RegistryException
+
+
+class Registry:
+    __classes = defaultdict(dict)
+    __defaults: Dict[Type[object], Type[object]] = {}
+
+    @classmethod
+    def classes(cls) -> Dict[Type[object], Dict[str, Type]]:
+        return cls.__classes
+
+    @classmethod
+    def defaults(cls) -> Dict[Type[object], Type[object]]:
+        return cls.__defaults
+
+    @classmethod
+    def register_base_class(cls, registry_name: str, default: bool = False):
+        def decorator(clz):
+            nonlocal registry_name, default
+            # if category not in REGISTRY_CLASSES:
+            #     warnings.warn(f'Category {str(category)} is not in REGISTRY_CLASSES')
+            cls.__classes[clz][registry_name.lower()] = clz
+            if default:
+                cls.__defaults[clz] = clz
+            return clz
+
+        return decorator
+
+    @classmethod
+    def register(cls, category: Type[object], registry_name: str, default: bool = False):
+        def decorator(clz):
+            nonlocal category, registry_name, default
+            # if category not in REGISTRY_CLASSES:
+            #     warnings.warn(f'Category {str(category)} is not in REGISTRY_CLASSES')
+            cls.__classes[category][registry_name.lower()] = clz
+            if default:
+                cls.__defaults[category] = clz
+            return clz
+
+        return decorator
+
+    @classmethod
+    def resolve(cls, category: Union[Type[object], str], class_name: str) -> Optional[Type]:
+        try:
+            return cls.__classes[category][class_name.lower()]
+        except KeyError:
+            try:
+                return cls.__defaults[category]
+            except KeyError:
+                raise RegistryException(f'No {str(category)} with registered name {class_name.lower()}, no default value either')
+
+    @classmethod
+    def get_default(cls, category: Type[object]) -> Optional[Type]:
+        try:
+            return cls.__defaults[category]
+        except KeyError:
+            raise RegistryException(
+                f'No default value for {str(category)}')
diff --git a/combo/data/__init__.py b/combo/data/__init__.py
index f75b7d8..3331bb6 100644
--- a/combo/data/__init__.py
+++ b/combo/data/__init__.py
@@ -1,4 +1,3 @@
-from .api import (Token, Sentence, sentence2conllu, tokens2conllu, conllu2sentence)
 from .vocabulary import Vocabulary
 from .samplers import TokenCountBatchSampler
 from .instance import Instance
@@ -7,3 +6,4 @@ from .tokenizers import (Tokenizer, Token, CharacterTokenizer, PretrainedTransfo
                          SpacyTokenizer, WhitespaceTokenizer, LamboTokenizer)
 from .dataset_readers import (ConllDatasetReader, DatasetReader,
                               TextClassificationJSONReader, UniversalDependenciesDatasetReader)
+from .api import (Sentence, tokens2conllu, conllu2sentence, sentence2conllu)
diff --git a/combo/data/dataset_readers/__init__.py b/combo/data/dataset_readers/__init__.py
index 2ec1049..b8088ab 100644
--- a/combo/data/dataset_readers/__init__.py
+++ b/combo/data/dataset_readers/__init__.py
@@ -1,4 +1,4 @@
 from .dataset_reader import DatasetReader
 from .text_classification_json_reader import TextClassificationJSONReader
 from .universal_dependencies_dataset_reader import UniversalDependenciesDatasetReader
-from .conll import ConllDatasetReader
+from .conllu import ConllDatasetReader
diff --git a/combo/data/dataset_readers/conll.py b/combo/data/dataset_readers/conllu.py
similarity index 96%
rename from combo/data/dataset_readers/conll.py
rename to combo/data/dataset_readers/conllu.py
index d8c2c5d..18b3eaa 100644
--- a/combo/data/dataset_readers/conll.py
+++ b/combo/data/dataset_readers/conllu.py
@@ -7,6 +7,8 @@ import itertools
 import logging
 from typing import Dict, List, Optional, Sequence, Iterable
 
+from overrides import overrides
+
 from combo.data.token_indexers.single_id_token_indexer import SingleIdTokenIndexer
 from combo.data.token_indexers.token_indexer import TokenIndexer, Token
 from combo.utils import ConfigurationError
@@ -14,6 +16,8 @@ from .dataset_reader import DatasetReader
 from .dataset_utils.span_utils import to_bioul
 from .. import Instance
 from ..fields import MetadataField, TextField, Field, SequenceLabelField
+from ...config import Registry
+from ...config.from_parameters import FromParameters
 from ...utils.file_utils import cached_path
 
 logger = logging.getLogger(__name__)
@@ -32,7 +36,8 @@ def _is_divider(line: str) -> bool:
 
 
 # TODO: maybe one should note whether the format is IOB1 or IOB2 in the processed dataset?
-class ConllDatasetReader(DatasetReader):
+@Registry.register(DatasetReader, 'conll2003')
+class ConllDatasetReader(DatasetReader, FromParameters):
     """
     Reads instances from a pretokenised file where each line is in the following format:
     ```
@@ -93,7 +98,7 @@ class ConllDatasetReader(DatasetReader):
 
         super().__init__(**kwargs)
 
-        self._token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()}
+        self.token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()}
         if tag_label is not None and tag_label not in self._VALID_LABELS:
             raise ConfigurationError("unknown tag label type: {}".format(tag_label))
         for label in feature_labels:
@@ -219,4 +224,4 @@ class ConllDatasetReader(DatasetReader):
         return self
 
     def apply_token_indexers(self, instance: Instance) -> None:
-        instance.fields["tokens"]._token_indexers = self._token_indexers  # type: ignore
+        instance.fields["tokens"]._token_indexers = self.token_indexers  # type: ignore
diff --git a/combo/data/dataset_readers/dataset_reader.py b/combo/data/dataset_readers/dataset_reader.py
index 092c7a0..cad9f4b 100644
--- a/combo/data/dataset_readers/dataset_reader.py
+++ b/combo/data/dataset_readers/dataset_reader.py
@@ -9,6 +9,7 @@ from typing import Iterable, Iterator, Optional, Union, TypeVar, Dict, List
 from overrides import overrides
 from torch.utils.data import IterableDataset
 
+from combo.config import FromParameters
 from combo.data import SpacyTokenizer, SingleIdTokenIndexer
 from combo.data.instance import Instance
 from combo.data.tokenizers import Tokenizer
@@ -21,7 +22,7 @@ PathOrStr = Union[PathLike, str]
 DatasetReaderInput = Union[PathOrStr, List[PathOrStr], Dict[str, PathOrStr]]
 
 
-class DatasetReader(IterableDataset):
+class DatasetReader(IterableDataset, FromParameters):
     """
     A `DatasetReader` knows how to turn a file containing a dataset into a collection
     of `Instance`s.
@@ -50,6 +51,10 @@ class DatasetReader(IterableDataset):
     def token_indexers(self) -> Optional[Dict[str, TokenIndexer]]:
         return self.__token_indexers
 
+    @token_indexers.setter
+    def token_indexers(self, token_indexers: Dict[str, TokenIndexer]):
+        self.__token_indexers = token_indexers
+
     @overrides
     def __getitem__(self, item, **kwargs) -> Instance:
         raise NotImplementedError
@@ -69,3 +74,22 @@ class DatasetReader(IterableDataset):
 
     def apply_token_indexers(self, instance: Instance) -> None:
         pass
+
+    def text_to_instance(self, *inputs) -> Instance:
+        """
+        Does whatever tokenization or processing is necessary to go from textual input to an
+        `Instance`.  The primary intended use for this is with a
+        :class:`~allennlp.predictors.predictor.Predictor`, which gets text input as a JSON
+        object and needs to process it to be input to a model.
+
+        The intent here is to share code between :func:`_read` and what happens at
+        model serving time, or any other time you want to make a prediction from new data.  We need
+        to process the data in the same way it was done at training time.  Allowing the
+        `DatasetReader` to process new text lets us accomplish this, as we can just call
+        `DatasetReader.text_to_instance` when serving predictions.
+
+        The input type here is rather vaguely specified, unfortunately.  The `Predictor` will
+        have to make some assumptions about the kind of `DatasetReader` that it's using, in order
+        to pass it the right information.
+        """
+        raise NotImplementedError
diff --git a/combo/data/dataset_readers/text_classification_json_reader.py b/combo/data/dataset_readers/text_classification_json_reader.py
index 95455de..781d310 100644
--- a/combo/data/dataset_readers/text_classification_json_reader.py
+++ b/combo/data/dataset_readers/text_classification_json_reader.py
@@ -14,6 +14,7 @@ from .. import Instance, Tokenizer, TokenIndexer
 from ..fields import Field, ListField
 from ..fields.label_field import LabelField
 from ..fields.text_field import TextField
+from ...config import Registry
 from ...utils import ConfigurationError
 
 
@@ -22,6 +23,7 @@ def _is_sentence_segmenter(sentence_segmenter: Optional[Tokenizer]) -> bool:
     return callable(split_sentences_method)
 
 
+@Registry.register(DatasetReader, 'text_classification_json')
 class TextClassificationJSONReader(DatasetReader):
     def __init__(self,
                  tokenizer: Optional[Tokenizer] = None,
@@ -110,6 +112,7 @@ class TextClassificationJSONReader(DatasetReader):
                         label = str(label)
                 yield self.text_to_instance(text, label)
 
+
     def text_to_instance(self,
                          text: str,
                          label: Optional[Union[str, int]] = None) -> Instance:
diff --git a/combo/data/dataset_readers/universal_dependencies_dataset_reader.py b/combo/data/dataset_readers/universal_dependencies_dataset_reader.py
index 33bad6d..d06852c 100644
--- a/combo/data/dataset_readers/universal_dependencies_dataset_reader.py
+++ b/combo/data/dataset_readers/universal_dependencies_dataset_reader.py
@@ -12,6 +12,7 @@ import torch
 from overrides import overrides
 
 from combo import data
+from combo.config import Registry
 from combo.data import Vocabulary, fields, Instance, Token
 from combo.data.dataset_readers.dataset_reader import DatasetReader
 from combo.data.fields import Field
@@ -24,6 +25,7 @@ from conllu import parser
 from combo.utils import checks, pad_sequence_to_length
 
 
+@Registry.register(DatasetReader, 'conllu')
 class UniversalDependenciesDatasetReader(DatasetReader, ABC):
     def __init__(
             self,
@@ -120,7 +122,7 @@ class UniversalDependenciesDatasetReader(DatasetReader, ABC):
 
     def text_to_instance(self, tree: conllu.TokenList) -> Instance:
         fields_: Dict[str, Field] = {}
-        tree_tokens = [t for t in tree if isinstance(t["id"], int)]
+        tree_tokens = [t for t in tree if isinstance(t["idx"], int)]
 
         tokens = [Token(text=t["token"],
                         upostag=t.get("upostag"),
diff --git a/combo/data/fields/sequence_label_field.py b/combo/data/fields/sequence_label_field.py
index 60d30c6..b20bb6f 100644
--- a/combo/data/fields/sequence_label_field.py
+++ b/combo/data/fields/sequence_label_field.py
@@ -30,7 +30,7 @@ class SequenceLabelField(Field[torch.Tensor]):
     labels : `Union[List[str], List[int]]`
         A sequence of categorical labels, encoded as strings or integers.  These could be POS tags like [NN,
         JJ, ...], BIO tags like [B-PERS, I-PERS, O, O, ...], or any other categorical tag sequence. If the
-        labels are encoded as integers, they will not be indexed using a vocab.
+        labels are encoded as integers, they will not be indexed using a vocabulary.
     sequence_field : `SequenceField`
         A field containing the sequence that this `SequenceLabelField` is labeling.  Most often, this is a
         `TextField`, for tagging individual tokens in a sentence.
diff --git a/combo/data/fields/sequence_multilabel_field.py b/combo/data/fields/sequence_multilabel_field.py
index 4938d43..55b579a 100644
--- a/combo/data/fields/sequence_multilabel_field.py
+++ b/combo/data/fields/sequence_multilabel_field.py
@@ -36,7 +36,7 @@ class SequenceMultiLabelField(Field[torch.Tensor]):
 
     multi_labels : `List[List[str]]`
     multi_label_indexer : `Callable[[data.Vocabulary], Callable[[List[str]], List[int]]]`
-        Nested callable which based on vocab and sequence length maps values of the fields in the sequence
+        Nested callable which based on vocabulary and sequence length maps values of the fields in the sequence
         from strings to indexed, int values.
     as_tensor: `Callable[["SequenceMultiLabelField"], Callable[[Dict[str, int]], torch.Tensor]]`
         Nested callable which based on the field itself, maps indexed data to a tensor.
diff --git a/combo/data/fields/text_field.py b/combo/data/fields/text_field.py
index c4de273..da42252 100644
--- a/combo/data/fields/text_field.py
+++ b/combo/data/fields/text_field.py
@@ -19,7 +19,6 @@ import torch
 # produced by a given TokenIndexer, which will be input to a particular TokenEmbedder's forward()
 # method.  We label these as tensors, because that's what they typically are, though they could in
 # reality have arbitrary type.
-from combo.data import Vocabulary
 from combo.data.fields.sequence_field import SequenceField
 from combo.data.token_indexers import TokenIndexer, IndexedTokenList
 from combo.data.tokenizers import Token
@@ -103,7 +102,7 @@ class TextField(SequenceField[TextFieldTensors]):
             for token in self.tokens:
                 indexer.count_vocab_items(token, counter)
 
-    def index(self, vocab: Vocabulary):
+    def index(self, vocab):
         self._indexed_tokens = {}
         for indexer_name, indexer in self.token_indexers.items():
             self._indexed_tokens[indexer_name] = indexer.tokens_to_indices(self.tokens, vocab)
diff --git a/combo/data/instance.py b/combo/data/instance.py
index 34771ae..1c0ab26 100644
--- a/combo/data/instance.py
+++ b/combo/data/instance.py
@@ -50,7 +50,7 @@ class Instance(Mapping[str, Field]):
         """
         Add the field to the existing fields mapping.
         If we have already indexed the Instance, then we also index `field`, so
-        it is necessary to supply the vocab.
+        it is necessary to supply the vocabulary.
         """
         self.fields[field_name] = field
         if self.indexed and vocab is not None:
diff --git a/combo/data/token_indexers/__init__.py b/combo/data/token_indexers/__init__.py
index 75df1b3..d3edf27 100644
--- a/combo/data/token_indexers/__init__.py
+++ b/combo/data/token_indexers/__init__.py
@@ -4,3 +4,4 @@ from .single_id_token_indexer import SingleIdTokenIndexer
 from .pretrained_transformer_indexer import PretrainedTransformerIndexer
 from .pretrained_transformer_mismatched_indexer import PretrainedTransformerMismatchedIndexer
 from .pretrained_transformer_fixed_mismatched_indexer import PretrainedTransformerFixedMismatchedIndexer
+from .token_const_padding_characters_indexer import TokenConstPaddingCharactersIndexer
diff --git a/combo/data/token_indexers/pretrained_transformer_fixed_mismatched_indexer.py b/combo/data/token_indexers/pretrained_transformer_fixed_mismatched_indexer.py
index af7bb0a..71e3667 100644
--- a/combo/data/token_indexers/pretrained_transformer_fixed_mismatched_indexer.py
+++ b/combo/data/token_indexers/pretrained_transformer_fixed_mismatched_indexer.py
@@ -3,16 +3,19 @@ Adapted from COMBO
 Authors: Mateusz Klimaszewski, Lukasz Pszenny
 """
 
-from typing import Optional, Dict, Any, List, Tuple
+from typing import Optional, Dict, Any
 
 from overrides import overrides
 
+from combo.config import Registry
 from combo.data import Vocabulary
+from combo.data.token_indexers.token_indexer import TokenIndexer
 from combo.data.token_indexers import IndexedTokenList
 from combo.data.token_indexers.pretrained_transformer_indexer import PretrainedTransformerIndexer
 from combo.data.token_indexers.pretrained_transformer_mismatched_indexer import PretrainedTransformerMismatchedIndexer
 
 
+@Registry.register(TokenIndexer, 'pretrained_transformer_mismatched_fixed')
 class PretrainedTransformerFixedMismatchedIndexer(PretrainedTransformerMismatchedIndexer):
 
     def __init__(self, model_name: str, namespace: str = "tags", max_length: int = None,
diff --git a/combo/data/token_indexers/pretrained_transformer_indexer.py b/combo/data/token_indexers/pretrained_transformer_indexer.py
index 6580f4b..d2a8a62 100644
--- a/combo/data/token_indexers/pretrained_transformer_indexer.py
+++ b/combo/data/token_indexers/pretrained_transformer_indexer.py
@@ -6,7 +6,9 @@ https://github.com/allenai/allennlp/blob/main/allennlp/data/token_indexers/pretr
 from typing import Dict, List, Optional, Tuple, Any
 import logging
 import torch
+from overrides import overrides
 
+from combo.config import Registry
 from combo.data import Vocabulary
 from combo.data.tokenizers import Token
 from combo.data.token_indexers import TokenIndexer, IndexedTokenList
@@ -16,6 +18,7 @@ from combo.utils import pad_sequence_to_length
 logger = logging.getLogger(__name__)
 
 
+@Registry.register(TokenIndexer, 'pretrained_transformer')
 class PretrainedTransformerIndexer(TokenIndexer):
     """
     This `TokenIndexer` assumes that Tokens already have their indexes in them (see `text_id` field).
@@ -79,7 +82,7 @@ class PretrainedTransformerIndexer(TokenIndexer):
 
     def _add_encoding_to_vocabulary_if_needed(self, vocab: Vocabulary) -> None:
         """
-        Copies tokens from ```transformers``` model's vocab to the specified _namespace.
+        Copies tokens from ```transformers``` model's vocabulary to the specified _namespace.
         """
         if self._added_to_vocabulary:
             return
@@ -88,10 +91,12 @@ class PretrainedTransformerIndexer(TokenIndexer):
 
         self._added_to_vocabulary = True
 
+    @overrides
     def count_vocab_items(self, token: Token, counter: Dict[str, Dict[str, int]]):
         # If we only use pretrained models, we don't need to do anything here.
         pass
 
+    @overrides
     def tokens_to_indices(self, tokens: List[Token], vocabulary: Vocabulary) -> IndexedTokenList:
         self._add_encoding_to_vocabulary_if_needed(vocabulary)
 
@@ -105,6 +110,7 @@ class PretrainedTransformerIndexer(TokenIndexer):
 
         return self._postprocess_output(output)
 
+    @overrides
     def indices_to_tokens(
         self, indexed_tokens: IndexedTokenList, vocabulary: Vocabulary
     ) -> List[Token]:
@@ -203,12 +209,14 @@ class PretrainedTransformerIndexer(TokenIndexer):
 
         return output
 
+    @overrides
     def get_empty_token_list(self) -> IndexedTokenList:
         output: IndexedTokenList = {"token_ids": [], "mask": [], "type_ids": []}
         if self._max_length is not None:
             output["segment_concat_mask"] = []
         return output
 
+    @overrides
     def as_padded_tensor_dict(
         self, tokens: IndexedTokenList, padding_lengths: Dict[str, int]
     ) -> Dict[str, torch.Tensor]:
diff --git a/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py b/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py
index 6d660e7..e41407b 100644
--- a/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py
+++ b/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py
@@ -7,7 +7,9 @@ from typing import Dict, List, Any, Optional
 import logging
 
 import torch
+from overrides import overrides
 
+from combo.config import Registry, FromParameters
 from combo.data import Vocabulary
 from combo.data.tokenizers import Token
 from combo.data.token_indexers import TokenIndexer, IndexedTokenList
@@ -17,6 +19,7 @@ from combo.utils import pad_sequence_to_length
 logger = logging.getLogger(__name__)
 
 
+@Registry.register(TokenIndexer, 'pretrained_transformer_mismatched')
 class PretrainedTransformerMismatchedIndexer(TokenIndexer):
     """
     Use this indexer when (for whatever reason) you are not using a corresponding
@@ -67,9 +70,11 @@ class PretrainedTransformerMismatchedIndexer(TokenIndexer):
         self._num_added_start_tokens = self._matched_indexer._num_added_start_tokens
         self._num_added_end_tokens = self._matched_indexer._num_added_end_tokens
 
+    @overrides
     def count_vocab_items(self, token: Token, counter: Dict[str, Dict[str, int]]):
         return self._matched_indexer.count_vocab_items(token, counter)
 
+    @overrides
     def tokens_to_indices(self, tokens: List[Token], vocabulary: Vocabulary) -> IndexedTokenList:
         self._matched_indexer._add_encoding_to_vocabulary_if_needed(vocabulary)
 
@@ -91,12 +96,14 @@ class PretrainedTransformerMismatchedIndexer(TokenIndexer):
 
         return self._matched_indexer._postprocess_output(output)
 
+    @overrides
     def get_empty_token_list(self) -> IndexedTokenList:
         output = self._matched_indexer.get_empty_token_list()
         output["offsets"] = []
         output["wordpiece_mask"] = []
         return output
 
+    @overrides
     def as_padded_tensor_dict(
         self, tokens: IndexedTokenList, padding_lengths: Dict[str, int]
     ) -> Dict[str, torch.Tensor]:
diff --git a/combo/data/token_indexers/single_id_token_indexer.py b/combo/data/token_indexers/single_id_token_indexer.py
index 143c786..4f49bd1 100644
--- a/combo/data/token_indexers/single_id_token_indexer.py
+++ b/combo/data/token_indexers/single_id_token_indexer.py
@@ -6,6 +6,9 @@ https://github.com/allenai/allennlp/blob/80fb6061e568cb9d6ab5d45b661e86eb61b92c8
 from typing import Dict, List, Optional, Any
 import itertools
 
+from overrides import overrides
+
+from combo.config import FromParameters, Registry
 from combo.data import Vocabulary
 from combo.data.tokenizers import Token
 from combo.data.token_indexers import TokenIndexer, IndexedTokenList
@@ -13,6 +16,7 @@ from combo.data.token_indexers import TokenIndexer, IndexedTokenList
 _DEFAULT_VALUE = "THIS IS A REALLY UNLIKELY VALUE THAT HAS TO BE A STRING"
 
 
+@Registry.register(TokenIndexer, 'single_id')
 class SingleIdTokenIndexer(TokenIndexer):
     """
     This :class:`TokenIndexer` represents tokens as single integers.
@@ -65,6 +69,7 @@ class SingleIdTokenIndexer(TokenIndexer):
         self._feature_name = feature_name
         self._default_value = default_value
 
+    @overrides
     def count_vocab_items(self, token: Token, counter: Dict[str, Dict[str, int]]):
         if self._namespace is not None:
             text = self._get_feature_value(token)
@@ -72,9 +77,10 @@ class SingleIdTokenIndexer(TokenIndexer):
                 text = text.lower()
             counter[self._namespace][text] += 1
 
+    @overrides
     def tokens_to_indices(
             self, tokens: List[Token], vocabulary: Vocabulary
-    ) -> Dict[str, List[int]]:
+    ) -> IndexedTokenList:
         indices: List[int] = []
 
         for token in itertools.chain(self._start_tokens, tokens, self._end_tokens):
@@ -89,6 +95,7 @@ class SingleIdTokenIndexer(TokenIndexer):
 
         return {"tokens": indices}
 
+    @overrides
     def get_empty_token_list(self) -> IndexedTokenList:
         return {"tokens": []}
 
diff --git a/combo/data/token_indexers/token_characters_indexer.py b/combo/data/token_indexers/token_characters_indexer.py
index 1227f2e..56f84f6 100644
--- a/combo/data/token_indexers/token_characters_indexer.py
+++ b/combo/data/token_indexers/token_characters_indexer.py
@@ -10,12 +10,14 @@ import warnings
 from overrides import overrides
 import torch
 
+from combo.config import Registry
 from combo.data import Vocabulary
 from combo.data.token_indexers import TokenIndexer, IndexedTokenList
 from combo.data.tokenizers import Token, CharacterTokenizer
 from combo.utils import ConfigurationError, pad_sequence_to_length
 
 
+@Registry.register(TokenIndexer, 'token_characters')
 class TokenCharactersIndexer(TokenIndexer):
     """
     This :class:`TokenIndexer` represents tokens as lists of character indices.
@@ -27,7 +29,7 @@ class TokenCharactersIndexer(TokenIndexer):
     _namespace : `str`, optional (default=`token_characters`)
         We will use this _namespace in the :class:`Vocabulary` to map the characters in each token
         to indices.
-    character_tokenizer : `CharacterTokenizer`, optional (default=`CharacterTokenizer()`)
+    tokenizer : `CharacterTokenizer`, optional (default=`CharacterTokenizer()`)
         We use a :class:`CharacterTokenizer` to handle splitting tokens into characters, as it has
         options for byte encoding and other things.  The default here is to instantiate a
         `CharacterTokenizer` with its default parameters, which uses unicode characters and
@@ -46,7 +48,7 @@ class TokenCharactersIndexer(TokenIndexer):
     def __init__(
         self,
         namespace: str = "token_characters",
-        character_tokenizer: CharacterTokenizer = CharacterTokenizer(),
+        tokenizer: CharacterTokenizer = CharacterTokenizer(),
         start_tokens: List[str] = None,
         end_tokens: List[str] = None,
         min_padding_length: int = 0,
@@ -64,7 +66,7 @@ class TokenCharactersIndexer(TokenIndexer):
             )
         self._min_padding_length = min_padding_length
         self._namespace = namespace
-        self._character_tokenizer = character_tokenizer
+        self._character_tokenizer = tokenizer
 
         self._start_tokens = [Token(st) for st in (start_tokens or [])]
         self._end_tokens = [Token(et) for et in (end_tokens or [])]
@@ -75,14 +77,14 @@ class TokenCharactersIndexer(TokenIndexer):
             raise ConfigurationError("TokenCharactersIndexer needs a tokenizer that retains text")
         for character in self._character_tokenizer.tokenize(token.text):
             # If `text_id` is set on the character token (e.g., if we're using byte encoding), we
-            # will not be using the vocab for this character.
+            # will not be using the vocabulary for this character.
             if getattr(character, "text_id", None) is None:
                 counter[self._namespace][character.text] += 1
 
     @overrides
     def tokens_to_indices(
         self, tokens: List[Token], vocabulary: Vocabulary
-    ) -> Dict[str, List[List[int]]]:
+    ) -> IndexedTokenList:
         indices: List[List[int]] = []
         for token in itertools.chain(self._start_tokens, tokens, self._end_tokens):
             token_indices: List[int] = []
@@ -92,7 +94,7 @@ class TokenCharactersIndexer(TokenIndexer):
                 )
             for character in self._character_tokenizer.tokenize(token.text):
                 if getattr(character, "text_id", None) is not None:
-                    # `text_id` being set on the token means that we aren't using the vocab, we just
+                    # `text_id` being set on the token means that we aren't using the vocabulary, we just
                     # use this idx instead.
                     index = character.text_id
                 else:
@@ -147,3 +149,14 @@ class TokenCharactersIndexer(TokenIndexer):
     @overrides
     def get_empty_token_list(self) -> IndexedTokenList:
         return {"token_characters": []}
+
+    @overrides
+    def _to_params(self) -> Dict[str, str]:
+        return {
+            'token_min_padding_length': str(self._token_min_padding_length),
+            'min_padding_length': str(self._min_padding_length),
+            'namespace': str(self._namespace),
+            'character_tokenizer': str(self._character_tokenizer),
+            'start_tokens': str(self._start_tokens),
+            'end_tokens': str(self._end_tokens)
+        }
diff --git a/combo/data/token_indexers/token_const_padding_characters_indexer.py b/combo/data/token_indexers/token_const_padding_characters_indexer.py
index 6b7440d..d193258 100644
--- a/combo/data/token_indexers/token_const_padding_characters_indexer.py
+++ b/combo/data/token_indexers/token_const_padding_characters_indexer.py
@@ -7,7 +7,10 @@ import itertools
 from typing import List, Dict
 
 import torch
+
+from combo.config import Registry
 from combo.data.token_indexers import IndexedTokenList
+from combo.data.token_indexers.token_indexer import TokenIndexer
 from overrides import overrides
 
 from combo.data.token_indexers.token_characters_indexer import TokenCharactersIndexer
@@ -15,17 +18,18 @@ from combo.data.tokenizers import CharacterTokenizer
 from combo.utils import pad_sequence_to_length
 
 
+@Registry.register(TokenIndexer, 'characters_const_padding')
 class TokenConstPaddingCharactersIndexer(TokenCharactersIndexer):
     """Wrapper around allennlp token indexer with const padding."""
 
     def __init__(self,
                  namespace: str = "token_characters",
-                 character_tokenizer: CharacterTokenizer = CharacterTokenizer(),
+                 tokenizer: CharacterTokenizer = CharacterTokenizer(),
                  start_tokens: List[str] = None,
                  end_tokens: List[str] = None,
                  min_padding_length: int = 0,
                  token_min_padding_length: int = 0):
-        super().__init__(namespace, character_tokenizer, start_tokens, end_tokens, min_padding_length,
+        super().__init__(namespace, tokenizer, start_tokens, end_tokens, min_padding_length,
                          token_min_padding_length)
 
     @overrides
diff --git a/combo/data/token_indexers/token_features_indexer.py b/combo/data/token_indexers/token_features_indexer.py
index c0c1f21..cd88be6 100644
--- a/combo/data/token_indexers/token_features_indexer.py
+++ b/combo/data/token_indexers/token_features_indexer.py
@@ -3,19 +3,19 @@ Adapted from COMBO.
 Author: Mateusz Klimaszewski
 """
 import collections
-from abc import ABC
 from typing import List, Dict
 
 import torch
 from overrides import overrides
 
-from combo.data import Vocabulary
+from combo.config import Registry
 from combo.data.tokenizers.tokenizer import Token
 from combo.data.token_indexers.token_indexer import TokenIndexer, IndexedTokenList
 from combo.utils import pad_sequence_to_length
 
 
-class TokenFeatsIndexer(TokenIndexer, ABC):
+@Registry.register(TokenIndexer, 'feats_indexer')
+class TokenFeatsIndexer(TokenIndexer):
 
     def __init__(
             self,
@@ -34,7 +34,7 @@ class TokenFeatsIndexer(TokenIndexer, ABC):
             counter[self.namespace][feat] += 1
 
     @overrides
-    def tokens_to_indices(self, tokens: List[Token], vocabulary: Vocabulary) -> IndexedTokenList:
+    def tokens_to_indices(self, tokens: List[Token], vocabulary) -> IndexedTokenList:
         indices: List[List[int]] = []
         vocab_size = vocabulary.get_vocab_size(self.namespace)
         for token in tokens:
@@ -79,3 +79,11 @@ class TokenFeatsIndexer(TokenIndexer, ABC):
                                   )
             tensor_dict[key] = tensor
         return tensor_dict
+
+    @overrides
+    def _to_params(self) -> Dict[str, str]:
+        return {
+            'token_min_padding_length': str(self._token_min_padding_length),
+            'namespace': str(self.namespace),
+            'feature_name': str(self._feature_name)
+        }
diff --git a/combo/data/token_indexers/token_indexer.py b/combo/data/token_indexers/token_indexer.py
index 31835da..2c701ff 100644
--- a/combo/data/token_indexers/token_indexer.py
+++ b/combo/data/token_indexers/token_indexer.py
@@ -6,7 +6,9 @@ https://github.com/allenai/allennlp/blob/main/allennlp/data/token_indexers/token
 from typing import Any, Dict, List
 
 import torch
+from overrides import overrides
 
+from combo.config import FromParameters
 from combo.data.tokenizers.tokenizer import Token
 from combo.data.vocabulary import Vocabulary
 from combo.utils import pad_sequence_to_length
@@ -19,7 +21,7 @@ from combo.utils import pad_sequence_to_length
 IndexedTokenList = Dict[str, List[Any]]
 
 
-class TokenIndexer:
+class TokenIndexer(FromParameters):
     """
     A `TokenIndexer` determines how string tokens get represented as arrays of indices in a model.
     This class both converts strings into numerical values, with the help of a
@@ -123,3 +125,9 @@ class TokenIndexer:
         if isinstance(self, other.__class__):
             return self.__dict__ == other.__dict__
         return NotImplemented
+
+    @overrides
+    def _to_params(self) -> Dict[str, str]:
+        return {
+            'token_min_padding_length': str(self._token_min_padding_length)
+        }
diff --git a/combo/data/tokenizers/character_tokenizer.py b/combo/data/tokenizers/character_tokenizer.py
index 1c873da..51dc557 100644
--- a/combo/data/tokenizers/character_tokenizer.py
+++ b/combo/data/tokenizers/character_tokenizer.py
@@ -5,10 +5,12 @@ https://github.com/allenai/allennlp/blob/main/allennlp/data/tokenizers/character
 
 from typing import List, Union, Dict, Any
 
+from combo.config import Registry
 from combo.data.tokenizers.token import Token
 from combo.data.tokenizers.tokenizer import Tokenizer
 
 
+@Registry.register(Tokenizer, 'character')
 class CharacterTokenizer(Tokenizer):
     """
     A `CharacterTokenizer` splits strings into character tokens.
diff --git a/combo/data/tokenizers/lambo_tokenizer.py b/combo/data/tokenizers/lambo_tokenizer.py
index 43a5ac6..b3cec7c 100644
--- a/combo/data/tokenizers/lambo_tokenizer.py
+++ b/combo/data/tokenizers/lambo_tokenizer.py
@@ -1,10 +1,12 @@
 from typing import List, Dict, Any
 
+from combo.config import Registry
 from combo.data.tokenizers.token import Token
 from combo.data.tokenizers.tokenizer import Tokenizer
 from lambo.segmenter.lambo import Lambo
 
 
+@Registry.register(Tokenizer, 'lambo')
 class LamboTokenizer(Tokenizer):
 
     def __init__(
diff --git a/combo/data/tokenizers/pretrained_transformer_tokenizer.py b/combo/data/tokenizers/pretrained_transformer_tokenizer.py
index 6672460..2db074a 100644
--- a/combo/data/tokenizers/pretrained_transformer_tokenizer.py
+++ b/combo/data/tokenizers/pretrained_transformer_tokenizer.py
@@ -10,6 +10,7 @@ from typing import Any, Dict, List, Optional, Tuple, Iterable
 
 from transformers import PreTrainedTokenizer, AutoTokenizer
 
+from combo.config import Registry
 from combo.data.tokenizers.token import Token
 from combo.data.tokenizers.tokenizer import Tokenizer
 from combo.utils import sanitize_wordpiece
@@ -17,6 +18,7 @@ from combo.utils import sanitize_wordpiece
 logger = logging.getLogger(__name__)
 
 
+@Registry.register(Tokenizer, 'pretrained_transformer')
 class PretrainedTransformerTokenizer(Tokenizer):
     """
     A `PretrainedTransformerTokenizer` uses a model from HuggingFace's
diff --git a/combo/data/tokenizers/spacy_tokenizer.py b/combo/data/tokenizers/spacy_tokenizer.py
index 927da8a..6844aa5 100644
--- a/combo/data/tokenizers/spacy_tokenizer.py
+++ b/combo/data/tokenizers/spacy_tokenizer.py
@@ -8,11 +8,13 @@ from typing import List, Optional
 import spacy
 from spacy.tokens import Doc
 
+from combo.config import Registry
 from combo.data.tokenizers.token import Token
 from combo.data.tokenizers.tokenizer import Tokenizer
 from combo.utils.spacy import get_spacy_model
 
 
+@Registry.register(Tokenizer, 'spacy')
 class SpacyTokenizer(Tokenizer):
     """
     A `Tokenizer` that uses spaCy's tokenizer.  It's fast and reasonable - this is the
@@ -155,7 +157,7 @@ class _WhitespaceSpacyTokenizer:
     as follows:
     nlp = spacy.load("en_core_web_md")
     # hack to replace tokenizer with a whitespace tokenizer
-    nlp.tokenizer = _WhitespaceSpacyTokenizer(nlp.vocab)
+    nlp.tokenizer = _WhitespaceSpacyTokenizer(nlp.vocabulary)
     ... use nlp("here is some text") as normal.
     """
 
diff --git a/combo/data/tokenizers/tokenizer.py b/combo/data/tokenizers/tokenizer.py
index b164501..5c4d66b 100644
--- a/combo/data/tokenizers/tokenizer.py
+++ b/combo/data/tokenizers/tokenizer.py
@@ -7,10 +7,12 @@ import logging
 from .token import Token
 from typing import List, Optional
 
+from ...config import FromParameters
+
 logger = logging.getLogger(__name__)
 
 
-class Tokenizer:
+class Tokenizer(FromParameters):
     """
     A `Tokenizer` splits strings of text into tokens.  Typically, this either splits text into
     word tokens or character tokens, and those are the two tokenizer subclasses we have implemented
diff --git a/combo/data/tokenizers/whitespace_tokenizer.py b/combo/data/tokenizers/whitespace_tokenizer.py
index a4802de..fe3413f 100644
--- a/combo/data/tokenizers/whitespace_tokenizer.py
+++ b/combo/data/tokenizers/whitespace_tokenizer.py
@@ -1,9 +1,11 @@
 from typing import List, Dict, Any
 
+from combo.config import Registry
 from combo.data.tokenizers.token import Token
 from combo.data.tokenizers.tokenizer import Tokenizer
 
 
+@Registry.register(Tokenizer, 'whitespace')
 class WhitespaceTokenizer(Tokenizer):
     """
     A `Tokenizer` that assumes you've already done your own tokenization somehow and have
diff --git a/combo/data/vocabulary.py b/combo/data/vocabulary.py
index 8f455a9..aea54bf 100644
--- a/combo/data/vocabulary.py
+++ b/combo/data/vocabulary.py
@@ -3,13 +3,15 @@ import os
 import re
 import glob
 from collections import defaultdict
-from typing import Dict, Optional, Iterable, Set, List
+from typing import Dict, Optional, Iterable, Set, List, Union
 
 import logging
 
 from filelock import FileLock
 from transformers import PreTrainedTokenizer
 
+from combo.common import Tqdm
+from combo.config import FromParameters, Registry
 from combo.utils import ConfigurationError
 from combo.utils.file_utils import cached_path
 
@@ -34,6 +36,24 @@ def match_namespace(pattern: str, namespace: str):
     return False
 
 
+def _read_pretrained_tokens(embeddings_file_uri: str) -> List[str]:
+    # Moving this import to the top breaks everything (cycling import, I guess)
+    from combo.data.token_embedders.embedding import EmbeddingsTextFile
+
+    logger.info("Reading pretrained tokens from: %s", embeddings_file_uri)
+    tokens: List[str] = []
+    with EmbeddingsTextFile(embeddings_file_uri) as embeddings_file:
+        for line_number, line in enumerate(Tqdm.tqdm(embeddings_file), start=1):
+            token_end = line.find(" ")
+            if token_end >= 0:
+                token = line[:token_end]
+                tokens.append(token)
+            else:
+                line_begin = line[:20] + "..." if len(line) > 20 else line
+                logger.warning("Skipping line number %d: %s", line_number, line_begin)
+    return tokens
+
+
 class NamespaceVocabulary:
     def __init__(self,
                  padding_token: Optional[str] = None,
@@ -93,9 +113,16 @@ class _NamespaceDependentDefaultDict(defaultdict[str, NamespaceVocabulary]):
         return value
 
 
-class Vocabulary:
+class Vocabulary(FromParameters):
     def __init__(self,
+                 counter: Dict[str, Dict[str, int]] = None,
+                 min_count: Dict[str, int] = None,
+                 max_vocab_size: Union[int, Dict[str, int]] = None,
                  non_padded_namespaces: Iterable[str] = DEFAULT_NON_PADDED_NAMESPACES,
+                 pretrained_files: Optional[Dict[str, str]] = None,
+                 only_include_pretrained_words: bool = False,
+                 tokens_to_add: Dict[str, List[str]] = None,
+                 min_pretrained_embeddings: Dict[str, int] = None,
                  padding_token: Optional[str] = DEFAULT_PADDING_TOKEN,
                  oov_token: Optional[str] = DEFAULT_OOV_TOKEN):
 
@@ -112,12 +139,309 @@ class Vocabulary:
         self._vocab = _NamespaceDependentDefaultDict(self._non_padded_namespaces,
                                                      self._padding_token,
                                                      self._oov_token)
+        self._retained_counter: Optional[Dict[str, Dict[str, int]]] = None
 
     def _extend(self,
-                tokens_to_add: Dict[str, Dict[str, int]]):
+                counter: Dict[str, Dict[str, int]] = None,
+                min_count: Dict[str, int] = None,
+                max_vocab_size: Union[int, Dict[str, int]] = None,
+                non_padded_namespaces: Iterable[str] = DEFAULT_NON_PADDED_NAMESPACES,
+                pretrained_files: Optional[Dict[str, str]] = None,
+                only_include_pretrained_words: bool = False,
+                tokens_to_add: Dict[str, Dict[str, int]] = None,
+                min_pretrained_embeddings: Dict[str, int] = None):
+        if min_count is not None:
+            for key in min_count:
+                if counter is not None and key not in counter or counter is None:
+                    raise ConfigurationError(
+                        f"The key '{key}' is present in min_count but not in counter"
+                    )
+        if not isinstance(max_vocab_size, dict):
+            int_max_vocab_size = max_vocab_size
+            max_vocab_size = defaultdict(lambda: int_max_vocab_size)  # type: ignore
+
+        min_count = min_count or {}
+        pretrained_files = pretrained_files or {}
+        min_pretrained_embeddings = min_pretrained_embeddings or {}
+        non_padded_namespaces = set(non_padded_namespaces)
+        counter = counter or {}
+        tokens_to_add = tokens_to_add or {}
+        self._retained_counter = counter
+        # Make sure vocabulary extension is safe.
+        current_namespaces = {*self._vocab}
+        extension_namespaces = {*counter, *tokens_to_add}
+
+        for namespace in current_namespaces & extension_namespaces:
+            # if new namespace was already present
+            # Either both should be padded or none should be.
+            original_padded = not any(
+                match_namespace(pattern, namespace) for pattern in self._non_padded_namespaces
+            )
+            extension_padded = not any(
+                match_namespace(pattern, namespace) for pattern in non_padded_namespaces
+            )
+            if original_padded != extension_padded:
+                raise ConfigurationError(
+                    "Common namespace {} has conflicting ".format(namespace)
+                    + "setting of padded = True/False. "
+                    + "Hence extension cannot be done."
+                )
+
+        self._vocab.add_non_padded_namespaces(non_padded_namespaces)
+        self._non_padded_namespaces.update(non_padded_namespaces)
+
+        for namespace in counter:
+            pretrained_set: Optional[Set] = None
+            if namespace in pretrained_files:
+                pretrained_list = _read_pretrained_tokens(pretrained_files[namespace])
+                min_embeddings = min_pretrained_embeddings.get(namespace, 0)
+                if min_embeddings > 0 or min_embeddings == -1:
+                    tokens_old = tokens_to_add.get(namespace, [])
+                    tokens_new = (
+                        pretrained_list
+                        if min_embeddings == -1
+                        else pretrained_list[:min_embeddings]
+                    )
+                    tokens_to_add[namespace] = tokens_old + tokens_new
+                pretrained_set = set(pretrained_list)
+            token_counts = list(counter[namespace].items())
+            token_counts.sort(key=lambda x: x[1], reverse=True)
+            max_vocab: Optional[int]
+            try:
+                max_vocab = max_vocab_size[namespace]
+            except KeyError:
+                max_vocab = None
+            if max_vocab:
+                token_counts = token_counts[:max_vocab]
+            for token, count in token_counts:
+                if pretrained_set is not None:
+                    if only_include_pretrained_words:
+                        if token in pretrained_set and count >= min_count.get(namespace, 1):
+                            self.add_token_to_namespace(token, namespace)
+                    elif token in pretrained_set or count >= min_count.get(namespace, 1):
+                        self.add_token_to_namespace(token, namespace)
+                elif count >= min_count.get(namespace, 1):
+                    self.add_token_to_namespace(token, namespace)
+
         for namespace, tokens in tokens_to_add.items():
             self._vocab[namespace].append_tokens(tokens)
 
+    @classmethod
+    def from_files(cls,
+                   directory: Union[str, os.PathLike],
+                   padding_token: Optional[str] = DEFAULT_PADDING_TOKEN,
+                   oov_token: Optional[str] = DEFAULT_OOV_TOKEN,
+                   ) -> "Vocabulary":
+        """
+                Loads a `Vocabulary` that was serialized either using `save_to_files` or inside
+                a model archive file.
+
+                # Parameters
+
+                directory : `str`
+                    The directory or archive file containing the serialized vocabulary.
+                """
+        logger.info("Loading token dictionary from %s.", directory)
+        padding_token = padding_token if padding_token is not None else DEFAULT_PADDING_TOKEN
+        oov_token = oov_token if oov_token is not None else DEFAULT_OOV_TOKEN
+
+        if not os.path.isdir(directory):
+            base_directory = cached_path(directory, extract_archive=True)
+            # For convenience we'll check for a 'vocabulary' subdirectory of the archive.
+            # That way you can use model archives directly.
+            vocab_subdir = os.path.join(base_directory, "vocabulary")
+            if os.path.isdir(vocab_subdir):
+                directory = vocab_subdir
+            elif os.path.isdir(base_directory):
+                directory = base_directory
+            else:
+                raise ConfigurationError(f"{directory} is neither a directory nor an archive")
+
+        files = [file for file in glob.glob(os.path.join(directory, '*.txt'))]
+
+        if len(files) == 0:
+            logger.warning(f'Directory %s is empty' % directory)
+
+        with FileLock(os.path.join(directory, ".lock")):
+            with codecs.open(
+                    os.path.join(directory, NAMESPACE_PADDING_FILE), "r", "utf-8"
+            ) as namespace_file:
+                non_padded_namespaces = [namespace_str.strip() for namespace_str in namespace_file]
+
+            vocab = cls(
+                non_padded_namespaces=non_padded_namespaces,
+                padding_token=padding_token,
+                oov_token=oov_token,
+            )
+
+            for namespace_filename in os.listdir(directory):
+                if namespace_filename == NAMESPACE_PADDING_FILE:
+                    continue
+                if namespace_filename.startswith("."):
+                    continue
+                namespace = namespace_filename.replace(".txt", "")
+                if any(match_namespace(pattern, namespace) for pattern in non_padded_namespaces):
+                    is_padded = False
+                else:
+                    is_padded = True
+                filename = os.path.join(directory, namespace_filename)
+                vocab.set_from_file(filename, is_padded, namespace=namespace, oov_token=oov_token)
+
+        return vocab
+
+    @classmethod
+    def from_instances(
+            cls,
+            instances: Iterable["Instance"],
+            min_count: Dict[str, int] = None,
+            max_vocab_size: Union[int, Dict[str, int]] = None,
+            non_padded_namespaces: Iterable[str] = DEFAULT_NON_PADDED_NAMESPACES,
+            pretrained_files: Optional[Dict[str, str]] = None,
+            only_include_pretrained_words: bool = False,
+            tokens_to_add: Dict[str, List[str]] = None,
+            min_pretrained_embeddings: Dict[str, int] = None,
+            padding_token: Optional[str] = DEFAULT_PADDING_TOKEN,
+            oov_token: Optional[str] = DEFAULT_OOV_TOKEN,
+    ) -> "Vocabulary":
+        """
+                Constructs a vocabulary given a collection of `Instances` and some parameters.
+                We count all of the vocabulary items in the instances, then pass those counts
+                and the other parameters, to :func:`__init__`.  See that method for a description
+                of what the other parameters do.
+
+                The `instances` parameter does not get an entry in a typical AllenNLP configuration file,
+                but the other parameters do (if you want non-default parameters).
+                """
+        logger.info("Fitting token dictionary from dataset.")
+        padding_token = padding_token if padding_token is not None else DEFAULT_PADDING_TOKEN
+        oov_token = oov_token if oov_token is not None else DEFAULT_OOV_TOKEN
+        namespace_token_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int))
+        for instance in Tqdm.tqdm(instances, desc="building vocabulary"):
+            instance.count_vocab_items(namespace_token_counts)
+
+        return cls(
+            counter=namespace_token_counts,
+            min_count=min_count,
+            max_vocab_size=max_vocab_size,
+            non_padded_namespaces=non_padded_namespaces,
+            pretrained_files=pretrained_files,
+            only_include_pretrained_words=only_include_pretrained_words,
+            tokens_to_add=tokens_to_add,
+            min_pretrained_embeddings=min_pretrained_embeddings,
+            padding_token=padding_token,
+            oov_token=oov_token,
+        )
+
+    @classmethod
+    def from_files_and_instances(
+            cls,
+            instances: Iterable["Instance"],
+            directory: str,
+            padding_token: Optional[str] = DEFAULT_PADDING_TOKEN,
+            oov_token: Optional[str] = DEFAULT_OOV_TOKEN,
+            min_count: Dict[str, int] = None,
+            max_vocab_size: Union[int, Dict[str, int]] = None,
+            non_padded_namespaces: Iterable[str] = DEFAULT_NON_PADDED_NAMESPACES,
+            pretrained_files: Optional[Dict[str, str]] = None,
+            only_include_pretrained_words: bool = False,
+            tokens_to_add: Dict[str, List[str]] = None,
+            min_pretrained_embeddings: Dict[str, int] = None,
+    ) -> "Vocabulary":
+        """
+        Extends an already generated vocabulary using a collection of instances.
+
+        The `instances` parameter does not get an entry in a typical AllenNLP configuration file,
+        but the other parameters do (if you want non-default parameters).  See `__init__` for a
+        description of what the other parameters mean.
+        """
+        vocab = cls.from_files(directory, padding_token, oov_token)
+        logger.info("Fitting token dictionary from dataset.")
+        namespace_token_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int))
+        for instance in Tqdm.tqdm(instances):
+            instance.count_vocab_items(namespace_token_counts)
+        vocab._extend(
+            counter=namespace_token_counts,
+            min_count=min_count,
+            max_vocab_size=max_vocab_size,
+            non_padded_namespaces=non_padded_namespaces,
+            pretrained_files=pretrained_files,
+            only_include_pretrained_words=only_include_pretrained_words,
+            tokens_to_add=tokens_to_add,
+            min_pretrained_embeddings=min_pretrained_embeddings,
+        )
+        return vocab
+
+    @classmethod
+    def from_pretrained_transformer_and_instances(
+            cls,
+            instances: Iterable["Instance"],
+            transformers: Dict[str, str],
+            min_count: Dict[str, int] = None,
+            max_vocab_size: Union[int, Dict[str, int]] = None,
+            non_padded_namespaces: Iterable[str] = DEFAULT_NON_PADDED_NAMESPACES,
+            pretrained_files: Optional[Dict[str, str]] = None,
+            only_include_pretrained_words: bool = False,
+            tokens_to_add: Dict[str, List[str]] = None,
+            min_pretrained_embeddings: Dict[str, int] = None,
+            padding_token: Optional[str] = DEFAULT_PADDING_TOKEN,
+            oov_token: Optional[str] = DEFAULT_OOV_TOKEN,
+    ) -> "Vocabulary":
+        """
+        Construct a vocabulary given a collection of `Instance`'s and some parameters. Then extends
+        it with generated vocabularies from pretrained transformers.
+
+        Vocabulary from instances is constructed by passing parameters to :func:`from_instances`,
+        and then updated by including merging in vocabularies from
+        :func:`from_pretrained_transformer`. See other methods for full descriptions for what the
+        other parameters do.
+
+        The `instances` parameters does not get an entry in a typical AllenNLP configuration file,
+        other parameters do (if you want non-default parameters).
+
+        # Parameters
+
+        transformers : `Dict[str, str]`
+            Dictionary mapping the vocabulary namespaces (keys) to a transformer model name (value).
+            Namespaces not included will be ignored.
+
+        # Examples
+
+        You can use this constructor by modifying the following example within your training
+        configuration.
+
+        ```jsonnet
+        {
+            vocabulary: {
+                type: 'from_pretrained_transformer_and_instances',
+                transformers: {
+                    'namespace1': 'bert-base-cased',
+                    'namespace2': 'roberta-base',
+                },
+            }
+        }
+        ```
+        """
+        vocab = cls.from_instances(
+            instances=instances,
+            min_count=min_count,
+            max_vocab_size=max_vocab_size,
+            non_padded_namespaces=non_padded_namespaces,
+            pretrained_files=pretrained_files,
+            only_include_pretrained_words=only_include_pretrained_words,
+            tokens_to_add=tokens_to_add,
+            min_pretrained_embeddings=min_pretrained_embeddings,
+            padding_token=padding_token,
+            oov_token=oov_token,
+        )
+
+        for namespace, model_name in transformers.items():
+            transformer_vocab = cls.from_pretrained_transformer(
+                model_name=model_name, namespace=namespace
+            )
+            vocab.extend_from_vocab(transformer_vocab)
+
+        return vocab
+
     def save_to_files(self, directory: str) -> None:
         """
         Persist this Vocabulary to files, so it can be reloaded later.
@@ -133,7 +457,7 @@ class Vocabulary:
             logger.warning("Directory %s is not empty", directory)
 
         # We use a lock file to avoid race conditions where multiple processes
-        # might be reading/writing from/to the same vocab files at once.
+        # might be reading/writing from/to the same vocabulary files at once.
         with FileLock(os.path.join(directory, ".lock")):
             with codecs.open(
                     os.path.join(directory, NAMESPACE_PADDING_FILE), "w", "utf-8"
@@ -190,7 +514,7 @@ class Vocabulary:
             self, tokenizer: PreTrainedTokenizer, namespace: str = "tokens"
     ) -> None:
         """
-        Copies tokens from a transformer tokenizer's vocab into the given namespace.
+        Copies tokens from a transformer tokenizer's vocabulary into the given namespace.
         """
         try:
             vocab_items = tokenizer.get_vocab().items()
@@ -269,87 +593,75 @@ class Vocabulary:
 
                 self._vocab[namespace].insert_token(token, index)
         if is_padded:
-            assert self._oov_token in self._vocab[namespace].get_itos(), "OOV token not found!"
-
-
-class PretrainedTransformerVocabulary(Vocabulary):
-    def __init__(self,
-                 model_name: str,
-                 namespace: str = "tokens",
-                 oov_token: Optional[str] = None):
-        """
-        Initialize a vocabulary from the vocabulary of a pretrained transformer model.
-        If `oov_token` is not given, we will try to infer it from the transformer tokenizer.
-        """
-        from combo.common import cached_transformers
-
-        tokenizer = cached_transformers.get_tokenizer(model_name)
-        if oov_token is None:
-            if hasattr(tokenizer, "_unk_token"):
-                oov_token = tokenizer._unk_token
-            elif hasattr(tokenizer, "special_tokens_map"):
-                oov_token = tokenizer.special_tokens_map.get("unk_token")
-
-        super().__init__(non_padded_namespaces=[namespace],
-                         oov_token=oov_token)
-        self.add_transformer_vocab(tokenizer, namespace)
-
-
-class FromFilesVocabulary(Vocabulary):
-    def __init__(self,
-                 directory: str,
-                 padding_token: Optional[str] = DEFAULT_PADDING_TOKEN,
-                 oov_token: Optional[str] = DEFAULT_OOV_TOKEN) -> None:
+            assert self._oov_token in self._vocab[namespace].get_itos().values(), "OOV token not found!"
+
+
+def get_slices_if_not_provided(vocab: Vocabulary):
+    if hasattr(vocab, "slices"):
+        return vocab.slices
+
+    if "feats_labels" in vocab.get_namespaces():
+        idx2token = vocab.get_index_to_token_vocabulary("feats_labels")
+        for _, v in dict(idx2token).items():
+            if v not in ["_", "__PAD__"]:
+                empty_value = v.split("=")[0] + "=None"
+                vocab.add_token_to_namespace(empty_value, "feats_labels")
+
+        slices = {}
+        for idx, name in vocab.get_index_to_token_vocabulary("feats_labels").items():
+            # There are 2 types features: with (Case=Acc) or without assigment (None).
+            # Here we group their indices by name (before assigment sign).
+            name = name.split("=")[0]
+            if name in slices:
+                slices[name].append(idx)
+            else:
+                slices[name] = [idx]
+        vocab.slices = slices
+        return vocab.slices
+
+
+@Registry.register(Vocabulary, "from_instances_extended")
+class FromInstancesVocabulary(Vocabulary, FromParameters):
+    @classmethod
+    def from_instances_extended(
+            cls,
+            instances: Iterable["Instance"],
+            min_count: Dict[str, int] = None,
+            max_vocab_size: Union[int, Dict[str, int]] = None,
+            non_padded_namespaces: Iterable[str] = DEFAULT_NON_PADDED_NAMESPACES,
+            pretrained_files: Optional[Dict[str, str]] = None,
+            only_include_pretrained_words: bool = False,
+            min_pretrained_embeddings: Dict[str, int] = None,
+            padding_token: Optional[str] = DEFAULT_PADDING_TOKEN,
+            oov_token: Optional[str] = DEFAULT_OOV_TOKEN,
+    ) -> "Vocabulary":
         """
-        Adapted from https://github.com/allenai/allennlp/blob/main/allennlp/data/vocabulary.py
-
-        :param directory:
-        :param padding_token:
-        :param oov_token:
-        :return:
+        Extension to manually fill gaps in missing 'feats_labels'.
         """
-        logger.info("Loading token dictionary from %s.", directory)
-        padding_token = padding_token if padding_token is not None else DEFAULT_PADDING_TOKEN
-        oov_token = oov_token if oov_token is not None else DEFAULT_OOV_TOKEN
-
-        if not os.path.isdir(directory):
-            base_directory = cached_path(directory, extract_archive=True)
-            # For convenience we'll check for a 'vocabulary' subdirectory of the archive.
-            # That way you can use model archives directly.
-            vocab_subdir = os.path.join(base_directory, "vocabulary")
-            if os.path.isdir(vocab_subdir):
-                directory = vocab_subdir
-            elif os.path.isdir(base_directory):
-                directory = base_directory
-            else:
-                raise ConfigurationError(f"{directory} is neither a directory nor an archive")
-
-        files = [file for file in glob.glob(os.path.join(directory, '*.txt'))]
-
-        if len(files) == 0:
-            logger.warning(f'Directory %s is empty' % directory)
-
-        with FileLock(os.path.join(directory, ".lock")):
-            with codecs.open(
-                    os.path.join(directory, NAMESPACE_PADDING_FILE), "r", "utf-8"
-            ) as namespace_file:
-                non_padded_namespaces = [namespace_str.strip() for namespace_str in namespace_file]
-
-            super().__init__(
-                non_padded_namespaces=non_padded_namespaces,
-                padding_token=padding_token,
-                oov_token=oov_token,
-            )
-
-            for namespace_filename in os.listdir(directory):
-                if namespace_filename == NAMESPACE_PADDING_FILE:
-                    continue
-                if namespace_filename.startswith("."):
-                    continue
-                namespace = namespace_filename.replace(".txt", "")
-                if any(match_namespace(pattern, namespace) for pattern in non_padded_namespaces):
-                    is_padded = False
-                else:
-                    is_padded = True
-                filename = os.path.join(directory, namespace_filename)
-                self.set_from_file(filename, is_padded, namespace=namespace, oov_token=oov_token)
+        # Load manually tokens from pretrained file (using different strategy
+        # - only words add all embedding file, without checking if were seen
+        # in any dataset.
+        tokens_to_add = None
+        if pretrained_files and "tokens" in pretrained_files:
+            pretrained_set = set(_read_pretrained_tokens(pretrained_files["tokens"]))
+            tokens_to_add = {"tokens": list(pretrained_set)}
+            pretrained_files = None
+
+        vocab = super().from_instances(
+            instances=instances,
+            min_count=min_count,
+            max_vocab_size=max_vocab_size,
+            non_padded_namespaces=non_padded_namespaces,
+            pretrained_files=pretrained_files,
+            only_include_pretrained_words=only_include_pretrained_words,
+            tokens_to_add=tokens_to_add,
+            min_pretrained_embeddings=min_pretrained_embeddings,
+            padding_token=padding_token,
+            oov_token=oov_token
+        )
+        # Extending vocabulary with features that does not show up explicitly.
+        # To know all features we need to read full dataset first.
+        # Adding auxiliary '=None' feature for each category is needed
+        # to perform classification.
+        get_slices_if_not_provided(vocab)
+        return vocab
diff --git a/combo/example.ipynb b/combo/example.ipynb
index 9597084..02b203d 100644
--- a/combo/example.ipynb
+++ b/combo/example.ipynb
@@ -7,8 +7,8 @@
    "metadata": {
     "collapsed": true,
     "ExecuteTime": {
-     "end_time": "2023-09-04T14:36:56.659578Z",
-     "start_time": "2023-09-04T14:36:45.221897Z"
+     "end_time": "2023-09-22T07:28:27.168062Z",
+     "start_time": "2023-09-22T07:28:23.442292Z"
     }
    },
    "outputs": [],
@@ -16,15 +16,88 @@
     "from combo.predict import COMBO"
    ]
   },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "error loading _jsonnet (this is expected on Windows), treating /var/folders/n6/5g0cjy9s6j9cgp2xbw_6c2v40000gn/T/tmpzyupif28/config.json as plain json\n"
+     ]
+    },
+    {
+     "ename": "TypeError",
+     "evalue": "__init__() missing 2 required positional arguments: 'model' and 'dataset_reader'",
+     "output_type": "error",
+     "traceback": [
+      "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
+      "\u001B[0;31mKeyError\u001B[0m                                  Traceback (most recent call last)",
+      "File \u001B[0;32m~/PycharmProjects/combo-lightning/combo/config/registry.py:48\u001B[0m, in \u001B[0;36mRegistry.resolve\u001B[0;34m(cls, category, class_name)\u001B[0m\n\u001B[1;32m     47\u001B[0m \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[0;32m---> 48\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mcls\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m__classes\u001B[49m\u001B[43m[\u001B[49m\u001B[43mcategory\u001B[49m\u001B[43m]\u001B[49m\u001B[43m[\u001B[49m\u001B[43mclass_name\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mlower\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\u001B[43m]\u001B[49m\n\u001B[1;32m     49\u001B[0m \u001B[38;5;28;01mexcept\u001B[39;00m \u001B[38;5;167;01mKeyError\u001B[39;00m:\n",
+      "\u001B[0;31mKeyError\u001B[0m: 'combo_dependency_parsing_from_vocab'",
+      "\nDuring handling of the above exception, another exception occurred:\n",
+      "\u001B[0;31mKeyError\u001B[0m                                  Traceback (most recent call last)",
+      "File \u001B[0;32m~/PycharmProjects/combo-lightning/combo/config/registry.py:51\u001B[0m, in \u001B[0;36mRegistry.resolve\u001B[0;34m(cls, category, class_name)\u001B[0m\n\u001B[1;32m     50\u001B[0m \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[0;32m---> 51\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mcls\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m__defaults\u001B[49m\u001B[43m[\u001B[49m\u001B[43mcategory\u001B[49m\u001B[43m]\u001B[49m\n\u001B[1;32m     52\u001B[0m \u001B[38;5;28;01mexcept\u001B[39;00m \u001B[38;5;167;01mKeyError\u001B[39;00m:\n",
+      "\u001B[0;31mKeyError\u001B[0m: <class 'combo.modules.parser.DependencyRelationModel'>",
+      "\nDuring handling of the above exception, another exception occurred:\n",
+      "\u001B[0;31mRegistryException\u001B[0m                         Traceback (most recent call last)",
+      "File \u001B[0;32m~/PycharmProjects/combo-lightning/combo/config/from_parameters.py:41\u001B[0m, in \u001B[0;36m_resolve\u001B[0;34m(type, values, pass_to_subclasses, default_name)\u001B[0m\n\u001B[1;32m     40\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;124m'\u001B[39m\u001B[38;5;124mtype\u001B[39m\u001B[38;5;124m'\u001B[39m \u001B[38;5;129;01min\u001B[39;00m values\u001B[38;5;241m.\u001B[39mkeys():\n\u001B[0;32m---> 41\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mRegistry\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mresolve\u001B[49m\u001B[43m(\u001B[49m\n\u001B[1;32m     42\u001B[0m \u001B[43m        \u001B[49m\u001B[38;5;28;43mtype\u001B[39;49m\u001B[43m,\u001B[49m\n\u001B[1;32m     43\u001B[0m \u001B[43m        \u001B[49m\u001B[43mvalues\u001B[49m\u001B[43m[\u001B[49m\u001B[38;5;124;43m'\u001B[39;49m\u001B[38;5;124;43mtype\u001B[39;49m\u001B[38;5;124;43m'\u001B[39;49m\u001B[43m]\u001B[49m\n\u001B[1;32m     44\u001B[0m \u001B[43m    \u001B[49m\u001B[43m)\u001B[49m\u001B[38;5;241m.\u001B[39mfrom_parameters({\u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mvalues, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mpass_to_subclasses}, pass_to_subclasses)\n\u001B[1;32m     45\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n",
+      "File \u001B[0;32m~/PycharmProjects/combo-lightning/combo/config/registry.py:53\u001B[0m, in \u001B[0;36mRegistry.resolve\u001B[0;34m(cls, category, class_name)\u001B[0m\n\u001B[1;32m     52\u001B[0m \u001B[38;5;28;01mexcept\u001B[39;00m \u001B[38;5;167;01mKeyError\u001B[39;00m:\n\u001B[0;32m---> 53\u001B[0m     \u001B[38;5;28;01mraise\u001B[39;00m RegistryException(\u001B[38;5;124mf\u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mNo \u001B[39m\u001B[38;5;132;01m{\u001B[39;00m\u001B[38;5;28mstr\u001B[39m(category)\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m with registered name \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mclass_name\u001B[38;5;241m.\u001B[39mlower()\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m, no default value either\u001B[39m\u001B[38;5;124m'\u001B[39m)\n",
+      "\u001B[0;31mRegistryException\u001B[0m: No <class 'combo.modules.parser.DependencyRelationModel'> with registered name combo_dependency_parsing_from_vocab, no default value either",
+      "\nDuring handling of the above exception, another exception occurred:\n",
+      "\u001B[0;31mTypeError\u001B[0m                                 Traceback (most recent call last)",
+      "Cell \u001B[0;32mIn[2], line 1\u001B[0m\n\u001B[0;32m----> 1\u001B[0m nlp \u001B[38;5;241m=\u001B[39m \u001B[43mCOMBO\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mfrom_pretrained\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[38;5;124;43mpolish-pdb-ud29\u001B[39;49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[43m)\u001B[49m\n",
+      "File \u001B[0;32m~/PycharmProjects/combo-lightning/combo/predict.py:249\u001B[0m, in \u001B[0;36mCOMBO.from_pretrained\u001B[0;34m(cls, path, tokenizer, batch_size, cuda_device)\u001B[0m\n\u001B[1;32m    246\u001B[0m         logger\u001B[38;5;241m.\u001B[39merror(e)\n\u001B[1;32m    247\u001B[0m         \u001B[38;5;28;01mraise\u001B[39;00m e\n\u001B[0;32m--> 249\u001B[0m archive \u001B[38;5;241m=\u001B[39m \u001B[43mmodels\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mload_archive\u001B[49m\u001B[43m(\u001B[49m\u001B[43mmodel_path\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcuda_device\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mcuda_device\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m    250\u001B[0m model \u001B[38;5;241m=\u001B[39m archive\u001B[38;5;241m.\u001B[39mmodel\n\u001B[1;32m    251\u001B[0m dataset_reader_class \u001B[38;5;241m=\u001B[39m archive\u001B[38;5;241m.\u001B[39mconfig[\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mdataset_reader\u001B[39m\u001B[38;5;124m\"\u001B[39m]\u001B[38;5;241m.\u001B[39mget(\n\u001B[1;32m    252\u001B[0m     \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mtype\u001B[39m\u001B[38;5;124m\"\u001B[39m, \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mconllu\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n",
+      "File \u001B[0;32m~/PycharmProjects/combo-lightning/combo/models/archival.py:199\u001B[0m, in \u001B[0;36mload_archive\u001B[0;34m(archive_file, cuda_device, overrides, weights_file)\u001B[0m\n\u001B[1;32m    195\u001B[0m     \u001B[38;5;66;03m# Instantiate model and dataset readers. Use a duplicate of the config, as it will get consumed.\u001B[39;00m\n\u001B[1;32m    196\u001B[0m     dataset_reader, validation_dataset_reader \u001B[38;5;241m=\u001B[39m _load_dataset_readers(\n\u001B[1;32m    197\u001B[0m         config\u001B[38;5;241m.\u001B[39mduplicate(), serialization_dir\n\u001B[1;32m    198\u001B[0m     )\n\u001B[0;32m--> 199\u001B[0m     model \u001B[38;5;241m=\u001B[39m \u001B[43m_load_model\u001B[49m\u001B[43m(\u001B[49m\u001B[43mconfig\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mduplicate\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mweights_path\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mserialization_dir\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcuda_device\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m    201\u001B[0m \u001B[38;5;28;01mfinally\u001B[39;00m:\n\u001B[1;32m    202\u001B[0m     \u001B[38;5;28;01mif\u001B[39;00m tempdir \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n",
+      "File \u001B[0;32m~/PycharmProjects/combo-lightning/combo/models/archival.py:234\u001B[0m, in \u001B[0;36m_load_model\u001B[0;34m(config, weights_path, serialization_dir, cuda_device)\u001B[0m\n\u001B[1;32m    233\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21m_load_model\u001B[39m(config, weights_path, serialization_dir, cuda_device):\n\u001B[0;32m--> 234\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mModel\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mload\u001B[49m\u001B[43m(\u001B[49m\n\u001B[1;32m    235\u001B[0m \u001B[43m        \u001B[49m\u001B[43mconfig\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m    236\u001B[0m \u001B[43m        \u001B[49m\u001B[43mweights_file\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mweights_path\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m    237\u001B[0m \u001B[43m        \u001B[49m\u001B[43mserialization_dir\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mserialization_dir\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m    238\u001B[0m \u001B[43m        \u001B[49m\u001B[43mcuda_device\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mcuda_device\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m    239\u001B[0m \u001B[43m    \u001B[49m\u001B[43m)\u001B[49m\n",
+      "File \u001B[0;32m~/PycharmProjects/combo-lightning/combo/modules/model.py:407\u001B[0m, in \u001B[0;36mModel.load\u001B[0;34m(cls, config, serialization_dir, weights_file, cuda_device)\u001B[0m\n\u001B[1;32m    401\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28misinstance\u001B[39m(model_class, \u001B[38;5;28mtype\u001B[39m):\n\u001B[1;32m    402\u001B[0m     \u001B[38;5;66;03m# If you're using from_archive to specify your model (e.g., for fine tuning), then you\u001B[39;00m\n\u001B[1;32m    403\u001B[0m     \u001B[38;5;66;03m# can't currently override the behavior of _load; we just use the default Model._load.\u001B[39;00m\n\u001B[1;32m    404\u001B[0m     \u001B[38;5;66;03m# If we really need to change this, we would need to implement a recursive\u001B[39;00m\n\u001B[1;32m    405\u001B[0m     \u001B[38;5;66;03m# get_model_class method, that recurses whenever it finds a from_archive model type.\u001B[39;00m\n\u001B[1;32m    406\u001B[0m     model_class \u001B[38;5;241m=\u001B[39m Model\n\u001B[0;32m--> 407\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mmodel_class\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_load\u001B[49m\u001B[43m(\u001B[49m\u001B[43mconfig\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mserialization_dir\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mweights_file\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcuda_device\u001B[49m\u001B[43m)\u001B[49m\n",
+      "File \u001B[0;32m~/PycharmProjects/combo-lightning/combo/modules/model.py:305\u001B[0m, in \u001B[0;36mModel._load\u001B[0;34m(cls, config, serialization_dir, weights_file, cuda_device)\u001B[0m\n\u001B[1;32m    303\u001B[0m remove_keys_from_params(model_params)\n\u001B[1;32m    304\u001B[0m model_type \u001B[38;5;241m=\u001B[39m Registry\u001B[38;5;241m.\u001B[39mresolve(Model, model_params\u001B[38;5;241m.\u001B[39mget(\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mtype\u001B[39m\u001B[38;5;124m'\u001B[39m, \u001B[38;5;124m'\u001B[39m\u001B[38;5;124msemantic_multitask\u001B[39m\u001B[38;5;124m'\u001B[39m))\n\u001B[0;32m--> 305\u001B[0m model \u001B[38;5;241m=\u001B[39m \u001B[43mmodel_type\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mfrom_parameters\u001B[49m\u001B[43m(\u001B[49m\n\u001B[1;32m    306\u001B[0m \u001B[43m    \u001B[49m\u001B[38;5;28;43mdict\u001B[39;49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;28;43mdict\u001B[39;49m\u001B[43m(\u001B[49m\u001B[43mmodel_params\u001B[49m\u001B[43m)\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mvocabulary\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mvocab\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mserialization_dir\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mserialization_dir\u001B[49m\u001B[43m)\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43m{\u001B[49m\u001B[38;5;124;43m'\u001B[39;49m\u001B[38;5;124;43mvocabulary\u001B[39;49m\u001B[38;5;124;43m'\u001B[39;49m\u001B[43m:\u001B[49m\u001B[43m \u001B[49m\u001B[43mvocab\u001B[49m\u001B[43m}\u001B[49m\n\u001B[1;32m    307\u001B[0m \u001B[43m\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m    309\u001B[0m \u001B[38;5;66;03m# Force model to cpu or gpu, as appropriate, to make sure that the embeddings are\u001B[39;00m\n\u001B[1;32m    310\u001B[0m \u001B[38;5;66;03m# in sync with the weights\u001B[39;00m\n\u001B[1;32m    311\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m cuda_device \u001B[38;5;241m>\u001B[39m\u001B[38;5;241m=\u001B[39m \u001B[38;5;241m0\u001B[39m:\n",
+      "File \u001B[0;32m~/PycharmProjects/combo-lightning/combo/config/from_parameters.py:128\u001B[0m, in \u001B[0;36mFromParameters.from_parameters\u001B[0;34m(cls, parameters, pass_to_subclasses, default_name)\u001B[0m\n\u001B[1;32m    126\u001B[0m \u001B[38;5;28;01melif\u001B[39;00m typing\u001B[38;5;241m.\u001B[39mget_origin(type_annotation) \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m \u001B[38;5;129;01mand\u001B[39;00m type_annotation\u001B[38;5;241m.\u001B[39m__base__ \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28mobject\u001B[39m:\n\u001B[1;32m    127\u001B[0m     \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[0;32m--> 128\u001B[0m         parameters_to_pass[k] \u001B[38;5;241m=\u001B[39m \u001B[43m_resolve\u001B[49m\u001B[43m(\u001B[49m\u001B[43mtype_annotation\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m    129\u001B[0m \u001B[43m                                         \u001B[49m\u001B[43mv\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mpass_to_subclasses\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mdefault_name\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m    130\u001B[0m     \u001B[38;5;28;01mexcept\u001B[39;00m ConfigurationError:\n\u001B[1;32m    131\u001B[0m         warnings\u001B[38;5;241m.\u001B[39mwarn(\u001B[38;5;124mf\u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mAn object of type \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mtype_annotation\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m is not in the registry!\u001B[39m\u001B[38;5;124m'\u001B[39m)\n",
+      "File \u001B[0;32m~/PycharmProjects/combo-lightning/combo/config/from_parameters.py:51\u001B[0m, in \u001B[0;36m_resolve\u001B[0;34m(type, values, pass_to_subclasses, default_name)\u001B[0m\n\u001B[1;32m     49\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m torch\u001B[38;5;241m.\u001B[39mnn\u001B[38;5;241m.\u001B[39mModule(\u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mvalues)\n\u001B[1;32m     50\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28mtype\u001B[39m\u001B[38;5;241m.\u001B[39m__base__ \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28mobject\u001B[39m:\n\u001B[0;32m---> 51\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43m_resolve\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mtype\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m__base__\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mvalues\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mpass_to_subclasses\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mdefault_name\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m     52\u001B[0m warnings\u001B[38;5;241m.\u001B[39mwarn(\u001B[38;5;124mf\u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mNo classes of type \u001B[39m\u001B[38;5;132;01m{\u001B[39;00m\u001B[38;5;28mtype\u001B[39m\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m (values: \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mvalues\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m) in Registry!\u001B[39m\u001B[38;5;124m'\u001B[39m)\n\u001B[1;32m     53\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m values\n",
+      "File \u001B[0;32m~/PycharmProjects/combo-lightning/combo/config/from_parameters.py:41\u001B[0m, in \u001B[0;36m_resolve\u001B[0;34m(type, values, pass_to_subclasses, default_name)\u001B[0m\n\u001B[1;32m     39\u001B[0m \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[1;32m     40\u001B[0m     \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;124m'\u001B[39m\u001B[38;5;124mtype\u001B[39m\u001B[38;5;124m'\u001B[39m \u001B[38;5;129;01min\u001B[39;00m values\u001B[38;5;241m.\u001B[39mkeys():\n\u001B[0;32m---> 41\u001B[0m         \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mRegistry\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mresolve\u001B[49m\u001B[43m(\u001B[49m\n\u001B[1;32m     42\u001B[0m \u001B[43m            \u001B[49m\u001B[38;5;28;43mtype\u001B[39;49m\u001B[43m,\u001B[49m\n\u001B[1;32m     43\u001B[0m \u001B[43m            \u001B[49m\u001B[43mvalues\u001B[49m\u001B[43m[\u001B[49m\u001B[38;5;124;43m'\u001B[39;49m\u001B[38;5;124;43mtype\u001B[39;49m\u001B[38;5;124;43m'\u001B[39;49m\u001B[43m]\u001B[49m\n\u001B[1;32m     44\u001B[0m \u001B[43m        \u001B[49m\u001B[43m)\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mfrom_parameters\u001B[49m\u001B[43m(\u001B[49m\u001B[43m{\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mvalues\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mpass_to_subclasses\u001B[49m\u001B[43m}\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mpass_to_subclasses\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m     45\u001B[0m     \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m     46\u001B[0m         \u001B[38;5;28;01mreturn\u001B[39;00m Registry\u001B[38;5;241m.\u001B[39mget_default(\u001B[38;5;28mtype\u001B[39m)\u001B[38;5;241m.\u001B[39mfrom_parameters({\u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mvalues, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mpass_to_subclasses}, pass_to_subclasses)\n",
+      "File \u001B[0;32m~/PycharmProjects/combo-lightning/combo/config/from_parameters.py:128\u001B[0m, in \u001B[0;36mFromParameters.from_parameters\u001B[0;34m(cls, parameters, pass_to_subclasses, default_name)\u001B[0m\n\u001B[1;32m    126\u001B[0m \u001B[38;5;28;01melif\u001B[39;00m typing\u001B[38;5;241m.\u001B[39mget_origin(type_annotation) \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m \u001B[38;5;129;01mand\u001B[39;00m type_annotation\u001B[38;5;241m.\u001B[39m__base__ \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28mobject\u001B[39m:\n\u001B[1;32m    127\u001B[0m     \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[0;32m--> 128\u001B[0m         parameters_to_pass[k] \u001B[38;5;241m=\u001B[39m \u001B[43m_resolve\u001B[49m\u001B[43m(\u001B[49m\u001B[43mtype_annotation\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m    129\u001B[0m \u001B[43m                                         \u001B[49m\u001B[43mv\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mpass_to_subclasses\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mdefault_name\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m    130\u001B[0m     \u001B[38;5;28;01mexcept\u001B[39;00m ConfigurationError:\n\u001B[1;32m    131\u001B[0m         warnings\u001B[38;5;241m.\u001B[39mwarn(\u001B[38;5;124mf\u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mAn object of type \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mtype_annotation\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m is not in the registry!\u001B[39m\u001B[38;5;124m'\u001B[39m)\n",
+      "File \u001B[0;32m~/PycharmProjects/combo-lightning/combo/config/from_parameters.py:46\u001B[0m, in \u001B[0;36m_resolve\u001B[0;34m(type, values, pass_to_subclasses, default_name)\u001B[0m\n\u001B[1;32m     41\u001B[0m         \u001B[38;5;28;01mreturn\u001B[39;00m Registry\u001B[38;5;241m.\u001B[39mresolve(\n\u001B[1;32m     42\u001B[0m             \u001B[38;5;28mtype\u001B[39m,\n\u001B[1;32m     43\u001B[0m             values[\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mtype\u001B[39m\u001B[38;5;124m'\u001B[39m]\n\u001B[1;32m     44\u001B[0m         )\u001B[38;5;241m.\u001B[39mfrom_parameters({\u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mvalues, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mpass_to_subclasses}, pass_to_subclasses)\n\u001B[1;32m     45\u001B[0m     \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[0;32m---> 46\u001B[0m         \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mRegistry\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mget_default\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mtype\u001B[39;49m\u001B[43m)\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mfrom_parameters\u001B[49m\u001B[43m(\u001B[49m\u001B[43m{\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mvalues\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mpass_to_subclasses\u001B[49m\u001B[43m}\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mpass_to_subclasses\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m     47\u001B[0m \u001B[38;5;28;01mexcept\u001B[39;00m RegistryException:\n\u001B[1;32m     48\u001B[0m     \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28mtype\u001B[39m \u001B[38;5;129;01mis\u001B[39;00m torch\u001B[38;5;241m.\u001B[39mnn\u001B[38;5;241m.\u001B[39mModule:\n",
+      "File \u001B[0;32m~/PycharmProjects/combo-lightning/combo/config/from_parameters.py:136\u001B[0m, in \u001B[0;36mFromParameters.from_parameters\u001B[0;34m(cls, parameters, pass_to_subclasses, default_name)\u001B[0m\n\u001B[1;32m    133\u001B[0m     \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m    134\u001B[0m         parameters_to_pass[k] \u001B[38;5;241m=\u001B[39m v\n\u001B[0;32m--> 136\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mcls\u001B[39;49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mparameters_to_pass\u001B[49m\u001B[43m)\u001B[49m\n",
+      "File \u001B[0;32m~/PycharmProjects/combo-lightning/combo/modules/parser.py:26\u001B[0m, in \u001B[0;36mHeadPredictionModel.__init__\u001B[0;34m(self, head_projection_layer, dependency_projection_layer, cycle_loss_n)\u001B[0m\n\u001B[1;32m     22\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21m__init__\u001B[39m(\u001B[38;5;28mself\u001B[39m,\n\u001B[1;32m     23\u001B[0m              head_projection_layer: base\u001B[38;5;241m.\u001B[39mLinear,\n\u001B[1;32m     24\u001B[0m              dependency_projection_layer: base\u001B[38;5;241m.\u001B[39mLinear,\n\u001B[1;32m     25\u001B[0m              cycle_loss_n: \u001B[38;5;28mint\u001B[39m \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m0\u001B[39m):\n\u001B[0;32m---> 26\u001B[0m     \u001B[38;5;28;43msuper\u001B[39;49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[38;5;21;43m__init__\u001B[39;49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m     27\u001B[0m     \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mhead_projection_layer \u001B[38;5;241m=\u001B[39m head_projection_layer\n\u001B[1;32m     28\u001B[0m     \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mdependency_projection_layer \u001B[38;5;241m=\u001B[39m dependency_projection_layer\n",
+      "\u001B[0;31mTypeError\u001B[0m: __init__() missing 2 required positional arguments: 'model' and 'dataset_reader'"
+     ]
+    }
+   ],
+   "source": [
+    "nlp = COMBO.from_pretrained(\"polish-pdb-ud29\")"
+   ],
+   "metadata": {
+    "collapsed": false,
+    "ExecuteTime": {
+     "end_time": "2023-09-22T07:28:45.877954Z",
+     "start_time": "2023-09-22T07:28:27.167654Z"
+    }
+   },
+   "id": "f1d81f8aba12630b"
+  },
   {
    "cell_type": "code",
    "execution_count": null,
    "outputs": [],
+   "source": [
+    "nlp.predict('Cześć')"
+   ],
+   "metadata": {
+    "collapsed": false,
+    "ExecuteTime": {
+     "start_time": "2023-09-22T07:26:16.004782Z"
+    }
+   },
+   "id": "7302b7d49ac2fc38"
+  },
+  {
+   "cell_type": "markdown",
    "source": [],
    "metadata": {
     "collapsed": false
    },
-   "id": "f1d81f8aba12630b"
+   "id": "fa5abf02a486db75"
   }
  ],
  "metadata": {
diff --git a/combo/main.py b/combo/main.py
index 10f2648..207b9fd 100644
--- a/combo/main.py
+++ b/combo/main.py
@@ -1,15 +1,11 @@
 import logging
-import os
 import pathlib
-import tempfile
 from typing import Dict
 
-import torch
 from absl import app
 from absl import flags
 
-from combo import models
-from combo.models.base import Predictor
+from combo.nn.base import Predictor
 from combo.utils import checks
 
 logger = logging.getLogger(__name__)
diff --git a/combo/models/__init__.py b/combo/models/__init__.py
index 8cc3df5..83092c4 100644
--- a/combo/models/__init__.py
+++ b/combo/models/__init__.py
@@ -1,11 +1,3 @@
-from .base import FeedForwardPredictor
-from .graph_parser import GraphDependencyRelationModel
-from .parser import DependencyRelationModel
-from .embeddings import (GloVe6BEmbedder, GloVe840BEmbedder, GloVeTwitter27BEmbedder,
-                         GloVe42BEmbedder, FastTextEmbedder, CharNGramEmbedder)
-from .encoder import ComboEncoder
-from .lemma import LemmatizerModel
+from .encoder import ComboStackedBidirectionalLSTM, ComboEncoder
 from .combo_model import ComboModel
-from .morpho import MorphologicalFeatures
-from .model import Model
-from .archival import *
\ No newline at end of file
+from .archival import *
diff --git a/combo/models/archival.py b/combo/models/archival.py
index 11e7314..ee43a3a 100644
--- a/combo/models/archival.py
+++ b/combo/models/archival.py
@@ -15,8 +15,9 @@ import glob
 from torch.nn import Module
 
 from combo.common.params import Params
+from combo.config import Registry
 from combo.data.dataset_readers import DatasetReader
-from combo.models.model import Model
+from combo.modules.model import Model
 from combo.utils import ConfigurationError
 from combo.utils.file_utils import cached_path
 
@@ -219,11 +220,11 @@ def _load_dataset_readers(config, serialization_dir):
         "validation_dataset_reader", dataset_reader_params.duplicate()
     )
 
-    dataset_reader = DatasetReader.from_params(
-        dataset_reader_params, serialization_dir=serialization_dir
+    dataset_reader = Registry.resolve(DatasetReader, dataset_reader_params.get('type', 'base')).from_parameters(
+        dict(**dataset_reader_params, serialization_dir=serialization_dir)
     )
-    validation_dataset_reader = DatasetReader.from_params(
-        validation_dataset_reader_params, serialization_dir=serialization_dir
+    validation_dataset_reader = Registry.resolve(DatasetReader, validation_dataset_reader_params.get('type', 'base')).from_parameters(
+        dict(**validation_dataset_reader_params, serialization_dir=serialization_dir)
     )
 
     return dataset_reader, validation_dataset_reader
diff --git a/combo/models/combo_model.py b/combo/models/combo_model.py
index dc94077..cdda10c 100644
--- a/combo/models/combo_model.py
+++ b/combo/models/combo_model.py
@@ -5,33 +5,44 @@ import torch
 from overrides import overrides
 
 from combo import data
-from combo.models import base
-from combo.models.embeddings import TokenEmbedder
-from combo.models.model import Model
-from combo.modules.seq2seq_encoder import Seq2SeqEncoder
-from combo.nn import RegularizerApplicator
+from combo.config import FromParameters, Registry
+from combo.modules.morpho import MorphologicalFeatures
+from combo.modules.lemma import LemmatizerModel
+from combo.modules.parser import DependencyRelationModel
+from combo.modules import TextFieldEmbedder
+from combo.modules.model import Model
+from combo.modules.seq2seq_encoders.seq2seq_encoder import Seq2SeqEncoder
+from combo.nn import RegularizerApplicator, base
 from combo.nn.util import get_text_field_mask
 from combo.utils import metrics
-
-
-class ComboModel(Model):
+#
+# lemmatizer: Optional[LemmatizerModel] = None,
+# upos_tagger: Optional[MorphologicalFeatures] = None,
+# xpos_tagger: Optional[MorphologicalFeatures] = None,
+# semantic_relation: Optional[base.Predictor] = None,
+# morphological_feat: Optional[MorphologicalFeatures] = None,
+# dependency_relation: Optional[DependencyRelationModel] = None,
+# enhanced_dependency_relation: Optional[DependencyRelationModel] = None,
+
+@Registry.register(Model, "semantic_multitask")
+class ComboModel(Model, FromParameters):
     """Main COMBO model."""
 
     def __init__(self,
-                 vocab: data.Vocabulary,
+                 vocabulary: data.Vocabulary,
                  loss_weights: Dict[str, float],
-                 text_field_embedder: TokenEmbedder,
+                 text_field_embedder: TextFieldEmbedder,
                  seq_encoder: Seq2SeqEncoder,
                  use_sample_weight: bool = True,
-                 lemmatizer: Optional[base.Predictor] = None,
-                 upos_tagger: Optional[base.Predictor] = None,
-                 xpos_tagger: Optional[base.Predictor] = None,
-                 semantic_relation: Optional[base.Predictor] = None,
-                 morphological_feat: Optional[base.Predictor] = None,
-                 dependency_relation: Optional[base.Predictor] = None,
-                 enhanced_dependency_relation: Optional[base.Predictor] = None,
+                 lemmatizer: LemmatizerModel = None,
+                 upos_tagger: MorphologicalFeatures = None,
+                 xpos_tagger: MorphologicalFeatures = None,
+                 semantic_relation: base.Predictor = None,
+                 morphological_feat: MorphologicalFeatures = None,
+                 dependency_relation: DependencyRelationModel = None,
+                 enhanced_dependency_relation: DependencyRelationModel = None,
                  regularizer: RegularizerApplicator = None) -> None:
-        super().__init__(vocab, regularizer)
+        super().__init__(vocabulary, regularizer)
         self.text_field_embedder = text_field_embedder
         self.loss_weights = loss_weights
         self.use_sample_weight = use_sample_weight
diff --git a/combo/models/combo_nn.py b/combo/models/combo_nn.py
deleted file mode 100644
index 822c1cd..0000000
--- a/combo/models/combo_nn.py
+++ /dev/null
@@ -1,14 +0,0 @@
-import torch
-import torch.nn as nn
-from overrides import overrides
-
-
-class Activation(nn.Module):
-    def forward(self, x: torch.Tensor) -> torch.Tensor:
-        raise NotImplementedError
-
-
-class LinearActivation(Activation):
-    @overrides
-    def forward(self, x: torch.Tensor) -> torch.Tensor:
-        return x
diff --git a/combo/models/encoder.py b/combo/models/encoder.py
index 1904cb8..12932ad 100644
--- a/combo/models/encoder.py
+++ b/combo/models/encoder.py
@@ -9,15 +9,19 @@ import torch.nn.utils.rnn as rnn
 from overrides import overrides
 from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence
 
+from combo.config import FromParameters, Registry
 from combo.modules import input_variational_dropout
 from combo.modules.augmented_lstm import AugmentedLstm
 from combo.modules.input_variational_dropout import InputVariationalDropout
+from combo.modules.module import Module
+from combo.modules.seq2seq_encoders.seq2seq_encoder import Seq2SeqEncoder
 from combo.utils import ConfigurationError
 
 TensorPair = Tuple[torch.Tensor, torch.Tensor]
 
 
-class StackedBidirectionalLstm(torch.nn.Module):
+@Registry.register_base_class('stacked_bilstm', default=True)
+class StackedBidirectionalLstm(torch.nn.Module, FromParameters):
     """
     A standard stacked Bidirectional LSTM where the LSTM layers
     are concatenated between each layer. The only difference between
@@ -53,26 +57,25 @@ class StackedBidirectionalLstm(torch.nn.Module):
     """
 
     def __init__(
-        self,
-        input_size: int,
-        hidden_size: int,
-        num_layers: int,
-        recurrent_dropout_probability: float = 0.0,
-        layer_dropout_probability: float = 0.0,
-        use_highway: bool = True,
+            self,
+            input_size: int,
+            hidden_size: int,
+            num_layers: int,
+            recurrent_dropout_probability: float = 0.0,
+            layer_dropout_probability: float = 0.0,
+            use_highway: bool = True,
     ) -> None:
         super().__init__()
 
         # Required to be wrapped with a `PytorchSeq2SeqWrapper`.
-        self.input_size = input_size
-        self.hidden_size = hidden_size
-        self.num_layers = num_layers
-        self.bidirectional = True
+        self.__input_size = input_size
+        self.__hidden_size = hidden_size
+        self.__num_layers = num_layers
+        self.__bidirectional = True
 
         layers = []
         lstm_input_size = input_size
         for layer_index in range(num_layers):
-
             forward_layer = AugmentedLstm(
                 lstm_input_size,
                 hidden_size,
@@ -97,8 +100,24 @@ class StackedBidirectionalLstm(torch.nn.Module):
         self.lstm_layers = layers
         self.layer_dropout = InputVariationalDropout(layer_dropout_probability)
 
+    @property
+    def input_size(self) -> int:
+        return self.__input_size
+
+    @property
+    def hidden_size(self) -> int:
+        return self.__hidden_size
+
+    @property
+    def num_layers(self) -> int:
+        return self.__num_layers
+
+    @property
+    def bidirectional(self) -> int:
+        return self.__bidirectional
+
     def forward(
-        self, inputs: PackedSequence, initial_state: Optional[TensorPair] = None
+            self, inputs: PackedSequence, initial_state: Optional[TensorPair] = None
     ) -> Tuple[PackedSequence, TensorPair]:
         """
         # Parameters
@@ -156,8 +175,11 @@ class StackedBidirectionalLstm(torch.nn.Module):
         return output_sequence, final_state_tuple
 
 
+
+
 # TODO: merge into one
-class ComboStackedBidirectionalLSTM(StackedBidirectionalLstm):
+@Registry.register_base_class('stacked_bilstm', default=True)
+class ComboStackedBidirectionalLSTM(StackedBidirectionalLstm, FromParameters):
 
     def __init__(self, input_size: int, hidden_size: int, num_layers: int, recurrent_dropout_probability: float,
                  layer_dropout_probability: float, use_highway: bool = False):
@@ -168,7 +190,7 @@ class ComboStackedBidirectionalLSTM(StackedBidirectionalLstm):
                          layer_dropout_probability=layer_dropout_probability,
                          use_highway=use_highway)
 
-    @overrides
+    # @overrides
     def forward(self,
                 inputs: rnn.PackedSequence,
                 initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
@@ -203,7 +225,8 @@ class ComboStackedBidirectionalLSTM(StackedBidirectionalLstm):
         return output_sequence, (state_fwd, state_bwd)
 
 
-class ComboEncoder:
+@Registry.register(Seq2SeqEncoder, 'combo_encoder', default=True)
+class ComboEncoder(Seq2SeqEncoder, FromParameters):
     """COMBO encoder (https://www.aclweb.org/anthology/K18-2004.pdf).
 
     This implementation uses Variational Dropout on the input and then outputs of each BiLSTM layer
diff --git a/combo/modules/__init__.py b/combo/modules/__init__.py
index e69de29..77b73d5 100644
--- a/combo/modules/__init__.py
+++ b/combo/modules/__init__.py
@@ -0,0 +1,2 @@
+from .text_field_embedders import TextFieldEmbedder
+from .token_embedders import *
diff --git a/combo/modules/augmented_lstm.py b/combo/modules/augmented_lstm.py
index 9e33d86..2131c95 100644
--- a/combo/modules/augmented_lstm.py
+++ b/combo/modules/augmented_lstm.py
@@ -7,13 +7,16 @@ from typing import Optional, Tuple
 import torch
 from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence
 
+from combo.config import FromParameters, Registry
+from combo.modules.module import Module
 from combo.nn.util import get_dropout_mask
 
 from combo.nn.initializers import block_orthogonal
 from combo.utils import ConfigurationError
 
 
-class AugmentedLSTMCell(torch.nn.Module):
+@Registry.register(torch.nn.Module, 'augumented_lstm_cell')
+class AugmentedLSTMCell(Module, FromParameters):
     """
     `AugmentedLSTMCell` implements a AugmentedLSTM cell.
 
@@ -144,7 +147,8 @@ class AugmentedLSTMCell(torch.nn.Module):
         return timestep_output, memory
 
 
-class AugmentedLstm(torch.nn.Module):
+@Registry.register(torch.nn.Module, 'augumented_lstm')
+class AugmentedLstm(Module, FromParameters):
     """
     `AugmentedLstm` implements a one-layer single directional
     AugmentedLSTM layer. AugmentedLSTM is an LSTM which optionally
diff --git a/combo/models/dilated_cnn.py b/combo/modules/dilated_cnn.py
similarity index 77%
rename from combo/models/dilated_cnn.py
rename to combo/modules/dilated_cnn.py
index 79ca6d9..2e89942 100644
--- a/combo/models/dilated_cnn.py
+++ b/combo/modules/dilated_cnn.py
@@ -6,12 +6,13 @@ Author: Mateusz Klimaszewski
 from typing import List
 
 import torch
-import torch.nn as nn
 
-from combo.models.combo_nn import Activation
+from combo.config import FromParameters, Registry
+from combo.nn.activations import Activation
 
 
-class DilatedCnnEncoder(nn.Module):
+@Registry.register_base_class('dilated_cnn', default=True)
+class DilatedCnnEncoder(torch.nn.Module, FromParameters):
 
     def __init__(self,
                  input_dim: int,
@@ -26,14 +27,14 @@ class DilatedCnnEncoder(nn.Module):
         input_dims = [input_dim] + filters[:-1]
         output_dims = filters
         for idx in range(len(activations)):
-            conv1d_layers.append(nn.Conv1d(
+            conv1d_layers.append(torch.nn.Conv1d(
                 in_channels=input_dims[idx],
                 out_channels=output_dims[idx],
                 kernel_size=(kernel_size[idx],),
                 stride=(stride[idx],),
                 padding=padding[idx],
                 dilation=(dilation[idx],)))
-        self.conv1d_layers = nn.ModuleList(conv1d_layers)
+        self.conv1d_layers = torch.nn.ModuleList(conv1d_layers)
         self.activations = activations
         assert len(self.activations) == len(self.conv1d_layers)
 
diff --git a/combo/modules/encoder.py b/combo/modules/encoder.py
index bf25856..4d053a8 100644
--- a/combo/modules/encoder.py
+++ b/combo/modules/encoder.py
@@ -7,6 +7,7 @@ from typing import Tuple, Union, Optional, Callable, Any
 import torch
 from torch.nn.utils.rnn import pack_padded_sequence, PackedSequence
 
+from combo.modules.module import Module
 from combo.nn.util import get_lengths_from_binary_sequence_mask, sort_batch_by_length
 
 # We have two types here for the state, because storing the state in something
diff --git a/combo/models/graph_parser.py b/combo/modules/graph_parser.py
similarity index 89%
rename from combo/models/graph_parser.py
rename to combo/modules/graph_parser.py
index 6dffef5..49aefee 100644
--- a/combo/models/graph_parser.py
+++ b/combo/modules/graph_parser.py
@@ -6,8 +6,9 @@ Author: Mateusz Klimaszewski
 from typing import List, Optional, Union, Tuple, Dict
 
 from combo import data
-from combo.models import base
-from combo.models.base import Predictor
+from combo.config import Registry
+from combo.nn import base
+from combo.nn.base import Predictor
 
 import torch
 import torch.nn.functional as F
@@ -102,15 +103,24 @@ class GraphHeadPredictionModel(Predictor):
         return loss.sum() / valid_positions + cycle_loss.mean(), cycle_loss.mean()
 
 
+@Registry.register(Predictor, 'combo_graph_dependency_parsing_from_vocab')
 class GraphDependencyRelationModel(Predictor):
     """Dependency relation parsing model."""
 
     def __init__(self,
+                 vocab: data.Vocabulary,
+                 vocab_namespace: str,
                  head_predictor: GraphHeadPredictionModel,
                  head_projection_layer: base.Linear,
-                 dependency_projection_layer: base.Linear,
-                 relation_prediction_layer: base.Linear):
+                 dependency_projection_layer: base.Linear):
+        """Creates parser combining model configuration and vocabulary data."""
         super().__init__()
+        assert vocab_namespace in vocab.get_namespaces()
+        relation_prediction_layer = base.Linear(
+            in_features=head_projection_layer.get_output_dim() + dependency_projection_layer.get_output_dim(),
+            out_features=vocab.get_vocab_size(vocab_namespace)
+        )
+
         self.head_predictor = head_predictor
         self.head_projection_layer = head_projection_layer
         self.dependency_projection_layer = dependency_projection_layer
@@ -167,24 +177,3 @@ class GraphDependencyRelationModel(Predictor):
         pred = pred[correct_heads_mask]
         loss = F.cross_entropy(pred, true.long())
         return loss.sum() / pred.size(0)
-
-    @classmethod
-    def from_vocab(cls,
-                   vocab: data.Vocabulary,
-                   vocab_namespace: str,
-                   head_predictor: GraphHeadPredictionModel,
-                   head_projection_layer: base.Linear,
-                   dependency_projection_layer: base.Linear
-                   ):
-        """Creates parser combining model configuration and vocabulary data."""
-        assert vocab_namespace in vocab.get_namespaces()
-        relation_prediction_layer = base.Linear(
-            in_features=head_projection_layer.get_output_dim() + dependency_projection_layer.get_output_dim(),
-            out_features=vocab.get_vocab_size(vocab_namespace)
-        )
-        return cls(
-            head_predictor=head_predictor,
-            head_projection_layer=head_projection_layer,
-            dependency_projection_layer=dependency_projection_layer,
-            relation_prediction_layer=relation_prediction_layer
-        )
diff --git a/combo/modules/input_variational_dropout.py b/combo/modules/input_variational_dropout.py
index 5744642..4433c1e 100644
--- a/combo/modules/input_variational_dropout.py
+++ b/combo/modules/input_variational_dropout.py
@@ -5,8 +5,11 @@ https://github.com/allenai/allennlp/blob/main/allennlp/modules/input_variational
 
 import torch
 
+from combo.config import FromParameters, Registry
 
-class InputVariationalDropout(torch.nn.Dropout):
+
+@Registry.register(torch.nn.Module, 'input_variational_dropout')
+class InputVariationalDropout(torch.nn.Dropout, FromParameters):
     """
     Apply the dropout technique in Gal and Ghahramani, [Dropout as a Bayesian Approximation:
     Representing Model Uncertainty in Deep Learning](https://arxiv.org/abs/1506.02142) to a
diff --git a/combo/models/lemma.py b/combo/modules/lemma.py
similarity index 72%
rename from combo/models/lemma.py
rename to combo/modules/lemma.py
index d724a1e..05e0cbf 100644
--- a/combo/models/lemma.py
+++ b/combo/modules/lemma.py
@@ -4,23 +4,54 @@ import torch
 import torch.nn as nn
 
 from combo import data
-from combo.models import dilated_cnn, base, utils
-from combo.models.base import Predictor, TimeDistributed
-from combo.models.combo_nn import Activation
+from combo.config import Registry
+from combo.models import utils
+from combo.modules import dilated_cnn
+from combo.nn import base
+from combo.nn.base import Predictor
+from combo.modules.time_distributed import TimeDistributed
+from combo.nn.activations import Activation
 from combo.utils import ConfigurationError
 
 
+@Registry.register(Predictor, 'combo_lemma_predictor_from_vocab')
 class LemmatizerModel(Predictor):
     """Lemmatizer model."""
 
     def __init__(self,
-                 num_embeddings: int,
+                 vocab: data.Vocabulary,
+                 char_vocab_namespace: str,
+                 lemma_vocab_namespace: str,
                  embedding_dim: int,
-                 dilated_cnn_encoder: dilated_cnn.DilatedCnnEncoder,
-                 input_projection_layer: base.Linear):
+                 input_projection_layer: base.Linear,
+                 filters: List[int],
+                 kernel_size: List[int],
+                 stride: List[int],
+                 padding: List[int],
+                 dilation: List[int],
+                 activations: List[Activation],
+                 ):
+        assert char_vocab_namespace in vocab.get_namespaces()
+        assert lemma_vocab_namespace in vocab.get_namespaces()
         super().__init__()
+
+        if len(filters) + 1 != len(kernel_size):
+            raise ConfigurationError(
+                f"len(filters) ({len(filters):d}) + 1 != kernel_size ({len(kernel_size):d})"
+            )
+        filters = filters + [vocab.get_vocab_size(lemma_vocab_namespace)]
+
+        dilated_cnn_encoder = dilated_cnn.DilatedCnnEncoder(
+            input_dim=embedding_dim + input_projection_layer.get_output_dim(),
+            filters=filters,
+            kernel_size=kernel_size,
+            stride=stride,
+            padding=padding,
+            dilation=dilation,
+            activations=activations,
+        )
         self.char_embed = nn.Embedding(
-            num_embeddings=num_embeddings,
+            num_embeddings=vocab.get_vocab_size(char_vocab_namespace),
             embedding_dim=embedding_dim,
         )
         self.dilated_cnn_encoder = TimeDistributed(dilated_cnn_encoder)
@@ -68,40 +99,3 @@ class LemmatizerModel(Predictor):
         loss = loss.reshape(BATCH_SIZE, -1) * sample_weights.unsqueeze(-1)
         valid_positions = mask.sum()
         return loss.sum() / valid_positions
-
-    @classmethod
-    def from_vocab(cls,
-                   vocab: data.Vocabulary,
-                   char_vocab_namespace: str,
-                   lemma_vocab_namespace: str,
-                   embedding_dim: int,
-                   input_projection_layer: base.Linear,
-                   filters: List[int],
-                   kernel_size: List[int],
-                   stride: List[int],
-                   padding: List[int],
-                   dilation: List[int],
-                   activations: List[Activation],
-                   ):
-        assert char_vocab_namespace in vocab.get_namespaces()
-        assert lemma_vocab_namespace in vocab.get_namespaces()
-
-        if len(filters) + 1 != len(kernel_size):
-            raise ConfigurationError(
-                f"len(filters) ({len(filters):d}) + 1 != kernel_size ({len(kernel_size):d})"
-            )
-        filters = filters + [vocab.get_vocab_size(lemma_vocab_namespace)]
-
-        dilated_cnn_encoder = dilated_cnn.DilatedCnnEncoder(
-            input_dim=embedding_dim + input_projection_layer.get_output_dim(),
-            filters=filters,
-            kernel_size=kernel_size,
-            stride=stride,
-            padding=padding,
-            dilation=dilation,
-            activations=activations,
-        )
-        return cls(num_embeddings=vocab.get_vocab_size(char_vocab_namespace),
-                   embedding_dim=embedding_dim,
-                   dilated_cnn_encoder=dilated_cnn_encoder,
-                   input_projection_layer=input_projection_layer)
diff --git a/combo/models/model.py b/combo/modules/model.py
similarity index 93%
rename from combo/models/model.py
rename to combo/modules/model.py
index 2dbc00a..1bdf52c 100644
--- a/combo/models/model.py
+++ b/combo/modules/model.py
@@ -13,8 +13,10 @@ import numpy
 import torch
 
 from combo.common.params import remove_keys_from_params, Params
+from combo.config import FromParameters, Registry
 from combo.data import Vocabulary, Instance
 from combo.data.batch import Batch
+from combo.modules.module import Module
 from combo.nn import util, RegularizerApplicator
 from combo.utils import ConfigurationError
 
@@ -25,7 +27,7 @@ logger = logging.getLogger(__name__)
 _DEFAULT_WEIGHTS = "best.th"
 
 
-class Model(torch.nn.Module):
+class Model(Module, FromParameters):
     """
     This abstract class represents a model to be trained. Rather than relying completely
     on the Pytorch Module, we modify the output spec of `forward` to be a dictionary.
@@ -53,7 +55,7 @@ class Model(torch.nn.Module):
 
     # Parameters
 
-    vocab: `Vocabulary`
+    vocabulary: `Vocabulary`
         There are two typical use-cases for the `Vocabulary` in a `Model`: getting vocabulary sizes
         when constructing embedding matrices or output classifiers (as the vocabulary holds the
         number of classes in your output, also), and translating model output into human-readable
@@ -73,12 +75,12 @@ class Model(torch.nn.Module):
 
     def __init__(
         self,
-        vocab: Vocabulary,
+        vocabulary: Vocabulary,
         regularizer: RegularizerApplicator = None,
         serialization_dir: Optional[str] = None,
     ) -> None:
-        super().__init__()
-        self.vocab = vocab
+        super(Model, self).__init__()
+        self.vocab = vocabulary
         self._regularizer = regularizer
         self.serialization_dir = serialization_dir
 
@@ -286,8 +288,8 @@ class Model(torch.nn.Module):
         vocab_dir = os.path.join(serialization_dir, "vocabulary")
         # If the config specifies a vocabulary subclass, we need to use it.
         vocab_params = config.get("vocabulary", Params({}))
-        vocab_choice = vocab_params.pop_choice("type", Vocabulary.list_available(), True)
-        vocab_class, _ = Vocabulary.resolve_class_name(vocab_choice)
+        vocab_choice = vocab_params.pop_choice("type", list(Registry.classes()[Vocabulary].keys()), True)
+        vocab_class = Registry.resolve(Vocabulary, vocab_choice)
         vocab = vocab_class.from_files(
             vocab_dir, vocab_params.get("padding_token"), vocab_params.get("oov_token")
         )
@@ -299,8 +301,9 @@ class Model(torch.nn.Module):
         # stored in our model. We don't need any pretrained weight file or initializers anymore,
         # and we don't want the code to look for it, so we remove it from the parameters here.
         remove_keys_from_params(model_params)
-        model = Model.from_params(
-            vocab=vocab, params=model_params, serialization_dir=serialization_dir
+        model_type = Registry.resolve(Model, model_params.get('type', 'semantic_multitask'))
+        model = model_type.from_parameters(
+            dict(**dict(model_params), vocabulary=vocab, serialization_dir=serialization_dir), {'vocabulary': vocab}
         )
 
         # Force model to cpu or gpu, as appropriate, to make sure that the embeddings are
@@ -310,12 +313,12 @@ class Model(torch.nn.Module):
         else:
             model.cpu()
 
-        # If vocab+embedding extension was done, the model initialized from from_params
+        # If vocabulary+embedding extension was done, the model initialized from from_params
         # and one defined by state dict in weights_file might not have same embedding shapes.
-        # Eg. when model embedder module was transferred along with vocab extension, the
+        # Eg. when model embedder module was transferred along with vocabulary extension, the
         # initialized embedding weight shape would be smaller than one in the state_dict.
         # So calling model embedding extension is required before load_state_dict.
-        # If vocab and model embeddings are in sync, following would be just a no-op.
+        # If vocabulary and model embeddings are in sync, following would be just a no-op.
         model.extend_embedder_vocab()
 
         # Load state dict. We pass `strict=False` so PyTorch doesn't raise a RuntimeError
@@ -342,12 +345,12 @@ class Model(torch.nn.Module):
 
         filter_out_authorized_missing_keys(model)
 
-        if unexpected_keys or missing_keys:
-            raise RuntimeError(
-                f"Error loading state dict for {model.__class__.__name__}\n\t"
-                f"Missing keys: {missing_keys}\n\t"
-                f"Unexpected keys: {unexpected_keys}"
-            )
+        # if unexpected_keys or missing_keys:
+        #     raise RuntimeError(
+        #         f"Error loading state dict for {model.__class__.__name__}\n\t"
+        #         f"Missing keys: {missing_keys}\n\t"
+        #         f"Unexpected keys: {unexpected_keys}"
+        #     )
 
         return model
 
@@ -394,7 +397,7 @@ class Model(torch.nn.Module):
         # Load using an overridable _load method.
         # This allows subclasses of Model to override _load.
 
-        model_class: Type[Model] = cls.by_name(model_type)  # type: ignore
+        model_class: Type[Model] = Registry.resolve(Model, model_type)  # type: ignore
         if not isinstance(model_class, type):
             # If you're using from_archive to specify your model (e.g., for fine tuning), then you
             # can't currently override the behavior of _load; we just use the default Model._load.
@@ -406,7 +409,7 @@ class Model(torch.nn.Module):
     def extend_embedder_vocab(self, embedding_sources_mapping: Dict[str, str] = None) -> None:
         """
         Iterates through all embedding modules in the model and assures it can embed
-        with the extended vocab. This is required in fine-tuning or transfer learning
+        with the extended vocabulary. This is required in fine-tuning or transfer learning
         scenarios where model was trained with original vocabulary but during
         fine-tuning/transfer-learning, it will have it work with extended vocabulary
         (original + new-data vocabulary).
@@ -440,7 +443,7 @@ class Model(torch.nn.Module):
         convenience, and so that we can register it for easy use for fine tuning an existing model
         from a config file.
 
-        If `vocab` is given, we will extend the loaded model's vocabulary using the passed vocab
+        If `vocabulary` is given, we will extend the loaded model's vocabulary using the passed vocabulary
         object (including calling `extend_embedder_vocab`, which extends embedding layers).
         """
         from combo.models.archival import load_archive # here to avoid circular imports
diff --git a/combo/modules/module.py b/combo/modules/module.py
new file mode 100644
index 0000000..6681257
--- /dev/null
+++ b/combo/modules/module.py
@@ -0,0 +1,49 @@
+"""
+Adapted from AllenNLP
+https://github.com/allenai/allennlp/blob/main/allennlp/nn/module.py#L14
+"""
+
+from typing import List, Optional, Tuple
+
+
+import torch
+
+from combo.nn.util import (
+    _check_incompatible_keys,
+    _IncompatibleKeys,
+    StateDictType,
+)
+
+
+class Module(torch.nn.Module):
+    """
+    This is just `torch.nn.Module` with some extra functionality.
+    """
+
+    def _post_load_state_dict(
+        self, missing_keys: List[str], unexpected_keys: List[str]
+    ) -> Tuple[List[str], List[str]]:
+        """
+        Subclasses can override this and potentially modify `missing_keys` or `unexpected_keys`.
+        """
+        return missing_keys, unexpected_keys
+
+    def load_state_dict(self, state_dict: StateDictType, strict: bool = True) -> _IncompatibleKeys:
+        """
+        Same as [`torch.nn.Module.load_state_dict()`]
+        (https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.load_state_dict)
+        except we also run the [`_post_load_state_dict`](#_post_load_state_dict) method before returning,
+        which can be implemented by subclasses to customize the behavior.
+        """
+        missing_keys, unexpected_keys = super().load_state_dict(state_dict, strict=False)  # type: ignore[arg-type]
+        missing_keys, unexpected_keys = self._post_load_state_dict(missing_keys, unexpected_keys)
+        _check_incompatible_keys(self, missing_keys, unexpected_keys, strict)
+        return _IncompatibleKeys(missing_keys, unexpected_keys)
+
+    # def load_state_dict_distributed(
+    #     self, state_dict: Optional[StateDictType], strict: bool = True
+    # ) -> _IncompatibleKeys:
+    #     missing_keys, unexpected_keys = load_state_dict_distributed(self, state_dict, strict=strict)
+    #     missing_keys, unexpected_keys = self._post_load_state_dict(missing_keys, unexpected_keys)
+    #     _check_incompatible_keys(self, missing_keys, unexpected_keys, strict)
+    #     return _IncompatibleKeys(missing_keys, unexpected_keys)
\ No newline at end of file
diff --git a/combo/models/morpho.py b/combo/modules/morpho.py
similarity index 74%
rename from combo/models/morpho.py
rename to combo/modules/morpho.py
index 5fb9545..0ab7897 100644
--- a/combo/models/morpho.py
+++ b/combo/modules/morpho.py
@@ -6,18 +6,44 @@ from typing import Dict, List, Optional, Union
 import torch
 
 from combo import data
+from combo.config import Registry
 from combo.data import dataset
-from combo.models import base, utils
-from combo.models.combo_nn import Activation
+from combo.models import utils
+from combo.nn import base
+from combo.nn.activations import Activation
+from combo.predictors.predictor import Predictor
 from combo.utils import ConfigurationError
 
 
-class MorphologicalFeatures(base.Predictor):
+@Registry.register(Predictor, 'combo_morpho_from_vocab')
+class MorphologicalFeatures(Predictor):
     """Morphological features predicting model."""
 
-    def __init__(self, feedforward_network: base.FeedForward, slices: Dict[str, List[int]]):
+    def __init__(self,
+                 vocab: data.Vocabulary,
+                 vocab_namespace: str,
+                 input_dim: int,
+                 num_layers: int,
+                 hidden_dims: List[int],
+                 activations: List[Activation], # change to Union[Activation, List[Activation]]
+                 dropout: Union[float, List[float]] = 0.0,
+                 ):
         super().__init__()
-        self.feedforward_network = feedforward_network
+        if len(hidden_dims) + 1 != num_layers:
+            raise ConfigurationError(
+                f"len(hidden_dims) ({len(hidden_dims):d}) + 1 != num_layers ({num_layers:d})"
+            )
+
+        assert vocab_namespace in vocab.get_namespaces()
+        hidden_dims = hidden_dims + [vocab.get_vocab_size(vocab_namespace)]
+
+        slices = dataset.get_slices_if_not_provided(vocab)
+        self.feedforward_network = base.FeedForward(
+            input_dim=input_dim,
+            num_layers=num_layers,
+            hidden_dims=hidden_dims,
+            activations=activations,
+            dropout=dropout)
         self.slices = slices
 
     def forward(self,
@@ -71,33 +97,3 @@ class MorphologicalFeatures(base.Predictor):
                                       mask)
         loss = loss.reshape(BATCH_SIZE, -1) * sample_weights.unsqueeze(-1)
         return loss.sum() / valid_positions
-
-    @classmethod
-    def from_vocab(cls,
-                   vocab: data.Vocabulary,
-                   vocab_namespace: str,
-                   input_dim: int,
-                   num_layers: int,
-                   hidden_dims: List[int],
-                   activations: Union[Activation, List[Activation]],
-                   dropout: Union[float, List[float]] = 0.0,
-                   ):
-        if len(hidden_dims) + 1 != num_layers:
-            raise ConfigurationError(
-                f"len(hidden_dims) ({len(hidden_dims):d}) + 1 != num_layers ({num_layers:d})"
-            )
-
-        assert vocab_namespace in vocab.get_namespaces()
-        hidden_dims = hidden_dims + [vocab.get_vocab_size(vocab_namespace)]
-
-        slices = dataset.get_slices_if_not_provided(vocab)
-
-        return cls(
-            feedforward_network=base.FeedForward(
-                input_dim=input_dim,
-                num_layers=num_layers,
-                hidden_dims=hidden_dims,
-                activations=activations,
-                dropout=dropout),
-            slices=slices
-        )
diff --git a/combo/models/parser.py b/combo/modules/parser.py
similarity index 89%
rename from combo/models/parser.py
rename to combo/modules/parser.py
index 42d2efb..2d8b470 100644
--- a/combo/models/parser.py
+++ b/combo/modules/parser.py
@@ -9,11 +9,14 @@ import torch
 import torch.nn.functional as F
 
 from combo import data
-from combo.models import base, utils
-from combo.nn import chu_liu_edmonds
+from combo.config import Registry
+from combo.models import utils
+from combo.nn import chu_liu_edmonds, base
+from combo.predictors.predictor import Predictor
 
 
-class HeadPredictionModel(base.Predictor):
+@Registry.register_base_class( 'head_prediction', default=True)
+class HeadPredictionModel(Predictor):
     """Head prediction model."""
 
     def __init__(self,
@@ -113,18 +116,26 @@ class HeadPredictionModel(base.Predictor):
         return loss.sum() / valid_positions + cycle_loss.mean(), cycle_loss.mean()
 
 
-
-class DependencyRelationModel(base.Predictor):
+@Registry.register(Predictor, 'combo_dependency_parsing_from_vocab')
+class DependencyRelationModel(Predictor):
     """Dependency relation parsing model."""
 
     def __init__(self,
-                 root_idx: int,
+                 vocab: data.Vocabulary,
+                 vocab_namespace: str,
                  head_predictor: HeadPredictionModel,
                  head_projection_layer: base.Linear,
-                 dependency_projection_layer: base.Linear,
-                 relation_prediction_layer: base.Linear):
+                 dependency_projection_layer: base.Linear
+                 ):
+        """Creates parser combining model configuration and vocabulary data."""
         super().__init__()
-        self.root_idx = root_idx
+        assert vocab_namespace in vocab.get_namespaces()
+        relation_prediction_layer = base.Linear(
+            in_features=head_projection_layer.get_output_dim() + dependency_projection_layer.get_output_dim(),
+            out_features=vocab.get_vocab_size(vocab_namespace)
+        )
+
+        self.root_idx = vocab.get_token_index("root", vocab_namespace)
         self.head_predictor = head_predictor
         self.head_projection_layer = head_projection_layer
         self.dependency_projection_layer = dependency_projection_layer
@@ -200,24 +211,6 @@ class DependencyRelationModel(base.Predictor):
         loss = loss.reshape(BATCH_SIZE, -1) * sample_weights.unsqueeze(-1)
         return loss.sum() / valid_positions
 
-    @classmethod
-    def from_vocab(cls,
-                   vocab: data.Vocabulary,
-                   vocab_namespace: str,
-                   head_predictor: HeadPredictionModel,
-                   head_projection_layer: base.Linear,
-                   dependency_projection_layer: base.Linear
-                   ):
-        """Creates parser combining model configuration and vocabulary data."""
-        assert vocab_namespace in vocab.get_namespaces()
-        relation_prediction_layer = base.Linear(
-            in_features=head_projection_layer.get_output_dim() + dependency_projection_layer.get_output_dim(),
-            out_features=vocab.get_vocab_size(vocab_namespace)
-        )
-        return cls(
-            head_predictor=head_predictor,
-            head_projection_layer=head_projection_layer,
-            dependency_projection_layer=dependency_projection_layer,
-            relation_prediction_layer=relation_prediction_layer,
-            root_idx=vocab.get_token_index("root", vocab_namespace)
-        )
+
+
+# combo_lemma_predictor_from_vocab
\ No newline at end of file
diff --git a/combo/modules/scalar_mix.py b/combo/modules/scalar_mix.py
new file mode 100644
index 0000000..b4e894e
--- /dev/null
+++ b/combo/modules/scalar_mix.py
@@ -0,0 +1,101 @@
+"""
+Adapted from AllenNLP
+https://github.com/allenai/allennlp/blob/main/allennlp/modules/scalar_mix.py
+"""
+
+from typing import List
+
+import torch
+from torch.nn import ParameterList, Parameter
+
+from combo.nn import util
+from combo.utils import ConfigurationError
+
+
+class ScalarMix(torch.nn.Module):
+    """
+    Computes a parameterised scalar mixture of N tensors, `mixture = gamma * sum(s_k * tensor_k)`
+    where `s = softmax(w)`, with `w` and `gamma` scalar parameters.
+
+    In addition, if `do_layer_norm=True` then apply layer normalization to each tensor
+    before weighting.
+    """
+
+    def __init__(
+        self,
+        mixture_size: int,
+        do_layer_norm: bool = False,
+        initial_scalar_parameters: List[float] = None,
+        trainable: bool = True,
+    ) -> None:
+        super().__init__()
+        self.mixture_size = mixture_size
+        self.do_layer_norm = do_layer_norm
+
+        if initial_scalar_parameters is None:
+            initial_scalar_parameters = [0.0] * mixture_size
+        elif len(initial_scalar_parameters) != mixture_size:
+            raise ConfigurationError(
+                "Length of initial_scalar_parameters {} differs "
+                "from mixture_size {}".format(initial_scalar_parameters, mixture_size)
+            )
+
+        self.scalar_parameters = ParameterList(
+            [
+                Parameter(
+                    torch.FloatTensor([initial_scalar_parameters[i]]), requires_grad=trainable
+                )
+                for i in range(mixture_size)
+            ]
+        )
+        self.gamma = Parameter(torch.FloatTensor([1.0]), requires_grad=trainable)
+
+    def forward(self, tensors: List[torch.Tensor], mask: torch.BoolTensor = None) -> torch.Tensor:
+        """
+        Compute a weighted average of the `tensors`.  The input tensors an be any shape
+        with at least two dimensions, but must all be the same shape.
+
+        When `do_layer_norm=True`, the `mask` is required input.  If the `tensors` are
+        dimensioned  `(dim_0, ..., dim_{n-1}, dim_n)`, then the `mask` is dimensioned
+        `(dim_0, ..., dim_{n-1})`, as in the typical case with `tensors` of shape
+        `(batch_size, timesteps, dim)` and `mask` of shape `(batch_size, timesteps)`.
+
+        When `do_layer_norm=False` the `mask` is ignored.
+        """
+        if len(tensors) != self.mixture_size:
+            raise ConfigurationError(
+                "{} tensors were passed, but the module was initialized to "
+                "mix {} tensors.".format(len(tensors), self.mixture_size)
+            )
+
+        def _do_layer_norm(tensor, broadcast_mask, num_elements_not_masked):
+            tensor_masked = tensor * broadcast_mask
+            mean = torch.sum(tensor_masked) / num_elements_not_masked
+            variance = (
+                torch.sum(((tensor_masked - mean) * broadcast_mask) ** 2) / num_elements_not_masked
+            )
+            return (tensor - mean) / torch.sqrt(variance + util.tiny_value_of_dtype(variance.dtype))
+
+        normed_weights = torch.nn.functional.softmax(
+            torch.cat([parameter for parameter in self.scalar_parameters]), dim=0
+        )
+        normed_weights = torch.split(normed_weights, split_size_or_sections=1)
+
+        if not self.do_layer_norm:
+            pieces = []
+            for weight, tensor in zip(normed_weights, tensors):
+                pieces.append(weight * tensor)
+            return self.gamma * sum(pieces)
+
+        else:
+            assert mask is not None
+            broadcast_mask = mask.unsqueeze(-1)
+            input_dim = tensors[0].size(-1)
+            num_elements_not_masked = torch.sum(mask) * input_dim
+
+            pieces = []
+            for weight, tensor in zip(normed_weights, tensors):
+                pieces.append(
+                    weight * _do_layer_norm(tensor, broadcast_mask, num_elements_not_masked)
+                )
+            return self.gamma * sum(pieces)
diff --git a/combo/modules/seq2seq_encoder.py b/combo/modules/seq2seq_encoder.py
deleted file mode 100644
index 71413f3..0000000
--- a/combo/modules/seq2seq_encoder.py
+++ /dev/null
@@ -1,33 +0,0 @@
-class Seq2SeqEncoder:
-    """
-    A `Seq2SeqEncoder` is a `Module` that takes as input a sequence of vectors and returns a
-    modified sequence of vectors.  Input shape : `(batch_size, sequence_length, input_dim)`; output
-    shape : `(batch_size, sequence_length, output_dim)`.
-
-    We add two methods to the basic `Module` API: `get_input_dim()` and `get_output_dim()`.
-    You might need this if you want to construct a `Linear` layer using the output of this encoder,
-    or to raise sensible errors for mis-matching input dimensions.
-    """
-
-    def get_input_dim(self) -> int:
-        """
-        Returns the dimension of the vector input for each element in the sequence input
-        to a `Seq2SeqEncoder`. This is `not` the shape of the input tensor, but the
-        last element of that shape.
-        """
-        raise NotImplementedError
-
-    def get_output_dim(self) -> int:
-        """
-        Returns the dimension of each vector in the sequence output by this `Seq2SeqEncoder`.
-        This is `not` the shape of the returned tensor, but the last element of that shape.
-        """
-        raise NotImplementedError
-
-    def is_bidirectional(self) -> bool:
-        """
-        Returns `True` if this encoder is bidirectional.  If so, we assume the forward direction
-        of the encoder is the first half of the final dimension, and the backward direction is the
-        second half.
-        """
-        raise NotImplementedError
diff --git a/combo/modules/seq2seq_encoders/__init__.py b/combo/modules/seq2seq_encoders/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/combo/modules/seq2seq_encoders/seq2seq_encoder.py b/combo/modules/seq2seq_encoders/seq2seq_encoder.py
new file mode 100644
index 0000000..78cc648
--- /dev/null
+++ b/combo/modules/seq2seq_encoders/seq2seq_encoder.py
@@ -0,0 +1,103 @@
+import torch
+from torch.nn.utils.rnn import pad_packed_sequence
+
+from combo.modules.encoder import _EncoderBase
+from combo.config.from_parameters import FromParameters
+from combo.utils import ConfigurationError
+
+
+class Seq2SeqEncoder(_EncoderBase, FromParameters):
+    """
+    A `Seq2SeqEncoder` is a `Module` that takes as input a sequence of vectors and returns a
+    modified sequence of vectors.  Input shape : `(batch_size, sequence_length, input_dim)`; output
+    shape : `(batch_size, sequence_length, output_dim)`.
+
+    We add two methods to the basic `Module` API: `get_input_dim()` and `get_output_dim()`.
+    You might need this if you want to construct a `Linear` layer using the output of this encoder,
+    or to raise sensible errors for mis-matching input dimensions.
+    """
+
+    def __init__(self, module: torch.nn.Module, stateful: bool = False) -> None:
+        super().__init__(stateful)
+        self._module = module
+        try:
+            if not self._module.batch_first:
+                raise ConfigurationError("Our encoder semantics assumes batch is always first!")
+        except AttributeError:
+            pass
+
+        try:
+            self._is_bidirectional = self._module.bidirectional
+        except AttributeError:
+            self._is_bidirectional = False
+        if self._is_bidirectional:
+            self._num_directions = 2
+        else:
+            self._num_directions = 1
+
+    def get_input_dim(self) -> int:
+        return self._module.input_size
+
+    def get_output_dim(self) -> int:
+        return self._module.hidden_size * self._num_directions
+
+    def is_bidirectional(self) -> bool:
+        return self._is_bidirectional
+
+    def forward(
+        self, inputs: torch.Tensor, mask: torch.BoolTensor, hidden_state: torch.Tensor = None
+    ) -> torch.Tensor:
+
+        if self.stateful and mask is None:
+            raise ValueError("Always pass a mask with stateful RNNs.")
+        if self.stateful and hidden_state is not None:
+            raise ValueError("Stateful RNNs provide their own initial hidden_state.")
+
+        if mask is None:
+            return self._module(inputs, hidden_state)[0]
+
+        batch_size, total_sequence_length = mask.size()
+
+        packed_sequence_output, final_states, restoration_indices = self.sort_and_run_forward(
+            self._module, inputs, mask, hidden_state
+        )
+
+        unpacked_sequence_tensor, _ = pad_packed_sequence(packed_sequence_output, batch_first=True)
+
+        num_valid = unpacked_sequence_tensor.size(0)
+        # Some RNNs (GRUs) only return one state as a Tensor.  Others (LSTMs) return two.
+        # If one state, use a single element list to handle in a consistent manner below.
+        if not isinstance(final_states, (list, tuple)) and self.stateful:
+            final_states = [final_states]
+
+        # Add back invalid rows.
+        if num_valid < batch_size:
+            _, length, output_dim = unpacked_sequence_tensor.size()
+            zeros = unpacked_sequence_tensor.new_zeros(batch_size - num_valid, length, output_dim)
+            unpacked_sequence_tensor = torch.cat([unpacked_sequence_tensor, zeros], 0)
+
+            # The states also need to have invalid rows added back.
+            if self.stateful:
+                new_states = []
+                for state in final_states:
+                    num_layers, _, state_dim = state.size()
+                    zeros = state.new_zeros(num_layers, batch_size - num_valid, state_dim)
+                    new_states.append(torch.cat([state, zeros], 1))
+                final_states = new_states
+
+        # It's possible to need to pass sequences which are padded to longer than the
+        # max length of the sequence to a Seq2SeqEncoder. However, packing and unpacking
+        # the sequences mean that the returned tensor won't include these dimensions, because
+        # the RNN did not need to process them. We add them back on in the form of zeros here.
+        sequence_length_difference = total_sequence_length - unpacked_sequence_tensor.size(1)
+        if sequence_length_difference > 0:
+            zeros = unpacked_sequence_tensor.new_zeros(
+                batch_size, sequence_length_difference, unpacked_sequence_tensor.size(-1)
+            )
+            unpacked_sequence_tensor = torch.cat([unpacked_sequence_tensor, zeros], 1)
+
+        if self.stateful:
+            self._update_states(final_states, restoration_indices)
+
+        # Restore the original indices and return the sequence.
+        return unpacked_sequence_tensor.index_select(0, restoration_indices)
diff --git a/combo/modules/text_field_embedders/__init__.py b/combo/modules/text_field_embedders/__init__.py
new file mode 100644
index 0000000..7051557
--- /dev/null
+++ b/combo/modules/text_field_embedders/__init__.py
@@ -0,0 +1,2 @@
+from .text_field_embedder import TextFieldEmbedder
+from .basic_text_field_embedder import BasicTextFieldEmbedder
diff --git a/combo/modules/text_field_embedders/basic_text_field_embedder.py b/combo/modules/text_field_embedders/basic_text_field_embedder.py
new file mode 100644
index 0000000..56a5cbe
--- /dev/null
+++ b/combo/modules/text_field_embedders/basic_text_field_embedder.py
@@ -0,0 +1,117 @@
+"""
+Adapted from AllenNLP
+https://github.com/allenai/allennlp/blob/main/allennlp/modules/text_field_embedders/basic_text_field_embedder.py
+"""
+from typing import Dict
+import inspect
+
+import torch
+
+from combo.common.params import Params
+from combo.config import Registry
+from combo.data.fields.text_field import TextFieldTensors
+from combo.modules.text_field_embedders.text_field_embedder import TextFieldEmbedder
+from combo.modules.time_distributed import TimeDistributed
+from combo.modules.token_embedders import EmptyEmbedder
+from combo.modules.token_embedders.token_embedder import TokenEmbedder
+from combo.utils import ConfigurationError
+
+
+@Registry.register(TextFieldEmbedder, "basic")
+class BasicTextFieldEmbedder(TextFieldEmbedder):
+    """
+    This is a `TextFieldEmbedder` that wraps a collection of
+    [`TokenEmbedder`](../token_embedders/token_embedder.md) objects.  Each
+    `TokenEmbedder` embeds or encodes the representation output from one
+    [`allennlp.data.TokenIndexer`](../../data/token_indexers/token_indexer.md). As the data produced by a
+    [`allennlp.data.fields.TextField`](../../data/fields/text_field.md) is a dictionary mapping names to these
+    representations, we take `TokenEmbedders` with corresponding names.  Each `TokenEmbedders`
+    embeds its input, and the result is concatenated in an arbitrary (but consistent) order.
+
+    Registered as a `TextFieldEmbedder` with name "basic", which is also the default.
+
+    # Parameters
+
+    token_embedders : `Dict[str, TokenEmbedder]`, required.
+        A dictionary mapping token embedder names to implementations.
+        These names should match the corresponding indexer used to generate
+        the tensor passed to the TokenEmbedder.
+    """
+
+    def __init__(self, token_embedders: Dict[str, TokenEmbedder]) -> None:
+        super().__init__()
+        # NOTE(mattg): I'd prefer to just use ModuleDict(token_embedders) here, but that changes
+        # weight locations in torch state dictionaries and invalidates all prior models, just for a
+        # cosmetic change in the code.
+        self._token_embedders = token_embedders
+        for key, embedder in token_embedders.items():
+            name = "token_embedder_%s" % key
+            if isinstance(embedder, Params):
+                embedder_params = dict(embedder)
+                embedder = Registry.resolve(
+                    TokenEmbedder, embedder_params.get("type", "basic")
+                ).from_parameters(embedder_params)
+            self.add_module(name, embedder)
+        self._ordered_embedder_keys = sorted(self._token_embedders.keys())
+
+    def get_output_dim(self) -> int:
+        output_dim = 0
+        for embedder in self._token_embedders.values():
+            output_dim += embedder.get_output_dim()
+        return output_dim
+
+    def forward(
+        self, text_field_input: TextFieldTensors, num_wrapping_dims: int = 0, **kwargs
+    ) -> torch.Tensor:
+        if sorted(self._token_embedders.keys()) != sorted(text_field_input.keys()):
+            message = "Mismatched token keys: %s and %s" % (
+                str(self._token_embedders.keys()),
+                str(text_field_input.keys()),
+            )
+            embedder_keys = set(self._token_embedders.keys())
+            input_keys = set(text_field_input.keys())
+            if embedder_keys > input_keys and all(
+                isinstance(embedder, EmptyEmbedder)
+                for name, embedder in self._token_embedders.items()
+                if name in embedder_keys - input_keys
+            ):
+                # Allow extra embedders that are only in the token embedders (but not input) and are empty to pass
+                # config check
+                pass
+            else:
+                raise ConfigurationError(message)
+
+        embedded_representations = []
+        for key in self._ordered_embedder_keys:
+            # Note: need to use getattr here so that the pytorch voodoo
+            # with submodules works with multiple GPUs.
+            embedder = getattr(self, "token_embedder_{}".format(key))
+            if isinstance(embedder, EmptyEmbedder):
+                # Skip empty embedders
+                continue
+            forward_params = inspect.signature(embedder.forward).parameters
+            forward_params_values = {}
+            missing_tensor_args = set()
+            for param in forward_params.keys():
+                if param in kwargs:
+                    forward_params_values[param] = kwargs[param]
+                else:
+                    missing_tensor_args.add(param)
+
+            for _ in range(num_wrapping_dims):
+                embedder = TimeDistributed(embedder)
+
+            tensors: Dict[str, torch.Tensor] = text_field_input[key]
+            if len(tensors) == 1 and len(missing_tensor_args) == 1:
+                # If there's only one tensor argument to the embedder, and we just have one tensor to
+                # embed, we can just pass in that tensor, without requiring a name match.
+                token_vectors = embedder(list(tensors.values())[0], **forward_params_values)
+            else:
+                # If there are multiple tensor arguments, we have to require matching names from the
+                # TokenIndexer.  I don't think there's an easy way around that.
+                token_vectors = embedder(**tensors, **forward_params_values)
+            if token_vectors is not None:
+                # To handle some very rare use cases, we allow the return value of the embedder to
+                # be None; we just skip it in that case.
+                embedded_representations.append(token_vectors)
+        return torch.cat(embedded_representations, dim=-1)
diff --git a/combo/modules/text_field_embedders/text_field_embedder.py b/combo/modules/text_field_embedders/text_field_embedder.py
new file mode 100644
index 0000000..f44c45f
--- /dev/null
+++ b/combo/modules/text_field_embedders/text_field_embedder.py
@@ -0,0 +1,55 @@
+"""
+Adapted from AllenNLP
+https://github.com/allenai/allennlp/blob/main/allennlp/modules/text_field_embedders/text_field_embedder.py
+"""
+
+import torch
+
+from combo.config import FromParameters
+from combo.data.fields.text_field import TextFieldTensors
+
+
+class TextFieldEmbedder(torch.nn.Module, FromParameters):
+    """
+    A `TextFieldEmbedder` is a `Module` that takes as input the
+    [`DataArray`](../../data/fields/text_field.md) produced by a [`TextField`](../../data/fields/text_field.md) and
+    returns as output an embedded representation of the tokens in that field.
+
+    The `DataArrays` produced by `TextFields` are _dictionaries_ with named representations, like
+    "words" and "characters".  When you create a `TextField`, you pass in a dictionary of
+    [`TokenIndexer`](../../data/token_indexers/token_indexer.md) objects, telling the field how exactly the
+    tokens in the field should be represented.  This class changes the type signature of `Module.forward`,
+    restricting `TextFieldEmbedders` to take inputs corresponding to a single `TextField`, which is
+    a dictionary of tensors with the same names as were passed to the `TextField`.
+
+    We also add a method to the basic `Module` API: `get_output_dim()`.  You might need this
+    if you want to construct a `Linear` layer using the output of this embedder, for instance.
+    """
+
+    default_implementation = "basic"
+
+    def forward(
+        self, text_field_input: TextFieldTensors, num_wrapping_dims: int = 0, **kwargs
+    ) -> torch.Tensor:
+        """
+        # Parameters
+
+        text_field_input : `TextFieldTensors`
+            A dictionary that was the output of a call to `TextField.as_tensor`.  Each tensor in
+            here is assumed to have a shape roughly similar to `(batch_size, sequence_length)`
+            (perhaps with an extra trailing dimension for the characters in each token).
+        num_wrapping_dims : `int`, optional (default=`0`)
+            If you have a `ListField[TextField]` that created the `text_field_input`, you'll
+            end up with tensors of shape `(batch_size, wrapping_dim1, wrapping_dim2, ...,
+            sequence_length)`.  This parameter tells us how many wrapping dimensions there are, so
+            that we can correctly `TimeDistribute` the embedding of each named representation.
+        """
+        raise NotImplementedError
+
+    def get_output_dim(self) -> int:
+        """
+        Returns the dimension of the vector representing each token in the output of this
+        `TextFieldEmbedder`.  This is _not_ the shape of the returned tensor, but the last element
+        of that shape.
+        """
+        raise NotImplementedError
diff --git a/combo/modules/time_distributed.py b/combo/modules/time_distributed.py
new file mode 100644
index 0000000..e7975d7
--- /dev/null
+++ b/combo/modules/time_distributed.py
@@ -0,0 +1,73 @@
+"""
+Adapted from AllenNLP
+"""
+from typing import List
+
+import torch
+from overrides import overrides
+
+from combo.config import Registry, FromParameters
+from combo.modules.module import Module
+
+
+@Registry.register(Module, 'time_distributed')
+class TimeDistributed(Module, FromParameters):
+    """
+    Given an input shaped like `(batch_size, time_steps, [rest])` and a `Module` that takes
+    inputs like `(batch_size, [rest])`, `TimeDistributed` reshapes the input to be
+    `(batch_size * time_steps, [rest])`, applies the contained `Module`, then reshapes it back.
+
+    Note that while the above gives shapes with `batch_size` first, this `Module` also works if
+    `batch_size` is second - we always just combine the first two dimensions, then split them.
+
+    It also reshapes keyword arguments unless they are not tensors or their name is specified in
+    the optional `pass_through` iterable.
+    """
+
+    def __init__(self, module: torch.nn.Module):
+        super().__init__()
+        self._module = module
+
+    @overrides
+    def forward(self, *inputs, pass_through: List[str] = None, **kwargs):
+
+        pass_through = pass_through or []
+
+        reshaped_inputs = [self._reshape_tensor(input_tensor) for input_tensor in inputs]
+
+        # Need some input to then get the batch_size and time_steps.
+        some_input = None
+        if inputs:
+            some_input = inputs[-1]
+
+        reshaped_kwargs = {}
+        for key, value in kwargs.items():
+            if isinstance(value, torch.Tensor) and key not in pass_through:
+                if some_input is None:
+                    some_input = value
+
+                value = self._reshape_tensor(value)
+
+            reshaped_kwargs[key] = value
+
+        reshaped_outputs = self._module(*reshaped_inputs, **reshaped_kwargs)
+
+        if some_input is None:
+            raise RuntimeError("No input tensor to time-distribute")
+
+        # Now get the output back into the right shape.
+        # (batch_size, time_steps, **output_size)
+        new_size = some_input.size()[:2] + reshaped_outputs.size()[1:]
+        outputs = reshaped_outputs.contiguous().view(new_size)
+
+        return outputs
+
+    @staticmethod
+    def _reshape_tensor(input_tensor):
+        input_size = input_tensor.size()
+        if len(input_size) <= 2:
+            raise RuntimeError(f"No dimension to distribute: {input_size}")
+        # Squash batch_size and time_steps into a single axis; result has shape
+        # (batch_size * time_steps, **input_size).
+        squashed_shape = [-1] + list(input_size[2:])
+        return input_tensor.contiguous().view(*squashed_shape)
diff --git a/combo/modules/token_embedders/__init__.py b/combo/modules/token_embedders/__init__.py
new file mode 100644
index 0000000..75d4b18
--- /dev/null
+++ b/combo/modules/token_embedders/__init__.py
@@ -0,0 +1,7 @@
+from .token_embedder import *
+from .empty_embedder import EmptyEmbedder
+from .character_token_embedder import CharacterBasedWordEmbedder
+from .pretrained_transformer_embedder import PretrainedTransformerEmbedder
+from .pretrained_transformer_mismatched_embedder import PretrainedTransformerMismatchedEmbedder
+from .transformers_words_embeddings import TransformersWordEmbedder
+from .projected_words_embedder import ProjectedWordEmbedder
diff --git a/combo/modules/token_embedders/character_token_embedder.py b/combo/modules/token_embedders/character_token_embedder.py
new file mode 100644
index 0000000..b066cea
--- /dev/null
+++ b/combo/modules/token_embedders/character_token_embedder.py
@@ -0,0 +1,53 @@
+"""
+Adapted from COMBO
+Author: Mateusz Klimaszewski
+https://gitlab.clarin-pl.eu/syntactic-tools/combo/-/blob/master/combo/models/embeddings.py
+"""
+from overrides import overrides
+
+from combo.config import Registry
+from combo.data import Vocabulary
+from combo.modules.dilated_cnn import DilatedCnnEncoder
+from combo.modules.token_embedders import TokenEmbedder
+
+"""Embeddings."""
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from combo import modules
+
+
+# @token_embedders.TokenEmbedder.register("char_embeddings")
+@Registry.register(TokenEmbedder, "char_embeddings_from_config")
+class CharacterBasedWordEmbedder(TokenEmbedder):
+    """Character-based word embeddings."""
+
+    def __init__(self,
+                 embedding_dim: int,
+                 vocabulary: Vocabulary,
+                 dilated_cnn_encoder: DilatedCnnEncoder,
+                 vocab_namespace: str = "token_characters"):
+        assert vocab_namespace in vocabulary.get_namespaces()
+        super().__init__()
+        self.char_embed = nn.Embedding(
+            num_embeddings=vocabulary.get_vocab_size(vocab_namespace),
+            embedding_dim=embedding_dim,
+        )
+        self.dilated_cnn_encoder = modules.TimeDistributed(dilated_cnn_encoder)
+        self.output_dim = embedding_dim
+
+    def forward(self,
+                x: torch.Tensor,
+                char_mask: Optional[torch.BoolTensor] = None) -> torch.Tensor:
+        if char_mask is None:
+            char_mask = x.new_ones(x.size())
+
+        x = self.char_embed(x)
+        x = x * char_mask.unsqueeze(-1).float()
+        x = self.dilated_cnn_encoder(x.transpose(2, 3))
+        return torch.max(x, dim=-1)[0]
+
+    @overrides
+    def get_output_dim(self) -> int:
+        return self.output_dim
diff --git a/combo/modules/token_embedders/embedding.py b/combo/modules/token_embedders/embedding.py
new file mode 100644
index 0000000..7a628e1
--- /dev/null
+++ b/combo/modules/token_embedders/embedding.py
@@ -0,0 +1,675 @@
+"""
+Adapted from AllenNLP
+https://github.com/allenai/allennlp/blob/main/allennlp/modules/token_embedders/token_embedder.py
+"""
+import io
+import itertools
+import logging
+import re
+import tarfile
+import warnings
+import zipfile
+from typing import Any, cast, Iterator, NamedTuple, Optional, Sequence, Tuple, BinaryIO
+
+import numpy
+import torch
+
+from torch.nn.functional import embedding
+
+from combo.common import Tqdm
+from combo.config import Registry
+from combo.utils import ConfigurationError
+from combo.utils.file_utils import cached_path, get_file_extension
+from cached_path import is_url_or_existing_file
+from combo.data.vocabulary import Vocabulary
+from combo.modules.time_distributed import TimeDistributed
+from combo.modules.token_embedders.token_embedder import TokenEmbedder
+from combo.nn import util
+
+with warnings.catch_warnings():
+    warnings.filterwarnings("ignore", category=FutureWarning)
+    import h5py
+
+logger = logging.getLogger(__name__)
+
+
+@Registry.register(TokenEmbedder, "embedding")
+class Embedding(TokenEmbedder):
+    """
+    A more featureful embedding module than the default in Pytorch.  Adds the ability to:
+
+        1. embed higher-order inputs
+        2. pre-specify the weight matrix
+        3. use a non-trainable embedding
+        4. project the resultant embeddings to some other dimension (which only makes sense with
+           non-trainable embeddings).
+
+    Note that if you are using our data API and are trying to embed a
+    [`TextField`](../../data/fields/text_field.md), you should use a
+    [`TextFieldEmbedder`](../text_field_embedders/text_field_embedder.md) instead of using this directly.
+
+    Registered as a `TokenEmbedder` with name "embedding".
+
+    # Parameters
+
+    num_embeddings : `int`
+        Size of the dictionary of embeddings (vocabulary size).
+    embedding_dim : `int`
+        The size of each embedding vector.
+    projection_dim : `int`, optional (default=`None`)
+        If given, we add a projection layer after the embedding layer.  This really only makes
+        sense if `trainable` is `False`.
+    weight : `torch.FloatTensor`, optional (default=`None`)
+        A pre-initialised weight matrix for the embedding lookup, allowing the use of
+        pretrained vectors.
+    padding_index : `int`, optional (default=`None`)
+        If given, pads the output with zeros whenever it encounters the index.
+    trainable : `bool`, optional (default=`True`)
+        Whether or not to optimize the embedding parameters.
+    max_norm : `float`, optional (default=`None`)
+        If given, will renormalize the embeddings to always have a norm lesser than this
+    norm_type : `float`, optional (default=`2`)
+        The p of the p-norm to compute for the max_norm option
+    scale_grad_by_freq : `bool`, optional (default=`False`)
+        If given, this will scale gradients by the frequency of the words in the mini-batch.
+    sparse : `bool`, optional (default=`False`)
+        Whether or not the Pytorch backend should use a sparse representation of the embedding weight.
+    vocab_namespace : `str`, optional (default=`None`)
+        In case of fine-tuning/transfer learning, the model's embedding matrix needs to be
+        extended according to the size of extended-vocabulary. To be able to know how much to
+        extend the embedding-matrix, it's necessary to know which vocab_namspace was used to
+        construct it in the original training. We store vocab_namespace used during the original
+        training as an attribute, so that it can be retrieved during fine-tuning.
+    pretrained_file : `str`, optional (default=`None`)
+        Path to a file of word vectors to initialize the embedding matrix. It can be the
+        path to a local file or a URL of a (cached) remote file. Two formats are supported:
+            * hdf5 file - containing an embedding matrix in the form of a torch.Tensor;
+            * text file - an utf-8 encoded text file with space separated fields.
+    vocabulary : `Vocabulary`, optional (default = `None`)
+        Used to construct an embedding from a pretrained file.
+
+        In a typical AllenNLP configuration file, this parameter does not get an entry under the
+        "embedding", it gets specified as a top-level parameter, then is passed in to this module
+        separately.
+
+    # Returns
+
+    An Embedding module.
+    """
+
+    def __init__(
+        self,
+        embedding_dim: int,
+        num_embeddings: int = None,
+        projection_dim: int = None,
+        weight: torch.FloatTensor = None,
+        padding_index: int = None,
+        trainable: bool = True,
+        max_norm: float = None,
+        norm_type: float = 2.0,
+        scale_grad_by_freq: bool = False,
+        sparse: bool = False,
+        vocab_namespace: str = "tokens",
+        pretrained_file: str = None,
+        vocabulary: Vocabulary = None,
+    ) -> None:
+        super().__init__()
+
+        if num_embeddings is None and vocabulary is None:
+            raise ConfigurationError(
+                "Embedding must be constructed with either num_embeddings or a vocabulary."
+            )
+
+        _vocab_namespace: Optional[str] = vocab_namespace
+        if num_embeddings is None:
+            num_embeddings = vocabulary.get_vocab_size(_vocab_namespace)  # type: ignore
+        else:
+            # If num_embeddings is present, set default namespace to None so that extend_vocab
+            # call doesn't misinterpret that some namespace was originally used.
+            _vocab_namespace = None  # type: ignore
+
+        self.num_embeddings = num_embeddings
+        self.padding_index = padding_index
+        self.max_norm = max_norm
+        self.norm_type = norm_type
+        self.scale_grad_by_freq = scale_grad_by_freq
+        self.sparse = sparse
+        self._vocab_namespace = _vocab_namespace
+        self._pretrained_file = pretrained_file
+
+        self.output_dim = projection_dim or embedding_dim
+
+        if weight is not None and pretrained_file:
+            raise ConfigurationError(
+                "Embedding was constructed with both a weight and a pretrained file."
+            )
+
+        elif pretrained_file is not None:
+
+            if vocabulary is None:
+                raise ConfigurationError(
+                    "To construct an Embedding from a pretrained file, you must also pass a vocabulary."
+                )
+
+            # If we're loading a saved model, we don't want to actually read a pre-trained
+            # embedding file - the embeddings will just be in our saved weights, and we might not
+            # have the original embedding file anymore, anyway.
+
+            # TODO: having to pass tokens here is SUPER gross, but otherwise this breaks the
+            # extend_vocab method, which relies on the value of vocab_namespace being None
+            # to infer at what stage the embedding has been constructed. Phew.
+            weight = _read_pretrained_embeddings_file(
+                pretrained_file, embedding_dim, vocabulary, vocab_namespace
+            )
+            self.weight = torch.nn.Parameter(weight, requires_grad=trainable)
+
+        elif weight is not None:
+            self.weight = torch.nn.Parameter(weight, requires_grad=trainable)
+
+        else:
+            weight = torch.FloatTensor(num_embeddings, embedding_dim)
+            self.weight = torch.nn.Parameter(weight, requires_grad=trainable)
+            torch.nn.init.xavier_uniform_(self.weight)
+
+        # Whatever way we have constructed the embedding, it should be consistent with
+        # num_embeddings and embedding_dim.
+        if self.weight.size() != (num_embeddings, embedding_dim):
+            raise ConfigurationError(
+                "A weight matrix was passed with contradictory embedding shapes."
+            )
+
+        if self.padding_index is not None:
+            self.weight.data[self.padding_index].fill_(0)
+
+        if projection_dim:
+            self._projection = torch.nn.Linear(embedding_dim, projection_dim)
+        else:
+            self._projection = None
+
+    def get_output_dim(self) -> int:
+        return self.output_dim
+
+    def forward(self, tokens: torch.Tensor) -> torch.Tensor:
+        # tokens may have extra dimensions (batch_size, d1, ..., dn, sequence_length),
+        # but embedding expects (batch_size, sequence_length), so pass tokens to
+        # util.combine_initial_dims (which is a no-op if there are no extra dimensions).
+        # Remember the original size.
+        original_size = tokens.size()
+        tokens = util.combine_initial_dims(tokens)
+
+        embedded = embedding(
+            tokens,
+            self.weight,
+            padding_idx=self.padding_index,
+            max_norm=self.max_norm,
+            norm_type=self.norm_type,
+            scale_grad_by_freq=self.scale_grad_by_freq,
+            sparse=self.sparse,
+        )
+
+        # Now (if necessary) add back in the extra dimensions.
+        embedded = util.uncombine_initial_dims(embedded, original_size)
+
+        if self._projection:
+            projection = self._projection
+            for _ in range(embedded.dim() - 2):
+                projection = TimeDistributed(projection)
+            embedded = projection(embedded)
+        return embedded
+
+    def extend_vocab(
+        self,
+        extended_vocab: Vocabulary,
+        vocab_namespace: str = None,
+        extension_pretrained_file: str = None,
+        model_path: str = None,
+    ):
+        """
+        Extends the embedding matrix according to the extended vocabulary.
+        If extension_pretrained_file is available, it will be used for initializing the new words
+        embeddings in the extended vocabulary; otherwise we will check if _pretrained_file attribute
+        is already available. If none is available, they will be initialized with xavier uniform.
+
+        # Parameters
+
+        extended_vocab : `Vocabulary`
+            Vocabulary extended from original vocabulary used to construct
+            this `Embedding`.
+        vocab_namespace : `str`, (optional, default=`None`)
+            In case you know what vocab_namespace should be used for extension, you
+            can pass it. If not passed, it will check if vocab_namespace used at the
+            time of `Embedding` construction is available. If so, this namespace
+            will be used or else extend_vocab will be a no-op.
+        extension_pretrained_file : `str`, (optional, default=`None`)
+            A file containing pretrained embeddings can be specified here. It can be
+            the path to a local file or an URL of a (cached) remote file. Check format
+            details in `from_params` of `Embedding` class.
+        model_path : `str`, (optional, default=`None`)
+            Path traversing the model attributes upto this embedding module.
+            Eg. "_text_field_embedder.token_embedder_tokens". This is only useful
+            to give a helpful error message when extend_vocab is implicitly called
+            by train or any other command.
+        """
+        # Caveat: For allennlp v0.8.1 and below, we weren't storing vocab_namespace as an attribute,
+        # knowing which is necessary at time of embedding vocabulary extension. So old archive models are
+        # currently unextendable.
+
+        vocab_namespace = vocab_namespace or self._vocab_namespace
+        if not vocab_namespace:
+            # It's not safe to default to "tokens" or any other namespace.
+            logger.info(
+                "Loading a model trained before embedding extension was implemented; "
+                "pass an explicit vocabulary namespace if you want to extend the vocabulary."
+            )
+            return
+
+        extended_num_embeddings = extended_vocab.get_vocab_size(vocab_namespace)
+        if extended_num_embeddings == self.num_embeddings:
+            # It's already been extended. No need to initialize / read pretrained file in first place (no-op)
+            return
+
+        if extended_num_embeddings < self.num_embeddings:
+            raise ConfigurationError(
+                f"Size of namespace, {vocab_namespace} for extended_vocab is smaller than "
+                f"embedding. You likely passed incorrect vocabulary or namespace for extension."
+            )
+
+        # Case 1: user passed extension_pretrained_file and it's available.
+        if extension_pretrained_file and is_url_or_existing_file(extension_pretrained_file):
+            # Don't have to do anything here, this is the happy case.
+            pass
+        # Case 2: user passed extension_pretrained_file and it's not available
+        elif extension_pretrained_file:
+            raise ConfigurationError(
+                f"You passed pretrained embedding file {extension_pretrained_file} "
+                f"for model_path {model_path} but it's not available."
+            )
+        # Case 3: user didn't pass extension_pretrained_file, but pretrained_file attribute was
+        # saved during training and is available.
+        elif is_url_or_existing_file(self._pretrained_file):
+            extension_pretrained_file = self._pretrained_file
+        # Case 4: no file is available, hope that pretrained embeddings weren't used in the first place and warn
+        elif self._pretrained_file is not None:
+            # Warn here instead of an exception to allow a fine-tuning even without the original pretrained_file
+            logger.warning(
+                f"Embedding at model_path, {model_path} cannot locate the pretrained_file. "
+                f"Originally pretrained_file was at '{self._pretrained_file}'."
+            )
+        else:
+            # When loading a model from archive there is no way to distinguish between whether a pretrained-file
+            # was or wasn't used during the original training. So we leave an info here.
+            logger.info(
+                "If you are fine-tuning and want to use a pretrained_file for "
+                "embedding extension, please pass the mapping by --embedding-sources argument."
+            )
+
+        embedding_dim = self.weight.data.shape[-1]
+        if not extension_pretrained_file:
+            extra_num_embeddings = extended_num_embeddings - self.num_embeddings
+            extra_weight = torch.FloatTensor(extra_num_embeddings, embedding_dim)
+            torch.nn.init.xavier_uniform_(extra_weight)
+        else:
+            # It's easiest to just reload the embeddings for the entire vocabulary,
+            # then only keep the ones we need.
+            whole_weight = _read_pretrained_embeddings_file(
+                extension_pretrained_file, embedding_dim, extended_vocab, vocab_namespace
+            )
+            extra_weight = whole_weight[self.num_embeddings :, :]
+
+        device = self.weight.data.device
+        extended_weight = torch.cat([self.weight.data, extra_weight.to(device)], dim=0)
+        self.weight = torch.nn.Parameter(extended_weight, requires_grad=self.weight.requires_grad)
+        self.num_embeddings = extended_num_embeddings
+
+
+def _read_pretrained_embeddings_file(
+    file_uri: str, embedding_dim: int, vocab: Vocabulary, namespace: str = "tokens"
+) -> torch.FloatTensor:
+    """
+    Returns and embedding matrix for the given vocabulary using the pretrained embeddings
+    contained in the given file. Embeddings for tokens not found in the pretrained embedding file
+    are randomly initialized using a normal distribution with mean and standard deviation equal to
+    those of the pretrained embeddings.
+
+    We support two file formats:
+
+        * text format - utf-8 encoded text file with space separated fields: [word] [dim 1] [dim 2] ...
+          The text file can eventually be compressed, and even resides in an archive with multiple files.
+          If the file resides in an archive with other files, then `embeddings_filename` must
+          be a URI "(archive_uri)#file_path_inside_the_archive"
+
+        * hdf5 format - hdf5 file containing an embedding matrix in the form of a torch.Tensor.
+
+    If the filename ends with '.hdf5' or '.h5' then we load from hdf5, otherwise we assume
+    text format.
+
+    # Parameters
+
+    file_uri : `str`, required.
+        It can be:
+
+        * a file system path or a URL of an eventually compressed text file or a zip/tar archive
+          containing a single file.
+
+        * URI of the type `(archive_path_or_url)#file_path_inside_archive` if the text file
+          is contained in a multi-file archive.
+
+    vocabulary : `Vocabulary`, required.
+        A Vocabulary object.
+    namespace : `str`, (optional, default=`"tokens"`)
+        The namespace of the vocabulary to find pretrained embeddings for.
+    trainable : `bool`, (optional, default=`True`)
+        Whether or not the embedding parameters should be optimized.
+
+    # Returns
+
+    A weight matrix with embeddings initialized from the read file.  The matrix has shape
+    `(vocabulary.get_vocab_size(namespace), embedding_dim)`, where the indices of words appearing in
+    the pretrained embedding file are initialized to the pretrained embedding value.
+    """
+    file_ext = get_file_extension(file_uri)
+    if file_ext in [".h5", ".hdf5"]:
+        return _read_embeddings_from_hdf5(file_uri, embedding_dim, vocab, namespace)
+
+    return _read_embeddings_from_text_file(file_uri, embedding_dim, vocab, namespace)
+
+
+def _read_embeddings_from_text_file(
+    file_uri: str, embedding_dim: int, vocab: Vocabulary, namespace: str = "tokens"
+) -> torch.FloatTensor:
+    """
+    Read pre-trained word vectors from an eventually compressed text file, possibly contained
+    inside an archive with multiple files. The text file is assumed to be utf-8 encoded with
+    space-separated fields: [word] [dim 1] [dim 2] ...
+
+    Lines that contain more numerical tokens than `embedding_dim` raise a warning and are skipped.
+
+    The remainder of the docstring is identical to `_read_pretrained_embeddings_file`.
+    """
+    tokens_to_keep = set(vocab.get_index_to_token_vocabulary(namespace).values())
+    vocab_size = vocab.get_vocab_size(namespace)
+    embeddings = {}
+
+    # First we read the embeddings from the file, only keeping vectors for the words we need.
+    logger.info("Reading pretrained embeddings from file")
+
+    with EmbeddingsTextFile(file_uri) as embeddings_file:
+        for line in Tqdm.tqdm(embeddings_file):
+            token = line.split(" ", 1)[0]
+            if token in tokens_to_keep:
+                fields = line.rstrip().split(" ")
+                if len(fields) - 1 != embedding_dim:
+                    # Sometimes there are funny unicode parsing problems that lead to different
+                    # fields lengths (e.g., a word with a unicode space character that splits
+                    # into more than one column).  We skip those lines.  Note that if you have
+                    # some kind of long header, this could result in all of your lines getting
+                    # skipped.  It's hard to check for that here; you just have to look in the
+                    # embedding_misses_file and at the model summary to make sure things look
+                    # like they are supposed to.
+                    logger.warning(
+                        "Found line with wrong number of dimensions (expected: %d; actual: %d): %s",
+                        embedding_dim,
+                        len(fields) - 1,
+                        line,
+                    )
+                    continue
+
+                vector = numpy.asarray(fields[1:], dtype="float32")
+                embeddings[token] = vector
+
+    if not embeddings:
+        raise ConfigurationError(
+            "No embeddings of correct dimension found; you probably "
+            "misspecified your embedding_dim parameter, or didn't "
+            "pre-populate your Vocabulary"
+        )
+
+    all_embeddings = numpy.asarray(list(embeddings.values()))
+    embeddings_mean = float(numpy.mean(all_embeddings))
+    embeddings_std = float(numpy.std(all_embeddings))
+    # Now we initialize the weight matrix for an embedding layer, starting with random vectors,
+    # then filling in the word vectors we just read.
+    logger.info("Initializing pre-trained embedding layer")
+    embedding_matrix = torch.FloatTensor(vocab_size, embedding_dim).normal_(
+        embeddings_mean, embeddings_std
+    )
+    num_tokens_found = 0
+    index_to_token = vocab.get_index_to_token_vocabulary(namespace)
+    for i in range(vocab_size):
+        token = index_to_token[i]
+
+        # If we don't have a pre-trained vector for this word, we'll just leave this row alone,
+        # so the word has a random initialization.
+        if token in embeddings:
+            embedding_matrix[i] = torch.FloatTensor(embeddings[token])
+            num_tokens_found += 1
+        else:
+            logger.debug(
+                "Token %s was not found in the embedding file. Initialising randomly.", token
+            )
+
+    logger.info(
+        "Pretrained embeddings were found for %d out of %d tokens", num_tokens_found, vocab_size
+    )
+
+    return embedding_matrix
+
+
+def _read_embeddings_from_hdf5(
+    embeddings_filename: str, embedding_dim: int, vocabulary: Vocabulary, namespace: str = "tokens"
+) -> torch.FloatTensor:
+    """
+    Reads from a hdf5 formatted file. The embedding matrix is assumed to
+    be keyed by 'embedding' and of size `(num_tokens, embedding_dim)`.
+    """
+    with h5py.File(embeddings_filename, "r") as fin:
+        embeddings = fin["embedding"][...]
+
+    if list(embeddings.shape) != [vocabulary.get_vocab_size(namespace), embedding_dim]:
+        raise ConfigurationError(
+            "Read shape {0} embeddings from the file, but expected {1}".format(
+                list(embeddings.shape), [vocabulary.get_vocab_size(namespace), embedding_dim]
+            )
+        )
+
+    return torch.FloatTensor(embeddings)
+
+
+def format_embeddings_file_uri(
+    main_file_path_or_url: str, path_inside_archive: Optional[str] = None
+) -> str:
+    if path_inside_archive:
+        return "({})#{}".format(main_file_path_or_url, path_inside_archive)
+    return main_file_path_or_url
+
+
+class EmbeddingsFileURI(NamedTuple):
+    main_file_uri: str
+    path_inside_archive: Optional[str] = None
+
+
+def parse_embeddings_file_uri(uri: str) -> "EmbeddingsFileURI":
+    match = re.fullmatch(r"\((.*)\)#(.*)", uri)
+    if match:
+        fields = cast(Tuple[str, str], match.groups())
+        return EmbeddingsFileURI(*fields)
+    else:
+        return EmbeddingsFileURI(uri, None)
+
+
+class EmbeddingsTextFile(Iterator[str]):
+    """
+    Utility class for opening embeddings text files. Handles various compression formats,
+    as well as context management.
+
+    # Parameters
+
+    file_uri : `str`
+        It can be:
+
+        * a file system path or a URL of an eventually compressed text file or a zip/tar archive
+          containing a single file.
+        * URI of the type `(archive_path_or_url)#file_path_inside_archive` if the text file
+          is contained in a multi-file archive.
+
+    encoding : `str`
+    cache_dir : `str`
+    """
+
+    DEFAULT_ENCODING = "utf-8"
+
+    def __init__(
+        self, file_uri: str, encoding: str = DEFAULT_ENCODING, cache_dir: str = None
+    ) -> None:
+
+        self.uri = file_uri
+        self._encoding = encoding
+        self._cache_dir = cache_dir
+        self._archive_handle: Any = None  # only if the file is inside an archive
+
+        main_file_uri, path_inside_archive = parse_embeddings_file_uri(file_uri)
+        main_file_local_path = cached_path(main_file_uri, cache_dir=cache_dir)
+
+        if zipfile.is_zipfile(main_file_local_path):  # ZIP archive
+            self._open_inside_zip(main_file_uri, path_inside_archive)
+
+        elif tarfile.is_tarfile(main_file_local_path):  # TAR archive
+            self._open_inside_tar(main_file_uri, path_inside_archive)
+
+        else:  # all the other supported formats, including uncompressed files
+            if path_inside_archive:
+                raise ValueError("Unsupported archive format: %s" + main_file_uri)
+
+            # All the python packages for compressed files share the same interface of io.open
+            extension = get_file_extension(main_file_uri)
+
+            # Some systems don't have support for all of these libraries, so we import them only
+            # when necessary.
+            package = None
+            if extension in [".txt", ".vec"]:
+                package = io
+            elif extension == ".gz":
+                import gzip
+
+                package = gzip
+            elif extension == ".bz2":
+                import bz2
+
+                package = bz2
+            elif extension == ".xz":
+                import lzma
+
+                package = lzma
+
+            if package is None:
+                logger.warning(
+                    'The embeddings file has an unknown file extension "%s". '
+                    "We will assume the file is an (uncompressed) text file",
+                    extension,
+                )
+                package = io
+
+            self._handle = package.open(  # type: ignore
+                main_file_local_path, "rt", encoding=encoding
+            )
+
+        # To use this with tqdm we'd like to know the number of tokens. It's possible that the
+        # first line of the embeddings file contains this: if it does, we want to start iteration
+        # from the 2nd line, otherwise we want to start from the 1st.
+        # Unfortunately, once we read the first line, we cannot move back the file iterator
+        # because the underlying file may be "not seekable"; we use itertools.chain instead.
+        first_line = next(self._handle)  # this moves the iterator forward
+        self.num_tokens = EmbeddingsTextFile._get_num_tokens_from_first_line(first_line)
+        if self.num_tokens:
+            # the first line is a header line: start iterating from the 2nd line
+            self._iterator = self._handle
+        else:
+            # the first line is not a header line: start iterating from the 1st line
+            self._iterator = itertools.chain([first_line], self._handle)
+
+    def _open_inside_zip(self, archive_path: str, member_path: Optional[str] = None) -> None:
+        cached_archive_path = cached_path(archive_path, cache_dir=self._cache_dir)
+        archive = zipfile.ZipFile(cached_archive_path, "r")
+        if member_path is None:
+            members_list = archive.namelist()
+            member_path = self._get_the_only_file_in_the_archive(members_list, archive_path)
+        member_path = cast(str, member_path)
+        member_file = cast(BinaryIO, archive.open(member_path, "r"))
+        self._handle = io.TextIOWrapper(member_file, encoding=self._encoding)
+        self._archive_handle = archive
+
+    def _open_inside_tar(self, archive_path: str, member_path: Optional[str] = None) -> None:
+        cached_archive_path = cached_path(archive_path, cache_dir=self._cache_dir)
+        archive = tarfile.open(cached_archive_path, "r")
+        if member_path is None:
+            members_list = archive.getnames()
+            member_path = self._get_the_only_file_in_the_archive(members_list, archive_path)
+        member_path = cast(str, member_path)
+        member = archive.getmember(member_path)  # raises exception if not present
+        member_file = cast(BinaryIO, archive.extractfile(member))
+        self._handle = io.TextIOWrapper(member_file, encoding=self._encoding)
+        self._archive_handle = archive
+
+    def read(self) -> str:
+        return "".join(self._iterator)
+
+    def readline(self) -> str:
+        return next(self._iterator)
+
+    def close(self) -> None:
+        self._handle.close()
+        if self._archive_handle:
+            self._archive_handle.close()
+
+    def __enter__(self) -> "EmbeddingsTextFile":
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb) -> None:
+        self.close()
+
+    def __iter__(self) -> "EmbeddingsTextFile":
+        return self
+
+    def __next__(self) -> str:
+        return next(self._iterator)
+
+    def __len__(self) -> Optional[int]:
+        if self.num_tokens:
+            return self.num_tokens
+        raise AttributeError(
+            "an object of type EmbeddingsTextFile implements `__len__` only if the underlying "
+            "text file declares the number of tokens (i.e. the number of lines following)"
+            "in the first line. That is not the case of this particular instance."
+        )
+
+    @staticmethod
+    def _get_the_only_file_in_the_archive(members_list: Sequence[str], archive_path: str) -> str:
+        if len(members_list) > 1:
+            raise ValueError(
+                "The archive %s contains multiple files, so you must select "
+                "one of the files inside providing a uri of the type: %s."
+                % (
+                    archive_path,
+                    format_embeddings_file_uri("path_or_url_to_archive", "path_inside_archive"),
+                )
+            )
+        return members_list[0]
+
+    @staticmethod
+    def _get_num_tokens_from_first_line(line: str) -> Optional[int]:
+        """This function takes in input a string and if it contains 1 or 2 integers, it assumes the
+        largest one it the number of tokens. Returns None if the line doesn't match that pattern."""
+        fields = line.split(" ")
+        if 1 <= len(fields) <= 2:
+            try:
+                int_fields = [int(x) for x in fields]
+            except ValueError:
+                return None
+            else:
+                num_tokens = max(int_fields)
+                logger.info(
+                    "Recognized a header line in the embedding file with number of tokens: %d",
+                    num_tokens,
+                )
+                return num_tokens
+        return None
diff --git a/combo/modules/token_embedders/empty_embedder.py b/combo/modules/token_embedders/empty_embedder.py
new file mode 100644
index 0000000..1def0f8
--- /dev/null
+++ b/combo/modules/token_embedders/empty_embedder.py
@@ -0,0 +1,28 @@
+import torch
+
+from combo.config import Registry
+from combo.modules.token_embedders import TokenEmbedder
+
+
+@Registry.register(TokenEmbedder, "empty")
+class EmptyEmbedder(TokenEmbedder):
+    """
+    Assumes you want to completely ignore the output of a `TokenIndexer` for some reason, and does
+    not return anything when asked to embed it.
+
+    You should almost never need to use this; normally you would just not use a particular
+    `TokenIndexer`. It's only in very rare cases, like simplicity in data processing for language
+    modeling (where we use just one `TextField` to handle input embedding and computing target ids),
+    where you might want to use this.
+
+    Registered as a `TokenEmbedder` with name "empty".
+    """
+
+    def __init__(self) -> None:
+        super().__init__()
+
+    def get_output_dim(self):
+        return 0
+
+    def forward(self, *inputs, **kwargs) -> torch.Tensor:
+        return None
diff --git a/combo/modules/token_embedders/pretrained_transformer_embedder.py b/combo/modules/token_embedders/pretrained_transformer_embedder.py
new file mode 100644
index 0000000..d325490
--- /dev/null
+++ b/combo/modules/token_embedders/pretrained_transformer_embedder.py
@@ -0,0 +1,415 @@
+"""
+Adapted from AllenNLP
+https://github.com/allenai/allennlp/blob/main/allennlp/modules/token_embedders/pretrained_transformer_embedder.py
+"""
+
+import logging
+import math
+from typing import Any, Dict, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+from overrides import overrides
+
+from combo.config import Registry
+from combo.data.tokenizers import PretrainedTransformerTokenizer
+from combo.modules.scalar_mix import ScalarMix
+from combo.modules.token_embedders.token_embedder import TokenEmbedder
+from combo.nn.util import batched_index_select
+from transformers import XLNetConfig
+
+logger = logging.getLogger(__name__)
+
+
+@Registry.register(TokenEmbedder, "pretrained_transformer")
+class PretrainedTransformerEmbedder(TokenEmbedder):
+    """
+    Uses a pretrained model from `transformers` as a `TokenEmbedder`.
+
+    Registered as a `TokenEmbedder` with name "pretrained_transformer".
+
+    # Parameters
+
+    model_name : `str`
+        The name of the `transformers` model to use. Should be the same as the corresponding
+        `PretrainedTransformerIndexer`.
+    max_length : `int`, optional (default = `None`)
+        If positive, folds input token IDs into multiple segments of this length, pass them
+        through the transformer model independently, and concatenate the final representations.
+        Should be set to the same value as the `max_length` option on the
+        `PretrainedTransformerIndexer`.
+    sub_module: `str`, optional (default = `None`)
+        The name of a submodule of the transformer to be used as the embedder. Some transformers naturally act
+        as embedders such as BERT. However, other models consist of encoder and decoder, in which case we just
+        want to use the encoder.
+    train_parameters: `bool`, optional (default = `True`)
+        If this is `True`, the transformer weights get updated during training. If this is `False`, the
+        transformer weights are not updated during training.
+    eval_mode: `bool`, optional (default = `False`)
+        If this is `True`, the model is always set to evaluation mode (e.g., the dropout is disabled and the
+        batch normalization layer statistics are not updated). If this is `False`, such dropout and batch
+        normalization layers are only set to evaluation mode when when the model is evaluating on development
+        or test data.
+    last_layer_only: `bool`, optional (default = `True`)
+        When `True` (the default), only the final layer of the pretrained transformer is taken
+        for the embeddings. But if set to `False`, a scalar mix of all of the layers
+        is used.
+    override_weights_file: `Optional[str]`, optional (default = `None`)
+        If set, this specifies a file from which to load alternate weights that override the
+        weights from huggingface. The file is expected to contain a PyTorch `state_dict`, created
+        with `torch.save()`.
+    override_weights_strip_prefix: `Optional[str]`, optional (default = `None`)
+        If set, strip the given prefix from the state dict when loading it.
+    reinit_modules: `Optional[Union[int, Tuple[int, ...], Tuple[str, ...]]]`, optional (default = `None`)
+        If this is an integer, the last `reinit_modules` layers of the transformer will be
+        re-initialized. If this is a tuple of integers, the layers indexed by `reinit_modules` will
+        be re-initialized. Note, because the module structure of the transformer `model_name` can
+        differ, we cannot guarantee that providing an integer or tuple of integers will work. If
+        this fails, you can instead provide a tuple of strings, which will be treated as regexes and
+        any module with a name matching the regex will be re-initialized. Re-initializing the last
+        few layers of a pretrained transformer can reduce the instability of fine-tuning on small
+        datasets and may improve performance (https://arxiv.org/abs/2006.05987v3). Has no effect
+        if `load_weights` is `False` or `override_weights_file` is not `None`.
+    load_weights: `bool`, optional (default = `True`)
+        Whether to load the pretrained weights. If you're loading your model/predictor from an AllenNLP archive
+        it usually makes sense to set this to `False` (via the `overrides` parameter)
+        to avoid unnecessarily caching and loading the original pretrained weights,
+        since the archive will already contain all of the weights needed.
+    gradient_checkpointing: `bool`, optional (default = `None`)
+        Enable or disable gradient checkpointing.
+    tokenizer_kwargs: `Dict[str, Any]`, optional (default = `None`)
+        Dictionary with
+        [additional arguments](https://github.com/huggingface/transformers/blob/155c782a2ccd103cf63ad48a2becd7c76a7d2115/transformers/tokenization_utils.py#L691)
+        for `AutoTokenizer.from_pretrained`.
+    transformer_kwargs: `Dict[str, Any]`, optional (default = `None`)
+        Dictionary with
+        [additional arguments](https://github.com/huggingface/transformers/blob/155c782a2ccd103cf63ad48a2becd7c76a7d2115/transformers/modeling_utils.py#L253)
+        for `AutoModel.from_pretrained`.
+    """  # noqa: E501
+
+    authorized_missing_keys = [r"position_ids$"]
+
+    def __init__(
+        self,
+        model_name: str,
+        *,
+        max_length: int = None,
+        sub_module: str = None,
+        train_parameters: bool = True,
+        eval_mode: bool = False,
+        last_layer_only: bool = True,
+        override_weights_file: Optional[str] = None,
+        override_weights_strip_prefix: Optional[str] = None,
+        reinit_modules: Optional[Union[int, Tuple[int, ...], Tuple[str, ...]]] = None,
+        load_weights: bool = True,
+        gradient_checkpointing: Optional[bool] = None,
+        tokenizer_kwargs: Optional[Dict[str, Any]] = None,
+        transformer_kwargs: Optional[Dict[str, Any]] = None,
+    ) -> None:
+        super().__init__()
+        from combo.common import cached_transformers
+
+        self.transformer_model = cached_transformers.get(
+            model_name,
+            True,
+            override_weights_file=override_weights_file,
+            override_weights_strip_prefix=override_weights_strip_prefix,
+            reinit_modules=reinit_modules,
+            load_weights=load_weights,
+            **(transformer_kwargs or {}),
+        )
+
+        if gradient_checkpointing is not None:
+            self.transformer_model.config.update({"gradient_checkpointing": gradient_checkpointing})
+
+        self.config = self.transformer_model.config
+        if sub_module:
+            assert hasattr(self.transformer_model, sub_module)
+            self.transformer_model = getattr(self.transformer_model, sub_module)
+        self._max_length = max_length
+
+        # I'm not sure if this works for all models; open an issue on github if you find a case
+        # where it doesn't work.
+        self.output_dim = self.config.hidden_size
+
+        self._scalar_mix: Optional[ScalarMix] = None
+        if not last_layer_only:
+            self._scalar_mix = ScalarMix(self.config.num_hidden_layers)
+            self.config.output_hidden_states = True
+
+        tokenizer = PretrainedTransformerTokenizer(
+            model_name,
+            tokenizer_kwargs=tokenizer_kwargs,
+        )
+
+        try:
+            if self.transformer_model.get_input_embeddings().num_embeddings != len(
+                tokenizer.tokenizer
+            ):
+                self.transformer_model.resize_token_embeddings(len(tokenizer.tokenizer))
+        except NotImplementedError:
+            # Can't resize for transformers models that don't implement base_model.get_input_embeddings()
+            logger.warning(
+                "Could not resize the token embedding matrix of the transformer model. "
+                "This model does not support resizing."
+            )
+
+        self._num_added_start_tokens = len(tokenizer.single_sequence_start_tokens)
+        self._num_added_end_tokens = len(tokenizer.single_sequence_end_tokens)
+        self._num_added_tokens = self._num_added_start_tokens + self._num_added_end_tokens
+
+        self.train_parameters = train_parameters
+        if not train_parameters:
+            for param in self.transformer_model.parameters():
+                param.requires_grad = False
+
+        self.eval_mode = eval_mode
+        if eval_mode:
+            self.transformer_model.eval()
+
+    def train(self, mode: bool = True):
+        self.training = mode
+        for name, module in self.named_children():
+            if self.eval_mode and name == "transformer_model":
+                module.eval()
+            else:
+                module.train(mode)
+        return self
+
+    @overrides
+    def get_output_dim(self) -> int:
+        return self.output_dim
+
+    def _number_of_token_type_embeddings(self):
+        if isinstance(self.config, XLNetConfig):
+            return 3  # XLNet has 3 type ids
+        elif hasattr(self.config, "type_vocab_size"):
+            return self.config.type_vocab_size
+        else:
+            return 0
+
+    def forward(
+        self,
+        token_ids: torch.LongTensor,
+        mask: torch.BoolTensor,
+        type_ids: Optional[torch.LongTensor] = None,
+        segment_concat_mask: Optional[torch.BoolTensor] = None,
+    ) -> torch.Tensor:  # type: ignore
+        """
+        # Parameters
+
+        token_ids: `torch.LongTensor`
+            Shape: `[batch_size, num_wordpieces if max_length is None else num_segment_concat_wordpieces]`.
+            num_segment_concat_wordpieces is num_wordpieces plus special tokens inserted in the
+            middle, e.g. the length of: "[CLS] A B C [SEP] [CLS] D E F [SEP]" (see indexer logic).
+        mask: `torch.BoolTensor`
+            Shape: [batch_size, num_wordpieces].
+        type_ids: `Optional[torch.LongTensor]`
+            Shape: `[batch_size, num_wordpieces if max_length is None else num_segment_concat_wordpieces]`.
+        segment_concat_mask: `Optional[torch.BoolTensor]`
+            Shape: `[batch_size, num_segment_concat_wordpieces]`.
+
+        # Returns
+
+        `torch.Tensor`
+            Shape: `[batch_size, num_wordpieces, embedding_size]`.
+
+        """
+        # Some of the huggingface transformers don't support type ids at all and crash when you supply
+        # them. For others, you can supply a tensor of zeros, and if you don't, they act as if you did.
+        # There is no practical difference to the caller, so here we pretend that one case is the same
+        # as another case.
+        if type_ids is not None:
+            max_type_id = type_ids.max()
+            if max_type_id == 0:
+                type_ids = None
+            else:
+                if max_type_id >= self._number_of_token_type_embeddings():
+                    raise ValueError("Found type ids too large for the chosen transformer model.")
+                assert token_ids.shape == type_ids.shape
+
+        fold_long_sequences = self._max_length is not None and token_ids.size(1) > self._max_length
+        if fold_long_sequences:
+            batch_size, num_segment_concat_wordpieces = token_ids.size()
+            token_ids, segment_concat_mask, type_ids = self._fold_long_sequences(
+                token_ids, segment_concat_mask, type_ids
+            )
+
+        transformer_mask = segment_concat_mask if self._max_length is not None else mask
+        assert transformer_mask is not None
+        # Shape: [batch_size, num_wordpieces, embedding_size],
+        # or if self._max_length is not None:
+        # [batch_size * num_segments, self._max_length, embedding_size]
+
+        # We call this with kwargs because some of the huggingface models don't have the
+        # token_type_ids parameter and fail even when it's given as None.
+        # Also, as of transformers v2.5.1, they are taking FloatTensor masks.
+        parameters = {"input_ids": token_ids, "attention_mask": transformer_mask.float()}
+        if type_ids is not None:
+            parameters["token_type_ids"] = type_ids
+
+        transformer_output = self.transformer_model(**parameters)
+        if self._scalar_mix is not None:
+            # The hidden states will also include the embedding layer, which we don't
+            # include in the scalar mix. Hence the `[1:]` slicing.
+            hidden_states = transformer_output.hidden_states[1:]
+            embeddings = self._scalar_mix(hidden_states)
+        else:
+            embeddings = transformer_output.last_hidden_state
+
+        if fold_long_sequences:
+            embeddings = self._unfold_long_sequences(
+                embeddings, segment_concat_mask, batch_size, num_segment_concat_wordpieces
+            )
+
+        return embeddings
+
+    def _fold_long_sequences(
+        self,
+        token_ids: torch.LongTensor,
+        mask: torch.BoolTensor,
+        type_ids: Optional[torch.LongTensor] = None,
+    ) -> Tuple[torch.LongTensor, torch.LongTensor, Optional[torch.LongTensor]]:
+        """
+        We fold 1D sequences (for each element in batch), returned by `PretrainedTransformerIndexer`
+        that are in reality multiple segments concatenated together, to 2D tensors, e.g.
+
+        [ [CLS] A B C [SEP] [CLS] D E [SEP] ]
+        -> [ [ [CLS] A B C [SEP] ], [ [CLS] D E [SEP] [PAD] ] ]
+        The [PAD] positions can be found in the returned `mask`.
+
+        # Parameters
+
+        token_ids: `torch.LongTensor`
+            Shape: `[batch_size, num_segment_concat_wordpieces]`.
+            num_segment_concat_wordpieces is num_wordpieces plus special tokens inserted in the
+            middle, i.e. the length of: "[CLS] A B C [SEP] [CLS] D E F [SEP]" (see indexer logic).
+        mask: `torch.BoolTensor`
+            Shape: `[batch_size, num_segment_concat_wordpieces]`.
+            The mask for the concatenated segments of wordpieces. The same as `segment_concat_mask`
+            in `forward()`.
+        type_ids: `Optional[torch.LongTensor]`
+            Shape: [batch_size, num_segment_concat_wordpieces].
+
+        # Returns:
+
+        token_ids: `torch.LongTensor`
+            Shape: [batch_size * num_segments, self._max_length].
+        mask: `torch.BoolTensor`
+            Shape: [batch_size * num_segments, self._max_length].
+        """
+        num_segment_concat_wordpieces = token_ids.size(1)
+        num_segments = math.ceil(num_segment_concat_wordpieces / self._max_length)  # type: ignore
+        padded_length = num_segments * self._max_length  # type: ignore
+        length_to_pad = padded_length - num_segment_concat_wordpieces
+
+        def fold(tensor):  # Shape: [batch_size, num_segment_concat_wordpieces]
+            # Shape: [batch_size, num_segments * self._max_length]
+            tensor = F.pad(tensor, [0, length_to_pad], value=0)
+            # Shape: [batch_size * num_segments, self._max_length]
+            return tensor.reshape(-1, self._max_length)
+
+        return fold(token_ids), fold(mask), fold(type_ids) if type_ids is not None else None
+
+    def _unfold_long_sequences(
+        self,
+        embeddings: torch.FloatTensor,
+        mask: torch.BoolTensor,
+        batch_size: int,
+        num_segment_concat_wordpieces: int,
+    ) -> torch.FloatTensor:
+        """
+        We take 2D segments of a long sequence and flatten them out to get the whole sequence
+        representation while remove unnecessary special tokens.
+
+        [ [ [CLS]_emb A_emb B_emb C_emb [SEP]_emb ], [ [CLS]_emb D_emb E_emb [SEP]_emb [PAD]_emb ] ]
+        -> [ [CLS]_emb A_emb B_emb C_emb D_emb E_emb [SEP]_emb ]
+
+        We truncate the start and end tokens for all segments, recombine the segments,
+        and manually add back the start and end tokens.
+
+        # Parameters
+
+        embeddings: `torch.FloatTensor`
+            Shape: [batch_size * num_segments, self._max_length, embedding_size].
+        mask: `torch.BoolTensor`
+            Shape: [batch_size * num_segments, self._max_length].
+            The mask for the concatenated segments of wordpieces. The same as `segment_concat_mask`
+            in `forward()`.
+        batch_size: `int`
+        num_segment_concat_wordpieces: `int`
+            The length of the original "[ [CLS] A B C [SEP] [CLS] D E F [SEP] ]", i.e.
+            the original `token_ids.size(1)`.
+
+        # Returns:
+
+        embeddings: `torch.FloatTensor`
+            Shape: [batch_size, self._num_wordpieces, embedding_size].
+        """
+
+        def lengths_to_mask(lengths, max_len, device):
+            return torch.arange(max_len, device=device).expand(
+                lengths.size(0), max_len
+            ) < lengths.unsqueeze(1)
+
+        device = embeddings.device
+        num_segments = int(embeddings.size(0) / batch_size)
+        embedding_size = embeddings.size(2)
+
+        # We want to remove all segment-level special tokens but maintain sequence-level ones
+        num_wordpieces = num_segment_concat_wordpieces - (num_segments - 1) * self._num_added_tokens
+
+        embeddings = embeddings.reshape(
+            batch_size, num_segments * self._max_length, embedding_size  # type: ignore
+        )
+        mask = mask.reshape(batch_size, num_segments * self._max_length)  # type: ignore
+        # We assume that all 1s in the mask precede all 0s, and add an assert for that.
+        # Open an issue on GitHub if this breaks for you.
+        # Shape: (batch_size,)
+        seq_lengths = mask.sum(-1)
+        if not (lengths_to_mask(seq_lengths, mask.size(1), device) == mask).all():
+            raise ValueError(
+                "Long sequence splitting only supports masks with all 1s preceding all 0s."
+            )
+        # Shape: (batch_size, self._num_added_end_tokens); this is a broadcast op
+        end_token_indices = (
+            seq_lengths.unsqueeze(-1) - torch.arange(self._num_added_end_tokens, device=device) - 1
+        )
+
+        # Shape: (batch_size, self._num_added_start_tokens, embedding_size)
+        start_token_embeddings = embeddings[:, : self._num_added_start_tokens, :]
+        # Shape: (batch_size, self._num_added_end_tokens, embedding_size)
+        end_token_embeddings = batched_index_select(embeddings, end_token_indices)
+
+        embeddings = embeddings.reshape(batch_size, num_segments, self._max_length, embedding_size)
+        embeddings = embeddings[
+            :, :, self._num_added_start_tokens : embeddings.size(2) - self._num_added_end_tokens, :
+        ]  # truncate segment-level start/end tokens
+        embeddings = embeddings.reshape(batch_size, -1, embedding_size)  # flatten
+
+        # Now try to put end token embeddings back which is a little tricky.
+
+        # The number of segment each sequence spans, excluding padding. Mimicking ceiling operation.
+        # Shape: (batch_size,)
+        num_effective_segments = (seq_lengths + self._max_length - 1) // self._max_length
+        # The number of indices that end tokens should shift back.
+        num_removed_non_end_tokens = (
+            num_effective_segments * self._num_added_tokens - self._num_added_end_tokens
+        )
+        # Shape: (batch_size, self._num_added_end_tokens)
+        end_token_indices -= num_removed_non_end_tokens.unsqueeze(-1)
+        assert (end_token_indices >= self._num_added_start_tokens).all()
+        # Add space for end embeddings
+        embeddings = torch.cat([embeddings, torch.zeros_like(end_token_embeddings)], 1)
+        # Add end token embeddings back
+        embeddings.scatter_(
+            1, end_token_indices.unsqueeze(-1).expand_as(end_token_embeddings), end_token_embeddings
+        )
+
+        # Now put back start tokens. We can do this before putting back end tokens, but then
+        # we need to change `num_removed_non_end_tokens` a little.
+        embeddings = torch.cat([start_token_embeddings, embeddings], 1)
+
+        # Truncate to original length
+        embeddings = embeddings[:, :num_wordpieces, :]
+        return embeddings
\ No newline at end of file
diff --git a/combo/modules/token_embedders/pretrained_transformer_mismatched_embedder.py b/combo/modules/token_embedders/pretrained_transformer_mismatched_embedder.py
new file mode 100644
index 0000000..d6f7a6f
--- /dev/null
+++ b/combo/modules/token_embedders/pretrained_transformer_mismatched_embedder.py
@@ -0,0 +1,183 @@
+"""
+Adapted from AllenNLP
+https://github.com/allenai/allennlp/blob/main/allennlp/modules/token_embedders/pretrained_transformer_mismatched_embedder.py
+"""
+
+from typing import Optional, Dict, Any
+
+
+import torch
+
+from combo.modules.token_embedders import PretrainedTransformerEmbedder, TokenEmbedder
+from combo.nn import util
+
+from combo.config import Registry
+from combo.utils import ConfigurationError
+
+
+@Registry.register(TokenEmbedder, "pretrained_transformer_mismatched")
+class PretrainedTransformerMismatchedEmbedder(TokenEmbedder):
+    """
+    Use this embedder to embed wordpieces given by `PretrainedTransformerMismatchedIndexer`
+    and to get word-level representations.
+
+    Registered as a `TokenEmbedder` with name "pretrained_transformer_mismatched".
+
+    # Parameters
+
+    model_name : `str`
+        The name of the `transformers` model to use. Should be the same as the corresponding
+        `PretrainedTransformerMismatchedIndexer`.
+    max_length : `int`, optional (default = `None`)
+        If positive, folds input token IDs into multiple segments of this length, pass them
+        through the transformer model independently, and concatenate the final representations.
+        Should be set to the same value as the `max_length` option on the
+        `PretrainedTransformerMismatchedIndexer`.
+    sub_module: `str`, optional (default = `None`)
+        The name of a submodule of the transformer to be used as the embedder. Some transformers naturally act
+        as embedders such as BERT. However, other models consist of encoder and decoder, in which case we just
+        want to use the encoder.
+    train_parameters: `bool`, optional (default = `True`)
+        If this is `True`, the transformer weights get updated during training.
+    last_layer_only: `bool`, optional (default = `True`)
+        When `True` (the default), only the final layer of the pretrained transformer is taken
+        for the embeddings. But if set to `False`, a scalar mix of all of the layers
+        is used.
+    override_weights_file: `Optional[str]`, optional (default = `None`)
+        If set, this specifies a file from which to load alternate weights that override the
+        weights from huggingface. The file is expected to contain a PyTorch `state_dict`, created
+        with `torch.save()`.
+    override_weights_strip_prefix: `Optional[str]`, optional (default = `None`)
+        If set, strip the given prefix from the state dict when loading it.
+    load_weights: `bool`, optional (default = `True`)
+        Whether to load the pretrained weights. If you're loading your model/predictor from an AllenNLP archive
+        it usually makes sense to set this to `False` (via the `overrides` parameter)
+        to avoid unnecessarily caching and loading the original pretrained weights,
+        since the archive will already contain all of the weights needed.
+    gradient_checkpointing: `bool`, optional (default = `None`)
+        Enable or disable gradient checkpointing.
+    tokenizer_kwargs: `Dict[str, Any]`, optional (default = `None`)
+        Dictionary with
+        [additional arguments](https://github.com/huggingface/transformers/blob/155c782a2ccd103cf63ad48a2becd7c76a7d2115/transformers/tokenization_utils.py#L691)
+        for `AutoTokenizer.from_pretrained`.
+    transformer_kwargs: `Dict[str, Any]`, optional (default = `None`)
+        Dictionary with
+        [additional arguments](https://github.com/huggingface/transformers/blob/155c782a2ccd103cf63ad48a2becd7c76a7d2115/transformers/modeling_utils.py#L253)
+        for `AutoModel.from_pretrained`.
+    sub_token_mode: `Optional[str]`, optional (default= `avg`)
+        If `sub_token_mode` is set to `first`, return first sub-token representation as word-level representation
+        If `sub_token_mode` is set to `avg`, return average of all the sub-tokens representation as word-level representation
+        If `sub_token_mode` is not specified it defaults to `avg`
+        If invalid `sub_token_mode` is provided, throw `ConfigurationError`
+
+    """  # noqa: E501
+
+    def __init__(
+        self,
+        model_name: str,
+        max_length: int = None,
+        sub_module: str = None,
+        train_parameters: bool = True,
+        last_layer_only: bool = True,
+        override_weights_file: Optional[str] = None,
+        override_weights_strip_prefix: Optional[str] = None,
+        load_weights: bool = True,
+        gradient_checkpointing: Optional[bool] = None,
+        tokenizer_kwargs: Optional[Dict[str, Any]] = None,
+        transformer_kwargs: Optional[Dict[str, Any]] = None,
+        sub_token_mode: Optional[str] = "avg",
+    ) -> None:
+        super().__init__()
+        # The matched version v.s. mismatched
+        self._matched_embedder = PretrainedTransformerEmbedder(
+            model_name,
+            max_length=max_length,
+            sub_module=sub_module,
+            train_parameters=train_parameters,
+            last_layer_only=last_layer_only,
+            override_weights_file=override_weights_file,
+            override_weights_strip_prefix=override_weights_strip_prefix,
+            load_weights=load_weights,
+            gradient_checkpointing=gradient_checkpointing,
+            tokenizer_kwargs=tokenizer_kwargs,
+            transformer_kwargs=transformer_kwargs,
+        )
+        self.sub_token_mode = sub_token_mode
+
+    def get_output_dim(self):
+        return self._matched_embedder.get_output_dim()
+
+    def forward(
+        self,
+        token_ids: torch.LongTensor,
+        mask: torch.BoolTensor,
+        offsets: torch.LongTensor,
+        wordpiece_mask: torch.BoolTensor,
+        type_ids: Optional[torch.LongTensor] = None,
+        segment_concat_mask: Optional[torch.BoolTensor] = None,
+    ) -> torch.Tensor:  # type: ignore
+        """
+        # Parameters
+
+        token_ids: `torch.LongTensor`
+            Shape: [batch_size, num_wordpieces] (for exception see `PretrainedTransformerEmbedder`).
+        mask: `torch.BoolTensor`
+            Shape: [batch_size, num_orig_tokens].
+        offsets: `torch.LongTensor`
+            Shape: [batch_size, num_orig_tokens, 2].
+            Maps indices for the original tokens, i.e. those given as input to the indexer,
+            to a span in token_ids. `token_ids[i][offsets[i][j][0]:offsets[i][j][1] + 1]`
+            corresponds to the original j-th token from the i-th batch.
+        wordpiece_mask: `torch.BoolTensor`
+            Shape: [batch_size, num_wordpieces].
+        type_ids: `Optional[torch.LongTensor]`
+            Shape: [batch_size, num_wordpieces].
+        segment_concat_mask: `Optional[torch.BoolTensor]`
+            See `PretrainedTransformerEmbedder`.
+
+        # Returns
+
+        `torch.Tensor`
+            Shape: [batch_size, num_orig_tokens, embedding_size].
+        """
+        # Shape: [batch_size, num_wordpieces, embedding_size].
+        embeddings = self._matched_embedder(
+            token_ids, wordpiece_mask, type_ids=type_ids, segment_concat_mask=segment_concat_mask
+        )
+
+        # span_embeddings: (batch_size, num_orig_tokens, max_span_length, embedding_size)
+        # span_mask: (batch_size, num_orig_tokens, max_span_length)
+        span_embeddings, span_mask = util.batched_span_select(embeddings.contiguous(), offsets)
+
+        span_mask = span_mask.unsqueeze(-1)
+
+        # Shape: (batch_size, num_orig_tokens, max_span_length, embedding_size)
+        span_embeddings *= span_mask  # zero out paddings
+
+        # If "sub_token_mode" is set to "first", return the first sub-token embedding
+        if self.sub_token_mode == "first":
+            # Select first sub-token embeddings from span embeddings
+            # Shape: (batch_size, num_orig_tokens, embedding_size)
+            orig_embeddings = span_embeddings[:, :, 0, :]
+
+        # If "sub_token_mode" is set to "avg", return the average of embeddings of all sub-tokens of a word
+        elif self.sub_token_mode == "avg":
+            # Sum over embeddings of all sub-tokens of a word
+            # Shape: (batch_size, num_orig_tokens, embedding_size)
+            span_embeddings_sum = span_embeddings.sum(2)
+
+            # Shape (batch_size, num_orig_tokens)
+            span_embeddings_len = span_mask.sum(2)
+
+            # Find the average of sub-tokens embeddings by dividing `span_embedding_sum` by `span_embedding_len`
+            # Shape: (batch_size, num_orig_tokens, embedding_size)
+            orig_embeddings = span_embeddings_sum / torch.clamp_min(span_embeddings_len, 1)
+
+            # All the places where the span length is zero, write in zeros.
+            orig_embeddings[(span_embeddings_len == 0).expand(orig_embeddings.shape)] = 0
+
+        # If invalid "sub_token_mode" is provided, throw error
+        else:
+            raise ConfigurationError(f"Do not recognise 'sub_token_mode' {self.sub_token_mode}")
+
+        return orig_embeddings
diff --git a/combo/modules/token_embedders/projected_words_embedder.py b/combo/modules/token_embedders/projected_words_embedder.py
new file mode 100644
index 0000000..cd5b5dc
--- /dev/null
+++ b/combo/modules/token_embedders/projected_words_embedder.py
@@ -0,0 +1,50 @@
+from typing import Optional
+
+import torch
+from overrides import overrides
+
+from combo.config import Registry
+from combo.data.vocabulary import Vocabulary
+from combo.nn.base import Linear
+from combo.modules.token_embedders.token_embedder import TokenEmbedder
+from combo.modules.token_embedders.embedding import Embedding
+
+
+@Registry.register(TokenEmbedder, "embeddings_projected")
+class ProjectedWordEmbedder(Embedding):
+    """Word embeddings."""
+
+    def __init__(self,
+                 embedding_dim: int,
+                 num_embeddings: int = None,
+                 weight: torch.FloatTensor = None,
+                 padding_index: int = None,
+                 trainable: bool = True,
+                 max_norm: float = None,
+                 norm_type: float = 2.0,
+                 scale_grad_by_freq: bool = False,
+                 sparse: bool = False,
+                 vocab_namespace: str = "tokens",
+                 pretrained_file: str = None,
+                 vocabulary: Vocabulary = None,
+                 projection_layer: Linear = None): # Change to Optional[Linear]
+        super().__init__(
+            embedding_dim=embedding_dim,
+            num_embeddings=num_embeddings,
+            weight=weight,
+            padding_index=padding_index,
+            trainable=trainable,
+            max_norm=max_norm,
+            norm_type=norm_type,
+            scale_grad_by_freq=scale_grad_by_freq,
+            sparse=sparse,
+            vocab_namespace=vocab_namespace,
+            pretrained_file=pretrained_file,
+            vocabulary=vocabulary
+        )
+        self._projection = projection_layer
+        self.output_dim = embedding_dim if projection_layer is None else projection_layer.out_features
+
+    @overrides
+    def get_output_dim(self) -> int:
+        return self.output_dim
diff --git a/combo/models/embeddings.py b/combo/modules/token_embedders/token_embedder.py
similarity index 62%
rename from combo/models/embeddings.py
rename to combo/modules/token_embedders/token_embedder.py
index 35b732a..ed8b619 100644
--- a/combo/models/embeddings.py
+++ b/combo/modules/token_embedders/token_embedder.py
@@ -5,19 +5,19 @@ from overrides import overrides
 from torch import nn
 from torchtext.vocab import Vectors, GloVe, FastText, CharNGram
 
+from combo.config import FromParameters, Registry
 from combo.data import Vocabulary
-from combo.models.base import TimeDistributed
-from combo.models.dilated_cnn import DilatedCnnEncoder
 from combo.models.utils import tiny_value_of_dtype
+from combo.modules.module import Module
+from combo.modules.time_distributed import TimeDistributed
 from combo.utils import ConfigurationError
 
 
-class TokenEmbedder(nn.Module):
+class TokenEmbedder(Module, FromParameters):
     def __init__(self):
         super(TokenEmbedder, self).__init__()
 
-    @property
-    def output_dim(self) -> int:
+    def get_output_dim(self) -> int:
         raise NotImplementedError()
 
     def forward(self,
@@ -26,10 +26,11 @@ class TokenEmbedder(nn.Module):
         raise NotImplementedError()
 
 
-class _TorchEmbedder(TokenEmbedder):
+@Registry.register(TokenEmbedder, 'base')
+class TorchEmbedder(TokenEmbedder):
     def __init__(self,
-                 num_embeddings: int,
-                 embedding_dim: int,
+                 num_TokenEmbedders: int,
+                 TokenEmbedder_dim: int,
                  padding_idx: Optional[int] = None,
                  max_norm: Optional[float] = None,
                  norm_type: float = 2.,
@@ -40,28 +41,28 @@ class _TorchEmbedder(TokenEmbedder):
                  weight: Optional[torch.Tensor] = None,
                  trainable: bool = True,
                  projection_dim: Optional[int] = None):
-        super(_TorchEmbedder, self).__init__()
-        self._embedding_dim = embedding_dim
-        self._embedding = nn.Embedding(num_embeddings=num_embeddings,
-                                       embedding_dim=embedding_dim,
-                                       padding_idx=padding_idx,
-                                       max_norm=max_norm,
-                                       norm_type=norm_type,
-                                       scale_grad_by_freq=scale_grad_by_freq,
-                                       sparse=sparse)
+        super(TorchEmbedder, self).__init__()
+        self._TokenEmbedder_dim = TokenEmbedder_dim
+        # self._TokenEmbedder = nn.TokenEmbedder(num_TokenEmbedders=num_TokenEmbedders,
+        #                                TokenEmbedder_dim=TokenEmbedder_dim,
+        #                                padding_idx=padding_idx,
+        #                                max_norm=max_norm,
+        #                                norm_type=norm_type,
+        #                                scale_grad_by_freq=scale_grad_by_freq,
+        #                                sparse=sparse)
         self.__vocab_namespace = vocab_namespace
         self.__vocab = vocab
 
         if weight is not None:
-            if weight.shape() != (num_embeddings, embedding_dim):
+            if weight.shape() != (num_TokenEmbedders, TokenEmbedder_dim):
                 raise ConfigurationError(
-                    "Weight matrix must be of shape (num_embeddings, embedding_dim)." +
+                    "Weight matrix must be of shape (num_TokenEmbedders, TokenEmbedder_dim)." +
                     f"Got: ({weight.shape()})"
                 )
 
             self.__weight = torch.nn.Parameter(weight, requires_grad=trainable)
         else:
-            self.__weight = torch.nn.Parameter(torch.FloatTensor(num_embeddings, embedding_dim),
+            self.__weight = torch.nn.Parameter(torch.FloatTensor(num_TokenEmbedders, TokenEmbedder_dim),
                                                requires_grad=trainable)
             torch.nn.init.xavier_uniform_(self.__weight)
 
@@ -69,21 +70,17 @@ class _TorchEmbedder(TokenEmbedder):
             self.__weight.data[padding_idx].fill_(0)
 
         if projection_dim:
-            self._projection = torch.nn.Linear(embedding_dim, projection_dim)
-            self._output_dim = projection_dim
+            self._projection = torch.nn.Linear(TokenEmbedder_dim, projection_dim)
+            self.output_dim = projection_dim
         else:
             self._projection = None
-            self._output_dim = embedding_dim
-
-    @overrides
-    def output_dim(self) -> int:
-        return self._output_dim
+            self.output_dim = TokenEmbedder_dim
 
     @overrides
     def forward(self,
                 x: torch.Tensor,
                 char_mask: Optional[torch.BoolTensor] = None) -> torch.Tensor:
-        embedded = self._embedding(x)
+        embedded = self._TokenEmbedder(x)
         if self._projection:
             projection = self._projection
             for p in range(embedded.dim()-2):
@@ -92,6 +89,7 @@ class _TorchEmbedder(TokenEmbedder):
         return embedded
 
 
+@Registry.register(TokenEmbedder, 'torchtext_vectors')
 class _TorchtextVectorsEmbedder(TokenEmbedder):
     """
     Torchtext Vectors object wrapper
@@ -110,7 +108,7 @@ class _TorchtextVectorsEmbedder(TokenEmbedder):
         self.__lower_case_backup = lower_case_backup
 
     @overrides
-    def output_dim(self) -> int:
+    def get_output_dim(self) -> int:
         return len(self.__torchtext_embedder)
 
     @overrides
@@ -120,75 +118,75 @@ class _TorchtextVectorsEmbedder(TokenEmbedder):
         return self.__torchtext_embedder.get_vecs_by_tokens(x, self.__lower_case_backup)
 
 
+@Registry.register(TokenEmbedder, 'glove42b')
 class GloVe42BEmbedder(_TorchtextVectorsEmbedder):
     def __init__(self, dim: int = 300):
         super(GloVe42BEmbedder, self).__init__(GloVe("42B", dim))
 
 
+@Registry.register(TokenEmbedder, 'glove840b')
 class GloVe840BEmbedder(_TorchtextVectorsEmbedder):
     def __init__(self, dim: int = 300):
         super(GloVe840BEmbedder, self).__init__(GloVe("840B", dim))
 
 
+@Registry.register(TokenEmbedder, 'glove_twitter27b')
 class GloVeTwitter27BEmbedder(_TorchtextVectorsEmbedder):
     def __init__(self, dim: int = 300):
         super(GloVeTwitter27BEmbedder, self).__init__(GloVe("twitter.27B", dim))
 
 
+@Registry.register(TokenEmbedder, 'glove6b')
 class GloVe6BEmbedder(_TorchtextVectorsEmbedder):
     def __init__(self, dim: int = 300):
         super(GloVe6BEmbedder, self).__init__(GloVe("6B", dim))
 
 
+@Registry.register(TokenEmbedder, 'fast_text')
 class FastTextEmbedder(_TorchtextVectorsEmbedder):
     def __init__(self, language: str = "en"):
         super(FastTextEmbedder, self).__init__(FastText(language))
 
 
+@Registry.register(TokenEmbedder, 'char_ngram')
 class CharNGramEmbedder(_TorchtextVectorsEmbedder):
     def __init__(self):
         super(CharNGramEmbedder, self).__init__(CharNGram())
 
 
-class CharacterBasedWordEmbedder(TokenEmbedder):
-    def __init__(self,
-                 num_embeddings: int,
-                 embedding_dim: int,
-                 dilated_cnn_encoder: DilatedCnnEncoder):
-        super(CharacterBasedWordEmbedder, self).__init__()
-        self.__embedding_dim = embedding_dim
-        self.__dilated_cnn_encoder = dilated_cnn_encoder
-        self.char_embed = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
-
-    @overrides
-    def output_dim(self) -> int:
-        return self.__embedding_dim
-
-    @overrides
-    def forward(self,
-                x: torch.Tensor,
-                char_mask: Optional[torch.BoolTensor] = None) -> torch.Tensor:
-        if char_mask is None:
-            char_mask = x.new_ones(x.size())
-
-        x = self.char_embed(x)
-        x = x * char_mask.unsqueeze(-1).float()
-        x = self.__dilated_cnn_encoder(x.transpose(2, 3))
-        return torch.max(x, dim=-1)[0]
-
-
-class PretrainedTransformerMismatchedEmbedder(TokenEmbedder):
-    pass
-
-
-class TransformersWordEmbedder(PretrainedTransformerMismatchedEmbedder):
-    pass
-
-
-class FeatsTokenEmbedder(_TorchEmbedder):
+# @Registry.register(RegistryCategory.TokenEmbedder, 'character_based_word')
+# class CharacterBasedWordEmbedder(TokenEmbedder):
+#     def __init__(self,
+#                  num_TokenEmbedders: int,
+#                  TokenEmbedder_dim: int,
+#                  dilated_cnn_encoder: DilatedCnnEncoder):
+#         super(CharacterBasedWordEmbedder, self).__init__()
+#         self.__TokenEmbedder_dim = TokenEmbedder_dim
+#         self.__dilated_cnn_encoder = dilated_cnn_encoder
+#         self.char_embed = nn.TokenEmbedder(num_TokenEmbedders=num_TokenEmbedders, TokenEmbedder_dim=TokenEmbedder_dim)
+#
+#     @overrides
+#     def output_dim(self) -> int:
+#         return self.__TokenEmbedder_dim
+#
+#     @overrides
+#     def forward(self,
+#                 x: torch.Tensor,
+#                 char_mask: Optional[torch.BoolTensor] = None) -> torch.Tensor:
+#         if char_mask is None:
+#             char_mask = x.new_ones(x.size())
+#
+#         x = self.char_embed(x)
+#         x = x * char_mask.unsqueeze(-1).float()
+#         x = self.__dilated_cnn_encoder(x.transpose(2, 3))
+#         return torch.max(x, dim=-1)[0]
+
+
+@Registry.register(TokenEmbedder, 'feats_token')
+class FeatsTokenEmbedder(TorchEmbedder):
     def __init__(self,
-                 num_embeddings: int,
-                 embedding_dim: int,
+                 num_TokenEmbedders: int,
+                 TokenEmbedder_dim: int,
                  padding_idx: Optional[int] = None,
                  max_norm: Optional[float] = None,
                  norm_type: float = 2.,
@@ -198,8 +196,8 @@ class FeatsTokenEmbedder(_TorchEmbedder):
                  vocab: Vocabulary = None,
                  weight: Optional[torch.Tensor] = None,
                  trainable: bool = True):
-        super(FeatsTokenEmbedder, self).__init__(num_embeddings,
-                                                 embedding_dim,
+        super(FeatsTokenEmbedder, self).__init__(num_TokenEmbedders,
+                                                 TokenEmbedder_dim,
                                                  padding_idx,
                                                  max_norm,
                                                  norm_type,
diff --git a/combo/modules/token_embedders/transformers_words_embeddings.py b/combo/modules/token_embedders/transformers_words_embeddings.py
new file mode 100644
index 0000000..c25de4d
--- /dev/null
+++ b/combo/modules/token_embedders/transformers_words_embeddings.py
@@ -0,0 +1,67 @@
+"""
+Adapted from COMBO
+Author: Mateusz Klimaszewski
+"""
+from overrides import overrides
+
+from combo.config import Registry
+from combo.nn.activations import Activation
+from combo.modules.token_embedders.token_embedder import TokenEmbedder
+from combo.modules.token_embedders.pretrained_transformer_mismatched_embedder import PretrainedTransformerMismatchedEmbedder
+import torch
+from combo.nn import base
+from typing import Any, Dict, Optional
+
+
+@Registry.register(TokenEmbedder, "transformers_word_embeddings")
+class TransformersWordEmbedder(PretrainedTransformerMismatchedEmbedder):
+    """
+    Transformers word embeddings as last hidden state + optional projection layers.
+
+    Tested with Bert (but should work for other models as well).
+    """
+
+    authorized_missing_keys = [r"position_ids$"]
+
+    def __init__(self,
+                 model_name: str,
+                 projection_dim: int = 0,
+                 projection_activation: Optional[Activation] = lambda x: x,
+                 projection_dropout_rate: Optional[float] = 0.0,
+                 freeze_transformer: bool = True,
+                 last_layer_only: bool = True,
+                 tokenizer_kwargs: Optional[Dict[str, Any]] = None,
+                 transformer_kwargs: Optional[Dict[str, Any]] = None):
+        super().__init__(model_name,
+                         train_parameters=not freeze_transformer,
+                         last_layer_only=last_layer_only,
+                         tokenizer_kwargs=tokenizer_kwargs,
+                         transformer_kwargs=transformer_kwargs)
+        if projection_dim:
+            self.projection_layer = base.Linear(in_features=super().get_output_dim(),
+                                                out_features=projection_dim,
+                                                dropout_rate=projection_dropout_rate,
+                                                activation=projection_activation)
+            self.output_dim = projection_dim
+        else:
+            self.projection_layer = None
+            self.output_dim = super().get_output_dim()
+
+    #@overrides
+    def forward(
+            self,
+            token_ids: torch.LongTensor,
+            mask: torch.BoolTensor,
+            offsets: torch.LongTensor,
+            wordpiece_mask: torch.BoolTensor,
+            type_ids: Optional[torch.LongTensor] = None,
+            segment_concat_mask: Optional[torch.BoolTensor] = None,
+    ) -> torch.Tensor:  # type: ignore
+        x = super().forward(token_ids, mask, offsets, wordpiece_mask, type_ids, segment_concat_mask)
+        if self.projection_layer:
+            x = self.projection_layer(x)
+        return x
+
+    @overrides
+    def get_output_dim(self) -> int:
+        return self.output_dim
diff --git a/combo/nn/__init__.py b/combo/nn/__init__.py
index a01e009..4216e25 100644
--- a/combo/nn/__init__.py
+++ b/combo/nn/__init__.py
@@ -1 +1,2 @@
-from .regularizers import *
\ No newline at end of file
+from .regularizers import *
+from .activations import *
diff --git a/combo/nn/activations.py b/combo/nn/activations.py
new file mode 100644
index 0000000..c04cdbd
--- /dev/null
+++ b/combo/nn/activations.py
@@ -0,0 +1,100 @@
+import torch
+import torch.nn as nn
+from overrides import overrides
+
+from combo.config.registry import Registry
+
+
+class Activation(nn.Module):
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        raise NotImplementedError
+
+
+@Registry.register(Activation, 'linear')
+class LinearActivation(Activation):
+    @overrides
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        return x
+
+
+@Registry.register(Activation, 'relu')
+class ReLUActivation(Activation):
+    def __init__(self):
+        super().__init__()
+        self.__torch_activation = torch.nn.ReLU()
+    @overrides
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        return self.__torch_activation.forward(x)
+
+
+@Registry.register(Activation, 'leaky_relu')
+class ReLUActivation(Activation):
+    def __init__(self):
+        super().__init__()
+        self.__torch_activation = torch.nn.LeakyReLU()
+    @overrides
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        return self.__torch_activation.forward(x)
+
+
+@Registry.register(Activation, 'gelu')
+class ReLUActivation(Activation):
+    def __init__(self):
+        super().__init__()
+        self.__torch_activation = torch.nn.GELU()
+    @overrides
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        return self.__torch_activation.forward(x)
+
+
+@Registry.register(Activation, 'sigmoid')
+class SigmoidActivation(Activation):
+    def __init__(self):
+        super().__init__()
+        self.__torch_activation = torch.nn.Sigmoid()
+    @overrides
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        return self.__torch_activation.forward(x)
+
+
+@Registry.register(Activation, 'tanh')
+class TanhActivation(Activation):
+    def __init__(self):
+        super().__init__()
+        self.__torch_activation = torch.nn.Tanh()
+    @overrides
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        return self.__torch_activation.forward(x)
+
+
+@Registry.register(Activation, 'mish')
+class MishActivation(Activation):
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        return x * torch.tanh(torch.nn.functional.softplus(x))
+
+
+@Registry.register(Activation, 'swish')
+class SwishActivation(Activation):
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        return x * torch.sigmoid(x)
+
+
+@Registry.register(Activation, 'gelu_new')
+class GeluNew(Activation):
+    """
+    Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also
+    see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
+    """
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        return (
+                0.5
+                * x
+                * (1.0 + torch.tanh(torch.math.sqrt(2.0 / torch.math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
+        )
+
+
+@Registry.register(Activation, 'gelu_fast')
+class GeluFast(Activation):
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
diff --git a/combo/models/base.py b/combo/nn/base.py
similarity index 73%
rename from combo/models/base.py
rename to combo/nn/base.py
index 45eae04..c752066 100644
--- a/combo/models/base.py
+++ b/combo/nn/base.py
@@ -2,20 +2,23 @@ from typing import Dict, Optional, List, Union, Tuple
 
 import torch
 import torch.nn as nn
-from overrides import overrides
 
-from combo.models.combo_nn import Activation
+from combo.config.registry import Registry
+from combo.config.from_parameters import FromParameters
+from combo.nn.activations import Activation
 import combo.utils.checks as checks
 from combo.data.vocabulary import Vocabulary
 from combo.models.utils import masked_cross_entropy
+from combo.modules.module import Module
 from combo.predictors.predictor import Predictor
 
 
-class Linear(nn.Linear):
+@Registry.register_base_class("linear", default=True)
+class Linear(nn.Linear, FromParameters):
     def __init__(self,
                  in_features: int,
                  out_features: int,
-                 activation: Optional[Activation] = None,
+                 activation: Activation = None,
                  dropout_rate: Optional[float] = 0.0):
         super().__init__(in_features, out_features)
         self.activation = activation if activation else self.identity
@@ -34,7 +37,7 @@ class Linear(nn.Linear):
         return x
 
 
-class FeedForward(torch.nn.Module):
+class FeedForward(Module):
     """
     Modified copy of allennlp.modules.feedforward.FeedForward
 
@@ -85,7 +88,7 @@ class FeedForward(torch.nn.Module):
             input_dim: int,
             num_layers: int,
             hidden_dims: Union[int, List[int]],
-            activations: Union[Activation, List[Activation]],
+            activations: List[Activation], # Change to Union[Activation, List[Activation]]
             dropout: Union[float, List[float]] = 0.0,
     ) -> None:
 
@@ -187,7 +190,7 @@ class FeedForwardPredictor(Predictor):
                    input_dim: int,
                    num_layers: int,
                    hidden_dims: List[int],
-                   activations: Union[Activation, List[Activation]],
+                   activations: List[Activation], # TODO: change to Union[Activation, List[Activation]]
                    dropout: Union[float, List[float]] = 0.0,
                    ):
         if len(hidden_dims) + 1 != num_layers:
@@ -205,70 +208,3 @@ class FeedForwardPredictor(Predictor):
             hidden_dims=hidden_dims,
             activations=activations,
             dropout=dropout))
-
-
-"""
-Adapted from AllenNLP
-"""
-
-
-class TimeDistributed(torch.nn.Module):
-    """
-    Given an input shaped like `(batch_size, time_steps, [rest])` and a `Module` that takes
-    inputs like `(batch_size, [rest])`, `TimeDistributed` reshapes the input to be
-    `(batch_size * time_steps, [rest])`, applies the contained `Module`, then reshapes it back.
-
-    Note that while the above gives shapes with `batch_size` first, this `Module` also works if
-    `batch_size` is second - we always just combine the first two dimensions, then split them.
-
-    It also reshapes keyword arguments unless they are not tensors or their name is specified in
-    the optional `pass_through` iterable.
-    """
-
-    def __init__(self, module):
-        super().__init__()
-        self._module = module
-
-    @overrides
-    def forward(self, *inputs, pass_through: List[str] = None, **kwargs):
-
-        pass_through = pass_through or []
-
-        reshaped_inputs = [self._reshape_tensor(input_tensor) for input_tensor in inputs]
-
-        # Need some input to then get the batch_size and time_steps.
-        some_input = None
-        if inputs:
-            some_input = inputs[-1]
-
-        reshaped_kwargs = {}
-        for key, value in kwargs.items():
-            if isinstance(value, torch.Tensor) and key not in pass_through:
-                if some_input is None:
-                    some_input = value
-
-                value = self._reshape_tensor(value)
-
-            reshaped_kwargs[key] = value
-
-        reshaped_outputs = self._module(*reshaped_inputs, **reshaped_kwargs)
-
-        if some_input is None:
-            raise RuntimeError("No input tensor to time-distribute")
-
-        # Now get the output back into the right shape.
-        # (batch_size, time_steps, **output_size)
-        new_size = some_input.size()[:2] + reshaped_outputs.size()[1:]
-        outputs = reshaped_outputs.contiguous().view(new_size)
-
-        return outputs
-
-    @staticmethod
-    def _reshape_tensor(input_tensor):
-        input_size = input_tensor.size()
-        if len(input_size) <= 2:
-            raise RuntimeError(f"No dimension to distribute: {input_size}")
-        # Squash batch_size and time_steps into a single axis; result has shape
-        # (batch_size * time_steps, **input_size).
-        squashed_shape = [-1] + list(input_size[2:])
-        return input_tensor.contiguous().view(*squashed_shape)
diff --git a/combo/nn/util.py b/combo/nn/util.py
index 56c564d..ae0cca7 100644
--- a/combo/nn/util.py
+++ b/combo/nn/util.py
@@ -2,7 +2,7 @@
 Adapted from AllenNLP
 https://github.com/allenai/allennlp/blob/80fb6061e568cb9d6ab5d45b661e86eb61b92c82/allennlp/nn/util.py
 """
-from typing import Union, Dict, Optional, List, Any
+from typing import Union, Dict, Optional, List, Any, NamedTuple
 
 import torch
 
@@ -10,6 +10,113 @@ from combo.common.util import int_to_device
 from combo.utils import ConfigurationError
 
 
+StateDictType = Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"]
+
+
+def tiny_value_of_dtype(dtype: torch.dtype):
+    """
+    Returns a moderately tiny value for a given PyTorch data type that is used to avoid numerical
+    issues such as division by zero.
+    This is different from `info_value_of_dtype(dtype).tiny` because it causes some NaN bugs.
+    Only supports floating point dtypes.
+    """
+    if not dtype.is_floating_point:
+        raise TypeError("Only supports floating point dtypes.")
+    if dtype == torch.float or dtype == torch.double:
+        return 1e-13
+    elif dtype == torch.half:
+        return 1e-4
+    else:
+        raise TypeError("Does not support dtype " + str(dtype))
+
+
+def get_device_of(tensor: torch.Tensor) -> int:
+    """
+    Returns the device of the tensor.
+    """
+    if not tensor.is_cuda:
+        return -1
+    else:
+        return tensor.get_device()
+
+def combine_initial_dims(tensor: torch.Tensor) -> torch.Tensor:
+    """
+    Given a (possibly higher order) tensor of ids with shape
+    (d1, ..., dn, sequence_length)
+    Return a view that's (d1 * ... * dn, sequence_length).
+    If original tensor is 1-d or 2-d, return it as is.
+    """
+    if tensor.dim() <= 2:
+        return tensor
+    else:
+        return tensor.view(-1, tensor.size(-1))
+
+
+def uncombine_initial_dims(tensor: torch.Tensor, original_size: torch.Size) -> torch.Tensor:
+    """
+    Given a tensor of embeddings with shape
+    (d1 * ... * dn, sequence_length, embedding_dim)
+    and the original shape
+    (d1, ..., dn, sequence_length),
+    return the reshaped tensor of embeddings with shape
+    (d1, ..., dn, sequence_length, embedding_dim).
+    If original size is 1-d or 2-d, return it as is.
+    """
+    if len(original_size) <= 2:
+        return tensor
+    else:
+        view_args = list(original_size) + [tensor.size(-1)]
+        return tensor.view(*view_args)
+
+
+def get_range_vector(size: int, device: int) -> torch.Tensor:
+    """
+    Returns a range vector with the desired size, starting at 0. The CUDA implementation
+    is meant to avoid copy data from CPU to GPU.
+    """
+    if device > -1:
+        return torch.cuda.LongTensor(size, device=device).fill_(1).cumsum(0) - 1
+    else:
+        return torch.arange(0, size, dtype=torch.long)
+
+
+def flatten_and_batch_shift_indices(indices: torch.Tensor, sequence_length: int) -> torch.Tensor:
+    if torch.max(indices) >= sequence_length or torch.min(indices) < 0:
+        raise ConfigurationError(
+            f"All elements in indices should be in range (0, {sequence_length - 1})"
+        )
+    offsets = get_range_vector(indices.size(0), get_device_of(indices)) * sequence_length
+    for _ in range(len(indices.size()) - 1):
+        offsets = offsets.unsqueeze(1)
+
+    # Shape: (batch_size, d_1, ..., d_n)
+    offset_indices = indices + offsets
+
+    # Shape: (batch_size * d_1 * ... * d_n)
+    offset_indices = offset_indices.view(-1)
+    return offset_indices
+
+
+def batched_index_select(
+    target: torch.Tensor,
+    indices: torch.LongTensor,
+    flattened_indices: Optional[torch.LongTensor] = None,
+) -> torch.Tensor:
+    if flattened_indices is None:
+        # Shape: (batch_size * d_1 * ... * d_n)
+        flattened_indices = flatten_and_batch_shift_indices(indices, target.size(1))
+
+        # Shape: (batch_size * sequence_length, embedding_size)
+    flattened_target = target.view(-1, target.size(-1))
+
+    # Shape: (batch_size * d_1 * ... * d_n, embedding_size)
+    flattened_selected = flattened_target.index_select(0, flattened_indices)
+    selected_shape = list(indices.size()) + [target.size(-1)]
+    # Shape: (batch_size, d_1, ..., d_n, embedding_size)
+    selected_targets = flattened_selected.view(*selected_shape)
+    return selected_targets
+
+
 def move_to_device(obj, device: Union[torch.device, int]):
     """
     Given a structure (possibly) containing Tensors,
@@ -257,3 +364,33 @@ def get_token_offsets_from_text_field_inputs(
                     return embedder_arg_value
     return None
 
+def _check_incompatible_keys(
+    module, missing_keys: List[str], unexpected_keys: List[str], strict: bool
+):
+    error_msgs: List[str] = []
+    if missing_keys:
+        error_msgs.append(
+            "Missing key(s) in state_dict: {}".format(", ".join(f'"{k}"' for k in missing_keys))
+        )
+    if unexpected_keys:
+        error_msgs.append(
+            "Unexpected key(s) in state_dict: {}".format(
+                ", ".join(f'"{k}"' for k in unexpected_keys)
+            )
+        )
+    if error_msgs and strict:
+        raise RuntimeError(
+            "Error(s) in loading state_dict for {}:\n\t{}".format(
+                module.__class__.__name__, "\n\t".join(error_msgs)
+            )
+        )
+
+class _IncompatibleKeys(NamedTuple):
+    missing_keys: List[str]
+    unexpected_keys: List[str]
+
+    def __repr__(self):
+        if not self.missing_keys and not self.unexpected_keys:
+            return "<All keys matched successfully>"
+        return f"(missing_keys = {self.missing_keys}, unexpected_keys = {self.unexpected_keys})"
+
diff --git a/combo/predict.py b/combo/predict.py
index aa547e4..a9c0607 100644
--- a/combo/predict.py
+++ b/combo/predict.py
@@ -8,8 +8,10 @@ from overrides import overrides
 
 from combo import data, models
 from combo.common import util
-from combo.data import sentence2conllu, tokens2conllu, conllu2sentence, tokenizers, Instance
+from combo.config import Registry
+from combo.data import tokenizers, Instance, conllu2sentence, tokens2conllu, sentence2conllu
 from combo.data.dataset_readers.dataset_reader import DatasetReader
+
 from combo.data.instance import JsonDict
 from combo.predictors.predictor import Predictor
 from combo.utils import download, graph
@@ -17,6 +19,7 @@ from combo.utils import download, graph
 logger = logging.getLogger(__name__)
 
 
+@Registry.register(Predictor, 'combo')
 class COMBO(Predictor):
 
     def __init__(self,
@@ -28,7 +31,6 @@ class COMBO(Predictor):
         super().__init__(model, dataset_reader)
         self.batch_size = batch_size
         self.vocab = model.vocab
-        self.dataset_reader = self._dataset_reader
         self.dataset_reader.generate_labels = False
         self.dataset_reader.lazy = True
         self._tokenizer = tokenizer
@@ -246,6 +248,10 @@ class COMBO(Predictor):
 
         archive = models.load_archive(model_path, cuda_device=cuda_device)
         model = archive.model
-        dataset_reader = DatasetReader.from_params(
-            archive.config["dataset_reader"])
+        dataset_reader_class = archive.config["dataset_reader"].get(
+            "type", "conllu")
+        dataset_reader = Registry.resolve(
+            DatasetReader,
+            dataset_reader_class
+        ).from_parameters(archive.config["dataset_reader"])
         return cls(model, dataset_reader, tokenizer, batch_size)
diff --git a/combo/predictors/predictor.py b/combo/predictors/predictor.py
index e4917a7..94876e4 100644
--- a/combo/predictors/predictor.py
+++ b/combo/predictors/predictor.py
@@ -3,12 +3,11 @@ Adapted from AllenNLP
 https://github.com/allenai/allennlp/blob/main/allennlp/predictors/predictor.py
 """
 
-from typing import List, Iterator, Dict, Tuple, Any, Type, Union
+from typing import List, Iterator, Dict, Tuple, Any
 import logging
 import json
 import re
 from contextlib import contextmanager
-from pathlib import Path
 
 import numpy
 import torch
@@ -17,16 +16,17 @@ from torch import Tensor
 from torch import backends
 
 from combo.common.util import sanitize
+from combo.config import FromParameters
 from combo.data.batch import Batch
 from combo.data.dataset_readers.dataset_reader import DatasetReader
 from combo.data.instance import JsonDict, Instance
-from combo.models.model import Model
+from combo.modules.model import Model
 from combo.nn import util
 
 logger = logging.getLogger(__name__)
 
 
-class Predictor:
+class Predictor(FromParameters):
     """
     a `Predictor` is a thin wrapper around an AllenNLP model that handles JSON -> JSON predictions
     that can be used for serving models through the web API or making predictions in bulk.
@@ -36,7 +36,7 @@ class Predictor:
         if frozen:
             model.eval()
         self._model = model
-        self._dataset_reader = dataset_reader
+        self.dataset_reader = dataset_reader
         self.cuda_device = next(self._model.named_parameters())[1].get_device()
         self._token_offsets: List[Tensor] = []
 
diff --git a/tests/config/__init__.py b/tests/config/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/config/test_configuration.py b/tests/config/test_configuration.py
new file mode 100644
index 0000000..ec69bb0
--- /dev/null
+++ b/tests/config/test_configuration.py
@@ -0,0 +1,31 @@
+import unittest
+
+from combo.config import Registry
+from combo.data import WhitespaceTokenizer, DatasetReader, UniversalDependenciesDatasetReader
+from combo.data.token_indexers.token_characters_indexer import TokenCharactersIndexer
+
+
+class ConfigurationTest(unittest.TestCase):
+
+    def test_dataset_reader_from_registry(self):
+        dataset_reader = Registry.resolve(DatasetReader, 'conllu')()
+        self.assertEqual(type(dataset_reader), UniversalDependenciesDatasetReader)
+
+    def test_dataset_reader_from_registry_with_parameters(self):
+        parameters = {'token_indexers': {'char': {'type': 'token_characters'}},
+                      'use_sem': True}
+        dataset_reader = Registry.resolve(DatasetReader, 'conllu').from_parameters(parameters)
+        self.assertEqual(type(dataset_reader), UniversalDependenciesDatasetReader)
+        self.assertEqual(type(dataset_reader.token_indexers['char']), TokenCharactersIndexer)
+        self.assertEqual(dataset_reader.use_sem, True)
+
+    def test_dataset_reader_from_registry_with_token_indexer_parameters(self):
+        parameters = {'token_indexers': {'char': {'type': 'token_characters',
+                                                  'namespace': 'custom_namespace',
+                                                  'tokenizer': {
+                                                      'type': 'whitespace'
+                                                  }}}}
+        dataset_reader = Registry.resolve(DatasetReader, 'conllu').from_parameters(parameters)
+        self.assertEqual(type(dataset_reader), UniversalDependenciesDatasetReader)
+        self.assertEqual(dataset_reader.token_indexers['char']._namespace, 'custom_namespace')
+        self.assertEqual(type(dataset_reader.token_indexers['char']._character_tokenizer), WhitespaceTokenizer)
-- 
GitLab