diff --git a/combo/data/dataset_readers/__init__.py b/combo/data/dataset_readers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/combo/data/dataset_readers/dataset_reader.py b/combo/data/dataset_readers/dataset_reader.py
new file mode 100644
index 0000000000000000000000000000000000000000..796496aa4029528e5627d3b0693c648ff3d2825d
--- /dev/null
+++ b/combo/data/dataset_readers/dataset_reader.py
@@ -0,0 +1,353 @@
+"""
+Adapted from AllenNLP
+https://github.com/allenai/allennlp/blob/main/allennlp/data/dataset_readers/dataset_reader.py
+"""
+from dataclasses import dataclass
+import itertools
+from os import PathLike
+from typing import Iterable, Iterator, Optional, Union, TypeVar, Dict, List
+import logging
+import warnings
+
+import torch.distributed as dist
+
+from combo.data import Instance
+from combo.data.dataset_readers import utils
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class WorkerInfo:
+    """
+    Contains information about the worker context when a `DatasetReader`
+    is being used within a multi-process `DataLoader`.
+    From a `DatasetReader` this can accessed with the [`get_worker_info()`](#get_worker_info) method.
+    """
+
+    num_workers: int
+    """
+    The total number of workers.
+    """
+
+    id: int
+    """
+    The 0-indexed ID of the current worker.
+    """
+
+
+@dataclass
+class DistributedInfo:
+    """
+    Contains information about the global process rank and total world size when the reader is being
+    used within distributed training.
+    From a `DatasetReader` this can be accessed with the [`get_distributed_info()`](#get_distributed_info) method.
+    """
+
+    world_size: int
+    """
+    The total number of processes in the distributed group.
+    """
+
+    global_rank: int
+    """
+    The 0-indexed ID of the current process within the distributed group.
+    This will be between 0 and `world_size - 1`, inclusive.
+    """
+
+
+_T = TypeVar("_T")
+
+PathOrStr = Union[PathLike, str]
+DatasetReaderInput = Union[PathOrStr, List[PathOrStr], Dict[str, PathOrStr]]
+
+
+class DatasetReader:
+    """
+    A `DatasetReader` knows how to turn a file containing a dataset into a collection
+    of `Instance`s.  To implement your own, just override the [`_read(file_path)`](#_read) method
+    to return an `Iterable` of the instances. Ideally this should be a lazy generator
+    that yields them one at a time.
+    All parameters necessary to `_read` the data apart from the filepath should be passed
+    to the constructor of the `DatasetReader`.
+    You should also implement [`text_to_instance(*inputs)`](#text_to_instance),
+    which should be used to turn raw data into `Instance`s. This method is required
+    in order to use a `Predictor` with your reader.
+    Usually the `_read()` method is implemented to call `text_to_instance()`.
+    # Parameters
+    max_instances : `int`, optional (default=`None`)
+        If given, will stop reading after this many instances. This is a useful setting for debugging.
+        Setting this disables caching.
+    manual_distributed_sharding: `bool`, optional (default=`False`)
+        By default, when used in a distributed setting, `DatasetReader` makes sure that each
+        trainer process only receives a subset of the data. It does this by reading the whole
+        dataset in each worker, but filtering out the instances that are not needed.
+        While this ensures that each worker will recieve unique instances, it's not a very efficient
+        way to do so since each worker still needs to process every single instance.
+        A better way to handle this is to manually handle the filtering within your `_read()`
+        method, in which case you should set `manual_distributed_sharding` to `True` so that
+        the base class knows that you handling the filtering.
+        See the section below about how to do this.
+    manual_multiprocess_sharding : `bool`, optional (default=`False`)
+        This is similar to the `manual_distributed_sharding` parameter, but applies to
+        multi-process data loading. By default, when this reader is used by a multi-process
+        data loader (i.e. a `DataLoader` with `num_workers > 1`), each worker will
+        filter out all but a subset of the instances that are needed so that you
+        don't end up with duplicates.
+        However, there is really no benefit to using multiple workers in your `DataLoader`
+        unless you implement the sharding within your `_read()` method, in which
+        case you should set `manual_multiprocess_sharding` to `True`, just as with
+        `manual_distributed_sharding`.
+        See the section below about how to do this.
+    serialization_dir: `str`, optional (default=`None`)
+        The directory in which the training output is saved to, or the directory the model is loaded from.
+        !!! Note
+            This is typically not given an entry in a configuration file. It will be set automatically
+            when using the built-in `allennp` commands.
+    # Using your reader with multi-process or distributed data loading
+    There are two things you may need to update in your `DatasetReader` in order for
+    it to be efficient in the multi-process or distributed data loading context.
+    1. The `_read()` method should handle filtering out all but the instances that
+        each particular worker should generate.
+        This is important because the default mechanism for filtering out `Instance`s in
+        the distributed or multi-process `DataLoader` setting is not very efficient, since every
+        worker would still need to process every single `Instance` in your dataset.
+        But by manually handling the filtering / sharding within your `_read()` method, each
+        worker only needs to perform a subset of the work required to create instances.
+        For example, if you were training using 2 GPUs and your `_read()` method reads a file
+        line-by-line, creating one `Instance` for each line, you could just check the node
+        rank within `_read()` and then throw away every other line starting at the line number
+        corresponding to the node rank.
+        The helper method [`shard_iterable()`](#shard_iterable) is there to make this easy for you.
+        You can wrap this around any iterable object in your `_read()` method, and it will
+        return an iterator that skips the right items based on the distributed training
+        or multi-process loading context. This method can always be called regardless
+        of whether or not you're actually using distributed training or multi-process loading.
+        Remember though that when you handle the sharding manually within `_read()`, you need
+        to let the `DatasetReader` know about this so that it doesn't do any additional
+        filtering. Therefore you need to ensure that both `self.manual_distributed_sharding` and
+        `self.manual_multiprocess_sharding` are set to `True`.
+        If you call the helper method `shard_iterable()` without setting these to `True`,
+        you'll get an exception.
+    2. If the instances generated by `_read()` contain `TextField`s, those `TextField`s
+        should not have any token indexers assigned. The token indexers need to be applied
+        in the [`apply_token_indexers()`](#apply_token_indexers) method instead.
+        This is highly recommended because if the instances generated by your `_read()` method
+        have token indexers attached, those indexers will be duplicated when they are sent across
+        processes. If your token indexers contain large objects (such as `PretrainedTransformerTokenIndexer`s)
+        this could take up a massive amount of memory.
+    """
+
+    def __init__(
+        self,
+        max_instances: Optional[int] = None,
+        manual_distributed_sharding: bool = False,
+        manual_multiprocess_sharding: bool = False,
+        serialization_dir: Optional[str] = None,
+    ) -> None:
+        # Do some validation.
+        if max_instances is not None and max_instances < 0:
+            raise ValueError("If specified, max_instances should be a positive int")
+
+        self.max_instances = max_instances
+        self.manual_distributed_sharding = manual_distributed_sharding
+        self.manual_multiprocess_sharding = manual_multiprocess_sharding
+        self.serialization_dir = serialization_dir
+        self._worker_info: Optional[WorkerInfo] = None
+        self._distributed_info: Optional[DistributedInfo] = None
+        # If we're actually in the main process, we can find the info using torch utils.
+        if utils.is_distributed():
+            self._distributed_info = DistributedInfo(dist.get_world_size(), dist.get_rank())
+
+    def read(self, file_path: DatasetReaderInput) -> Iterator[Instance]:
+        """
+        Returns an iterator of instances that can be read from the file path.
+        """
+        for instance in self._multi_worker_islice(self._read(file_path)):  # type: ignore
+            if self._worker_info is None:
+                # If not running in a subprocess, it's safe to apply the token_indexers right away.
+                self.apply_token_indexers(instance)
+            yield instance
+
+    def _read(self, file_path) -> Iterable[Instance]:
+        """
+        Reads the instances from the given `file_path` and returns them as an
+        `Iterable`.
+        You are strongly encouraged to use a generator so that users can
+        read a dataset in a lazy way, if they so choose.
+        """
+        # NOTE: `file_path` is left untyped here on purpose.
+        # Technically the type should be `DatasetReaderInput`, but many subclass
+        # implementations of `DatasetReader` define their `_read()` method to take a more
+        # specific type, such as just `str`. But that would be a type error
+        # according to mypy: https://mypy.readthedocs.io/en/stable/common_issues.html#incompatible-overrides
+        raise NotImplementedError
+
+    def text_to_instance(self, *inputs) -> Instance:
+        """
+        Does whatever tokenization or processing is necessary to go from textual input to an
+        `Instance`.  The primary intended use for this is with a
+        :class:`~allennlp.predictors.predictor.Predictor`, which gets text input as a JSON
+        object and needs to process it to be input to a model.
+        The intent here is to share code between :func:`_read` and what happens at
+        model serving time, or any other time you want to make a prediction from new data.  We need
+        to process the data in the same way it was done at training time.  Allowing the
+        `DatasetReader` to process new text lets us accomplish this, as we can just call
+        `DatasetReader.text_to_instance` when serving predictions.
+        The input type here is rather vaguely specified, unfortunately.  The `Predictor` will
+        have to make some assumptions about the kind of `DatasetReader` that it's using, in order
+        to pass it the right information.
+        """
+        raise NotImplementedError
+
+    def apply_token_indexers(self, instance: Instance) -> None:
+        """
+        If `Instance`s created by this reader contain `TextField`s without `token_indexers`,
+        this method can be overriden to set the `token_indexers` of those fields.
+        E.g. if you have you have `"source"` `TextField`, you could implement this method like this:
+        ```python
+        def apply_token_indexers(self, instance: Instance) -> None:
+            instance["source"].token_indexers = self._token_indexers
+        ```
+        If your `TextField`s are wrapped in a `ListField`, you can access them via `field_list`.
+        E.g. if you had a `"source"` field of `ListField[TextField]` objects, you could:
+        ```python
+        for text_field in instance["source"].field_list:
+            text_field.token_indexers = self._token_indexers
+        ```
+        """
+        pass
+
+    def get_worker_info(self) -> Optional[WorkerInfo]:
+        """
+        Provides a [`WorkerInfo`](#WorkerInfo) object when the reader is being used within a
+        worker of a multi-process `DataLoader`.
+        If the reader is in the main process, this is just `None`.
+        !!! NOTE
+            This is different than distributed training. If the `DatasetReader`
+            is being used within distributed training, `get_worker_info()` will only
+            provide information on the `DataLoader` worker within its node.
+            Use [`get_distributed_info`](#get_distributed_info) to get information on distributed
+            training context.
+        """
+        return self._worker_info
+
+    def get_distributed_info(self) -> Optional[DistributedInfo]:
+        """
+        Provides a [`DistributedInfo`](#DistributedInfo) object when the reader is being
+        used within distributed training.
+        If not in distributed training, this is just `None`.
+        """
+        return self._distributed_info
+
+    def _set_worker_info(self, info: Optional[WorkerInfo]) -> None:
+        """
+        Should only be used internally.
+        """
+        self._worker_info = info
+
+    def _set_distributed_info(self, info: Optional[DistributedInfo]) -> None:
+        """
+        Should only be used internally.
+        """
+        self._distributed_info = info
+
+    def shard_iterable(self, iterable: Iterable[_T]) -> Iterator[_T]:
+        """
+        Helper method that determines which items in an iterable object to skip based
+        on the current node rank (for distributed training) and worker ID (for multi-process data loading).
+        """
+        if not self.manual_distributed_sharding or not self.manual_multiprocess_sharding:
+            raise ValueError(
+                "self.shard_iterable() was called but self.manual_distributed_sharding and "
+                "self.manual_multiprocess_sharding was not set to True. Did you forget to call "
+                "super().__init__(manual_distributed_sharding=True, manual_multiprocess_sharding=True) "
+                "in your constructor?"
+            )
+
+        sharded_slice: Iterator[_T] = iter(iterable)
+
+        if utils.is_distributed():
+            sharded_slice = itertools.islice(
+                sharded_slice, dist.get_rank(), None, dist.get_world_size()
+            )
+
+        if self._worker_info is not None:
+            sharded_slice = itertools.islice(
+                sharded_slice, self._worker_info.id, None, self._worker_info.num_workers
+            )
+
+        # We don't know for sure how many instances we have to produce.
+        # _multi_worker_islice() figures that out. But we know for sure
+        # it won't be more than max_instances.
+        if self.max_instances is not None:
+            sharded_slice = itertools.islice(sharded_slice, self.max_instances)
+
+        return sharded_slice
+
+    def _multi_worker_islice(
+        self,
+        iterable: Iterable[_T],
+    ) -> Iterator[_T]:
+        """
+        This is just like `shard_iterable` but is for internal use only.
+        It has some additional logic to handle `max_instances` based on the distributed
+        or multi-process context, and whether or not sharding is handled manually
+        in the `_read()` method.
+        """
+        # This has some complicated logic because any given reader may or may not
+        # implement manual multi-process and manual distributed sharding itself.
+        # We have to handle all possibilities.
+
+        sharded_slice: Iterator[_T] = iter(iterable)
+
+        # We'll adjust max_instances as we go, depending on what sort of sharding is done.
+        # At the end, we want to ensure the total number of instances collected across
+        # all workers processes is equal to self.max_instances.
+        max_instances = self.max_instances
+
+        if self._distributed_info is not None:
+            if max_instances is not None:
+                # Need to scale down max_instances because otherwise each node would read self.max_instances,
+                # but we really want self.max_instances total across all nodes.
+                if self._distributed_info.global_rank < (
+                    max_instances % self._distributed_info.world_size
+                ):
+                    max_instances = max_instances // self._distributed_info.world_size + 1
+                else:
+                    max_instances = max_instances // self._distributed_info.world_size
+
+            if not self.manual_distributed_sharding:
+                sharded_slice = itertools.islice(
+                    sharded_slice,
+                    self._distributed_info.global_rank,
+                    None,
+                    self._distributed_info.world_size,
+                )
+
+        if self._worker_info is not None:
+            if max_instances is not None:
+                # Like in the distributed case above, we need to adjust max_instances.
+                if self._worker_info.id < (max_instances % self._worker_info.num_workers):
+                    max_instances = max_instances // self._worker_info.num_workers + 1
+                else:
+                    max_instances = max_instances // self._worker_info.num_workers
+
+            if not self.manual_multiprocess_sharding:
+                warnings.warn(
+                    "Using multi-process data loading without setting "
+                    "DatasetReader.manual_multiprocess_sharding to True.\n"
+                    "Did you forget to set this?\n"
+                    "If you're not handling the multi-process sharding logic within your "
+                    "_read() method, there is probably no benefit to using more than one "
+                    "worker.",
+                    UserWarning,
+                )
+                sharded_slice = itertools.islice(
+                    sharded_slice, self._worker_info.id, None, self._worker_info.num_workers
+                )
+
+        if max_instances is not None:
+            sharded_slice = itertools.islice(sharded_slice, max_instances)
+
+        return sharded_slice
diff --git a/combo/data/dataset_readers/utils.py b/combo/data/dataset_readers/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0e9b6da66bcaea2eb0d019f497139df6004a3cc
--- /dev/null
+++ b/combo/data/dataset_readers/utils.py
@@ -0,0 +1,8 @@
+import torch.distributed as dist
+
+
+def is_distributed() -> bool:
+    """
+    Checks if the distributed process group is available and has been initialized
+    """
+    return dist.is_available() and dist.is_initialized()