Skip to content
Snippets Groups Projects
Commit 8f0f24ac authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski
Browse files

Simplify batch sampler.

parent 2b72300b
Branches
Tags
No related merge requests found
......@@ -35,17 +35,4 @@ class TokenCountBatchSampler(allen_data.BatchSampler):
return batches
def get_num_batches(self, instances: Sequence[data.Instance]) -> int:
dataset = list(instances)
batches = []
batch = []
words_count = 0
lengths = [len(instance.fields["sentence"].tokens) for instance in dataset]
argsorted_lengths = np.argsort(lengths)
for idx in argsorted_lengths:
words_count += lengths[idx]
batch.append(idx)
if words_count > self._word_batch_size:
batches.append(batch)
words_count = 0
batch = []
return len(batches)
return sum(1 for _ in self.get_batch_indices(instances))
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment