From 801830222e121a4a02e40e8a101cad98e51d4409 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Maja=20Jab=C5=82o=C5=84ska?= <majajjablonska@gmail.com>
Date: Fri, 24 Mar 2023 20:52:06 +0100
Subject: [PATCH] Add TokenIndexers and Tokenizers

---
 combo/data/__init__.py                        |   7 +-
 combo/data/token.py                           |   2 -
 combo/data/token_indexers/__init__.py         |   4 +-
 combo/data/token_indexers/base_indexer.py     |   6 -
 ...ed_transformer_fixed_mismatched_indexer.py | 120 +++++
 .../pretrained_transformer_indexer.py         | 243 +++++++++
 ...etrained_transformer_mismatched_indexer.py | 119 ++++-
 .../token_characters_indexer.py               | 149 +++++-
 .../token_const_padding_characters_indexer.py |  62 +++
 .../token_indexers/token_features_indexer.py  |  76 ++-
 combo/data/token_indexers/token_indexer.py    | 124 +++++
 combo/data/tokenizers/__init__.py             |   2 +
 combo/data/tokenizers/character_tokenizer.py  |  11 +
 .../pretrained_transformer_tokenizer.py       | 488 ++++++++++++++++++
 combo/data/tokenizers/tokenizer.py            |  15 +
 combo/utils/cached_transformers.py            |   0
 combo/utils/sequence.py                       |  14 +
 requirements.txt                              |   5 +-
 18 files changed, 1418 insertions(+), 29 deletions(-)
 delete mode 100644 combo/data/token.py
 delete mode 100644 combo/data/token_indexers/base_indexer.py
 create mode 100644 combo/data/token_indexers/pretrained_transformer_fixed_mismatched_indexer.py
 create mode 100644 combo/data/token_indexers/pretrained_transformer_indexer.py
 create mode 100644 combo/data/token_indexers/token_const_padding_characters_indexer.py
 create mode 100644 combo/data/token_indexers/token_indexer.py
 create mode 100644 combo/data/tokenizers/__init__.py
 create mode 100644 combo/data/tokenizers/character_tokenizer.py
 create mode 100644 combo/data/tokenizers/pretrained_transformer_tokenizer.py
 create mode 100644 combo/data/tokenizers/tokenizer.py
 create mode 100644 combo/utils/cached_transformers.py

diff --git a/combo/data/__init__.py b/combo/data/__init__.py
index 3c35c82..42bda68 100644
--- a/combo/data/__init__.py
+++ b/combo/data/__init__.py
@@ -1,5 +1,2 @@
-from .samplers import TokenCountBatchSampler
-from .token import Token
-from .token_indexers import *
-from .api import *
-from .vocabulary import Vocabulary
\ No newline at end of file
+from .api import Token
+from .vocabulary import Vocabulary
diff --git a/combo/data/token.py b/combo/data/token.py
deleted file mode 100644
index 048fab0..0000000
--- a/combo/data/token.py
+++ /dev/null
@@ -1,2 +0,0 @@
-class Token:
-    pass
diff --git a/combo/data/token_indexers/__init__.py b/combo/data/token_indexers/__init__.py
index d179540..9fd4ead 100644
--- a/combo/data/token_indexers/__init__.py
+++ b/combo/data/token_indexers/__init__.py
@@ -1,4 +1,2 @@
-from .base_indexer import TokenIndexer
-from .pretrained_transformer_mismatched_indexer import PretrainedTransformerMismatchedIndexer
-from .token_characters_indexer import TokenCharactersIndexer
+from .token_indexer import IndexedTokenList, TokenIndexer
 from .token_features_indexer import TokenFeatsIndexer
diff --git a/combo/data/token_indexers/base_indexer.py b/combo/data/token_indexers/base_indexer.py
deleted file mode 100644
index fa70d63..0000000
--- a/combo/data/token_indexers/base_indexer.py
+++ /dev/null
@@ -1,6 +0,0 @@
-class TokenIndexer:
-    pass
-
-
-class PretrainedTransformerMismatchedIndexer(TokenIndexer):
-    pass
diff --git a/combo/data/token_indexers/pretrained_transformer_fixed_mismatched_indexer.py b/combo/data/token_indexers/pretrained_transformer_fixed_mismatched_indexer.py
new file mode 100644
index 0000000..493f7f4
--- /dev/null
+++ b/combo/data/token_indexers/pretrained_transformer_fixed_mismatched_indexer.py
@@ -0,0 +1,120 @@
+from typing import Optional, Dict, Any, List, Tuple
+
+from overrides import overrides
+
+from combo.data import Vocabulary, Token
+from combo.data.token_indexers import IndexedTokenList
+from combo.data.token_indexers.pretrained_transformer_indexer import PretrainedTransformerIndexer
+from combo.data.token_indexers.pretrained_transformer_mismatched_indexer import PretrainedTransformerMismatchedIndexer
+from combo.data.tokenizers.pretrained_transformer_tokenizer import PretrainedTransformerTokenizer
+
+
+class PretrainedTransformerFixedMismatchedIndexer(PretrainedTransformerMismatchedIndexer):
+
+    def __init__(self, model_name: str, namespace: str = "tags", max_length: int = None,
+                 tokenizer_kwargs: Optional[Dict[str, Any]] = None, **kwargs) -> None:
+        # The matched version v.s. mismatched
+        super().__init__(model_name, namespace, max_length, tokenizer_kwargs, **kwargs)
+        self._matched_indexer = PretrainedTransformerIndexer(
+            model_name,
+            namespace=namespace,
+            max_length=max_length,
+            tokenizer_kwargs=tokenizer_kwargs,
+            **kwargs,
+        )
+        self._allennlp_tokenizer = self._matched_indexer._allennlp_tokenizer
+        self._tokenizer = self._matched_indexer._tokenizer
+        self._num_added_start_tokens = self._matched_indexer._num_added_start_tokens
+        self._num_added_end_tokens = self._matched_indexer._num_added_end_tokens
+
+    @overrides
+    def tokens_to_indices(self,
+                          tokens,
+                          vocabulary: Vocabulary) -> IndexedTokenList:
+        """
+        Method is overridden in order to raise an error while the number of tokens needed to embed a sentence exceeds the
+        maximal input of a model.
+        """
+        self._matched_indexer._add_encoding_to_vocabulary_if_needed(vocabulary)
+
+        wordpieces, offsets = self._allennlp_tokenizer.intra_word_tokenize(
+            [t.ensure_text() for t in tokens])
+
+        if len(wordpieces) > self._tokenizer.max_len_single_sentence:
+            raise ValueError("Following sentence consists of more wordpiece tokens that the model can process:\n" +\
+                             " ".join([str(x) for x in tokens[:10]]) + " ... \n" + \
+                             f"Maximal input: {self._tokenizer.max_len_single_sentence}\n"+ \
+                             f"Current input: {len(wordpieces)}")
+
+        offsets = [x if x is not None else (-1, -1) for x in offsets]
+
+        output: IndexedTokenList = {
+            "token_ids": [t.text_id for t in wordpieces],
+            "mask": [True] * len(tokens),  # for original tokens (i.e. word-level)
+            "type_ids": [t.type_id for t in wordpieces],
+            "offsets": offsets,
+            "wordpiece_mask": [True] * len(wordpieces),  # for wordpieces (i.e. subword-level)
+        }
+
+        return self._matched_indexer._postprocess_output(output)
+
+
+class PretrainedTransformerIndexer(PretrainedTransformerIndexer):
+
+    def __init__(
+            self,
+            model_name: str,
+            namespace: str = "tags",
+            max_length: int = None,
+            tokenizer_kwargs: Optional[Dict[str, Any]] = None,
+            **kwargs,
+    ) -> None:
+        super().__init__(model_name, namespace, max_length, tokenizer_kwargs, **kwargs)
+        self._namespace = namespace
+        self._allennlp_tokenizer = PretrainedTransformerTokenizer(
+            model_name, tokenizer_kwargs=tokenizer_kwargs
+        )
+        self._tokenizer = self._allennlp_tokenizer.tokenizer
+        self._added_to_vocabulary = False
+
+        self._num_added_start_tokens = len(self._allennlp_tokenizer.single_sequence_start_tokens)
+        self._num_added_end_tokens = len(self._allennlp_tokenizer.single_sequence_end_tokens)
+
+        self._max_length = max_length
+        if self._max_length is not None:
+            num_added_tokens = len(self._allennlp_tokenizer.tokenize("a")) - 1
+            self._effective_max_length = (  # we need to take into account special tokens
+                    self._max_length - num_added_tokens
+            )
+            if self._effective_max_length <= 0:
+                raise ValueError(
+                    "max_length needs to be greater than the number of special tokens inserted."
+                )
+
+
+class PretrainedTransformerTokenizer(PretrainedTransformerTokenizer):
+
+    def _intra_word_tokenize(
+            self, string_tokens: List[str]
+    ) -> Tuple[List[Token], List[Optional[Tuple[int, int]]]]:
+        tokens: List[Token] = []
+        offsets: List[Optional[Tuple[int, int]]] = []
+        for token_string in string_tokens:
+            wordpieces = self.tokenizer.encode_plus(
+                token_string,
+                add_special_tokens=False,
+                return_tensors=None,
+                return_offsets_mapping=False,
+                return_attention_mask=False,
+            )
+            wp_ids = wordpieces["input_ids"]
+
+            if len(wp_ids) > 0:
+                offsets.append((len(tokens), len(tokens) + len(wp_ids) - 1))
+                tokens.extend(
+                    Token(text=wp_text, text_id=wp_id)
+                    for wp_id, wp_text in zip(wp_ids, self.tokenizer.convert_ids_to_tokens(wp_ids))
+                )
+            else:
+                offsets.append(None)
+        return tokens, offsets
diff --git a/combo/data/token_indexers/pretrained_transformer_indexer.py b/combo/data/token_indexers/pretrained_transformer_indexer.py
new file mode 100644
index 0000000..e00a9c2
--- /dev/null
+++ b/combo/data/token_indexers/pretrained_transformer_indexer.py
@@ -0,0 +1,243 @@
+from typing import Dict, List, Optional, Tuple, Any
+import logging
+import torch
+
+from combo.data import Vocabulary, Token
+from combo.data.token_indexers import TokenIndexer, IndexedTokenList
+from combo.data.tokenizers.pretrained_transformer_tokenizer import PretrainedTransformerTokenizer
+from combo.utils import pad_sequence_to_length
+
+logger = logging.getLogger(__name__)
+
+
+class PretrainedTransformerIndexer(TokenIndexer):
+    """
+    This `TokenIndexer` assumes that Tokens already have their indexes in them (see `text_id` field).
+    We still require `model_name` because we want to form allennlp vocabulary from pretrained one.
+    This `Indexer` is only really appropriate to use if you've also used a
+    corresponding :class:`PretrainedTransformerTokenizer` to tokenize your input.  Otherwise you'll
+    have a mismatch between your tokens and your vocabulary, and you'll get a lot of UNK tokens.
+    Registered as a `TokenIndexer` with name "pretrained_transformer".
+    # Parameters
+    model_name : `str`
+        The name of the `transformers` model to use.
+    namespace : `str`, optional (default=`tags`)
+        We will add the tokens in the pytorch_transformer vocabulary to this vocabulary namespace.
+        We use a somewhat confusing default value of `tags` so that we do not add padding or UNK
+        tokens to this namespace, which would break on loading because we wouldn't find our default
+        OOV token.
+    max_length : `int`, optional (default = `None`)
+        If not None, split the document into segments of this many tokens (including special tokens)
+        before feeding into the embedder. The embedder embeds these segments independently and
+        concatenate the results to get the original document representation. Should be set to
+        the same value as the `max_length` option on the `PretrainedTransformerEmbedder`.
+    tokenizer_kwargs : `Dict[str, Any]`, optional (default = `None`)
+        Dictionary with
+        [additional arguments](https://github.com/huggingface/transformers/blob/155c782a2ccd103cf63ad48a2becd7c76a7d2115/transformers/tokenization_utils.py#L691)
+        for `AutoTokenizer.from_pretrained`.
+    """  # noqa: E501
+
+    def __init__(
+        self,
+        model_name: str,
+        namespace: str = "tags",
+        max_length: int = None,
+        tokenizer_kwargs: Optional[Dict[str, Any]] = None,
+        **kwargs,
+    ) -> None:
+        super().__init__(**kwargs)
+        self._namespace = namespace
+        self._allennlp_tokenizer = PretrainedTransformerTokenizer(
+            model_name, tokenizer_kwargs=tokenizer_kwargs
+        )
+        self._tokenizer = self._allennlp_tokenizer.tokenizer
+        self._added_to_vocabulary = False
+
+        self._num_added_start_tokens = len(self._allennlp_tokenizer.single_sequence_start_tokens)
+        self._num_added_end_tokens = len(self._allennlp_tokenizer.single_sequence_end_tokens)
+
+        self._max_length = max_length
+        if self._max_length is not None:
+            num_added_tokens = len(self._allennlp_tokenizer.tokenize("a")) - 1
+            self._effective_max_length = (  # we need to take into account special tokens
+                self._max_length - num_added_tokens
+            )
+            if self._effective_max_length <= 0:
+                raise ValueError(
+                    "max_length needs to be greater than the number of special tokens inserted."
+                )
+
+    def _add_encoding_to_vocabulary_if_needed(self, vocab: Vocabulary) -> None:
+        """
+        Copies tokens from ```transformers``` model's vocab to the specified namespace.
+        """
+        if self._added_to_vocabulary:
+            return
+
+        vocab.add_transformer_vocab(self._tokenizer, self._namespace)
+
+        self._added_to_vocabulary = True
+
+    def count_vocab_items(self, token: Token, counter: Dict[str, Dict[str, int]]):
+        # If we only use pretrained models, we don't need to do anything here.
+        pass
+
+    def tokens_to_indices(self, tokens: List[Token], vocabulary: Vocabulary) -> IndexedTokenList:
+        self._add_encoding_to_vocabulary_if_needed(vocabulary)
+
+        indices, type_ids = self._extract_token_and_type_ids(tokens)
+        # The mask has 1 for real tokens and 0 for padding tokens. Only real tokens are attended to.
+        output: IndexedTokenList = {
+            "token_ids": indices,
+            "mask": [True] * len(indices),
+            "type_ids": type_ids or [0] * len(indices),
+        }
+
+        return self._postprocess_output(output)
+
+    def indices_to_tokens(
+        self, indexed_tokens: IndexedTokenList, vocabulary: Vocabulary
+    ) -> List[Token]:
+        self._add_encoding_to_vocabulary_if_needed(vocabulary)
+
+        token_ids = indexed_tokens["token_ids"]
+        type_ids = indexed_tokens.get("type_ids")
+
+        return [
+            Token(
+                text=vocabulary.get_token_from_index(token_ids[i], self._namespace),
+                text_id=token_ids[i],
+                type_id=type_ids[i] if type_ids is not None else None,
+            )
+            for i in range(len(token_ids))
+        ]
+
+    def _extract_token_and_type_ids(self, tokens: List[Token]) -> Tuple[List[int], List[int]]:
+        """
+        Roughly equivalent to `zip(*[(token.text_id, token.type_id) for token in tokens])`,
+        with some checks.
+        """
+        indices: List[int] = []
+        type_ids: List[int] = []
+        for token in tokens:
+            indices.append(
+                token.text_id
+                if token.text_id is not None
+                else self._tokenizer.convert_tokens_to_ids(token.text)
+            )
+            type_ids.append(token.type_id if token.type_id is not None else 0)
+        return indices, type_ids
+
+    def _postprocess_output(self, output: IndexedTokenList) -> IndexedTokenList:
+        """
+        Takes an IndexedTokenList about to be returned by `tokens_to_indices()` and adds any
+        necessary postprocessing, e.g. long sequence splitting.
+        The input should have a `"token_ids"` key corresponding to the token indices. They should
+        have special tokens already inserted.
+        """
+        if self._max_length is not None:
+            # We prepare long indices by converting them to (assuming max_length == 5)
+            # [CLS] A B C [SEP] [CLS] D E F [SEP] ...
+            # Embedder is responsible for folding this 1-d sequence to 2-d and feed to the
+            # transformer model.
+            # TODO(zhaofengw): we aren't respecting word boundaries when segmenting wordpieces.
+
+            indices = output["token_ids"]
+            type_ids = output.get("type_ids", [0] * len(indices))
+
+            # Strips original special tokens
+            indices = indices[
+                self._num_added_start_tokens : len(indices) - self._num_added_end_tokens
+            ]
+            type_ids = type_ids[
+                self._num_added_start_tokens : len(type_ids) - self._num_added_end_tokens
+            ]
+
+            # Folds indices
+            folded_indices = [
+                indices[i : i + self._effective_max_length]
+                for i in range(0, len(indices), self._effective_max_length)
+            ]
+            folded_type_ids = [
+                type_ids[i : i + self._effective_max_length]
+                for i in range(0, len(type_ids), self._effective_max_length)
+            ]
+
+            # Adds special tokens to each segment
+            folded_indices = [
+                self._tokenizer.build_inputs_with_special_tokens(segment)
+                for segment in folded_indices
+            ]
+            single_sequence_start_type_ids = [
+                t.type_id for t in self._allennlp_tokenizer.single_sequence_start_tokens
+            ]
+            single_sequence_end_type_ids = [
+                t.type_id for t in self._allennlp_tokenizer.single_sequence_end_tokens
+            ]
+            folded_type_ids = [
+                single_sequence_start_type_ids + segment + single_sequence_end_type_ids
+                for segment in folded_type_ids
+            ]
+            assert all(
+                len(segment_indices) == len(segment_type_ids)
+                for segment_indices, segment_type_ids in zip(folded_indices, folded_type_ids)
+            )
+
+            # Flattens
+            indices = [i for segment in folded_indices for i in segment]
+            type_ids = [i for segment in folded_type_ids for i in segment]
+
+            output["token_ids"] = indices
+            output["type_ids"] = type_ids
+            output["segment_concat_mask"] = [True] * len(indices)
+
+        return output
+
+    def get_empty_token_list(self) -> IndexedTokenList:
+        output: IndexedTokenList = {"token_ids": [], "mask": [], "type_ids": []}
+        if self._max_length is not None:
+            output["segment_concat_mask"] = []
+        return output
+
+    def as_padded_tensor_dict(
+        self, tokens: IndexedTokenList, padding_lengths: Dict[str, int]
+    ) -> Dict[str, torch.Tensor]:
+        tensor_dict = {}
+        for key, val in tokens.items():
+            if key == "type_ids":
+                padding_value = 0
+                mktensor = torch.LongTensor
+            elif key == "mask" or key == "wordpiece_mask":
+                padding_value = False
+                mktensor = torch.BoolTensor
+            elif len(val) > 0 and isinstance(val[0], bool):
+                padding_value = False
+                mktensor = torch.BoolTensor
+            else:
+                padding_value = self._tokenizer.pad_token_id
+                if padding_value is None:
+                    padding_value = (
+                        0  # Some tokenizers don't have padding tokens and rely on the mask only.
+                    )
+                mktensor = torch.LongTensor
+
+            tensor = mktensor(
+                pad_sequence_to_length(
+                    val, padding_lengths[key], default_value=lambda: padding_value
+                )
+            )
+
+            tensor_dict[key] = tensor
+        return tensor_dict
+
+    def __eq__(self, other):
+        if isinstance(other, PretrainedTransformerIndexer):
+            for key in self.__dict__:
+                if key == "_tokenizer":
+                    # This is a reference to a function in the huggingface code, which we can't
+                    # really modify to make this clean.  So we special-case it.
+                    continue
+                if self.__dict__[key] != other.__dict__[key]:
+                    return False
+            return True
+        return NotImplemented
diff --git a/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py b/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py
index ea6a663..6c88343 100644
--- a/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py
+++ b/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py
@@ -1,13 +1,120 @@
-from .base_indexer import TokenIndexer
+from typing import Dict, List, Any, Optional
+import logging
+
+import torch
+
+from combo.data import Vocabulary, Token
+from combo.data.token_indexers import TokenIndexer, IndexedTokenList
+from combo.utils import pad_sequence_to_length
+
+logger = logging.getLogger(__name__)
 
 
 class PretrainedTransformerMismatchedIndexer(TokenIndexer):
