From 8f0f24ac07f0b927783658054e2406f1c81aa8fc Mon Sep 17 00:00:00 2001 From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com> Date: Wed, 14 Apr 2021 11:27:54 +0200 Subject: [PATCH] Simplify batch sampler. --- combo/data/samplers/samplers.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/combo/data/samplers/samplers.py b/combo/data/samplers/samplers.py index dcb83ee..a175498 100644 --- a/combo/data/samplers/samplers.py +++ b/combo/data/samplers/samplers.py @@ -35,17 +35,4 @@ class TokenCountBatchSampler(allen_data.BatchSampler): return batches def get_num_batches(self, instances: Sequence[data.Instance]) -> int: - dataset = list(instances) - batches = [] - batch = [] - words_count = 0 - lengths = [len(instance.fields["sentence"].tokens) for instance in dataset] - argsorted_lengths = np.argsort(lengths) - for idx in argsorted_lengths: - words_count += lengths[idx] - batch.append(idx) - if words_count > self._word_batch_size: - batches.append(batch) - words_count = 0 - batch = [] - return len(batches) + return sum(1 for _ in self.get_batch_indices(instances)) -- GitLab