From 76a782fe5a8ef9a6f800bd414e0c66630ce4c408 Mon Sep 17 00:00:00 2001
From: Maja Jablonska <majajjablonska@gmail.com>
Date: Wed, 19 Apr 2023 19:54:53 +0200
Subject: [PATCH] Add DatasetReader based on PyTorch

---
 combo/data/__init__.py                        |   2 +
 combo/data/dataset_readers/__init__.py        |   2 +
 .../classification_textfile_dataset_reader.py |  37 ++
 combo/data/dataset_readers/dataset_reader.py  | 339 +++---------------
 combo/data/fields/label_field.py              | 111 ++++++
 combo/data/token_indexers/__init__.py         |   1 +
 ...ed_transformer_fixed_mismatched_indexer.py |   9 +-
 .../pretrained_transformer_indexer.py         |  13 +-
 ...etrained_transformer_mismatched_indexer.py |   8 +-
 .../token_indexers/single_id_token_indexer.py | 116 ++++++
 .../token_characters_indexer.py               |  12 +-
 .../token_indexers/token_features_indexer.py  |   7 +-
 combo/data/token_indexers/token_indexer.py    |  11 +-
 combo/models/encoder.py                       |   2 +-
 test.tsv                                      |   2 +
 15 files changed, 346 insertions(+), 326 deletions(-)
 create mode 100644 combo/data/dataset_readers/classification_textfile_dataset_reader.py
 create mode 100644 combo/data/fields/label_field.py
 create mode 100644 combo/data/token_indexers/single_id_token_indexer.py
 create mode 100644 test.tsv