-    pass
+    """
+    Use this indexer when (for whatever reason) you are not using a corresponding
+    `PretrainedTransformerTokenizer` on your input. We assume that you used a tokenizer that splits
+    strings into words, while the transformer expects wordpieces as input. This indexer splits the
+    words into wordpieces and flattens them out. You should use the corresponding
+    `PretrainedTransformerMismatchedEmbedder` to embed these wordpieces and then pull out a single
+    vector for each original word.
+    Registered as a `TokenIndexer` with name "pretrained_transformer_mismatched".
+    # Parameters
+    model_name : `str`
+        The name of the `transformers` model to use.
+    namespace : `str`, optional (default=`tags`)
+        We will add the tokens in the pytorch_transformer vocabulary to this vocabulary namespace.
+        We use a somewhat confusing default value of `tags` so that we do not add padding or UNK
+        tokens to this namespace, which would break on loading because we wouldn't find our default
+        OOV token.
+    max_length : `int`, optional (default = `None`)
+        If positive, split the document into segments of this many tokens (including special tokens)
+        before feeding into the embedder. The embedder embeds these segments independently and
+        concatenate the results to get the original document representation. Should be set to
+        the same value as the `max_length` option on the `PretrainedTransformerMismatchedEmbedder`.
+    tokenizer_kwargs : `Dict[str, Any]`, optional (default = `None`)
+        Dictionary with
+        [additional arguments](https://github.com/huggingface/transformers/blob/155c782a2ccd103cf63ad48a2becd7c76a7d2115/transformers/tokenization_utils.py#L691)
+        for `AutoTokenizer.from_pretrained`.
+    """  # noqa: E501
+
+    def __init__(
+        self,
+        model_name: str,
+        namespace: str = "tags",
+        max_length: int = None,
+        tokenizer_kwargs: Optional[Dict[str, Any]] = None,
+        **kwargs,
+    ) -> None:
+        super().__init__(**kwargs)
+        # The matched version v.s. mismatched
+        self._matched_indexer = PretrainedTransformerIndexer(
+            model_name,
+            namespace=namespace,
+            max_length=max_length,
+            tokenizer_kwargs=tokenizer_kwargs,
+            **kwargs,
+        )
+        self._allennlp_tokenizer = self._matched_indexer._allennlp_tokenizer
+        self._tokenizer = self._matched_indexer._tokenizer
+        self._num_added_start_tokens = self._matched_indexer._num_added_start_tokens
+        self._num_added_end_tokens = self._matched_indexer._num_added_end_tokens
+
+    def count_vocab_items(self, token: Token, counter: Dict[str, Dict[str, int]]):
+        return self._matched_indexer.count_vocab_items(token, counter)
+
+    def tokens_to_indices(self, tokens: List[Token], vocabulary: Vocabulary) -> IndexedTokenList:
+        self._matched_indexer._add_encoding_to_vocabulary_if_needed(vocabulary)
+
+        wordpieces, offsets = self._allennlp_tokenizer.intra_word_tokenize(
+            [t.ensure_text() for t in tokens]
+        )
+
+        # For tokens that don't correspond to any word pieces, we put (-1, -1) into the offsets.
+        # That results in the embedding for the token to be all zeros.
+        offsets = [x if x is not None else (-1, -1) for x in offsets]
+
+        output: IndexedTokenList = {
+            "token_ids": [t.text_id for t in wordpieces],
+            "mask": [True] * len(tokens),  # for original tokens (i.e. word-level)
+            "type_ids": [t.type_id for t in wordpieces],
+            "offsets": offsets,
+            "wordpiece_mask": [True] * len(wordpieces),  # for wordpieces (i.e. subword-level)
+        }
+
+        return self._matched_indexer._postprocess_output(output)
+
+    def get_empty_token_list(self) -> IndexedTokenList:
+        output = self._matched_indexer.get_empty_token_list()
+        output["offsets"] = []
+        output["wordpiece_mask"] = []
+        return output
 
+    def as_padded_tensor_dict(
+        self, tokens: IndexedTokenList, padding_lengths: Dict[str, int]
+    ) -> Dict[str, torch.Tensor]:
+        tokens = tokens.copy()
+        padding_lengths = padding_lengths.copy()
 
-class PretrainedTransformerIndexer(TokenIndexer):
-    pass
+        offsets_tokens = tokens.pop("offsets")
+        offsets_padding_lengths = padding_lengths.pop("offsets")
 
+        tensor_dict = self._matched_indexer.as_padded_tensor_dict(tokens, padding_lengths)
+        tensor_dict["offsets"] = torch.LongTensor(
+            pad_sequence_to_length(
+                offsets_tokens, offsets_padding_lengths, default_value=lambda: (0, 0)
+            )
+        )
+        return tensor_dict
 
