diff --git a/combo/common/__init__.py b/combo/common/__init__.py
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..211d5391decfe2b44406fd450e707452dbb8623a 100644
--- a/combo/common/__init__.py
+++ b/combo/common/__init__.py
@@ -0,0 +1,2 @@
+from .logging import *
+from .tqdm import *
diff --git a/combo/common/cached_transformers.py b/combo/common/cached_transformers.py
new file mode 100644
index 0000000000000000000000000000000000000000..627879c9acb4391557930de212505649ab74adff
--- /dev/null
+++ b/combo/common/cached_transformers.py
@@ -0,0 +1,28 @@
+"""
+Adapted from AllenNLP
+https://github.com/allenai/allennlp/blob/main/allennlp/common/cached_transformers.py
+"""
+import logging
+import transformers
+from typing import Dict, Tuple
+
+from combo.common.util import hash_object
+
+
+logger = logging.getLogger(__name__)
+
+_tokenizer_cache: Dict[Tuple[str, str], transformers.PreTrainedTokenizer] = {}
+
+def get_tokenizer(model_name: str, **kwargs) -> transformers.PreTrainedTokenizer:
+
+    cache_key = (model_name, hash_object(kwargs))
+
+    global _tokenizer_cache
+    tokenizer = _tokenizer_cache.get(cache_key, None)
+    if tokenizer is None:
+        tokenizer = transformers.AutoTokenizer.from_pretrained(
+            model_name,
+            **kwargs,
+        )
+        _tokenizer_cache[cache_key] = tokenizer
+    return tokenizer
diff --git a/combo/common/logging.py b/combo/common/logging.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cb652e909b1b021f630b5d9e165cf44be0aa8e8
--- /dev/null
+++ b/combo/common/logging.py
@@ -0,0 +1,55 @@
+"""
+Adapted from AllenNLP
+https://github.com/allenai/allennlp/blob/main/allennlp/common/logging.py
+"""
+
+import logging
+
+
+class AllenNlpLogger(logging.Logger):
+    """
+    A custom subclass of 'logging.Logger' that keeps a set of messages to
+    implement {debug,info,etc.}_once() methods.
+    """
+
+    def __init__(self, name):
+        super().__init__(name)
+        self._seen_msgs = set()
+
+    def debug_once(self, msg, *args, **kwargs):
+        if msg not in self._seen_msgs:
+            self.debug(msg, *args, **kwargs)
+            self._seen_msgs.add(msg)
+
+    def info_once(self, msg, *args, **kwargs):
+        if msg not in self._seen_msgs:
+            self.info(msg, *args, **kwargs)
+            self._seen_msgs.add(msg)
+
+    def warning_once(self, msg, *args, **kwargs):
+        if msg not in self._seen_msgs:
+            self.warning(msg, *args, **kwargs)
+            self._seen_msgs.add(msg)
+
+    def error_once(self, msg, *args, **kwargs):
+        if msg not in self._seen_msgs:
+            self.error(msg, *args, **kwargs)
+            self._seen_msgs.add(msg)
+
+    def critical_once(self, msg, *args, **kwargs):
+        if msg not in self._seen_msgs:
+            self.critical(msg, *args, **kwargs)
+            self._seen_msgs.add(msg)
+
+
+logging.setLoggerClass(AllenNlpLogger)
+logger = logging.getLogger(__name__)
+
+
+FILE_FRIENDLY_LOGGING: bool = False
+"""
+If this flag is set to `True`, we add newlines to tqdm output, even on an interactive terminal, and we slow
+down tqdm's output to only once every 10 seconds.
+
+By default, it is set to `False`.
+"""
diff --git a/combo/common/tqdm.py b/combo/common/tqdm.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb17281d8ea20add2285da781d4b811f6849781a
--- /dev/null
+++ b/combo/common/tqdm.py
@@ -0,0 +1,99 @@
+"""
+Adapted from AllenNLP
+https://github.com/allenai/allennlp/blob/main/allennlp/common/tqdm.py
+"""
+
+import logging
+import sys
+from time import time
+from typing import Optional
+import combo.common.logging as common_logging
+
+try:
+    SHELL = str(type(get_ipython()))  # type:ignore # noqa: F821
+except:  # noqa: E722
+    SHELL = ""
+
+
+if "zmqshell.ZMQInteractiveShell" in SHELL:
+    from tqdm import tqdm_notebook as _tqdm
+else:
+    from tqdm import tqdm as _tqdm
+
+# This is necessary to stop tqdm from hanging
+# when exceptions are raised inside iterators.
+# It should have been fixed in 4.2.1, but it still
+# occurs.
+# TODO(Mark): Remove this once tqdm cleans up after itself properly.
+# https://github.com/tqdm/tqdm/issues/469
+_tqdm.monitor_interval = 0
+
+
+logger = logging.getLogger("tqdm")
+logger.propagate = False
+
+
+def replace_cr_with_newline(message: str) -> str:
+    """
+    TQDM and requests use carriage returns to get the training line to update for each batch
+    without adding more lines to the terminal output. Displaying those in a file won't work
+    correctly, so we'll just make sure that each batch shows up on its one line.
+    """
+    # In addition to carriage returns, nested progress-bars will contain extra new-line
+    # characters and this special control sequence which tells the terminal to move the
+    # cursor one line up.
+    message = message.replace("\r", "").replace("\n", "").replace("", "")
+    if message and message[-1] != "\n":
+        message += "\n"
+    return message
+
+
+class TqdmToLogsWriter(object):
+    def __init__(self):
+        self.last_message_written_time = 0.0
+
+    def write(self, message):
+        file_friendly_message: Optional[str] = None
+        if common_logging.FILE_FRIENDLY_LOGGING:
+            file_friendly_message = replace_cr_with_newline(message)
+            if file_friendly_message.strip():
+                sys.stderr.write(file_friendly_message)
+        else:
+            sys.stderr.write(message)
+
+        # Every 10 seconds we also log the message.
+        now = time()
+        if now - self.last_message_written_time >= 10 or "100%" in message:
+            if file_friendly_message is None:
+                file_friendly_message = replace_cr_with_newline(message)
+            for message in file_friendly_message.split("\n"):
+                message = message.strip()
+                if len(message) > 0:
+                    logger.info(message)
+                    self.last_message_written_time = now
+
+    def flush(self):
+        sys.stderr.flush()
+
+
+class Tqdm:
+    @staticmethod
+    def tqdm(*args, **kwargs):
+        # Use a slower interval when FILE_FRIENDLY_LOGGING is set.
+        default_mininterval = 2.0 if common_logging.FILE_FRIENDLY_LOGGING else 0.1
+
+        new_kwargs = {
+            "file": TqdmToLogsWriter(),
+            "mininterval": default_mininterval,
+            **kwargs,
+        }
+
+        return _tqdm(*args, **new_kwargs)
+
+    @staticmethod
+    def set_lock(lock):
+        _tqdm.set_lock(lock)
+
+    @staticmethod
+    def get_lock():
+        return _tqdm.get_lock()
diff --git a/combo/common/util.py b/combo/common/util.py
index 7cb5c8880228fc6e7f13707510d4d61f4fe7e08b..64d440a0ed20f90c8aeba0856a5785f7c95d2126 100644
--- a/combo/common/util.py
+++ b/combo/common/util.py
@@ -4,10 +4,24 @@ from itertools import islice
 import numpy
 import spacy
 import torch
+import hashlib
+import io
+import base58
+import dill
 
 A = TypeVar("A")
 
 
+def hash_object(o: Any) -> str:
+    """Adapted from AllenNLP"""
+    """Returns a character hash code of arbitrary Python objects."""
+    m = hashlib.blake2b()
+    with io.BytesIO() as buffer:
+        dill.dump(o, buffer)
+        m.update(buffer.getbuffer())
+        return base58.b58encode(m.digest()).decode()
+
+
 def int_to_device(device: Union[int, torch.device]) -> torch.device:
     if isinstance(device, torch.device):
         return device
diff --git a/combo/data/dataset_readers/conll.py b/combo/data/dataset_readers/conll.py
index da08e1ef89f8dd88f1ffc3f404c1c5e6e4052420..d8c2c5df947d918f9a925fcf3e4e60ff3e697de2 100644
--- a/combo/data/dataset_readers/conll.py
+++ b/combo/data/dataset_readers/conll.py
@@ -59,7 +59,7 @@ class ConllDatasetReader(DatasetReader):
     feature_labels : `Sequence[str]`, optional (default=`()`)
         These labels will be loaded as features into the corresponding instance fields:
         `pos` -> `pos_tags`, `chunk` -> `chunk_tags`, `ner` -> `ner_tags`
