diff --git a/combo/data/samplers/__init__.py b/combo/data/samplers/__init__.py
index 1db58a9b3221ba0442c0d0b3321be208ab2c9daa..7e90dca36a0bf3ca8e0c7cbf939ca0f4305d03e8 100644
--- a/combo/data/samplers/__init__.py
+++ b/combo/data/samplers/__init__.py
@@ -1,2 +1,2 @@
-from .base_sampler import Sampler
+from .batch_sampler import BatchSampler
 from .samplers import TokenCountBatchSampler
diff --git a/combo/data/samplers/base_sampler.py b/combo/data/samplers/base_sampler.py
deleted file mode 100644
index e570a36d5a38abc945e93a133a5b91de376b0489..0000000000000000000000000000000000000000
--- a/combo/data/samplers/base_sampler.py
+++ /dev/null
@@ -1,2 +0,0 @@
-class Sampler:
-    pass
diff --git a/combo/data/samplers/batch_sampler.py b/combo/data/samplers/batch_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6c67eee08e16cd3259e1247e411c01929c5705e
--- /dev/null
+++ b/combo/data/samplers/batch_sampler.py
@@ -0,0 +1,21 @@
+"""
+Re-implemented AllenNLP BatchSampler
+https://github.com/allenai/allennlp/blob/main/allennlp/data/samplers/batch_sampler.py
+"""
+from typing import Sequence, Iterable, List, Optional
+from torch import Tensor
+
+
+class BatchSampler:
+    def get_batch_indices(self, instances: Sequence[Tensor]) -> Iterable[List[int]]:
+        raise NotImplementedError
+
+    def get_num_batches(self, instances: Sequence[Tensor]) -> int:
+        raise NotImplementedError
+
+    def get_batch_size(self) -> Optional[int]:
+        """
+        Not all `BatchSamplers` define a consistent `batch_size`, but those that
+        do should override this method.
+        """
+        return None
diff --git a/combo/data/samplers/samplers.py b/combo/data/samplers/samplers.py
index ee32a5e54ef9a58e047182594294103c501003cf..3dcab19fc6621cb25c7d34321f0b4feeab20612c 100644
--- a/combo/data/samplers/samplers.py
+++ b/combo/data/samplers/samplers.py
@@ -1,11 +1,16 @@
+"""
+Adapted from COMBO
+Author: Mateusz Klimaszewski
+"""
+
 from typing import List
 
 import numpy as np
 
-from combo.data.samplers import Sampler
+from combo.data.samplers import BatchSampler
 
 
-class TokenCountBatchSampler(Sampler):
+class TokenCountBatchSampler(BatchSampler):
 
     def __init__(self, dataset, word_batch_size: int = 2500, shuffle_dataset: bool = True):
         self._index = 0