From 8f0f24ac07f0b927783658054e2406f1c81aa8fc Mon Sep 17 00:00:00 2001
From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com>
Date: Wed, 14 Apr 2021 11:27:54 +0200
Subject: [PATCH] Simplify batch sampler.

---
 combo/data/samplers/samplers.py | 15 +--------------
 1 file changed, 1 insertion(+), 14 deletions(-)

diff --git a/combo/data/samplers/samplers.py b/combo/data/samplers/samplers.py
index dcb83ee..a175498 100644
--- a/combo/data/samplers/samplers.py
+++ b/combo/data/samplers/samplers.py
@@ -35,17 +35,4 @@ class TokenCountBatchSampler(allen_data.BatchSampler):
         return batches
 
     def get_num_batches(self, instances: Sequence[data.Instance]) -> int:
-        dataset = list(instances)
-        batches = []
-        batch = []
-        words_count = 0
-        lengths = [len(instance.fields["sentence"].tokens) for instance in dataset]
-        argsorted_lengths = np.argsort(lengths)
-        for idx in argsorted_lengths:
-            words_count += lengths[idx]
-            batch.append(idx)
-            if words_count > self._word_batch_size:
-                batches.append(batch)
-                words_count = 0
-                batch = []
-        return len(batches)
+        return sum(1 for _ in self.get_batch_indices(instances))
-- 
GitLab