-        Each will have its own namespace : `pos_tags`, `chunk_tags`, `ner_tags`.
+        Each will have its own _namespace : `pos_tags`, `chunk_tags`, `ner_tags`.
         If you want to use one of the tags as a `feature` in your model, it should be
         specified here.
     convert_to_coding_scheme : `Optional[str]`, optional (default=`None`)
@@ -73,7 +73,7 @@ class ConllDatasetReader(DatasetReader):
         Specifies the coding scheme of the input file.
         Valid options are `IOB1` and `IOB2`.
     label_namespace : `str`, optional (default=`labels`)
-        Specifies the namespace for the chosen `tag_label`.
+        Specifies the _namespace for the chosen `tag_label`.
     """
 
     _VALID_LABELS = {"ner", "pos", "chunk"}
diff --git a/combo/data/dataset_readers/universal_dependencies_dataset_reader.py b/combo/data/dataset_readers/universal_dependencies_dataset_reader.py
index 3a503d440ac54e388d0dbdd88852cc64de8d1e36..33bad6db815fcc5a568a753192b223afe9384a56 100644
--- a/combo/data/dataset_readers/universal_dependencies_dataset_reader.py
+++ b/combo/data/dataset_readers/universal_dependencies_dataset_reader.py
@@ -179,7 +179,7 @@ class UniversalDependenciesDatasetReader(DatasetReader, ABC):
                             indices=enhanced_heads,
                             sequence_field=text_field_deps,
                             labels=enhanced_deprels,
-                            # Label namespace matches regular tree parsing.
+                            # Label _namespace matches regular tree parsing.
                             label_namespace="enhanced_deprel_labels",
                             padding_value=0,
                         )
diff --git a/combo/data/fields/adjacency_field.py b/combo/data/fields/adjacency_field.py
index e4bc4abe5d22282650d666fdf7c185d451da13df..e12a83745717e82b1079e71a4f93ab960e04cb08 100644
--- a/combo/data/fields/adjacency_field.py
+++ b/combo/data/fields/adjacency_field.py
@@ -34,7 +34,7 @@ class AdjacencyField(Field[torch.Tensor]):
     labels : `List[str]`, optional, (default = `None`)
         Optional labels for the edges of the adjacency matrix.
     label_namespace : `str`, optional (default=`'labels'`)
-        The namespace to use for converting tag strings into integers.  We convert tag strings to
+        The _namespace to use for converting tag strings into integers.  We convert tag strings to
         integers for you, and this parameter tells the `Vocabulary` object which mapping from
         strings to integers to use (so that "O" as a tag doesn't get the same idx as "O" as a word).
     padding_value : `int`, optional (default = `-1`)
@@ -50,10 +50,10 @@ class AdjacencyField(Field[torch.Tensor]):
         "_indexed_labels",
     ]
 
-    # It is possible that users want to use this field with a namespace which uses OOV/PAD tokens.
+    # It is possible that users want to use this field with a _namespace which uses OOV/PAD tokens.
     # This warning will be repeated for every instantiation of this class (i.e for every data
     # instance), spewing a lot of warnings so this class variable is used to only log a single
-    # warning per namespace.
+    # warning per _namespace.
     _already_warned_namespaces: Set[str] = set()
 
     def __init__(
@@ -95,7 +95,7 @@ class AdjacencyField(Field[torch.Tensor]):
         if not (self._label_namespace.endswith("labels") or self._label_namespace.endswith("tags")):
             if label_namespace not in self._already_warned_namespaces:
                 logger.warning(
-                    "Your label namespace was '%s'. We recommend you use a namespace "
+                    "Your label _namespace was '%s'. We recommend you use a _namespace "
                     "ending with 'labels' or 'tags', so we don't add UNK and PAD tokens by "
                     "default to your vocabulary.  See documentation for "
                     "`non_padded_namespaces` parameter in Vocabulary.",
@@ -146,7 +146,7 @@ class AdjacencyField(Field[torch.Tensor]):
         return (
             f"AdjacencyField of length {length}\n"
             f"\t\twith indices:\n {formatted_indices}\n"
-            f"\t\tand labels:\n {formatted_labels} \t\tin namespace: '{self._label_namespace}'."
+            f"\t\tand labels:\n {formatted_labels} \t\tin _namespace: '{self._label_namespace}'."
         )
 
     def __len__(self):
diff --git a/combo/data/fields/field.py b/combo/data/fields/field.py
index 952cf975bfc9631a27527a3f68385d6d2ef28aa8..2e890f83557fdb3e6bea73ffaffb17eeee684f37 100644
--- a/combo/data/fields/field.py
+++ b/combo/data/fields/field.py
@@ -48,7 +48,7 @@ class Field(Generic[DataArray]):
         be represented as a combination of word ids and character ids, and you don't want words and
         characters to share the same vocabulary - "a" as a word should get a different idx from "a"
         as a character, and the vocabulary sizes of words and characters are very different.
-        Because of this, the first key in the `counter` object is a `namespace`, like "tokens",
+        Because of this, the first key in the `counter` object is a `_namespace`, like "tokens",
         "token_characters", "tags", or "labels", and the second key is the actual vocabulary item.
         """
         pass
diff --git a/combo/data/fields/label_field.py b/combo/data/fields/label_field.py
index fabe2428f94b60a14c40d7b5295d22962a092293..68f60880a5c93c36e5a6cf99f296b04486f90120 100644
--- a/combo/data/fields/label_field.py
+++ b/combo/data/fields/label_field.py
@@ -26,9 +26,9 @@ class LabelField(Field[torch.Tensor]):
     # Parameters
     label : `Union[str, int]`
     label_namespace : `str`, optional (default=`"labels"`)
-        The namespace to use for converting label strings into integers.  We map label strings to
+        The _namespace to use for converting label strings into integers.  We map label strings to
         integers for you (e.g., "entailment" and "contradiction" get converted to 0, 1, ...),