diff --git a/combo/data/__init__.py b/combo/data/__init__.py
index 7708f38..bc3f8fb 100644
--- a/combo/data/__init__.py
+++ b/combo/data/__init__.py
@@ -2,5 +2,7 @@ from .api import (Token, Sentence, sentence2conllu, tokens2conllu, conllu2senten
 from .vocabulary import Vocabulary
 from .samplers import TokenCountBatchSampler
 from .instance import Instance
+from .token_indexers import (SingleIdTokenIndexer, TokenIndexer, TokenFeatsIndexer)
 from .tokenizers import (Tokenizer, TokenizerToken, CharacterTokenizer, PretrainedTransformerTokenizer,
                          SpacyTokenizer, WhitespaceTokenizer)
+from .dataset_readers import DatasetReader, ClassificationTextfileDatasetReader
diff --git a/combo/data/dataset_readers/__init__.py b/combo/data/dataset_readers/__init__.py
index e69de29..3458881 100644
--- a/combo/data/dataset_readers/__init__.py
+++ b/combo/data/dataset_readers/__init__.py
@@ -0,0 +1,2 @@
+from .dataset_reader import DatasetReader
+from .classification_textfile_dataset_reader import ClassificationTextfileDatasetReader
diff --git a/combo/data/dataset_readers/classification_textfile_dataset_reader.py b/combo/data/dataset_readers/classification_textfile_dataset_reader.py
new file mode 100644
index 0000000..9ec7ca9
--- /dev/null
+++ b/combo/data/dataset_readers/classification_textfile_dataset_reader.py
@@ -0,0 +1,37 @@
+from typing import Dict, Iterable, Optional
+
+from .dataset_reader import DatasetReader, DatasetReaderInput
+from overrides import overrides
+
+from .. import Instance, Tokenizer, TokenIndexer
+from ..fields.label_field import LabelField
+from ..fields.text_field import TextField
+
+
+class ClassificationTextfileDatasetReader(DatasetReader):
+    def __init__(self,
+                 file_path: Optional[DatasetReaderInput] = None,
+                 tokenizer: Optional[Tokenizer] = None,
+                 token_indexers: Optional[Dict[str, TokenIndexer]] = None,
+                 separator: str = ',') -> None:
+        super().__init__(file_path, tokenizer, token_indexers)
+        self.__separator = separator
+
+    @property
+    def separator(self) -> str:
+        return self.__separator
+
+    @separator.setter
+    def separator(self, new_separator: str):
+        self.__separator = new_separator
+
+    @overrides
+    def _read(self) -> Iterable[Instance]:
+        with open(self.file_path, 'r') as lines:
+            for line in lines:
+                text, label = line.strip().split(self.separator)
+                text_field = TextField(self.tokenizer.tokenize(text),
+                                       self.token_indexers)
+                label_field = LabelField(label)
+                fields = {'text': text_field, 'label': label_field}
+                yield Instance(fields)
diff --git a/combo/data/dataset_readers/dataset_reader.py b/combo/data/dataset_readers/dataset_reader.py
index 6f0fbd1..0d2c34c 100644
--- a/combo/data/dataset_readers/dataset_reader.py
+++ b/combo/data/dataset_readers/dataset_reader.py
@@ -2,174 +2,68 @@
 Adapted from AllenNLP
 https://github.com/allenai/allennlp/blob/main/allennlp/data/dataset_readers/dataset_reader.py
 """
-from dataclasses import dataclass
-import itertools
+import logging
 from os import PathLike
 from typing import Iterable, Iterator, Optional, Union, TypeVar, Dict, List
-import logging
-import warnings
 
-import torch.distributed as dist
+from overrides import overrides
+from torch.utils.data import IterableDataset
 
-from combo.data import Instance
-from combo.data.dataset_readers import utils
+from combo.data import Instance, Tokenizer
+from combo.data.token_indexers import TokenIndexer
 
 logger = logging.getLogger(__name__)
 
 
-@dataclass
-class WorkerInfo:
-    """
-    Contains information about the worker context when a `DatasetReader`
-    is being used within a multi-process `DataLoader`.
-    From a `DatasetReader` this can accessed with the [`get_worker_info()`](#get_worker_info) method.
-    """
-
-    num_workers: int
-    """
-    The total number of workers.
-    """
-
-    id: int
-    """
-    The 0-indexed ID of the current worker.
-    """
-
-
-@dataclass
-class DistributedInfo:
-    """
-    Contains information about the global process rank and total world size when the reader is being
-    used within distributed training.
-    From a `DatasetReader` this can be accessed with the [`get_distributed_info()`](#get_distributed_info) method.
-    """
-
-    world_size: int
-    """
-    The total number of processes in the distributed group.
-    """
-
-    global_rank: int
-    """
-    The 0-indexed ID of the current process within the distributed group.
-    This will be between 0 and `world_size - 1`, inclusive.
-    """
-
-
-_T = TypeVar("_T")
-
 PathOrStr = Union[PathLike, str]
 DatasetReaderInput = Union[PathOrStr, List[PathOrStr], Dict[str, PathOrStr]]
 
 
-class DatasetReader:
+class DatasetReader(IterableDataset):
     """
     A `DatasetReader` knows how to turn a file containing a dataset into a collection
-    of `Instance`s.  To implement your own, just override the [`_read(file_path)`](#_read) method
-    to return an `Iterable` of the instances. Ideally this should be a lazy generator
-    that yields them one at a time.
-    All parameters necessary to `_read` the data apart from the filepath should be passed
-    to the constructor of the `DatasetReader`.
-    You should also implement [`text_to_instance(*inputs)`](#text_to_instance),
-    which should be used to turn raw data into `Instance`s. This method is required
-    in order to use a `Predictor` with your reader.
-    Usually the `_read()` method is implemented to call `text_to_instance()`.
-    # Parameters
-    max_instances : `int`, optional (default=`None`)
-        If given, will stop reading after this many instances. This is a useful setting for debugging.
-        Setting this disables caching.
-    manual_distributed_sharding: `bool`, optional (default=`False`)
-        By default, when used in a distributed setting, `DatasetReader` makes sure that each
-        trainer process only receives a subset of the data. It does this by reading the whole
-        dataset in each worker, but filtering out the instances that are not needed.
-        While this ensures that each worker will recieve unique instances, it's not a very efficient
-        way to do so since each worker still needs to process every single instance.
-        A better way to handle this is to manually handle the filtering within your `_read()`
-        method, in which case you should set `manual_distributed_sharding` to `True` so that
-        the base class knows that you handling the filtering.
-        See the section below about how to do this.
-    manual_multiprocess_sharding : `bool`, optional (default=`False`)
-        This is similar to the `manual_distributed_sharding` parameter, but applies to
-        multi-process data loading. By default, when this reader is used by a multi-process
-        data loader (i.e. a `DataLoader` with `num_workers > 1`), each worker will
-        filter out all but a subset of the instances that are needed so that you
-        don't end up with duplicates.
-        However, there is really no benefit to using multiple workers in your `DataLoader`
-        unless you implement the sharding within your `_read()` method, in which
-        case you should set `manual_multiprocess_sharding` to `True`, just as with
-        `manual_distributed_sharding`.
-        See the section below about how to do this.
-    serialization_dir: `str`, optional (default=`None`)
-        The directory in which the training output is saved to, or the directory the model is loaded from.
-        !!! Note
-            This is typically not given an entry in a configuration file. It will be set automatically
-            when using the built-in `allennp` commands.
-    # Using your reader with multi-process or distributed data loading
-    There are two things you may need to update in your `DatasetReader` in order for
-    it to be efficient in the multi-process or distributed data loading context.
-    1. The `_read()` method should handle filtering out all but the instances that
-        each particular worker should generate.
-        This is important because the default mechanism for filtering out `Instance`s in
-        the distributed or multi-process `DataLoader` setting is not very efficient, since every
-        worker would still need to process every single `Instance` in your dataset.
-        But by manually handling the filtering / sharding within your `_read()` method, each
-        worker only needs to perform a subset of the work required to create instances.
-        For example, if you were training using 2 GPUs and your `_read()` method reads a file
-        line-by-line, creating one `Instance` for each line, you could just check the node
-        rank within `_read()` and then throw away every other line starting at the line number
-        corresponding to the node rank.
-        The helper method [`shard_iterable()`](#shard_iterable) is there to make this easy for you.
-        You can wrap this around any iterable object in your `_read()` method, and it will
-        return an iterator that skips the right items based on the distributed training
-        or multi-process loading context. This method can always be called regardless
-        of whether or not you're actually using distributed training or multi-process loading.
-        Remember though that when you handle the sharding manually within `_read()`, you need
-        to let the `DatasetReader` know about this so that it doesn't do any additional
-        filtering. Therefore you need to ensure that both `self.manual_distributed_sharding` and
-        `self.manual_multiprocess_sharding` are set to `True`.
-        If you call the helper method `shard_iterable()` without setting these to `True`,
-        you'll get an exception.
-    2. If the instances generated by `_read()` contain `TextField`s, those `TextField`s
-        should not have any token indexers assigned. The token indexers need to be applied
-        in the [`apply_token_indexers()`](#apply_token_indexers) method instead.
-        This is highly recommended because if the instances generated by your `_read()` method
-        have token indexers attached, those indexers will be duplicated when they are sent across
-        processes. If your token indexers contain large objects (such as `PretrainedTransformerTokenIndexer`s)
-        this could take up a massive amount of memory.
+    of `Instance`s.
     """
+    def __init__(self,
+                 file_path: Optional[DatasetReaderInput] = None,
+                 tokenizer: Optional[Tokenizer] = None,
+                 token_indexers: Optional[Dict[str, TokenIndexer]] = None) -> None:
+        super(DatasetReader).__init__()
+        self.__file_path = file_path
+        self.__tokenizer = tokenizer
+        self.__token_indexers = token_indexers
+
+    @property
+    def file_path(self) -> DatasetReaderInput:
+        return self.__file_path
+
+    @file_path.setter
+    def file_path(self, new_file_path: DatasetReaderInput):
+        self.__file_path = new_file_path
+
+    @property
+    def tokenizer(self) -> Optional[Tokenizer]:
+        return self.__tokenizer
+
+    @property
+    def token_indexers(self) -> Optional[Dict[str, TokenIndexer]]:
+        return self.__token_indexers
+
+    @overrides
+    def __getitem__(self, item) -> Instance:
+        raise NotImplementedError
 
-    def __init__(
-        self,
-        max_instances: Optional[int] = None,
-        manual_distributed_sharding: bool = False,
-        manual_multiprocess_sharding: bool = False,
-        serialization_dir: Optional[str] = None,
-    ) -> None:
-        # Do some validation.
-        if max_instances is not None and max_instances < 0:
-            raise ValueError("If specified, max_instances should be a positive int")
-
-        self.max_instances = max_instances
-        self.manual_distributed_sharding = manual_distributed_sharding
-        self.manual_multiprocess_sharding = manual_multiprocess_sharding
-        self.serialization_dir = serialization_dir
-        self._worker_info: Optional[WorkerInfo] = None
-        self._distributed_info: Optional[DistributedInfo] = None
-        # If we're actually in the main process, we can find the info using torch utils.
-        if utils.is_distributed():
-            self._distributed_info = DistributedInfo(dist.get_world_size(), dist.get_rank())
-
-    def read(self, file_path: DatasetReaderInput) -> Iterator[Instance]:
+    @overrides
+    def __iter__(self) -> Iterator[Instance]:
         """
         Returns an iterator of instances that can be read from the file path.
         """
-        for instance in self._multi_worker_islice(self._read(file_path)):  # type: ignore
-            if self._worker_info is None:
-                # If not running in a subprocess, it's safe to apply the token_indexers right away.
-                self.apply_token_indexers(instance)
+        # TODO: add multiprocessing
+        for instance in self._read():
+            self.apply_token_indexers(instance)
             yield instance
 
-    def _read(self, file_path: str) -> Iterable[Instance]:
+    def _read(self) -> Iterable[Instance]:
         """
         Reads the instances from the given `file_path` and returns them as an
         `Iterable`.
@@ -183,23 +77,6 @@ class DatasetReader:
         # according to mypy: https://mypy.readthedocs.io/en/stable/common_issues.html#incompatible-overrides
         raise NotImplementedError
 
-    def text_to_instance(self, *inputs) -> Instance:
-        """
-        Does whatever tokenization or processing is necessary to go from textual input to an
-        `Instance`.  The primary intended use for this is with a
-        :class:`~allennlp.predictors.predictor.Predictor`, which gets text input as a JSON
-        object and needs to process it to be input to a model.
-        The intent here is to share code between :func:`_read` and what happens at
-        model serving time, or any other time you want to make a prediction from new data.  We need
-        to process the data in the same way it was done at training time.  Allowing the
-        `DatasetReader` to process new text lets us accomplish this, as we can just call
-        `DatasetReader.text_to_instance` when serving predictions.
-        The input type here is rather vaguely specified, unfortunately.  The `Predictor` will
-        have to make some assumptions about the kind of `DatasetReader` that it's using, in order
-        to pass it the right information.
-        """
-        raise NotImplementedError
-
     def apply_token_indexers(self, instance: Instance) -> None:
         """
         If `Instance`s created by this reader contain `TextField`s without `token_indexers`,
@@ -217,137 +94,3 @@ class DatasetReader:
         ```
         """
         pass
-
-    def get_worker_info(self) -> Optional[WorkerInfo]:
-        """
-        Provides a [`WorkerInfo`](#WorkerInfo) object when the reader is being used within a
-        worker of a multi-process `DataLoader`.
-        If the reader is in the main process, this is just `None`.
-        !!! NOTE
-            This is different than distributed training. If the `DatasetReader`
-            is being used within distributed training, `get_worker_info()` will only
-            provide information on the `DataLoader` worker within its node.
-            Use [`get_distributed_info`](#get_distributed_info) to get information on distributed
-            training context.
-        """
-        return self._worker_info
-
-    def get_distributed_info(self) -> Optional[DistributedInfo]:
-        """
-        Provides a [`DistributedInfo`](#DistributedInfo) object when the reader is being
-        used within distributed training.
-        If not in distributed training, this is just `None`.
-        """
-        return self._distributed_info
-
-    def _set_worker_info(self, info: Optional[WorkerInfo]) -> None:
-        """
-        Should only be used internally.
-        """
-        self._worker_info = info
-
-    def _set_distributed_info(self, info: Optional[DistributedInfo]) -> None:
-        """
-        Should only be used internally.
-        """
-        self._distributed_info = info
-
-    def shard_iterable(self, iterable: Iterable[_T]) -> Iterator[_T]:
-        """
-        Helper method that determines which items in an iterable object to skip based
-        on the current node rank (for distributed training) and worker ID (for multi-process data loading).
-        """
-        if not self.manual_distributed_sharding or not self.manual_multiprocess_sharding:
-            raise ValueError(
-                "self.shard_iterable() was called but self.manual_distributed_sharding and "
-                "self.manual_multiprocess_sharding was not set to True. Did you forget to call "
-                "super().__init__(manual_distributed_sharding=True, manual_multiprocess_sharding=True) "
-                "in your constructor?"
-            )
-
-        sharded_slice: Iterator[_T] = iter(iterable)
-
-        if utils.is_distributed():
-            sharded_slice = itertools.islice(
-                sharded_slice, dist.get_rank(), None, dist.get_world_size()
-            )
-
-        if self._worker_info is not None:
-            sharded_slice = itertools.islice(
-                sharded_slice, self._worker_info.id, None, self._worker_info.num_workers
-            )
-
-        # We don't know for sure how many instances we have to produce.
-        # _multi_worker_islice() figures that out. But we know for sure
-        # it won't be more than max_instances.
-        if self.max_instances is not None:
-            sharded_slice = itertools.islice(sharded_slice, self.max_instances)
-
-        return sharded_slice
-
-    def _multi_worker_islice(
-        self,
-        iterable: Iterable[_T],
-    ) -> Iterator[_T]:
-        """
-        This is just like `shard_iterable` but is for internal use only.
-        It has some additional logic to handle `max_instances` based on the distributed
-        or multi-process context, and whether or not sharding is handled manually
-        in the `_read()` method.
-        """
-        # This has some complicated logic because any given reader may or may not
-        # implement manual multi-process and manual distributed sharding itself.
-        # We have to handle all possibilities.
-
-        sharded_slice: Iterator[_T] = iter(iterable)
-
-        # We'll adjust max_instances as we go, depending on what sort of sharding is done.
-        # At the end, we want to ensure the total number of instances collected across
-        # all workers processes is equal to self.max_instances.
-        max_instances = self.max_instances
-
-        if self._distributed_info is not None:
-            if max_instances is not None:
-                # Need to scale down max_instances because otherwise each node would read self.max_instances,
-                # but we really want self.max_instances total across all nodes.
-                if self._distributed_info.global_rank < (
-                    max_instances % self._distributed_info.world_size
-                ):
-                    max_instances = max_instances // self._distributed_info.world_size + 1
-                else:
-                    max_instances = max_instances // self._distributed_info.world_size
-
-            if not self.manual_distributed_sharding:
-                sharded_slice = itertools.islice(
-                    sharded_slice,
-                    self._distributed_info.global_rank,
-                    None,
-                    self._distributed_info.world_size,
-                )
-
-        if self._worker_info is not None:
-            if max_instances is not None:
-                # Like in the distributed case above, we need to adjust max_instances.
-                if self._worker_info.id < (max_instances % self._worker_info.num_workers):
-                    max_instances = max_instances // self._worker_info.num_workers + 1
-                else:
-                    max_instances = max_instances // self._worker_info.num_workers
-
-            if not self.manual_multiprocess_sharding:
-                warnings.warn(
-                    "Using multi-process data loading without setting "
-                    "DatasetReader.manual_multiprocess_sharding to True.\n"
-                    "Did you forget to set this?\n"
-                    "If you're not handling the multi-process sharding logic within your "
-                    "_read() method, there is probably no benefit to using more than one "
-                    "worker.",
-                    UserWarning,
-                )
-                sharded_slice = itertools.islice(
-                    sharded_slice, self._worker_info.id, None, self._worker_info.num_workers
-                )
-
-        if max_instances is not None:
-            sharded_slice = itertools.islice(sharded_slice, max_instances)
-
-        return sharded_slice
diff --git a/combo/data/fields/label_field.py b/combo/data/fields/label_field.py
new file mode 100644
index 0000000..12bad9a
--- /dev/null
+++ b/combo/data/fields/label_field.py
@@ -0,0 +1,111 @@
+"""
+Adapted from AllenNLP
+https://github.com/allenai/allennlp/blob/80fb6061e568cb9d6ab5d45b661e86eb61b92c82/allennlp/data/fields/label_field.py
+"""
+
+from typing import Dict, Union, Set
+import logging
+
+
+import torch
+
+from combo.data import Vocabulary
+from combo.data.fields import Field
+from combo.utils import ConfigurationError
+
+logger = logging.getLogger(__name__)
+
+
+class LabelField(Field[torch.Tensor]):
+    """
+    A `LabelField` is a categorical label of some kind, where the labels are either strings of
+    text or 0-indexed integers (if you wish to skip indexing by passing skip_indexing=True).
+    If the labels need indexing, we will use a :class:`Vocabulary` to convert the string labels
+    into integers.
+    This field will get converted into an integer index representing the class label.
+    # Parameters
+    label : `Union[str, int]`
+    label_namespace : `str`, optional (default=`"labels"`)
+        The namespace to use for converting label strings into integers.  We map label strings to
+        integers for you (e.g., "entailment" and "contradiction" get converted to 0, 1, ...),
+        and this namespace tells the `Vocabulary` object which mapping from strings to integers
+        to use (so "entailment" as a label doesn't get the same integer id as "entailment" as a
+        word).  If you have multiple different label fields in your data, you should make sure you
+        use different namespaces for each one, always using the suffix "labels" (e.g.,
+        "passage_labels" and "question_labels").
+    skip_indexing : `bool`, optional (default=`False`)
+        If your labels are 0-indexed integers, you can pass in this flag, and we'll skip the indexing
+        step.  If this is `False` and your labels are not strings, this throws a `ConfigurationError`.
+    """
+
+    __slots__ = ["label", "_label_namespace", "_label_id", "_skip_indexing"]
+
+    # Most often, you probably don't want to have OOV/PAD tokens with a LabelField, so we warn you
+    # about it when you pick a namespace that will getting these tokens by default.  It is
+    # possible, however, that you _do_ actually want OOV/PAD tokens with this Field.  This class
+    # variable is used to make sure that we only log a single warning for this per namespace, and
+    # not every time you create one of these Field objects.
+    _already_warned_namespaces: Set[str] = set()
+
+    def __init__(
+        self, label: Union[str, int], label_namespace: str = "labels", skip_indexing: bool = False
+    ) -> None:
+        self.label = label
+        self._label_namespace = label_namespace
+        self._label_id = None
+        self._maybe_warn_for_namespace(label_namespace)
+        self._skip_indexing = skip_indexing
+
+        if skip_indexing:
+            if not isinstance(label, int):
+                raise ConfigurationError(
+                    "In order to skip indexing, your labels must be integers. "
+                    "Found label = {}".format(label)
+                )
+            self._label_id = label
+        elif not isinstance(label, str):
+            raise ConfigurationError(
+                "LabelFields must be passed a string label if skip_indexing=False. "
+                "Found label: {} with type: {}.".format(label, type(label))
+            )
+
+    def _maybe_warn_for_namespace(self, label_namespace: str) -> None:
+        if not (self._label_namespace.endswith("labels") or self._label_namespace.endswith("tags")):
+            if label_namespace not in self._already_warned_namespaces:
+                logger.warning(
+                    "Your label namespace was '%s'. We recommend you use a namespace "
+                    "ending with 'labels' or 'tags', so we don't add UNK and PAD tokens by "
+                    "default to your vocabulary.  See documentation for "
+                    "`non_padded_namespaces` parameter in Vocabulary.",
+                    self._label_namespace,
+                )
+                self._already_warned_namespaces.add(label_namespace)
+
+    def count_vocab_items(self, counter: Dict[str, Dict[str, int]]):
+        if self._label_id is None:
+            counter[self._label_namespace][self.label] += 1  # type: ignore
+
+    def index(self, vocab: Vocabulary):
+        if not self._skip_indexing:
+            self._label_id = vocab.get_token_index(
+                self.label, self._label_namespace  # type: ignore
+            )
+
+    def get_padding_lengths(self) -> Dict[str, int]:
+        return {}
+
+    def as_tensor(self, padding_lengths: Dict[str, int]) -> torch.Tensor:
+        tensor = torch.tensor(self._label_id, dtype=torch.long)
+        return tensor
+
+    def empty_field(self):
+        return LabelField(-1, self._label_namespace, skip_indexing=True)
+
+    def human_readable_repr(self) -> Union[str, int]:
+        return self.label
+
+    def __str__(self) -> str:
+        return f"LabelField with label: {self.label} in namespace: '{self._label_namespace}'."
+
+    def __len__(self):
+        return 1
diff --git a/combo/data/token_indexers/__init__.py b/combo/data/token_indexers/__init__.py
index 9fd4ead..6e993b9 100644
--- a/combo/data/token_indexers/__init__.py
+++ b/combo/data/token_indexers/__init__.py
@@ -1,2 +1,3 @@
 from .token_indexer import IndexedTokenList, TokenIndexer
 from .token_features_indexer import TokenFeatsIndexer
+from .single_id_token_indexer import SingleIdTokenIndexer
diff --git a/combo/data/token_indexers/pretrained_transformer_fixed_mismatched_indexer.py b/combo/data/token_indexers/pretrained_transformer_fixed_mismatched_indexer.py
index 4a27336..90b3ff6 100644
--- a/combo/data/token_indexers/pretrained_transformer_fixed_mismatched_indexer.py
+++ b/combo/data/token_indexers/pretrained_transformer_fixed_mismatched_indexer.py
@@ -7,7 +7,8 @@ from typing import Optional, Dict, Any, List, Tuple
 
 from overrides import overrides
 
-from combo.data import Vocabulary, Token
+from combo.data import Vocabulary
+from combo.data.tokenizers import TokenizerToken
 from combo.data.token_indexers import IndexedTokenList
 from combo.data.token_indexers.pretrained_transformer_indexer import PretrainedTransformerIndexer
 from combo.data.token_indexers.pretrained_transformer_mismatched_indexer import PretrainedTransformerMismatchedIndexer
@@ -101,8 +102,8 @@ class PretrainedTransformerTokenizer(PretrainedTransformerTokenizer):
 
     def _intra_word_tokenize(
             self, string_tokens: List[str]
-    ) -> Tuple[List[Token], List[Optional[Tuple[int, int]]]]:
-        tokens: List[Token] = []
+    ) -> Tuple[List[TokenizerToken], List[Optional[Tuple[int, int]]]]:
+        tokens: List[TokenizerToken] = []
         offsets: List[Optional[Tuple[int, int]]] = []
         for token_string in string_tokens:
             wordpieces = self.tokenizer.encode_plus(
@@ -117,7 +118,7 @@ class PretrainedTransformerTokenizer(PretrainedTransformerTokenizer):
             if len(wp_ids) > 0:
                 offsets.append((len(tokens), len(tokens) + len(wp_ids) - 1))
                 tokens.extend(
-                    Token(text=wp_text, text_id=wp_id)
+                    TokenizerToken(text=wp_text, text_id=wp_id)
                     for wp_id, wp_text in zip(wp_ids, self.tokenizer.convert_ids_to_tokens(wp_ids))
                 )
             else:
diff --git a/combo/data/token_indexers/pretrained_transformer_indexer.py b/combo/data/token_indexers/pretrained_transformer_indexer.py
index c81596e..6af8347 100644
--- a/combo/data/token_indexers/pretrained_transformer_indexer.py
+++ b/combo/data/token_indexers/pretrained_transformer_indexer.py
@@ -7,7 +7,8 @@ from typing import Dict, List, Optional, Tuple, Any
 import logging
 import torch
 
-from combo.data import Vocabulary, Token
+from combo.data import Vocabulary
+from combo.data.tokenizers import TokenizerToken
 from combo.data.token_indexers import TokenIndexer, IndexedTokenList
 from combo.data.tokenizers.pretrained_transformer_tokenizer import PretrainedTransformerTokenizer
 from combo.utils import pad_sequence_to_length
@@ -83,11 +84,11 @@ class PretrainedTransformerIndexer(TokenIndexer):
 
         self._added_to_vocabulary = True
 
-    def count_vocab_items(self, token: Token, counter: Dict[str, Dict[str, int]]):
+    def count_vocab_items(self, token: TokenizerToken, counter: Dict[str, Dict[str, int]]):
         # If we only use pretrained models, we don't need to do anything here.
         pass
 
-    def tokens_to_indices(self, tokens: List[Token], vocabulary: Vocabulary) -> IndexedTokenList:
+    def tokens_to_indices(self, tokens: List[TokenizerToken], vocabulary: Vocabulary) -> IndexedTokenList:
         self._add_encoding_to_vocabulary_if_needed(vocabulary)
 
         indices, type_ids = self._extract_token_and_type_ids(tokens)
@@ -102,14 +103,14 @@ class PretrainedTransformerIndexer(TokenIndexer):
 
     def indices_to_tokens(
         self, indexed_tokens: IndexedTokenList, vocabulary: Vocabulary
-    ) -> List[Token]:
+    ) -> List[TokenizerToken]:
         self._add_encoding_to_vocabulary_if_needed(vocabulary)
 
         token_ids = indexed_tokens["token_ids"]
         type_ids = indexed_tokens.get("type_ids")
 
         return [
-            Token(
+            TokenizerToken(
                 text=vocabulary.get_token_from_index(token_ids[i], self._namespace),
                 text_id=token_ids[i],
                 type_id=type_ids[i] if type_ids is not None else None,
@@ -117,7 +118,7 @@ class PretrainedTransformerIndexer(TokenIndexer):
             for i in range(len(token_ids))
         ]
 
-    def _extract_token_and_type_ids(self, tokens: List[Token]) -> Tuple[List[int], List[int]]:
+    def _extract_token_and_type_ids(self, tokens: List[TokenizerToken]) -> Tuple[List[int], List[int]]:
         """
         Roughly equivalent to `zip(*[(token.text_id, token.type_id) for token in tokens])`,
         with some checks.
diff --git a/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py b/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py
index 20411dd..6b5043b 100644
--- a/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py
+++ b/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py
@@ -8,8 +8,10 @@ import logging
 
 import torch
 
-from combo.data import Vocabulary, Token
+from combo.data import Vocabulary
+from combo.data.tokenizers import TokenizerToken
 from combo.data.token_indexers import TokenIndexer, IndexedTokenList
+from combo.data.token_indexers.pretrained_transformer_indexer import PretrainedTransformerIndexer
 from combo.utils import pad_sequence_to_length
 
 logger = logging.getLogger(__name__)
@@ -65,10 +67,10 @@ class PretrainedTransformerMismatchedIndexer(TokenIndexer):
         self._num_added_start_tokens = self._matched_indexer._num_added_start_tokens
         self._num_added_end_tokens = self._matched_indexer._num_added_end_tokens
 
-    def count_vocab_items(self, token: Token, counter: Dict[str, Dict[str, int]]):
+    def count_vocab_items(self, token: TokenizerToken, counter: Dict[str, Dict[str, int]]):
         return self._matched_indexer.count_vocab_items(token, counter)
 
-    def tokens_to_indices(self, tokens: List[Token], vocabulary: Vocabulary) -> IndexedTokenList:
+    def tokens_to_indices(self, tokens: List[TokenizerToken], vocabulary: Vocabulary) -> IndexedTokenList:
         self._matched_indexer._add_encoding_to_vocabulary_if_needed(vocabulary)
 
         wordpieces, offsets = self._allennlp_tokenizer.intra_word_tokenize(
diff --git a/combo/data/token_indexers/single_id_token_indexer.py b/combo/data/token_indexers/single_id_token_indexer.py
new file mode 100644
index 0000000..ed03cb2
--- /dev/null
+++ b/combo/data/token_indexers/single_id_token_indexer.py
@@ -0,0 +1,116 @@
+"""
+Adapted from AllenNLP
+https://github.com/allenai/allennlp/blob/80fb6061e568cb9d6ab5d45b661e86eb61b92c82/allennlp/data/token_indexers/single_id_token_indexer.py
+"""
+
+from typing import Dict, List, Optional, Any
+import itertools
+
+from combo.data import Vocabulary
+from combo.data.tokenizers import TokenizerToken
+from combo.data.token_indexers import TokenIndexer, IndexedTokenList
+
+_DEFAULT_VALUE = "THIS IS A REALLY UNLIKELY VALUE THAT HAS TO BE A STRING"
+
+class SingleIdTokenIndexer(TokenIndexer):
+    """
+    This :class:`TokenIndexer` represents tokens as single integers.
+    Registered as a `TokenIndexer` with name "single_id".
+    # Parameters
+    namespace : `Optional[str]`, optional (default=`"tokens"`)
+        We will use this namespace in the :class:`Vocabulary` to map strings to indices.  If you
+        explicitly pass in `None` here, we will skip indexing and vocabulary lookups.  This means
+        that the `feature_name` you use must correspond to an integer value (like `text_id`, for
+        instance, which gets set by some tokenizers, such as when using byte encoding).
+    lowercase_tokens : `bool`, optional (default=`False`)
+        If `True`, we will call `token.lower()` before getting an index for the token from the
+        vocabulary.
+    start_tokens : `List[str]`, optional (default=`None`)
+        These are prepended to the tokens provided to `tokens_to_indices`.
+    end_tokens : `List[str]`, optional (default=`None`)
+        These are appended to the tokens provided to `tokens_to_indices`.
+    feature_name : `str`, optional (default=`"text"`)
+        We will use the :class:`Token` attribute with this name as input.  This is potentially
+        useful, e.g., for using NER tags instead of (or in addition to) surface forms as your inputs
+        (passing `ent_type_` here would do that).  If you use a non-default value here, you almost
+        certainly want to also change the `namespace` parameter, and you might want to give a
+        `default_value`.
+    default_value : `str`, optional
+        When you want to use a non-default `feature_name`, you sometimes want to have a default
+        value to go with it, e.g., in case you don't have an NER tag for a particular token, for
+        some reason.  This value will get used if we don't find a value in `feature_name`.  If this
+        is not given, we will crash if a token doesn't have a value for the given `feature_name`, so
+        that you don't get weird, silent errors by default.
+    token_min_padding_length : `int`, optional (default=`0`)
+        See :class:`TokenIndexer`.
+    """
+
+    def __init__(
+        self,
+        namespace: Optional[str] = "tokens",
+        lowercase_tokens: bool = False,
+        start_tokens: List[str] = None,
+        end_tokens: List[str] = None,
+        feature_name: str = "text",
+        default_value: str = _DEFAULT_VALUE,
+        token_min_padding_length: int = 0,
+    ) -> None:
+        super().__init__(token_min_padding_length)
+        self.namespace = namespace
+        self.lowercase_tokens = lowercase_tokens
+
+        self._start_tokens = [TokenizerToken(st) for st in (start_tokens or [])]
+        self._end_tokens = [TokenizerToken(et) for et in (end_tokens or [])]
+        self._feature_name = feature_name
+        self._default_value = default_value
+
+    def count_vocab_items(self, token: TokenizerToken, counter: Dict[str, Dict[str, int]]):
+        if self.namespace is not None:
+            text = self._get_feature_value(token)
+            if self.lowercase_tokens:
+                text = text.lower()
+            counter[self.namespace][text] += 1
+
+    def tokens_to_indices(
+        self, tokens: List[TokenizerToken], vocabulary: Vocabulary
+    ) -> Dict[str, List[int]]:
+        indices: List[int] = []
+
+        for token in itertools.chain(self._start_tokens, tokens, self._end_tokens):
+            text = self._get_feature_value(token)
+            if self.namespace is None:
+                # We could have a check here that `text` is an int; not sure it's worth it.
+                indices.append(text)  # type: ignore
+            else:
+                if self.lowercase_tokens:
+                    text = text.lower()
+                indices.append(vocabulary.get_token_index(text, self.namespace))
+
+        return {"tokens": indices}
+
+    def get_empty_token_list(self) -> IndexedTokenList:
+        return {"tokens": []}
+
+    def _get_feature_value(self, token: TokenizerToken) -> str:
+        text = getattr(token, self._feature_name)
+        if text is None:
+            if self._default_value is not _DEFAULT_VALUE:
+                text = self._default_value
+            else:
+                raise ValueError(
+                    f"{token} did not have attribute {self._feature_name}. If you "
+                    "want to ignore this kind of error, give a default value in the "
+                    "constructor of this indexer."
+                )
+        return text
+
+    def _to_params(self) -> Dict[str, Any]:
+        return {
+            "namespace": self.namespace,
+            "lowercase_tokens": self.lowercase_tokens,
+            "start_tokens": [t.text for t in self._start_tokens],
+            "end_tokens": [t.text for t in self._end_tokens],
+            "feature_name": self._feature_name,
+            "default_value": self._default_value,
+            "token_min_padding_length": self._token_min_padding_length,
+        }
diff --git a/combo/data/token_indexers/token_characters_indexer.py b/combo/data/token_indexers/token_characters_indexer.py
index 3be23a4..2b62e33 100644
--- a/combo/data/token_indexers/token_characters_indexer.py
+++ b/combo/data/token_indexers/token_characters_indexer.py
@@ -10,9 +10,9 @@ import warnings
 from overrides import overrides
 import torch
 
-from combo.data import Token, Vocabulary
+from combo.data import Vocabulary
 from combo.data.token_indexers import TokenIndexer, IndexedTokenList
-from combo.data.tokenizers import CharacterTokenizer
+from combo.data.tokenizers import TokenizerToken, CharacterTokenizer
 from combo.utils import ConfigurationError, pad_sequence_to_length
 
 
@@ -66,11 +66,11 @@ class TokenCharactersIndexer(TokenIndexer):
         self._namespace = namespace
         self._character_tokenizer = character_tokenizer
 
-        self._start_tokens = [Token(st) for st in (start_tokens or [])]
-        self._end_tokens = [Token(et) for et in (end_tokens or [])]
+        self._start_tokens = [TokenizerToken(st) for st in (start_tokens or [])]
+        self._end_tokens = [TokenizerToken(et) for et in (end_tokens or [])]
 
     @overrides
-    def count_vocab_items(self, token: Token, counter: Dict[str, Dict[str, int]]):
+    def count_vocab_items(self, token: TokenizerToken, counter: Dict[str, Dict[str, int]]):
         if token.text is None:
             raise ConfigurationError("TokenCharactersIndexer needs a tokenizer that retains text")
         for character in self._character_tokenizer.tokenize(token.text):
@@ -81,7 +81,7 @@ class TokenCharactersIndexer(TokenIndexer):
 
     @overrides
     def tokens_to_indices(
-        self, tokens: List[Token], vocabulary: Vocabulary
+        self, tokens: List[TokenizerToken], vocabulary: Vocabulary
     ) -> Dict[str, List[List[int]]]:
         indices: List[List[int]] = []
         for token in itertools.chain(self._start_tokens, tokens, self._end_tokens):
diff --git a/combo/data/token_indexers/token_features_indexer.py b/combo/data/token_indexers/token_features_indexer.py
index 9ae6547..b3ad468 100644
--- a/combo/data/token_indexers/token_features_indexer.py
+++ b/combo/data/token_indexers/token_features_indexer.py
@@ -9,7 +9,8 @@ from typing import List, Dict
 import torch
 from overrides import overrides
 
-from combo.data import Token, Vocabulary
+from combo.data import Vocabulary
+from combo.data.tokenizers.tokenizer import TokenizerToken
 from combo.data.token_indexers.token_indexer import TokenIndexer, IndexedTokenList
 from combo.utils import pad_sequence_to_length
 
@@ -27,13 +28,13 @@ class TokenFeatsIndexer(TokenIndexer, ABC):
         self._feature_name = feature_name
 
     @overrides
-    def count_vocab_items(self, token: Token, counter: Dict[str, Dict[str, int]]):
+    def count_vocab_items(self, token: TokenizerToken, counter: Dict[str, Dict[str, int]]):
         feats = self._feat_values(token)
         for feat in feats:
             counter[self.namespace][feat] += 1
 
     @overrides
-    def tokens_to_indices(self, tokens: List[Token], vocabulary: Vocabulary) -> IndexedTokenList:
+    def tokens_to_indices(self, tokens: List[TokenizerToken], vocabulary: Vocabulary) -> IndexedTokenList:
         indices: List[List[int]] = []
         vocab_size = vocabulary.get_vocab_size(self.namespace)
         for token in tokens:
diff --git a/combo/data/token_indexers/token_indexer.py b/combo/data/token_indexers/token_indexer.py
index eaaa01a..0e2921c 100644
--- a/combo/data/token_indexers/token_indexer.py
+++ b/combo/data/token_indexers/token_indexer.py
@@ -7,7 +7,8 @@ from typing import Any, Dict, List
 
 import torch
 
-from combo.data import Token, Vocabulary
+from combo.data.tokenizers.tokenizer import TokenizerToken
+from combo.data.vocabulary import Vocabulary
 from combo.utils import pad_sequence_to_length
 
 # An indexed token list represents the arguments that will be passed to a TokenEmbedder
@@ -42,7 +43,7 @@ class TokenIndexer:
     def __init__(self, token_min_padding_length: int = 0) -> None:
         self._token_min_padding_length: int = token_min_padding_length
 
-    def count_vocab_items(self, token: Token, counter: Dict[str, Dict[str, int]]):
+    def count_vocab_items(self, token: TokenizerToken, counter: Dict[str, Dict[str, int]]):
         """
         The :class:`Vocabulary` needs to assign indices to whatever strings we see in the training
         data (possibly doing some frequency filtering and using an OOV, or out of vocabulary,
@@ -53,7 +54,7 @@ class TokenIndexer:
         """
         raise NotImplementedError
 
-    def tokens_to_indices(self, tokens: List[Token], vocabulary: Vocabulary) -> IndexedTokenList:
+    def tokens_to_indices(self, tokens: List[TokenizerToken], vocabulary: Vocabulary) -> IndexedTokenList:
         """
         Takes a list of tokens and converts them to an `IndexedTokenList`.
         This could be just an ID for each token from the vocabulary.
@@ -66,7 +67,7 @@ class TokenIndexer:
 
     def indices_to_tokens(
         self, indexed_tokens: IndexedTokenList, vocabulary: Vocabulary
-    ) -> List[Token]:
+    ) -> List[TokenizerToken]:
         """
         Inverse operations of tokens_to_indices. Takes an `IndexedTokenList` and converts it back
         into a list of tokens.
@@ -121,4 +122,4 @@ class TokenIndexer:
     def __eq__(self, other) -> bool:
         if isinstance(self, other.__class__):
             return self.__dict__ == other.__dict__
-        return NotImplemented
\ No newline at end of file
+        return NotImplemented
diff --git a/combo/models/encoder.py b/combo/models/encoder.py
index 2d525fa..1904cb8 100644
--- a/combo/models/encoder.py
+++ b/combo/models/encoder.py
@@ -21,7 +21,7 @@ class StackedBidirectionalLstm(torch.nn.Module):
     """
     A standard stacked Bidirectional LSTM where the LSTM layers
     are concatenated between each layer. The only difference between
-    this and a regular bidirectional LSTM is the application of
+    this and a regular bidirectional LSTM is the app alication of
     variational dropout to the hidden states and outputs of each layer apart
     from the last layer of the LSTM. Note that this will be slower, as it
     doesn't use CUDNN.
diff --git a/test.tsv b/test.tsv
new file mode 100644
index 0000000..9bc21c4
--- /dev/null
+++ b/test.tsv
@@ -0,0 +1,2 @@
+Dog,one
+Cat,two
-- 
GitLab