-class PretrainedTransformerTokenizer(TokenIndexer):
-    pass
+    def __eq__(self, other):
+        if isinstance(other, PretrainedTransformerMismatchedIndexer):
+            for key in self.__dict__:
+                if key == "_tokenizer":
+                    # This is a reference to a function in the huggingface code, which we can't
+                    # really modify to make this clean.  So we special-case it.
+                    continue
+                if self.__dict__[key] != other.__dict__[key]:
+                    return False
+            return True
+        return NotImplemented
diff --git a/combo/data/token_indexers/token_characters_indexer.py b/combo/data/token_indexers/token_characters_indexer.py
index f99923e..3be23a4 100644
--- a/combo/data/token_indexers/token_characters_indexer.py
+++ b/combo/data/token_indexers/token_characters_indexer.py
@@ -1,6 +1,149 @@
-from .base_indexer import TokenIndexer
+"""
+Adapted from AllenNLP.
+https://github.com/allenai/allennlp/blob/main/allennlp/data/token_indexers/token_characters_indexer.py
+"""
+
+from typing import Dict, List
+import itertools
+import warnings
+
+from overrides import overrides
+import torch
+
+from combo.data import Token, Vocabulary
+from combo.data.token_indexers import TokenIndexer, IndexedTokenList
+from combo.data.tokenizers import CharacterTokenizer
+from combo.utils import ConfigurationError, pad_sequence_to_length
 
 
 class TokenCharactersIndexer(TokenIndexer):
-    """Wrapper around allennlp token indexer with const padding."""
-    pass
+    """
+    This :class:`TokenIndexer` represents tokens as lists of character indices.
+
+    Registered as a `TokenIndexer` with name "characters".
+
+    # Parameters
+
+    namespace : `str`, optional (default=`token_characters`)
+        We will use this namespace in the :class:`Vocabulary` to map the characters in each token
+        to indices.
+    character_tokenizer : `CharacterTokenizer`, optional (default=`CharacterTokenizer()`)
+        We use a :class:`CharacterTokenizer` to handle splitting tokens into characters, as it has
+        options for byte encoding and other things.  The default here is to instantiate a
+        `CharacterTokenizer` with its default parameters, which uses unicode characters and
+        retains casing.
+    start_tokens : `List[str]`, optional (default=`None`)
+        These are prepended to the tokens provided to `tokens_to_indices`.
+    end_tokens : `List[str]`, optional (default=`None`)
+        These are appended to the tokens provided to `tokens_to_indices`.
+    min_padding_length : `int`, optional (default=`0`)
+        We use this value as the minimum length of padding. Usually used with :class:`CnnEncoder`, its
+        value should be set to the maximum value of `ngram_filter_sizes` correspondingly.
+    token_min_padding_length : `int`, optional (default=`0`)
+        See :class:`TokenIndexer`.
+    """
+
+    def __init__(
+        self,
+        namespace: str = "token_characters",
+        character_tokenizer: CharacterTokenizer = CharacterTokenizer(),
+        start_tokens: List[str] = None,
+        end_tokens: List[str] = None,
+        min_padding_length: int = 0,
+        token_min_padding_length: int = 0,
+    ) -> None:
+        super().__init__(token_min_padding_length)
+        if min_padding_length == 0:
+            url = "https://github.com/allenai/allennlp/issues/1954"
+            warnings.warn(
+                "You are using the default value (0) of `min_padding_length`, "
+                f"which can cause some subtle bugs (more info see {url}). "
+                "Strongly recommend to set a value, usually the maximum size "
+                "of the convolutional layer size when using CnnEncoder.",
+                UserWarning,
+            )
+        self._min_padding_length = min_padding_length
+        self._namespace = namespace
+        self._character_tokenizer = character_tokenizer
+
+        self._start_tokens = [Token(st) for st in (start_tokens or [])]
+        self._end_tokens = [Token(et) for et in (end_tokens or [])]
+
+    @overrides
+    def count_vocab_items(self, token: Token, counter: Dict[str, Dict[str, int]]):
+        if token.text is None:
+            raise ConfigurationError("TokenCharactersIndexer needs a tokenizer that retains text")
+        for character in self._character_tokenizer.tokenize(token.text):
+            # If `text_id` is set on the character token (e.g., if we're using byte encoding), we
+            # will not be using the vocab for this character.
+            if getattr(character, "text_id", None) is None:
+                counter[self._namespace][character.text] += 1
+
+    @overrides
+    def tokens_to_indices(
+        self, tokens: List[Token], vocabulary: Vocabulary
+    ) -> Dict[str, List[List[int]]]:
+        indices: List[List[int]] = []
+        for token in itertools.chain(self._start_tokens, tokens, self._end_tokens):
+            token_indices: List[int] = []
+            if token.text is None:
+                raise ConfigurationError(
+                    "TokenCharactersIndexer needs a tokenizer that retains text"
+                )
+            for character in self._character_tokenizer.tokenize(token.text):
+                if getattr(character, "text_id", None) is not None:
+                    # `text_id` being set on the token means that we aren't using the vocab, we just
+                    # use this id instead.
+                    index = character.text_id
+                else:
+                    index = vocabulary.get_token_index(character.text, self._namespace)
+                token_indices.append(index)
+            indices.append(token_indices)
+        return {"token_characters": indices}
+
+    @overrides
+    def get_padding_lengths(self, indexed_tokens: IndexedTokenList) -> Dict[str, int]:
+        padding_lengths = {}
+        padding_lengths["token_characters"] = max(
+            len(indexed_tokens["token_characters"]), self._token_min_padding_length
+        )
+        max_num_characters = self._min_padding_length
+        for token in indexed_tokens["token_characters"]:
+            max_num_characters = max(len(token), max_num_characters)  # type: ignore
+        padding_lengths["num_token_characters"] = max_num_characters
+        return padding_lengths
+
+    @overrides
+    def as_padded_tensor_dict(
+        self, tokens: IndexedTokenList, padding_lengths: Dict[str, int]
+    ) -> Dict[str, torch.Tensor]:
+        # Pad the tokens.
+        padded_tokens = pad_sequence_to_length(
+            tokens["token_characters"],
+            padding_lengths["token_characters"],
+            default_value=lambda: [],
+        )
+
+        # Pad the characters within the tokens.
+        desired_token_length = padding_lengths["num_token_characters"]
+        longest_token: List[int] = max(tokens["token_characters"], key=len, default=[])  # type: ignore
+        padding_value = 0
+        if desired_token_length > len(longest_token):
+            # Since we want to pad to greater than the longest token, we add a
+            # "dummy token" so we can take advantage of the fast implementation of itertools.zip_longest.
+            padded_tokens.append([padding_value] * desired_token_length)
+        # pad the list of lists to the longest sublist, appending 0's
+        padded_tokens = list(zip(*itertools.zip_longest(*padded_tokens, fillvalue=padding_value)))
+        if desired_token_length > len(longest_token):
+            # Removes the "dummy token".
+            padded_tokens.pop()
+        # Truncates all the tokens to the desired length, and return the result.
+        return {
+            "token_characters": torch.LongTensor(
+                [list(token[:desired_token_length]) for token in padded_tokens]
+            )
+        }
+
+    @overrides
+    def get_empty_token_list(self) -> IndexedTokenList:
+        return {"token_characters": []}
diff --git a/combo/data/token_indexers/token_const_padding_characters_indexer.py b/combo/data/token_indexers/token_const_padding_characters_indexer.py
new file mode 100644
index 0000000..db5e9b7
--- /dev/null
+++ b/combo/data/token_indexers/token_const_padding_characters_indexer.py
@@ -0,0 +1,62 @@
+"""Custom character token indexer."""
+import itertools
+from typing import List, Dict
+
+import torch
+from combo.data.token_indexers import IndexedTokenList
+from overrides import overrides
+
+from combo.data.token_indexers.token_characters_indexer import TokenCharactersIndexer
+from combo.data.tokenizers import CharacterTokenizer
+from combo.utils import pad_sequence_to_length
+
+
+class TokenConstPaddingCharactersIndexer(TokenCharactersIndexer):
+    """Wrapper around allennlp token indexer with const padding."""
+
+    def __init__(self,
+                 namespace: str = "token_characters",
+                 character_tokenizer: CharacterTokenizer = CharacterTokenizer(),
+                 start_tokens: List[str] = None,
+                 end_tokens: List[str] = None,
+                 min_padding_length: int = 0,
+                 token_min_padding_length: int = 0):
+        super().__init__(namespace, character_tokenizer, start_tokens, end_tokens, min_padding_length,
+                         token_min_padding_length)
+
+    @overrides
+    def get_padding_lengths(self, indexed_tokens: IndexedTokenList) -> Dict[str, int]:
+        padding_lengths = {"token_characters": len(indexed_tokens["token_characters"]),
+                           "num_token_characters": self._min_padding_length}
+        return padding_lengths
+
+    @overrides
+    def as_padded_tensor_dict(
+            self, tokens: IndexedTokenList, padding_lengths: Dict[str, int]
+    ) -> Dict[str, torch.Tensor]:
+        # Pad the tokens.
+        padded_tokens = pad_sequence_to_length(
+            tokens["token_characters"],
+            padding_lengths["token_characters"],
+            default_value=lambda: [],
+        )
+
+        # Pad the characters within the tokens.
+        desired_token_length = padding_lengths["num_token_characters"]
+        longest_token: List[int] = max(tokens["token_characters"], key=len, default=[])  # type: ignore
+        padding_value = 0
+        if desired_token_length > len(longest_token):
+            # Since we want to pad to greater than the longest token, we add a
+            # "dummy token" so we can take advantage of the fast implementation of itertools.zip_longest.
+            padded_tokens.append([padding_value] * desired_token_length)
+        # pad the list of lists to the longest sublist, appending 0's
+        padded_tokens = list(zip(*itertools.zip_longest(*padded_tokens, fillvalue=padding_value)))
+        if desired_token_length > len(longest_token):
+            # Removes the "dummy token".
+            padded_tokens.pop()
+        # Truncates all the tokens to the desired length, and return the result.
+        return {
+            "token_characters": torch.LongTensor(
+                [list(token[:desired_token_length]) for token in padded_tokens]
+            )
+        }
diff --git a/combo/data/token_indexers/token_features_indexer.py b/combo/data/token_indexers/token_features_indexer.py
index 901ffdb..4835941 100644
--- a/combo/data/token_indexers/token_features_indexer.py
+++ b/combo/data/token_indexers/token_features_indexer.py
@@ -1,7 +1,77 @@
 """Features indexer."""
