From f8d8a05773d942b266d6000adc6ded7b1f5c6d75 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Maja=20Jab=C5=82o=C5=84ska?= <majajjablonska@gmail.com>
Date: Thu, 23 Mar 2023 22:36:56 +0100
Subject: [PATCH] Add TokenCountBatchSampler from COMBO

---
 combo/data/samplers/__init__.py      |  2 +-
 combo/data/samplers/base_sampler.py  |  2 --
 combo/data/samplers/batch_sampler.py | 21 +++++++++++++++++++++
 combo/data/samplers/samplers.py      |  9 +++++++--
 4 files changed, 29 insertions(+), 5 deletions(-)
 delete mode 100644 combo/data/samplers/base_sampler.py
 create mode 100644 combo/data/samplers/batch_sampler.py

diff --git a/combo/data/samplers/__init__.py b/combo/data/samplers/__init__.py
index 1db58a9..7e90dca 100644
--- a/combo/data/samplers/__init__.py
+++ b/combo/data/samplers/__init__.py
@@ -1,2 +1,2 @@
-from .base_sampler import Sampler
+from .batch_sampler import BatchSampler
 from .samplers import TokenCountBatchSampler
diff --git a/combo/data/samplers/base_sampler.py b/combo/data/samplers/base_sampler.py
deleted file mode 100644
index e570a36..0000000
--- a/combo/data/samplers/base_sampler.py
+++ /dev/null
@@ -1,2 +0,0 @@
-class Sampler:
-    pass
diff --git a/combo/data/samplers/batch_sampler.py b/combo/data/samplers/batch_sampler.py
new file mode 100644
index 0000000..c6c67ee
--- /dev/null
+++ b/combo/data/samplers/batch_sampler.py
@@ -0,0 +1,21 @@
+"""
+Re-implemented AllenNLP BatchSampler
+https://github.com/allenai/allennlp/blob/main/allennlp/data/samplers/batch_sampler.py
+"""
+from typing import Sequence, Iterable, List, Optional
+from torch import Tensor
+
+
+class BatchSampler:
+    def get_batch_indices(self, instances: Sequence[Tensor]) -> Iterable[List[int]]:
+        raise NotImplementedError
+
+    def get_num_batches(self, instances: Sequence[Tensor]) -> int:
+        raise NotImplementedError
+
+    def get_batch_size(self) -> Optional[int]:
+        """
+        Not all `BatchSamplers` define a consistent `batch_size`, but those that
+        do should override this method.
+        """
+        return None
diff --git a/combo/data/samplers/samplers.py b/combo/data/samplers/samplers.py
index ee32a5e..3dcab19 100644
--- a/combo/data/samplers/samplers.py
+++ b/combo/data/samplers/samplers.py
@@ -1,11 +1,16 @@
+"""
+Adapted from COMBO
+Author: Mateusz Klimaszewski
+"""
+
 from typing import List
 
 import numpy as np
 
-from combo.data.samplers import Sampler
+from combo.data.samplers import BatchSampler
 
 
-class TokenCountBatchSampler(Sampler):
+class TokenCountBatchSampler(BatchSampler):
 
     def __init__(self, dataset, word_batch_size: int = 2500, shuffle_dataset: bool = True):
         self._index = 0
-- 
GitLab