-        and this namespace tells the `Vocabulary` object which mapping from strings to integers
+        and this _namespace tells the `Vocabulary` object which mapping from strings to integers
         to use (so "entailment" as a label doesn't get the same integer idx as "entailment" as a
         word).  If you have multiple different label fields in your data, you should make sure you
         use different namespaces for each one, always using the suffix "labels" (e.g.,
@@ -41,9 +41,9 @@ class LabelField(Field[torch.Tensor]):
     __slots__ = ["label", "_label_namespace", "_label_id", "_skip_indexing"]
 
     # Most often, you probably don't want to have OOV/PAD tokens with a LabelField, so we warn you
-    # about it when you pick a namespace that will getting these tokens by default.  It is
+    # about it when you pick a _namespace that will getting these tokens by default.  It is
     # possible, however, that you _do_ actually want OOV/PAD tokens with this Field.  This class
-    # variable is used to make sure that we only log a single warning for this per namespace, and
+    # variable is used to make sure that we only log a single warning for this per _namespace, and
     # not every time you create one of these Field objects.
     _already_warned_namespaces: Set[str] = set()
 
@@ -73,7 +73,7 @@ class LabelField(Field[torch.Tensor]):
         if not (self._label_namespace.endswith("labels") or self._label_namespace.endswith("tags")):
             if label_namespace not in self._already_warned_namespaces:
                 logger.warning(
-                    "Your label namespace was '%s'. We recommend you use a namespace "
+                    "Your label _namespace was '%s'. We recommend you use a _namespace "
                     "ending with 'labels' or 'tags', so we don't add UNK and PAD tokens by "
                     "default to your vocabulary.  See documentation for "
                     "`non_padded_namespaces` parameter in Vocabulary.",
@@ -105,7 +105,7 @@ class LabelField(Field[torch.Tensor]):
         return self.label
 
     def __str__(self) -> str:
-        return f"LabelField with label: {self.label} in namespace: '{self._label_namespace}'."
+        return f"LabelField with label: {self.label} in _namespace: '{self._label_namespace}'."
 
     def __len__(self):
         return 1
diff --git a/combo/data/fields/sequence_label_field.py b/combo/data/fields/sequence_label_field.py
index 99c032ee8a8358091fc9a9549e4b19f66abc0613..60d30c6f3959eb276c949aacf0deb596853d4dbe 100644
--- a/combo/data/fields/sequence_label_field.py
+++ b/combo/data/fields/sequence_label_field.py
@@ -35,7 +35,7 @@ class SequenceLabelField(Field[torch.Tensor]):
         A field containing the sequence that this `SequenceLabelField` is labeling.  Most often, this is a
         `TextField`, for tagging individual tokens in a sentence.
     label_namespace : `str`, optional (default=`'labels'`)
-        The namespace to use for converting tag strings into integers.  We convert tag strings to
+        The _namespace to use for converting tag strings into integers.  We convert tag strings to
         integers for you, and this parameter tells the `Vocabulary` object which mapping from
         strings to integers to use (so that "O" as a tag doesn't get the same idx as "O" as a word).
     """
@@ -48,10 +48,10 @@ class SequenceLabelField(Field[torch.Tensor]):
         "_skip_indexing",
     ]
 
-    # It is possible that users want to use this field with a namespace which uses OOV/PAD tokens.
+    # It is possible that users want to use this field with a _namespace which uses OOV/PAD tokens.
     # This warning will be repeated for every instantiation of this class (i.e for every data
     # instance), spewing a lot of warnings so this class variable is used to only log a single
-    # warning per namespace.
+    # warning per _namespace.
     _already_warned_namespaces: Set[str] = set()
 
     def __init__(
@@ -87,7 +87,7 @@ class SequenceLabelField(Field[torch.Tensor]):
         if not (self._label_namespace.endswith("labels") or self._label_namespace.endswith("tags")):
             if label_namespace not in self._already_warned_namespaces:
                 logger.warning(
-                    "Your label namespace was '%s'. We recommend you use a namespace "
+                    "Your label _namespace was '%s'. We recommend you use a _namespace "
                     "ending with 'labels' or 'tags', so we don't add UNK and PAD tokens by "
                     "default to your vocabulary.  See documentation for "
                     "`non_padded_namespaces` parameter in Vocabulary.",
@@ -144,7 +144,7 @@ class SequenceLabelField(Field[torch.Tensor]):
         )
         return (
             f"SequenceLabelField of length {length} with "
-            f"labels:\n {formatted_labels} \t\tin namespace: '{self._label_namespace}'."
+            f"labels:\n {formatted_labels} \t\tin _namespace: '{self._label_namespace}'."
         )
 
     def human_readable_repr(self) -> Union[List[str], List[int]]:
diff --git a/combo/data/fields/sequence_multilabel_field.py b/combo/data/fields/sequence_multilabel_field.py
index d0c19d8e4390d980e2aac6f5db4e19d0ac6a699d..4938d438a0bf14ea369dae000a34df8e9a855a97 100644
--- a/combo/data/fields/sequence_multilabel_field.py
+++ b/combo/data/fields/sequence_multilabel_field.py
@@ -44,9 +44,9 @@ class SequenceMultiLabelField(Field[torch.Tensor]):
         A field containing the sequence that this `SequenceMultiLabelField` is labeling.  Most often, this is a
         `TextField`, for tagging individual tokens in a sentence.
     label_namespace : `str`, optional (default="labels")
-        The namespace to use for converting label strings into integers.  We map label strings to
+        The _namespace to use for converting label strings into integers.  We map label strings to
         integers for you (e.g., "entailment" and "contradiction" get converted to 0, 1, ...),
-        and this namespace tells the `Vocabulary` object which mapping from strings to integers
+        and this _namespace tells the `Vocabulary` object which mapping from strings to integers
         to use (so "entailment" as a label doesn't get the same integer idx as "entailment" as a
         word).  If you have multiple different label fields in your data, you should make sure you
         use different namespaces for each one, always using the suffix "labels" (e.g.,
@@ -86,7 +86,7 @@ class SequenceMultiLabelField(Field[torch.Tensor]):
         if not (self._label_namespace.endswith("labels") or self._label_namespace.endswith("tags")):
             if label_namespace not in self._already_warned_namespaces:
                 logger.warning(
-                    "Your label namespace was '%s'. We recommend you use a namespace "
+                    "Your label _namespace was '%s'. We recommend you use a _namespace "
                     "ending with 'labels' or 'tags', so we don't add UNK and PAD tokens by "
                     "default to your vocabulary.  See documentation for "
                     "`non_padded_namespaces` parameter in Vocabulary.",
@@ -144,5 +144,5 @@ class SequenceMultiLabelField(Field[torch.Tensor]):
         )
         return (
             f"SequenceMultiLabelField of length {length} with "
-            f"labels:\n {formatted_labels} \t\tin namespace: '{self._label_namespace}'."
+            f"labels:\n {formatted_labels} \t\tin _namespace: '{self._label_namespace}'."
         )
diff --git a/combo/data/token_embedders/__init__.py b/combo/data/token_embedders/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/combo/data/token_embedders/embedding.py b/combo/data/token_embedders/embedding.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae17b56d7875f29cdd6bec1876e0f8b02dece43e
--- /dev/null
+++ b/combo/data/token_embedders/embedding.py
@@ -0,0 +1,216 @@
+"""
+Adapted from AllenNLP
+https://github.com/allenai/allennlp/blob/main/allennlp/modules/token_embedders/embedding.py
+"""
+import io
+import itertools
+import logging
+import re
+import tarfile
+import zipfile
+from typing import Any, Iterator, Tuple, NamedTuple, Optional, BinaryIO, Sequence
+
+from combo.utils.typing import cast
+from combo.utils.file_utils import cached_path, get_file_extension
+
+logger = logging.getLogger(__name__)
+
+
+class EmbeddingsFileURI(NamedTuple):
+    main_file_uri: str
+    path_inside_archive: Optional[str] = None
+
+
+def format_embeddings_file_uri(
+    main_file_path_or_url: str, path_inside_archive: Optional[str] = None
+) -> str:
+    if path_inside_archive:
+        return "({})#{}".format(main_file_path_or_url, path_inside_archive)
+    return main_file_path_or_url
+
+
+def parse_embeddings_file_uri(uri: str) -> "EmbeddingsFileURI":
+    match = re.fullmatch(r"\((.*)\)#(.*)", uri)
+    if match:
+        fields = cast(Tuple[str, str], match.groups())
+        return EmbeddingsFileURI(*fields)
+    else:
+        return EmbeddingsFileURI(uri, None)
+
+
+class EmbeddingsTextFile(Iterator[str]):
+    """
+    Utility class for opening embeddings text files. Handles various compression formats,
+    as well as context management.
+
+    # Parameters
+
+    file_uri : `str`
+        It can be:
+
+        * a file system path or a URL of an eventually compressed text file or a zip/tar archive
+          containing a single file.
+        * URI of the type `(archive_path_or_url)#file_path_inside_archive` if the text file
+          is contained in a multi-file archive.
+
+    encoding : `str`
+    cache_dir : `str`
+    """
+
+    DEFAULT_ENCODING = "utf-8"
+
+    def __init__(
+            self, file_uri: str, encoding: str = DEFAULT_ENCODING, cache_dir: str = None
+    ) -> None:
+
+        self.uri = file_uri
+        self._encoding = encoding
+        self._cache_dir = cache_dir
+        self._archive_handle: Any = None  # only if the file is inside an archive
+
+        main_file_uri, path_inside_archive = parse_embeddings_file_uri(file_uri)
+        main_file_local_path = cached_path(main_file_uri, cache_dir=cache_dir)
+
+        if zipfile.is_zipfile(main_file_local_path):  # ZIP archive
+            self._open_inside_zip(main_file_uri, path_inside_archive)
+
+        elif tarfile.is_tarfile(main_file_local_path):  # TAR archive
+            self._open_inside_tar(main_file_uri, path_inside_archive)
+
+        else:  # all the other supported formats, including uncompressed files
+            if path_inside_archive:
+                raise ValueError("Unsupported archive format: %s" + main_file_uri)
+
+            # All the python packages for compressed files share the same interface of io.open
+            extension = get_file_extension(main_file_uri)
+
+            # Some systems don't have support for all of these libraries, so we import them only
+            # when necessary.
+            package = None
+            if extension in [".txt", ".vec"]:
+                package = io
+            elif extension == ".gz":
+                import gzip
+
+                package = gzip
+            elif extension == ".bz2":
+                import bz2
+
+                package = bz2
+            elif extension == ".xz":
+                import lzma
+
+                package = lzma
+
+            if package is None:
+                logger.warning(
+                    'The embeddings file has an unknown file extension "%s". '
+                    "We will assume the file is an (uncompressed) text file",
+                    extension,
+                )
+                package = io
+
+            self._handle = package.open(  # type: ignore
+                main_file_local_path, "rt", encoding=encoding
+            )
+
+        # To use this with tqdm we'd like to know the number of tokens. It's possible that the
+        # first line of the embeddings file contains this: if it does, we want to start iteration
+        # from the 2nd line, otherwise we want to start from the 1st.
+        # Unfortunately, once we read the first line, we cannot move back the file iterator
+        # because the underlying file may be "not seekable"; we use itertools.chain instead.
+        first_line = next(self._handle)  # this moves the iterator forward
+        self.num_tokens = EmbeddingsTextFile._get_num_tokens_from_first_line(first_line)
+        if self.num_tokens:
+            # the first line is a header line: start iterating from the 2nd line
+            self._iterator = self._handle
+        else:
+            # the first line is not a header line: start iterating from the 1st line
+            self._iterator = itertools.chain([first_line], self._handle)
+
+    def _open_inside_zip(self, archive_path: str, member_path: Optional[str] = None) -> None:
+        cached_archive_path = cached_path(archive_path, cache_dir=self._cache_dir)
+        archive = zipfile.ZipFile(cached_archive_path, "r")
+        if member_path is None:
+            members_list = archive.namelist()
+            member_path = self._get_the_only_file_in_the_archive(members_list, archive_path)
+        member_path = cast(str, member_path)
+        member_file = cast(BinaryIO, archive.open(member_path, "r"))
+        self._handle = io.TextIOWrapper(member_file, encoding=self._encoding)
+        self._archive_handle = archive
+
+    def _open_inside_tar(self, archive_path: str, member_path: Optional[str] = None) -> None:
+        cached_archive_path = cached_path(archive_path, cache_dir=self._cache_dir)
+        archive = tarfile.open(cached_archive_path, "r")
+        if member_path is None:
+            members_list = archive.getnames()
+            member_path = self._get_the_only_file_in_the_archive(members_list, archive_path)
+        member_path = cast(str, member_path)
+        member = archive.getmember(member_path)  # raises exception if not present
+        member_file = cast(BinaryIO, archive.extractfile(member))
+        self._handle = io.TextIOWrapper(member_file, encoding=self._encoding)
+        self._archive_handle = archive
+
+    def read(self) -> str:
+        return "".join(self._iterator)
+
+    def readline(self) -> str:
+        return next(self._iterator)
+
+    def close(self) -> None:
+        self._handle.close()
+        if self._archive_handle:
+            self._archive_handle.close()
+
+    def __enter__(self) -> "EmbeddingsTextFile":
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb) -> None:
+        self.close()
+
+    def __iter__(self) -> "EmbeddingsTextFile":
+        return self
+
+    def __next__(self) -> str:
+        return next(self._iterator)
+
+    def __len__(self) -> Optional[int]:
+        if self.num_tokens:
+            return self.num_tokens
+        raise AttributeError(
+            "an object of type EmbeddingsTextFile implements `__len__` only if the underlying "
+            "text file declares the number of tokens (i.e. the number of lines following)"
+            "in the first line. That is not the case of this particular instance."
+        )
+
+    @staticmethod
+    def _get_the_only_file_in_the_archive(members_list: Sequence[str], archive_path: str) -> str:
+        if len(members_list) > 1:
+            raise ValueError(
+                "The archive %s contains multiple files, so you must select "
+                "one of the files inside providing a uri of the type: %s."
+                % (
+                    archive_path,
+                    format_embeddings_file_uri("path_or_url_to_archive", "path_inside_archive"),
+                )
+            )
+        return members_list[0]
+
+    @staticmethod
+    def _get_num_tokens_from_first_line(line: str) -> Optional[int]:
+        """This function takes in input a string and if it contains 1 or 2 integers, it assumes the
+        largest one it the number of tokens. Returns None if the line doesn't match that pattern."""
+        fields = line.split(" ")
+        if 1 <= len(fields) <= 2:
+            try:
+                int_fields = [int(x) for x in fields]
+            except ValueError:
+                return None
+            else:
+                num_tokens = max(int_fields)
+                logger.info(
+                    "Recognized a header line in the embedding file with number of tokens: %d",
+                    num_tokens,
+                )
+                return num_tokens
+        return None
diff --git a/combo/data/token_indexers/__init__.py b/combo/data/token_indexers/__init__.py
index 6e993b952c2145cc6c2e66bf086873ebee298dc2..75df1b396d7e25b8b71c47f4704fcfc31f3cf947 100644
--- a/combo/data/token_indexers/__init__.py
+++ b/combo/data/token_indexers/__init__.py
@@ -1,3 +1,6 @@
 from .token_indexer import IndexedTokenList, TokenIndexer
 from .token_features_indexer import TokenFeatsIndexer
 from .single_id_token_indexer import SingleIdTokenIndexer
+from .pretrained_transformer_indexer import PretrainedTransformerIndexer
+from .pretrained_transformer_mismatched_indexer import PretrainedTransformerMismatchedIndexer
+from .pretrained_transformer_fixed_mismatched_indexer import PretrainedTransformerFixedMismatchedIndexer
diff --git a/combo/data/token_indexers/pretrained_transformer_fixed_mismatched_indexer.py b/combo/data/token_indexers/pretrained_transformer_fixed_mismatched_indexer.py
index 6366a1e43808df3c79b17d5a1643396f4d3c3c46..af7bb0a9afcb01a94d77853fa303c929e3571f5b 100644
--- a/combo/data/token_indexers/pretrained_transformer_fixed_mismatched_indexer.py
+++ b/combo/data/token_indexers/pretrained_transformer_fixed_mismatched_indexer.py
@@ -8,11 +8,9 @@ from typing import Optional, Dict, Any, List, Tuple
 from overrides import overrides
 
 from combo.data import Vocabulary
-from combo.data.tokenizers import Token
 from combo.data.token_indexers import IndexedTokenList
 from combo.data.token_indexers.pretrained_transformer_indexer import PretrainedTransformerIndexer
 from combo.data.token_indexers.pretrained_transformer_mismatched_indexer import PretrainedTransformerMismatchedIndexer
-from combo.data.tokenizers.pretrained_transformer_tokenizer import PretrainedTransformerTokenizer
 
 
 class PretrainedTransformerFixedMismatchedIndexer(PretrainedTransformerMismatchedIndexer):
@@ -63,64 +61,3 @@ class PretrainedTransformerFixedMismatchedIndexer(PretrainedTransformerMismatche
         }
 
         return self._matched_indexer._postprocess_output(output)
-
-
-class PretrainedTransformerIndexer(PretrainedTransformerIndexer):
-
-    def __init__(
-            self,
-            model_name: str,
-            namespace: str = "tags",
-            max_length: int = None,
-            tokenizer_kwargs: Optional[Dict[str, Any]] = None,
-            **kwargs,
-    ) -> None:
-        super().__init__(model_name, namespace, max_length, tokenizer_kwargs, **kwargs)
-        self._namespace = namespace
-        self._allennlp_tokenizer = PretrainedTransformerTokenizer(
-            model_name, tokenizer_kwargs=tokenizer_kwargs
-        )
-        self._tokenizer = self._allennlp_tokenizer.tokenizer
-        self._added_to_vocabulary = False
-
-        self._num_added_start_tokens = len(self._allennlp_tokenizer.single_sequence_start_tokens)
-        self._num_added_end_tokens = len(self._allennlp_tokenizer.single_sequence_end_tokens)
-
-        self._max_length = max_length
-        if self._max_length is not None:
-            num_added_tokens = len(self._allennlp_tokenizer.tokenize("a")) - 1
-            self._effective_max_length = (  # we need to take into account special tokens
-                    self._max_length - num_added_tokens
-            )
-            if self._effective_max_length <= 0:
-                raise ValueError(
-                    "max_length needs to be greater than the number of special tokens inserted."
-                )
-
-
-class PretrainedTransformerTokenizer(PretrainedTransformerTokenizer):
-
-    def _intra_word_tokenize(
-            self, string_tokens: List[str]
-    ) -> Tuple[List[Token], List[Optional[Tuple[int, int]]]]:
-        tokens: List[Token] = []
-        offsets: List[Optional[Tuple[int, int]]] = []
-        for token_string in string_tokens:
-            wordpieces = self.tokenizer.encode_plus(
-                token_string,
-                add_special_tokens=False,
-                return_tensors=None,
-                return_offsets_mapping=False,
-                return_attention_mask=False,
-            )
-            wp_ids = wordpieces["input_ids"]
-
-            if len(wp_ids) > 0:
-                offsets.append((len(tokens), len(tokens) + len(wp_ids) - 1))
-                tokens.extend(
-                    Token(text=wp_text, text_id=wp_id)
-                    for wp_id, wp_text in zip(wp_ids, self.tokenizer.convert_ids_to_tokens(wp_ids))
-                )
-            else:
-                offsets.append(None)
-        return tokens, offsets
diff --git a/combo/data/token_indexers/pretrained_transformer_indexer.py b/combo/data/token_indexers/pretrained_transformer_indexer.py
index 1c5302c4cc5cd4b2169ace7204b51cb1adef1e51..6580f4b5bd48a11b79d9b510c05becabaac0c16a 100644
--- a/combo/data/token_indexers/pretrained_transformer_indexer.py
+++ b/combo/data/token_indexers/pretrained_transformer_indexer.py
@@ -27,10 +27,10 @@ class PretrainedTransformerIndexer(TokenIndexer):
     # Parameters
     model_name : `str`
         The name of the `transformers` model to use.
-    namespace : `str`, optional (default=`tags`)
-        We will add the tokens in the pytorch_transformer vocabulary to this vocabulary namespace.
+    _namespace : `str`, optional (default=`tags`)
+        We will add the tokens in the pytorch_transformer vocabulary to this vocabulary _namespace.
         We use a somewhat confusing default value of `tags` so that we do not add padding or UNK
-        tokens to this namespace, which would break on loading because we wouldn't find our default
+        tokens to this _namespace, which would break on loading because we wouldn't find our default
         OOV token.
     max_length : `int`, optional (default = `None`)
         If not None, split the document into segments of this many tokens (including special tokens)
@@ -48,13 +48,17 @@ class PretrainedTransformerIndexer(TokenIndexer):
         model_name: str,
         namespace: str = "tags",
         max_length: int = None,
+        add_special_tokens: bool = False,
         tokenizer_kwargs: Optional[Dict[str, Any]] = None,
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
         self._namespace = namespace
+        self._model_name = model_name
+        self._add_special_tokens = add_special_tokens
+        self._tokenizer_kwargs = tokenizer_kwargs
         self._allennlp_tokenizer = PretrainedTransformerTokenizer(
-            model_name, tokenizer_kwargs=tokenizer_kwargs
+            model_name, add_special_tokens=add_special_tokens, tokenizer_kwargs=tokenizer_kwargs
         )
         self._tokenizer = self._allennlp_tokenizer.tokenizer
         self._added_to_vocabulary = False
@@ -75,7 +79,7 @@ class PretrainedTransformerIndexer(TokenIndexer):
 
     def _add_encoding_to_vocabulary_if_needed(self, vocab: Vocabulary) -> None:
         """
-        Copies tokens from ```transformers``` model's vocab to the specified namespace.
+        Copies tokens from ```transformers``` model's vocab to the specified _namespace.
         """
         if self._added_to_vocabulary:
             return
@@ -236,6 +240,35 @@ class PretrainedTransformerIndexer(TokenIndexer):
             tensor_dict[key] = tensor
         return tensor_dict
 
+    def _intra_word_tokenize(
+            self, string_tokens: List[str]
+    ) -> Tuple[List[Token], List[Optional[Tuple[int, int]]]]:
+        """
+        Adapted from COMBO
+        Authors: Mateusz Klimaszewski, Lukasz Pszenny
+        """
+        tokens: List[Token] = []
+        offsets: List[Optional[Tuple[int, int]]] = []
+        for token_string in string_tokens:
+            wordpieces = self.tokenizer.encode_plus(
+                token_string,
+                add_special_tokens=False,
+                return_tensors=None,
+                return_offsets_mapping=False,
+                return_attention_mask=False,
+            )
+            wp_ids = wordpieces["input_ids"]
+
+            if len(wp_ids) > 0:
+                offsets.append((len(tokens), len(tokens) + len(wp_ids) - 1))
+                tokens.extend(
+                    Token(text=wp_text, text_id=wp_id)
+                    for wp_id, wp_text in zip(wp_ids, self._tokenizer.convert_ids_to_tokens(wp_ids))
+                )
+            else:
+                offsets.append(None)
+        return tokens, offsets
+
     def __eq__(self, other):
         if isinstance(other, PretrainedTransformerIndexer):
             for key in self.__dict__:
@@ -247,3 +280,12 @@ class PretrainedTransformerIndexer(TokenIndexer):
                     return False
             return True
         return NotImplemented
+
+    def _to_params(self) -> Dict[str, Any]:
+        return {
+            "namespace": self._namespace,
+            "model_name": self._model_name,
+            "add_special_tokens": self._add_special_tokens,
+            "max_length": self._max_length,
+            "tokenizer_kwargs": self._tokenizer_kwargs,
+        }
diff --git a/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py b/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py
index f5105177e8c95d0138a3575053a600a08bb1725d..6d660e7cb40c78e98bae173641c91a28ac6a2ce4 100644
--- a/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py
+++ b/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py
@@ -29,10 +29,10 @@ class PretrainedTransformerMismatchedIndexer(TokenIndexer):
     # Parameters
     model_name : `str`
         The name of the `transformers` model to use.
-    namespace : `str`, optional (default=`tags`)
-        We will add the tokens in the pytorch_transformer vocabulary to this vocabulary namespace.
+    _namespace : `str`, optional (default=`tags`)
+        We will add the tokens in the pytorch_transformer vocabulary to this vocabulary _namespace.
         We use a somewhat confusing default value of `tags` so that we do not add padding or UNK
-        tokens to this namespace, which would break on loading because we wouldn't find our default
+        tokens to this _namespace, which would break on loading because we wouldn't find our default
         OOV token.
     max_length : `int`, optional (default = `None`)
         If positive, split the document into segments of this many tokens (including special tokens)
diff --git a/combo/data/token_indexers/single_id_token_indexer.py b/combo/data/token_indexers/single_id_token_indexer.py
index 576ae8a6bff933d0d441b550622425625b1a9256..143c7861fbf824eb216f1f9526a489a50f13c25d 100644
--- a/combo/data/token_indexers/single_id_token_indexer.py
+++ b/combo/data/token_indexers/single_id_token_indexer.py
@@ -12,17 +12,18 @@ from combo.data.token_indexers import TokenIndexer, IndexedTokenList
 
 _DEFAULT_VALUE = "THIS IS A REALLY UNLIKELY VALUE THAT HAS TO BE A STRING"
 
+
 class SingleIdTokenIndexer(TokenIndexer):
     """
     This :class:`TokenIndexer` represents tokens as single integers.
     Registered as a `TokenIndexer` with name "single_id".
     # Parameters
-    namespace : `Optional[str]`, optional (default=`"tokens"`)
-        We will use this namespace in the :class:`Vocabulary` to map strings to indices.  If you
+    _namespace : `Optional[str]`, optional (default=`"tokens"`)
+        We will use this _namespace in the :class:`Vocabulary` to map strings to indices.  If you
         explicitly pass in `None` here, we will skip indexing and vocabulary lookups.  This means
         that the `feature_name` you use must correspond to an integer value (like `text_id`, for
         instance, which gets set by some tokenizers, such as when using byte encoding).
-    lowercase_tokens : `bool`, optional (default=`False`)
+    _lowercase_tokens : `bool`, optional (default=`False`)
         If `True`, we will call `token.lower()` before getting an index for the token from the
         vocabulary.
     start_tokens : `List[str]`, optional (default=`None`)
@@ -33,7 +34,7 @@ class SingleIdTokenIndexer(TokenIndexer):
         We will use the :class:`Token` attribute with this name as input.  This is potentially
         useful, e.g., for using NER tags instead of (or in addition to) surface forms as your inputs
         (passing `ent_type_` here would do that).  If you use a non-default value here, you almost
-        certainly want to also change the `namespace` parameter, and you might want to give a
+        certainly want to also change the `_namespace` parameter, and you might want to give a
         `default_value`.
     default_value : `str`, optional
         When you want to use a non-default `feature_name`, you sometimes want to have a default
@@ -46,18 +47,18 @@ class SingleIdTokenIndexer(TokenIndexer):
     """
 
     def __init__(
-        self,
-        namespace: Optional[str] = "tokens",
-        lowercase_tokens: bool = False,
-        start_tokens: List[str] = None,
-        end_tokens: List[str] = None,
-        feature_name: str = "text",
-        default_value: str = _DEFAULT_VALUE,
-        token_min_padding_length: int = 0,
+            self,
+            namespace: Optional[str] = "tokens",
+            lowercase_tokens: bool = False,
+            start_tokens: List[str] = None,
+            end_tokens: List[str] = None,
+            feature_name: str = "text",
+            default_value: str = _DEFAULT_VALUE,
+            token_min_padding_length: int = 0,
     ) -> None:
         super().__init__(token_min_padding_length)
-        self.namespace = namespace
-        self.lowercase_tokens = lowercase_tokens
+        self._namespace = namespace
+        self._lowercase_tokens = lowercase_tokens
 
         self._start_tokens = [Token(st) for st in (start_tokens or [])]
         self._end_tokens = [Token(et) for et in (end_tokens or [])]
@@ -65,26 +66,26 @@ class SingleIdTokenIndexer(TokenIndexer):
         self._default_value = default_value
 
     def count_vocab_items(self, token: Token, counter: Dict[str, Dict[str, int]]):
-        if self.namespace is not None:
+        if self._namespace is not None:
             text = self._get_feature_value(token)
-            if self.lowercase_tokens:
+            if self._lowercase_tokens:
                 text = text.lower()
-            counter[self.namespace][text] += 1
+            counter[self._namespace][text] += 1
 
     def tokens_to_indices(
-        self, tokens: List[Token], vocabulary: Vocabulary
+            self, tokens: List[Token], vocabulary: Vocabulary
     ) -> Dict[str, List[int]]:
         indices: List[int] = []
 
         for token in itertools.chain(self._start_tokens, tokens, self._end_tokens):
             text = self._get_feature_value(token)
-            if self.namespace is None:
+            if self._namespace is None:
                 # We could have a check here that `text` is an int; not sure it's worth it.
                 indices.append(text)  # type: ignore
             else:
-                if self.lowercase_tokens:
+                if self._lowercase_tokens:
                     text = text.lower()
-                indices.append(vocabulary.get_token_index(text, self.namespace))
+                indices.append(vocabulary.get_token_index(text, self._namespace))
 
         return {"tokens": indices}
 
@@ -106,8 +107,8 @@ class SingleIdTokenIndexer(TokenIndexer):
 
     def _to_params(self) -> Dict[str, Any]:
         return {
-            "namespace": self.namespace,
-            "lowercase_tokens": self.lowercase_tokens,
+            "namespace": self._namespace,
+            "lowercase_tokens": self._lowercase_tokens,
             "start_tokens": [t.text for t in self._start_tokens],
             "end_tokens": [t.text for t in self._end_tokens],
             "feature_name": self._feature_name,
diff --git a/combo/data/token_indexers/token_characters_indexer.py b/combo/data/token_indexers/token_characters_indexer.py
index 30197be5fbcd8f8ec16571fde67d86a851aab24a..1227f2e5a8b433a9a17b11c80a1d379b1a883f93 100644
--- a/combo/data/token_indexers/token_characters_indexer.py
+++ b/combo/data/token_indexers/token_characters_indexer.py
@@ -24,8 +24,8 @@ class TokenCharactersIndexer(TokenIndexer):
 
     # Parameters
 
-    namespace : `str`, optional (default=`token_characters`)
-        We will use this namespace in the :class:`Vocabulary` to map the characters in each token
+    _namespace : `str`, optional (default=`token_characters`)
+        We will use this _namespace in the :class:`Vocabulary` to map the characters in each token
         to indices.
     character_tokenizer : `CharacterTokenizer`, optional (default=`CharacterTokenizer()`)
         We use a :class:`CharacterTokenizer` to handle splitting tokens into characters, as it has
diff --git a/combo/data/tokenizers/pretrained_transformer_tokenizer.py b/combo/data/tokenizers/pretrained_transformer_tokenizer.py
index 093bfca16a0a571e69d3d0afb95739bd5c52bb8f..667246037178dbdcbc8fa6602e47b0139aed3624 100644
--- a/combo/data/tokenizers/pretrained_transformer_tokenizer.py
+++ b/combo/data/tokenizers/pretrained_transformer_tokenizer.py
@@ -5,12 +5,8 @@ https://github.com/allenai/allennlp/blob/main/allennlp/data/tokenizers/pretraine
 
 import copy
 import dataclasses
-import hashlib
-import io
 import logging
 from typing import Any, Dict, List, Optional, Tuple, Iterable
-import base58
-import dill
 
 from transformers import PreTrainedTokenizer, AutoTokenizer
 
@@ -20,27 +16,6 @@ from combo.utils import sanitize_wordpiece
 
 logger = logging.getLogger(__name__)
 
-def hash_object(o: Any) -> str:
-    """Returns a character hash code of arbitrary Python objects."""
-    m = hashlib.blake2b()
-    with io.BytesIO() as buffer:
-        dill.dump(o, buffer)
-        m.update(buffer.getbuffer())
-        return base58.b58encode(m.digest()).decode()
-
-def get_tokenizer(model_name: str, **kwargs) -> PreTrainedTokenizer:
-    cache_key = (model_name, hash_object(kwargs))
-
-    global _tokenizer_cache
-    tokenizer = _tokenizer_cache.get(cache_key, None)
-    if tokenizer is None:
-        tokenizer = AutoTokenizer.from_pretrained(
-            model_name,
-            **kwargs,
-        )
-        _tokenizer_cache[cache_key] = tokenizer
-    return tokenizer
-
 
 class PretrainedTransformerTokenizer(Tokenizer):
     """
@@ -72,13 +47,15 @@ class PretrainedTransformerTokenizer(Tokenizer):
     """  # noqa: E501
 
     def __init__(
-        self,
-        model_name: str,
-        add_special_tokens: bool = True,
-        max_length: Optional[int] = None,
-        tokenizer_kwargs: Optional[Dict[str, Any]] = None,
-        verification_tokens: Optional[Tuple[str, str]] = None,
+            self,
+            model_name: str,
+            add_special_tokens: bool = True,
+            max_length: Optional[int] = None,
+            tokenizer_kwargs: Optional[Dict[str, Any]] = None,
+            verification_tokens: Optional[Tuple[str, str]] = None,
     ) -> None:
+        from combo.common.cached_transformers import get_tokenizer
+
         if tokenizer_kwargs is None:
             tokenizer_kwargs = {}
         else:
@@ -110,12 +87,13 @@ class PretrainedTransformerTokenizer(Tokenizer):
             self._reverse_engineer_special_tokens(token_a, token_b, model_name, tokenizer_kwargs)
 
     def _reverse_engineer_special_tokens(
-        self,
-        token_a: str,
-        token_b: str,
-        model_name: str,
-        tokenizer_kwargs: Optional[Dict[str, Any]],
+            self,
+            token_a: str,
+            token_b: str,
+            model_name: str,
+            tokenizer_kwargs: Optional[Dict[str, Any]],
     ):
+        from combo.common.cached_transformers import get_tokenizer
         # storing the special tokens
         self.sequence_pair_start_tokens = []
         self.sequence_pair_mid_tokens = []
@@ -157,15 +135,15 @@ class PretrainedTransformerTokenizer(Tokenizer):
         seen_dummy_a = False
         seen_dummy_b = False
         for token_id, token_type_id in zip(
-            dummy_output["input_ids"], dummy_output["token_type_ids"]
+                dummy_output["input_ids"], dummy_output["token_type_ids"]
         ):
             if token_id == dummy_a:
                 if seen_dummy_a or seen_dummy_b:  # seeing a twice or b before a
                     raise ValueError("Cannot auto-determine the number of special tokens added.")
                 seen_dummy_a = True
                 assert (
-                    self.sequence_pair_first_token_type_id is None
-                    or self.sequence_pair_first_token_type_id == token_type_id
+                        self.sequence_pair_first_token_type_id is None
+                        or self.sequence_pair_first_token_type_id == token_type_id
                 ), "multiple different token type ids found for the first sequence"
                 self.sequence_pair_first_token_type_id = token_type_id
                 continue
@@ -175,8 +153,8 @@ class PretrainedTransformerTokenizer(Tokenizer):
                     raise ValueError("Cannot auto-determine the number of special tokens added.")
                 seen_dummy_b = True
                 assert (
-                    self.sequence_pair_second_token_type_id is None
-                    or self.sequence_pair_second_token_type_id == token_type_id
+                        self.sequence_pair_second_token_type_id is None
+                        or self.sequence_pair_second_token_type_id == token_type_id
                 ), "multiple different token type ids found for the second sequence"
                 self.sequence_pair_second_token_type_id = token_type_id
                 continue
@@ -194,10 +172,10 @@ class PretrainedTransformerTokenizer(Tokenizer):
                 self.sequence_pair_end_tokens.append(token)
 
         assert (
-            len(self.sequence_pair_start_tokens)
-            + len(self.sequence_pair_mid_tokens)
-            + len(self.sequence_pair_end_tokens)
-        ) == self.tokenizer.num_special_tokens_to_add(pair=True)
+                       len(self.sequence_pair_start_tokens)
+                       + len(self.sequence_pair_mid_tokens)
+                       + len(self.sequence_pair_end_tokens)
+               ) == self.tokenizer.num_special_tokens_to_add(pair=True)
 
         # Reverse-engineer the tokenizer for one sequence
         dummy_output = tokenizer_with_special_tokens.encode_plus(
@@ -214,15 +192,15 @@ class PretrainedTransformerTokenizer(Tokenizer):
 
         seen_dummy_a = False
         for token_id, token_type_id in zip(
-            dummy_output["input_ids"], dummy_output["token_type_ids"]
+                dummy_output["input_ids"], dummy_output["token_type_ids"]
         ):
             if token_id == dummy_a:
                 if seen_dummy_a:
                     raise ValueError("Cannot auto-determine the number of special tokens added.")
                 seen_dummy_a = True
                 assert (
-                    self.single_sequence_token_type_id is None
-                    or self.single_sequence_token_type_id == token_type_id
+                        self.single_sequence_token_type_id is None
+                        or self.single_sequence_token_type_id == token_type_id
                 ), "multiple different token type ids found for the sequence"
                 self.single_sequence_token_type_id = token_type_id
                 continue
@@ -238,8 +216,8 @@ class PretrainedTransformerTokenizer(Tokenizer):
                 self.single_sequence_end_tokens.append(token)
 
         assert (
-            len(self.single_sequence_start_tokens) + len(self.single_sequence_end_tokens)
-        ) == self.tokenizer.num_special_tokens_to_add(pair=False)
+                       len(self.single_sequence_start_tokens) + len(self.single_sequence_end_tokens)
+               ) == self.tokenizer.num_special_tokens_to_add(pair=False)
 
     @staticmethod
     def tokenizer_lowercases(tokenizer: PreTrainedTokenizer) -> bool:
@@ -284,7 +262,7 @@ class PretrainedTransformerTokenizer(Tokenizer):
 
         tokens = []
         for token_id, token_type_id, special_token_mask, offsets in zip(
-            token_ids, token_type_ids, special_tokens_mask, token_offsets
+                token_ids, token_type_ids, special_tokens_mask, token_offsets
         ):
             # In `special_tokens_mask`, 1s indicate special tokens and 0s indicate regular tokens.
             # NOTE: in transformers v3.4.0 (and probably older versions) the docstring
@@ -304,15 +282,14 @@ class PretrainedTransformerTokenizer(Tokenizer):
                     text=self.tokenizer.convert_ids_to_tokens(token_id, skip_special_tokens=False),
                     text_id=token_id,
                     type_id=token_type_id,
-                    idx=start,
-                    idx_end=end,
+                    idx=(start, end)
                 )
             )
 
         return tokens
 
     def _estimate_character_indices(
-        self, text: str, token_ids: List[int]
+            self, text: str, token_ids: List[int]
     ) -> List[Optional[Tuple[int, int]]]:
         """
         The huggingface tokenizers produce tokens that may or may not be slices from the
@@ -373,7 +350,7 @@ class PretrainedTransformerTokenizer(Tokenizer):
         return token_offsets
 
     def _intra_word_tokenize(
-        self, string_tokens: List[str]
+            self, string_tokens: List[str]
     ) -> Tuple[List[Token], List[Optional[Tuple[int, int]]]]:
         tokens: List[Token] = []
         offsets: List[Optional[Tuple[int, int]]] = []
@@ -399,7 +376,7 @@ class PretrainedTransformerTokenizer(Tokenizer):
 
     @staticmethod
     def _increment_offsets(
-        offsets: Iterable[Optional[Tuple[int, int]]], increment: int
+            offsets: Iterable[Optional[Tuple[int, int]]], increment: int
     ) -> List[Optional[Tuple[int, int]]]:
         return [
             None if offset is None else (offset[0] + increment, offset[1] + increment)
@@ -407,7 +384,7 @@ class PretrainedTransformerTokenizer(Tokenizer):
         ]
 
     def intra_word_tokenize(
-        self, string_tokens: List[str]
+            self, string_tokens: List[str]
     ) -> Tuple[List[Token], List[Optional[Tuple[int, int]]]]:
         """
         Tokenizes each word into wordpieces separately and returns the wordpiece IDs.
@@ -421,7 +398,7 @@ class PretrainedTransformerTokenizer(Tokenizer):
         return tokens, offsets
 
     def intra_word_tokenize_sentence_pair(
-        self, string_tokens_a: List[str], string_tokens_b: List[str]
+            self, string_tokens_a: List[str], string_tokens_b: List[str]
     ) -> Tuple[List[Token], List[Optional[Tuple[int, int]]], List[Optional[Tuple[int, int]]]]:
         """
         Tokenizes each word into wordpieces separately and returns the wordpiece IDs.
@@ -434,9 +411,9 @@ class PretrainedTransformerTokenizer(Tokenizer):
         offsets_b = self._increment_offsets(
             offsets_b,
             (
-                len(self.sequence_pair_start_tokens)
-                + len(tokens_a)
-                + len(self.sequence_pair_mid_tokens)
+                    len(self.sequence_pair_start_tokens)
+                    + len(tokens_a)
+                    + len(self.sequence_pair_mid_tokens)
             ),
         )
         tokens_a = self.add_special_tokens(tokens_a, tokens_b)
@@ -445,7 +422,7 @@ class PretrainedTransformerTokenizer(Tokenizer):
         return tokens_a, offsets_a, offsets_b
 
     def add_special_tokens(
-        self, tokens1: List[Token], tokens2: Optional[List[Token]] = None
+            self, tokens1: List[Token], tokens2: Optional[List[Token]] = None
     ) -> List[Token]:
         def with_new_type_id(tokens: List[Token], type_id: int) -> List[Token]:
             return [dataclasses.replace(t, type_id=type_id) for t in tokens]
@@ -455,17 +432,17 @@ class PretrainedTransformerTokenizer(Tokenizer):
 
         if tokens2 is None:
             return (
-                self.single_sequence_start_tokens
-                + with_new_type_id(tokens1, self.single_sequence_token_type_id)  # type: ignore
-                + self.single_sequence_end_tokens
+                    self.single_sequence_start_tokens
+                    + with_new_type_id(tokens1, self.single_sequence_token_type_id)  # type: ignore
+                    + self.single_sequence_end_tokens
             )
         else:
             return (
-                self.sequence_pair_start_tokens
-                + with_new_type_id(tokens1, self.sequence_pair_first_token_type_id)  # type: ignore
-                + self.sequence_pair_mid_tokens
-                + with_new_type_id(tokens2, self.sequence_pair_second_token_type_id)  # type: ignore
-                + self.sequence_pair_end_tokens
+                    self.sequence_pair_start_tokens
+                    + with_new_type_id(tokens1, self.sequence_pair_first_token_type_id)  # type: ignore
+                    + self.sequence_pair_mid_tokens
+                    + with_new_type_id(tokens2, self.sequence_pair_second_token_type_id)  # type: ignore
+                    + self.sequence_pair_end_tokens
             )
 
     def num_special_tokens_for_sequence(self) -> int:
@@ -473,9 +450,9 @@ class PretrainedTransformerTokenizer(Tokenizer):
 
     def num_special_tokens_for_pair(self) -> int:
         return (
-            len(self.sequence_pair_start_tokens)
-            + len(self.sequence_pair_mid_tokens)
-            + len(self.sequence_pair_end_tokens)
+                len(self.sequence_pair_start_tokens)
+                + len(self.sequence_pair_mid_tokens)
+                + len(self.sequence_pair_end_tokens)
         )
 
     def _to_params(self) -> Dict[str, Any]:
diff --git a/combo/data/tokenizers/token.py b/combo/data/tokenizers/token.py
index d8f2590b04810424334828f206d89e860a0305f8..b282a72ceed559a8b03d315c287089861a6f0fdb 100644
--- a/combo/data/tokenizers/token.py
+++ b/combo/data/tokenizers/token.py
@@ -31,7 +31,8 @@ class Token:
         "subwords",
         "semrel",
         "embeddings",
-        "text_id"
+        "text_id",
+        "type_id"
     ]
 
     text: Optional[str]
@@ -49,6 +50,7 @@ class Token:
     semrel: Optional[str]
     embeddings: Dict[str, List[float]]
     text_id: Optional[int]
+    type_id: Optional[int]
 
     def __init__(self,
                  text: str = None,
@@ -65,7 +67,8 @@ class Token:
                  subwords: List[str] = None,
                  semrel: str = None,
                  embeddings: Dict[str, List[float]] = None,
-                 text_id: int = None) -> None:
+                 text_id: int = None,
+                 type_id: int = None,) -> None:
         _assert_none_or_type(text, str)
 
         self.text = text
@@ -89,6 +92,7 @@ class Token:
             self.embeddings = embeddings
 
         self.text_id = text_id
+        self.type_id = type_id
 
     def __str__(self):
         return self.text
@@ -112,5 +116,6 @@ class Token:
             f"(subwords: {','.join(self.subwords)})"
             f"(semrel: {self.semrel}) "
             f"(embeddings: {self.embeddings}) "
-            f"(text_id: {self.text_id})"
+            f"(text_id: {self.text_id}) "
+            f"(type_id: {self.type_id}) "
         )
diff --git a/combo/data/vocabulary.py b/combo/data/vocabulary.py
index a2e53b8eff298f7e33550a769d22fffc8ebe9b53..345c66616bb3eb854cfac04993c2dd78646c69ae 100644
--- a/combo/data/vocabulary.py
+++ b/combo/data/vocabulary.py
@@ -9,6 +9,10 @@ from torchtext.vocab import vocab as torchtext_vocab
 import logging
 
 from filelock import FileLock
+from transformers import PreTrainedTokenizer
+
+from combo.common import Tqdm
+from combo.data.token_embedders.embedding import EmbeddingsTextFile
 
 logger = logging.Logger(__name__)
 
@@ -21,7 +25,7 @@ DEFAULT_NAMESPACE = "tokens"
 
 def match_namespace(pattern: str, namespace: str):
     if not isinstance(pattern, str):
-        raise ValueError("Pattern and namespace must be string types, got %s and %s." %
+        raise ValueError("Pattern and _namespace must be string types, got %s and %s." %
                          (type(pattern), type(namespace)))
     if pattern == namespace:
         return True
@@ -30,6 +34,23 @@ def match_namespace(pattern: str, namespace: str):
     return False
 
 
+def _read_pretrained_tokens(embeddings_file_uri: str) -> List[str]:
+    # Moving this import to the top breaks everything (cycling import, I guess)
+
+    logger.info("Reading pretrained tokens from: %s", embeddings_file_uri)
+    tokens: List[str] = []
+    with EmbeddingsTextFile(embeddings_file_uri) as embeddings_file:
+        for line_number, line in enumerate(Tqdm.tqdm(embeddings_file), start=1):
+            token_end = line.find(" ")
+            if token_end >= 0:
+                token = line[:token_end]
+                tokens.append(token)
+            else:
+                line_begin = line[:20] + "..." if len(line) > 20 else line
+                logger.warning("Skipping line number %d: %s", line_number, line_begin)
+    return tokens
+
+
 class _NamespaceDependentDefaultDict(defaultdict[str, TorchtextVocab]):
     def __init__(self,
                  non_padded_namespaces: Iterable[str],
@@ -41,7 +62,7 @@ class _NamespaceDependentDefaultDict(defaultdict[str, TorchtextVocab]):
         super().__init__()
 
     def __missing__(self, namespace: str):
-        # Non-padded namespace
+        # Non-padded _namespace
         if any([match_namespace(npn, namespace) for npn in self._non_padded_namespaces]):
             value = torchtext_vocab(OrderedDict([]))
         else:
@@ -128,7 +149,7 @@ class Vocabulary:
     def save_to_files(self, directory: str) -> None:
         """
         Persist this Vocabulary to files, so it can be reloaded later.
-        Each namespace corresponds to one file.
+        Each _namespace corresponds to one file.
 
         Adapred from https://github.com/allenai/allennlp/blob/main/allennlp/data/vocabulary.py
         # Parameters
@@ -149,7 +170,7 @@ class Vocabulary:
                     print(namespace_str, file=namespace_file)
 
             for namespace, vocab in self._vocab.items():
-                # Each namespace gets written to its own file, in index order.
+                # Each _namespace gets written to its own file, in index order.
                 with codecs.open(
                         os.path.join(directory, namespace + ".txt"), "w", "utf-8"
                 ) as token_file:
@@ -165,11 +186,11 @@ class Vocabulary:
 
     def add_token_to_namespace(self, token: str, namespace: str = DEFAULT_NAMESPACE):
         """
-        Add the token if not present and return the index even if token was already in the namespace.
+        Add the token if not present and return the index even if token was already in the _namespace.
 
         :param token: token to be added
-        :param namespace: namespace to add the token to
-        :return: index of the token in the namespace
+        :param namespace: _namespace to add the token to
+        :return: index of the token in the _namespace
         """
 
         if not isinstance(token, str):
@@ -179,11 +200,11 @@ class Vocabulary:
 
     def add_tokens_to_namespace(self, tokens: List[str], namespace: str = DEFAULT_NAMESPACE):
         """
-        Add the token if not present and return the index even if token was already in the namespace.
+        Add the token if not present and return the index even if token was already in the _namespace.
 
         :param tokens: tokens to be added
-        :param namespace: namespace to add the token to
-        :return: index of the token in the namespace
+        :param namespace: _namespace to add the token to
+        :return: index of the token in the _namespace
         """
 
         if not isinstance(tokens, List):
@@ -193,6 +214,24 @@ class Vocabulary:
         for token in tokens:
             self._vocab[namespace].append_token(token)
 
+    def add_transformer_vocab(
+        self, tokenizer: PreTrainedTokenizer, namespace: str = "tokens"
+    ) -> None:
+        """
+        Copies tokens from a transformer tokenizer's vocab into the given namespace.
+        """
+        try:
+            vocab_items = tokenizer.get_vocab().items()
+        except NotImplementedError:
+            vocab_items = (
+                (tokenizer.convert_ids_to_tokens(idx), idx) for idx in range(tokenizer.vocab_size)
+            )
+
+        for word, idx in vocab_items:
+            self._vocab[namespace].insert_token(token=word, index=idx)
+
+        self._non_padded_namespaces.add(namespace)
+
     def get_index_to_token_vocabulary(self, namespace: str = DEFAULT_NAMESPACE) -> Dict[int, str]:
         if not isinstance(namespace, str):
             raise ValueError(
diff --git a/combo/utils/__init__.py b/combo/utils/__init__.py
index 791484edb2584a62a905dada70365db332193dcf..b36a4d9c69c70253fa56c6e7fce558a75463e0c5 100644
--- a/combo/utils/__init__.py
+++ b/combo/utils/__init__.py
@@ -1,3 +1,4 @@
 from .checks import *
 from .sequence import *
-from .exceptions import *
\ No newline at end of file
+from .exceptions import *
+from .typing import *
\ No newline at end of file
diff --git a/combo/utils/file_utils.py b/combo/utils/file_utils.py
index 96fe63b9627c48449b967aadf1949127c935807d..fc41d46ba8aaf651a456642d838009907e007352 100644
--- a/combo/utils/file_utils.py
+++ b/combo/utils/file_utils.py
@@ -13,6 +13,12 @@ CACHE_ROOT = Path(os.getenv("COMBO_CACHE_ROOT", Path.home() / ".combo"))
 CACHE_DIRECTORY = str(CACHE_ROOT / "cache")
 
 
+def get_file_extension(path: str, dot=True, lower: bool = True):
+    ext = os.path.splitext(path)[1]
+    ext = ext if dot else ext[1:]
+    return ext.lower() if lower else ext
+
+
 def cached_path(
     url_or_filename: Union[str, PathLike],
     cache_dir: Union[str, Path] = None,
diff --git a/combo/utils/typing.py b/combo/utils/typing.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba566b0c22ac590ff4233f0b0b9ffe3d99e6ff6c
--- /dev/null
+++ b/combo/utils/typing.py
@@ -0,0 +1,2 @@
+def cast(typ, val):
+    return val