+import collections
+from abc import ABC
+from typing import List, Dict
 
-from .base_indexer import TokenIndexer
+import torch
+from overrides import overrides
 
+from combo.data import Token, Vocabulary
+from combo.data.token_indexers.token_indexer import TokenIndexer, IndexedTokenList
+from combo.utils import pad_sequence_to_length
 
-class TokenFeatsIndexer(TokenIndexer):
-    pass
+
+class TokenFeatsIndexer(TokenIndexer, ABC):
+
+    def __init__(
+            self,
+            namespace: str = "feats",
+            feature_name: str = "feats_",
+            token_min_padding_length: int = 0,
+    ) -> None:
+        super().__init__(token_min_padding_length)
+        self.namespace = namespace
+        self._feature_name = feature_name
+
+    @overrides
+    def count_vocab_items(self, token: Token, counter: Dict[str, Dict[str, int]]):
+        feats = self._feat_values(token)
+        for feat in feats:
+            counter[self.namespace][feat] += 1
+
+    @overrides
+    def tokens_to_indices(self, tokens: List[Token], vocabulary: Vocabulary) -> IndexedTokenList:
+        indices: List[List[int]] = []
+        vocab_size = vocabulary.get_vocab_size(self.namespace)
+        for token in tokens:
+            token_indices = []
+            feats = self._feat_values(token)
+            for feat in feats:
+                token_indices.append(vocabulary.get_token_index(feat, self.namespace))
+            indices.append(pad_sequence_to_length(token_indices, vocab_size))
+        return {"tokens": indices}
+
+    @overrides
+    def get_empty_token_list(self) -> IndexedTokenList:
+        return {"tokens": [[]]}
+
+    def _feat_values(self, token):
+        feats = getattr(token, self._feature_name)
+        if feats is None:
+            feats = collections.OrderedDict()
+        features = []
+        for feat, value in feats.items():
+            if feat in ["_", "__ROOT__"]:
+                pass
+            else:
+                # Handle case where feature is binary (doesn't have associated value)
+                if value:
+                    features.append(feat + "=" + value)
+                else:
+                    features.append(feat)
+        return features
+
+    @overrides
+    def as_padded_tensor_dict(
+            self, tokens: IndexedTokenList, padding_lengths: Dict[str, int]
+    ) -> Dict[str, torch.Tensor]:
+        tensor_dict = {}
+        for key, val in tokens.items():
+            vocab_size = len(val[0])
+            tensor = torch.tensor(pad_sequence_to_length(val,
+                                                         padding_lengths[key],
+                                                         default_value=lambda: [0] * vocab_size,
+                                                         )
+                                  )
+            tensor_dict[key] = tensor
+        return tensor_dict
diff --git a/combo/data/token_indexers/token_indexer.py b/combo/data/token_indexers/token_indexer.py
new file mode 100644
index 0000000..eaaa01a
--- /dev/null
+++ b/combo/data/token_indexers/token_indexer.py
@@ -0,0 +1,124 @@
+"""
+Adapted from AllenNLP
+https://github.com/allenai/allennlp/blob/main/allennlp/data/token_indexers/token_indexer.py
+"""
+
+from typing import Any, Dict, List
+
+import torch
+
+from combo.data import Token, Vocabulary
+from combo.utils import pad_sequence_to_length
+
+# An indexed token list represents the arguments that will be passed to a TokenEmbedder
+# corresponding to this TokenIndexer.  Each argument that the TokenEmbedder needs will have one
+# entry in the IndexedTokenList dictionary, and that argument will typically be a list of integers
+# (for single ID word embeddings) or a nested list of integers (for character ID word embeddings),
+# though it could also be a mask, or any other data that you want to pass.
+IndexedTokenList = Dict[str, List[Any]]
+
+
+class TokenIndexer:
+    """
+    A `TokenIndexer` determines how string tokens get represented as arrays of indices in a model.
+    This class both converts strings into numerical values, with the help of a
+    :class:`~allennlp.data.vocabulary.Vocabulary`, and it produces actual arrays.
+    Tokens can be represented as single IDs (e.g., the word "cat" gets represented by the number
+    34), or as lists of character IDs (e.g., "cat" gets represented by the numbers [23, 10, 18]),
+    or in some other way that you can come up with (e.g., if you have some structured input you
+    want to represent in a special way in your data arrays, you can do that here).
+    # Parameters
+    token_min_padding_length : `int`, optional (default=`0`)
+        The minimum padding length required for the :class:`TokenIndexer`. For example,
+        the minimum padding length of :class:`SingleIdTokenIndexer` is the largest size of
+        filter when using :class:`CnnEncoder`.
+        Note that if you set this for one TokenIndexer, you likely have to set it for all
+        :class:`TokenIndexer` for the same field, otherwise you'll get mismatched tensor sizes.
+    """
+
+    default_implementation = "single_id"
+    has_warned_for_as_padded_tensor = False
+
+    def __init__(self, token_min_padding_length: int = 0) -> None:
+        self._token_min_padding_length: int = token_min_padding_length
+
+    def count_vocab_items(self, token: Token, counter: Dict[str, Dict[str, int]]):
+        """
+        The :class:`Vocabulary` needs to assign indices to whatever strings we see in the training
+        data (possibly doing some frequency filtering and using an OOV, or out of vocabulary,
+        token).  This method takes a token and a dictionary of counts and increments counts for
+        whatever vocabulary items are present in the token.  If this is a single token ID
+        representation, the vocabulary item is likely the token itself.  If this is a token
+        characters representation, the vocabulary items are all of the characters in the token.
+        """
+        raise NotImplementedError
+
+    def tokens_to_indices(self, tokens: List[Token], vocabulary: Vocabulary) -> IndexedTokenList:
+        """
+        Takes a list of tokens and converts them to an `IndexedTokenList`.
+        This could be just an ID for each token from the vocabulary.
+        Or it could split each token into characters and return one ID per character.
+        Or (for instance, in the case of byte-pair encoding) there might not be a clean
+        mapping from individual tokens to indices, and the `IndexedTokenList` could be a complex
+        data structure.
+        """
+        raise NotImplementedError
+
+    def indices_to_tokens(
+        self, indexed_tokens: IndexedTokenList, vocabulary: Vocabulary
+    ) -> List[Token]:
+        """
+        Inverse operations of tokens_to_indices. Takes an `IndexedTokenList` and converts it back
+        into a list of tokens.
+        """
+        raise NotImplementedError
+
+    def get_empty_token_list(self) -> IndexedTokenList:
+        """
+        Returns an `already indexed` version of an empty token list.  This is typically just an
+        empty list for whatever keys are used in the indexer.
+        """
+        raise NotImplementedError
+
+    def get_padding_lengths(self, indexed_tokens: IndexedTokenList) -> Dict[str, int]:
+        """
+        This method returns a padding dictionary for the given `indexed_tokens` specifying all
+        lengths that need padding.  If all you have is a list of single ID tokens, this is just the
+        length of the list, and that's what the default implementation will give you.  If you have
+        something more complicated, like a list of character ids for token, you'll need to override
+        this.
+        """
+        padding_lengths = {}
+        for key, token_list in indexed_tokens.items():
+            padding_lengths[key] = max(len(token_list), self._token_min_padding_length)
+        return padding_lengths
+
+    def as_padded_tensor_dict(
+        self, tokens: IndexedTokenList, padding_lengths: Dict[str, int]
+    ) -> Dict[str, torch.Tensor]:
+        """
+        This method pads a list of tokens given the input padding lengths (which could actually
+        truncate things, depending on settings) and returns that padded list of input tokens as a
+        `Dict[str, torch.Tensor]`.  This is a dictionary because there should be one key per
+        argument that the `TokenEmbedder` corresponding to this class expects in its `forward()`
+        method (where the argument name in the `TokenEmbedder` needs to make the key in this
+        dictionary).
+        The base class implements the case when all you want to do is create a padded `LongTensor`
+        for every list in the `tokens` dictionary.  If your `TokenIndexer` needs more complex
+        logic than that, you need to override this method.
+        """
+        tensor_dict = {}
+        for key, val in tokens.items():
+            if val and isinstance(val[0], bool):
+                tensor = torch.BoolTensor(
+                    pad_sequence_to_length(val, padding_lengths[key], default_value=lambda: False)
+                )
+            else:
+                tensor = torch.LongTensor(pad_sequence_to_length(val, padding_lengths[key]))
+            tensor_dict[key] = tensor
+        return tensor_dict
+
+    def __eq__(self, other) -> bool:
+        if isinstance(self, other.__class__):
+            return self.__dict__ == other.__dict__
+        return NotImplemented
\ No newline at end of file
diff --git a/combo/data/tokenizers/__init__.py b/combo/data/tokenizers/__init__.py
new file mode 100644
index 0000000..71d32e6
--- /dev/null
+++ b/combo/data/tokenizers/__init__.py
@@ -0,0 +1,2 @@
+from .tokenizer import Tokenizer
+from .character_tokenizer import CharacterTokenizer
diff --git a/combo/data/tokenizers/character_tokenizer.py b/combo/data/tokenizers/character_tokenizer.py
new file mode 100644
index 0000000..5302e7b
--- /dev/null
+++ b/combo/data/tokenizers/character_tokenizer.py
@@ -0,0 +1,11 @@
+from typing import List
+
+from combo.data import Token
+from combo.data.tokenizers import Tokenizer
+from overrides import override
+
+
+class CharacterTokenizer(Tokenizer):
+    @override
+    def tokenize(self, text: str) -> List[Token]:
+        return [Token(c) for c in list(text)]
diff --git a/combo/data/tokenizers/pretrained_transformer_tokenizer.py b/combo/data/tokenizers/pretrained_transformer_tokenizer.py
new file mode 100644
index 0000000..f9e2ba8
--- /dev/null
+++ b/combo/data/tokenizers/pretrained_transformer_tokenizer.py
@@ -0,0 +1,488 @@
+"""
+Adapted from AllenNLP
+https://github.com/allenai/allennlp/blob/main/allennlp/data/tokenizers/pretrained_transformer_tokenizer.py
+"""
+
+import copy
+import dataclasses
+import hashlib
+import io
+import logging
+from typing import Any, Dict, List, Optional, Tuple, Iterable
+import base58
+import dill
+
+from transformers import PreTrainedTokenizer, AutoTokenizer
+
+from combo.data import Token
+from combo.data.tokenizers import Tokenizer
+from combo.utils import sanitize_wordpiece
+
+logger = logging.getLogger(__name__)
+
+def hash_object(o: Any) -> str:
+    """Returns a character hash code of arbitrary Python objects."""
+    m = hashlib.blake2b()
+    with io.BytesIO() as buffer:
+        dill.dump(o, buffer)
+        m.update(buffer.getbuffer())
+        return base58.b58encode(m.digest()).decode()
+
+def get_tokenizer(model_name: str, **kwargs) -> PreTrainedTokenizer:
+    cache_key = (model_name, hash_object(kwargs))
+
+    global _tokenizer_cache
+    tokenizer = _tokenizer_cache.get(cache_key, None)
+    if tokenizer is None:
+        tokenizer = AutoTokenizer.from_pretrained(
+            model_name,
+            **kwargs,
+        )
+        _tokenizer_cache[cache_key] = tokenizer
+    return tokenizer
+
+
+class PretrainedTransformerTokenizer(Tokenizer):
+    """
+    A `PretrainedTransformerTokenizer` uses a model from HuggingFace's
+    `transformers` library to tokenize some input text.  This often means wordpieces
+    (where `'AllenNLP is awesome'` might get split into `['Allen', '##NL', '##P', 'is',
+    'awesome']`), but it could also use byte-pair encoding, or some other tokenization, depending
+    on the pretrained model that you're using.
+    We take a model name as an input parameter, which we will pass to
+    `AutoTokenizer.from_pretrained`.
+    We also add special tokens relative to the pretrained model and truncate the sequences.
+    This tokenizer also indexes tokens and adds the indexes to the `Token` fields so that
+    they can be picked up by `PretrainedTransformerIndexer`.
+    Registered as a `Tokenizer` with name "pretrained_transformer".
+    # Parameters
+    model_name : `str`
+        The name of the pretrained wordpiece tokenizer to use.
+    add_special_tokens : `bool`, optional, (default=`True`)
+        If set to `True`, the sequences will be encoded with the special tokens relative
+        to their model.
+    max_length : `int`, optional (default=`None`)
+        If set to a number, will limit the total sequence returned so that it has a maximum length.
+    tokenizer_kwargs: `Dict[str, Any]`, optional (default = `None`)
+        Dictionary with
+        [additional arguments](https://github.com/huggingface/transformers/blob/155c782a2ccd103cf63ad48a2becd7c76a7d2115/transformers/tokenization_utils.py#L691)
+        for `AutoTokenizer.from_pretrained`.
+    verification_tokens: `Tuple[str, str]`, optional (default = `None`)
+        A pair of tokens having different token IDs. It's used for reverse-engineering special tokens.
+    """  # noqa: E501
+
+    def __init__(
+        self,
+        model_name: str,
+        add_special_tokens: bool = True,
+        max_length: Optional[int] = None,
+        tokenizer_kwargs: Optional[Dict[str, Any]] = None,
+        verification_tokens: Optional[Tuple[str, str]] = None,
+    ) -> None:
+        if tokenizer_kwargs is None:
+            tokenizer_kwargs = {}
+        else:
+            tokenizer_kwargs = tokenizer_kwargs.copy()
+        # Note: Just because we request a fast tokenizer doesn't mean we get one.
+        tokenizer_kwargs.setdefault("use_fast", True)
+
+        self._tokenizer_kwargs = tokenizer_kwargs
+        self._model_name = model_name
+
+        self.tokenizer = get_tokenizer(
+            self._model_name, add_special_tokens=False, **self._tokenizer_kwargs
+        )
+
+        self._add_special_tokens = add_special_tokens
+        self._max_length = max_length
+
+        self._tokenizer_lowercases = self.tokenizer_lowercases(self.tokenizer)
+
+        if verification_tokens is None:
+            try:
+                self._reverse_engineer_special_tokens("a", "b", model_name, tokenizer_kwargs)
+            except AssertionError:
+                # For most transformer models, "a" and "b" work just fine as dummy tokens.  For a few,
+                # they don't, and so we use "1" and "2" instead.
+                self._reverse_engineer_special_tokens("1", "2", model_name, tokenizer_kwargs)
+        else:
+            token_a, token_b = verification_tokens
+            self._reverse_engineer_special_tokens(token_a, token_b, model_name, tokenizer_kwargs)
+
+    def _reverse_engineer_special_tokens(
+        self,
+        token_a: str,
+        token_b: str,
+        model_name: str,
+        tokenizer_kwargs: Optional[Dict[str, Any]],
+    ):
+        # storing the special tokens
+        self.sequence_pair_start_tokens = []
+        self.sequence_pair_mid_tokens = []
+        self.sequence_pair_end_tokens = []
+        # storing token type ids for the sequences
+        self.sequence_pair_first_token_type_id = None
+        self.sequence_pair_second_token_type_id = None
+
+        # storing the special tokens
+        self.single_sequence_start_tokens = []
+        self.single_sequence_end_tokens = []
+        # storing token type id for the sequence
+        self.single_sequence_token_type_id = None
+
+        # Reverse-engineer the tokenizer for two sequences
+
+        tokenizer_with_special_tokens = get_tokenizer(
+            model_name, add_special_tokens=True, **(tokenizer_kwargs or {})
+        )
+        dummy_output = tokenizer_with_special_tokens.encode_plus(
+            token_a,
+            token_b,
+            add_special_tokens=True,
+            return_token_type_ids=True,
+            return_attention_mask=False,
+        )
+        if len(dummy_output["token_type_ids"]) != len(dummy_output["input_ids"]):
+            logger.warning(
+                "Tokenizer library did not return valid token type ids. We will assume they are all zero."
+            )
+            dummy_output["token_type_ids"] = [0] * len(dummy_output["input_ids"])
+
+        dummy_a = self.tokenizer.encode(token_a, add_special_tokens=False)[0]
+        assert dummy_a in dummy_output["input_ids"]
+        dummy_b = self.tokenizer.encode(token_b, add_special_tokens=False)[0]
+        assert dummy_b in dummy_output["input_ids"]
+        assert dummy_a != dummy_b
+
+        seen_dummy_a = False
+        seen_dummy_b = False
+        for token_id, token_type_id in zip(
+            dummy_output["input_ids"], dummy_output["token_type_ids"]
+        ):
+            if token_id == dummy_a:
+                if seen_dummy_a or seen_dummy_b:  # seeing a twice or b before a
+                    raise ValueError("Cannot auto-determine the number of special tokens added.")
+                seen_dummy_a = True
+                assert (
+                    self.sequence_pair_first_token_type_id is None
+                    or self.sequence_pair_first_token_type_id == token_type_id
+                ), "multiple different token type ids found for the first sequence"
+                self.sequence_pair_first_token_type_id = token_type_id
+                continue
+
+            if token_id == dummy_b:
+                if seen_dummy_b:  # seeing b twice
+                    raise ValueError("Cannot auto-determine the number of special tokens added.")
+                seen_dummy_b = True
+                assert (
+                    self.sequence_pair_second_token_type_id is None
+                    or self.sequence_pair_second_token_type_id == token_type_id
+                ), "multiple different token type ids found for the second sequence"
+                self.sequence_pair_second_token_type_id = token_type_id
+                continue
+
+            token = Token(
+                tokenizer_with_special_tokens.convert_ids_to_tokens(token_id),
+                text_id=token_id,
+                type_id=token_type_id,
+            )
+            if not seen_dummy_a:
+                self.sequence_pair_start_tokens.append(token)
+            elif not seen_dummy_b:
+                self.sequence_pair_mid_tokens.append(token)
+            else:
+                self.sequence_pair_end_tokens.append(token)
+
+        assert (
+            len(self.sequence_pair_start_tokens)
+            + len(self.sequence_pair_mid_tokens)
+            + len(self.sequence_pair_end_tokens)
+        ) == self.tokenizer.num_special_tokens_to_add(pair=True)
+
+        # Reverse-engineer the tokenizer for one sequence
+        dummy_output = tokenizer_with_special_tokens.encode_plus(
+            token_a,
+            add_special_tokens=True,
+            return_token_type_ids=True,
+            return_attention_mask=False,
+        )
+        if len(dummy_output["token_type_ids"]) != len(dummy_output["input_ids"]):
+            logger.warning(
+                "Tokenizer library did not return valid token type ids. We will assume they are all zero."
+            )
+            dummy_output["token_type_ids"] = [0] * len(dummy_output["input_ids"])
+
+        seen_dummy_a = False
+        for token_id, token_type_id in zip(
+            dummy_output["input_ids"], dummy_output["token_type_ids"]
+        ):
+            if token_id == dummy_a:
+                if seen_dummy_a:
+                    raise ValueError("Cannot auto-determine the number of special tokens added.")
+                seen_dummy_a = True
+                assert (
+                    self.single_sequence_token_type_id is None
+                    or self.single_sequence_token_type_id == token_type_id
+                ), "multiple different token type ids found for the sequence"
+                self.single_sequence_token_type_id = token_type_id
+                continue
+
+            token = Token(
+                tokenizer_with_special_tokens.convert_ids_to_tokens(token_id),
+                text_id=token_id,
+                type_id=token_type_id,
+            )
+            if not seen_dummy_a:
+                self.single_sequence_start_tokens.append(token)
+            else:
+                self.single_sequence_end_tokens.append(token)
+
+        assert (
+            len(self.single_sequence_start_tokens) + len(self.single_sequence_end_tokens)
+        ) == self.tokenizer.num_special_tokens_to_add(pair=False)
+
+    @staticmethod
+    def tokenizer_lowercases(tokenizer: PreTrainedTokenizer) -> bool:
+        # Huggingface tokenizers have different ways of remembering whether they lowercase or not. Detecting it
+        # this way seems like the least brittle way to do it.
+        tokenized = tokenizer.tokenize(
+            "A"
+        )  # Use a single character that won't be cut into word pieces.
+        detokenized = " ".join(tokenized)
+        return "a" in detokenized
+
+    def tokenize(self, text: str) -> List[Token]:
+        """
+        This method only handles a single sentence (or sequence) of text.
+        """
+        max_length = self._max_length
+        if max_length is not None and not self._add_special_tokens:
+            max_length += self.num_special_tokens_for_sequence()
+
+        encoded_tokens = self.tokenizer.encode_plus(
+            text=text,
+            add_special_tokens=True,
+            max_length=max_length,
+            truncation=True if max_length is not None else False,
+            return_tensors=None,
+            return_offsets_mapping=self.tokenizer.is_fast,
+            return_attention_mask=False,
+            return_token_type_ids=True,
+            return_special_tokens_mask=True,
+        )
+        # token_ids contains a final list with ids for both regular and special tokens
+        token_ids, token_type_ids, special_tokens_mask, token_offsets = (
+            encoded_tokens["input_ids"],
+            encoded_tokens["token_type_ids"],
+            encoded_tokens["special_tokens_mask"],
+            encoded_tokens.get("offset_mapping"),
+        )
+
+        # If we don't have token offsets, try to calculate them ourselves.
+        if token_offsets is None:
+            token_offsets = self._estimate_character_indices(text, token_ids)
+
+        tokens = []
+        for token_id, token_type_id, special_token_mask, offsets in zip(
+            token_ids, token_type_ids, special_tokens_mask, token_offsets
+        ):
+            # In `special_tokens_mask`, 1s indicate special tokens and 0s indicate regular tokens.
+            # NOTE: in transformers v3.4.0 (and probably older versions) the docstring
+            # for `encode_plus` was incorrect as it had the 0s and 1s reversed.
+            # https://github.com/huggingface/transformers/pull/7949 fixed this.
+            if not self._add_special_tokens and special_token_mask == 1:
+                continue
+
+            if offsets is None or offsets[0] >= offsets[1]:
+                start = None
+                end = None
+            else:
+                start, end = offsets
+
+            tokens.append(
+                Token(
+                    text=self.tokenizer.convert_ids_to_tokens(token_id, skip_special_tokens=False),
+                    text_id=token_id,
+                    type_id=token_type_id,
+                    idx=start,
+                    idx_end=end,
+                )
+            )
+
+        return tokens
+
+    def _estimate_character_indices(
+        self, text: str, token_ids: List[int]
+    ) -> List[Optional[Tuple[int, int]]]:
+        """
+        The huggingface tokenizers produce tokens that may or may not be slices from the
+        original text.  Differences arise from lowercasing, Unicode normalization, and other
+        kinds of normalization, as well as special characters that are included to denote
+        various situations, such as "##" in BERT for word pieces from the middle of a word, or
+        "Ä " in RoBERTa for the beginning of words not at the start of a sentence.
+        This code attempts to calculate character offsets while being tolerant to these
+        differences. It scans through the text and the tokens in parallel, trying to match up
+        positions in both. If it gets out of sync, it backs off to not adding any token
+        indices, and attempts to catch back up afterwards. This procedure is approximate.
+        Don't rely on precise results, especially in non-English languages that are far more
+        affected by Unicode normalization.
+        """
+
+        token_texts = [
+            sanitize_wordpiece(t) for t in self.tokenizer.convert_ids_to_tokens(token_ids)
+        ]
+        token_offsets: List[Optional[Tuple[int, int]]] = [None] * len(token_ids)
+        if self._tokenizer_lowercases:
+            text = text.lower()
+            token_texts = [t.lower() for t in token_texts]
+
+        min_allowed_skipped_whitespace = 3
+        allowed_skipped_whitespace = min_allowed_skipped_whitespace
+
+        text_index = 0
+        token_index = 0
+        while text_index < len(text) and token_index < len(token_ids):
+            token_text = token_texts[token_index]
+            token_start_index = text.find(token_text, text_index)
+
+            # Did we not find it at all?
+            if token_start_index < 0:
+                token_index += 1
+                # When we skip a token, we increase our tolerance, so we have a chance of catching back up.
+                allowed_skipped_whitespace += 1 + min_allowed_skipped_whitespace
+                continue
+
+            # Did we jump too far?
+            non_whitespace_chars_skipped = sum(
+                1 for c in text[text_index:token_start_index] if not c.isspace()
+            )
+            if non_whitespace_chars_skipped > allowed_skipped_whitespace:
+                # Too many skipped characters. Something is wrong. Ignore this token.
+                token_index += 1
+                # When we skip a token, we increase our tolerance, so we have a chance of catching back up.
+                allowed_skipped_whitespace += 1 + min_allowed_skipped_whitespace
+                continue
+            allowed_skipped_whitespace = min_allowed_skipped_whitespace
+
+            token_offsets[token_index] = (
+                token_start_index,
+                token_start_index + len(token_text),
+            )
+            text_index = token_start_index + len(token_text)
+            token_index += 1
+        return token_offsets
+
+    def _intra_word_tokenize(
+        self, string_tokens: List[str]
+    ) -> Tuple[List[Token], List[Optional[Tuple[int, int]]]]:
+        tokens: List[Token] = []
+        offsets: List[Optional[Tuple[int, int]]] = []
+        for token_string in string_tokens:
+            wordpieces = self.tokenizer.encode_plus(
+                token_string,
+                add_special_tokens=False,
+                return_tensors=None,
+                return_offsets_mapping=False,
+                return_attention_mask=False,
+            )
+            wp_ids = wordpieces["input_ids"]
+
+            if len(wp_ids) > 0:
+                offsets.append((len(tokens), len(tokens) + len(wp_ids) - 1))
+                tokens.extend(
+                    Token(text=wp_text, text_id=wp_id)
+                    for wp_id, wp_text in zip(wp_ids, self.tokenizer.convert_ids_to_tokens(wp_ids))
+                )
+            else:
+                offsets.append(None)
+        return tokens, offsets
+
+    @staticmethod
+    def _increment_offsets(
+        offsets: Iterable[Optional[Tuple[int, int]]], increment: int
+    ) -> List[Optional[Tuple[int, int]]]:
+        return [
+            None if offset is None else (offset[0] + increment, offset[1] + increment)
+            for offset in offsets
+        ]
+
+    def intra_word_tokenize(
+        self, string_tokens: List[str]
+    ) -> Tuple[List[Token], List[Optional[Tuple[int, int]]]]:
+        """
+        Tokenizes each word into wordpieces separately and returns the wordpiece IDs.
+        Also calculates offsets such that tokens[offsets[i][0]:offsets[i][1] + 1]
+        corresponds to the original i-th token.
+        This function inserts special tokens.
+        """
+        tokens, offsets = self._intra_word_tokenize(string_tokens)
+        tokens = self.add_special_tokens(tokens)
+        offsets = self._increment_offsets(offsets, len(self.single_sequence_start_tokens))
+        return tokens, offsets
+
+    def intra_word_tokenize_sentence_pair(
+        self, string_tokens_a: List[str], string_tokens_b: List[str]
+    ) -> Tuple[List[Token], List[Optional[Tuple[int, int]]], List[Optional[Tuple[int, int]]]]:
+        """
+        Tokenizes each word into wordpieces separately and returns the wordpiece IDs.
+        Also calculates offsets such that wordpieces[offsets[i][0]:offsets[i][1] + 1]
+        corresponds to the original i-th token.
+        This function inserts special tokens.
+        """
+        tokens_a, offsets_a = self._intra_word_tokenize(string_tokens_a)
+        tokens_b, offsets_b = self._intra_word_tokenize(string_tokens_b)
+        offsets_b = self._increment_offsets(
+            offsets_b,
+            (
+                len(self.sequence_pair_start_tokens)
+                + len(tokens_a)
+                + len(self.sequence_pair_mid_tokens)
+            ),
+        )
+        tokens_a = self.add_special_tokens(tokens_a, tokens_b)
+        offsets_a = self._increment_offsets(offsets_a, len(self.sequence_pair_start_tokens))
+
+        return tokens_a, offsets_a, offsets_b
+
+    def add_special_tokens(
+        self, tokens1: List[Token], tokens2: Optional[List[Token]] = None
+    ) -> List[Token]:
+        def with_new_type_id(tokens: List[Token], type_id: int) -> List[Token]:
+            return [dataclasses.replace(t, type_id=type_id) for t in tokens]
+
+        # Make sure we don't change the input parameters
+        tokens2 = copy.deepcopy(tokens2)
+
+        if tokens2 is None:
+            return (
+                self.single_sequence_start_tokens
+                + with_new_type_id(tokens1, self.single_sequence_token_type_id)  # type: ignore
+                + self.single_sequence_end_tokens
+            )
+        else:
+            return (
+                self.sequence_pair_start_tokens
+                + with_new_type_id(tokens1, self.sequence_pair_first_token_type_id)  # type: ignore
+                + self.sequence_pair_mid_tokens
+                + with_new_type_id(tokens2, self.sequence_pair_second_token_type_id)  # type: ignore
+                + self.sequence_pair_end_tokens
+            )
+
+    def num_special_tokens_for_sequence(self) -> int:
+        return len(self.single_sequence_start_tokens) + len(self.single_sequence_end_tokens)
+
+    def num_special_tokens_for_pair(self) -> int:
+        return (
+            len(self.sequence_pair_start_tokens)
+            + len(self.sequence_pair_mid_tokens)
+            + len(self.sequence_pair_end_tokens)
+        )
+
+    def _to_params(self) -> Dict[str, Any]:
+        return {
+            "type": "pretrained_transformer",
+            "model_name": self._model_name,
+            "add_special_tokens": self._add_special_tokens,
+            "max_length": self._max_length,
+            "tokenizer_kwargs": self._tokenizer_kwargs,
+        }
diff --git a/combo/data/tokenizers/tokenizer.py b/combo/data/tokenizers/tokenizer.py
new file mode 100644
index 0000000..b9bd901
--- /dev/null
+++ b/combo/data/tokenizers/tokenizer.py
@@ -0,0 +1,15 @@
+from typing import List
+import logging
+
+from combo.data import Token
+
+
+logger = logging.getLogger(__name__)
+
+
+class Tokenizer:
+    def tokenize(self, text: str) -> List[Token]:
+        raise NotImplementedError
+
+    def batch_tokenize(self, texts: List[str]) -> List[List[Token]]:
+        return [self.tokenize(text) for text in texts]
diff --git a/combo/utils/cached_transformers.py b/combo/utils/cached_transformers.py
new file mode 100644
index 0000000..e69de29
diff --git a/combo/utils/sequence.py b/combo/utils/sequence.py
index 0bdcb0e..155bb89 100644
--- a/combo/utils/sequence.py
+++ b/combo/utils/sequence.py
@@ -6,6 +6,20 @@ https://github.com/allenai/allennlp/blob/main/allennlp/common/util.py
 from typing import Any, Callable, List, Sequence
 
 
+def sanitize_wordpiece(wordpiece: str) -> str:
+    """
+    Sanitizes wordpieces from BERT, RoBERTa or ALBERT tokenizers.
+    """
+    if wordpiece.startswith("##"):
+        return wordpiece[2:]
+    elif wordpiece.startswith("Ä "):
+        return wordpiece[1:]
+    elif wordpiece.startswith("▁"):
+        return wordpiece[1:]
+    else:
+        return wordpiece
+
+
 def pad_sequence_to_length(
     sequence: Sequence,
     desired_length: int,
diff --git a/requirements.txt b/requirements.txt
index 75adf56..07dc582 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,6 +1,8 @@
 absl-py~=1.4.0
+base58~=2.1.1
 conllu~=4.4.1
 dependency-injector~=4.41.0
+dill~=0.3.6
 overrides~=7.3.1
 torch~=1.13.1
 torchtext~=0.14.1
@@ -10,4 +12,5 @@ requests~=2.28.2
 tqdm~=4.64.1
 urllib3~=1.26.14
 filelock~=3.9.0
-pytest~=7.2.2
\ No newline at end of file
+pytest~=7.2.2
+transformers~=4.27.3
\ No newline at end of file
-- 
GitLab