From f8d8a05773d942b266d6000adc6ded7b1f5c6d75 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maja=20Jab=C5=82o=C5=84ska?= <majajjablonska@gmail.com> Date: Thu, 23 Mar 2023 22:36:56 +0100 Subject: [PATCH] Add TokenCountBatchSampler from COMBO --- combo/data/samplers/__init__.py | 2 +- combo/data/samplers/base_sampler.py | 2 -- combo/data/samplers/batch_sampler.py | 21 +++++++++++++++++++++ combo/data/samplers/samplers.py | 9 +++++++-- 4 files changed, 29 insertions(+), 5 deletions(-) delete mode 100644 combo/data/samplers/base_sampler.py create mode 100644 combo/data/samplers/batch_sampler.py diff --git a/combo/data/samplers/__init__.py b/combo/data/samplers/__init__.py index 1db58a9..7e90dca 100644 --- a/combo/data/samplers/__init__.py +++ b/combo/data/samplers/__init__.py @@ -1,2 +1,2 @@ -from .base_sampler import Sampler +from .batch_sampler import BatchSampler from .samplers import TokenCountBatchSampler diff --git a/combo/data/samplers/base_sampler.py b/combo/data/samplers/base_sampler.py deleted file mode 100644 index e570a36..0000000 --- a/combo/data/samplers/base_sampler.py +++ /dev/null @@ -1,2 +0,0 @@ -class Sampler: - pass diff --git a/combo/data/samplers/batch_sampler.py b/combo/data/samplers/batch_sampler.py new file mode 100644 index 0000000..c6c67ee --- /dev/null +++ b/combo/data/samplers/batch_sampler.py @@ -0,0 +1,21 @@ +""" +Re-implemented AllenNLP BatchSampler +https://github.com/allenai/allennlp/blob/main/allennlp/data/samplers/batch_sampler.py +""" +from typing import Sequence, Iterable, List, Optional +from torch import Tensor + + +class BatchSampler: + def get_batch_indices(self, instances: Sequence[Tensor]) -> Iterable[List[int]]: + raise NotImplementedError + + def get_num_batches(self, instances: Sequence[Tensor]) -> int: + raise NotImplementedError + + def get_batch_size(self) -> Optional[int]: + """ + Not all `BatchSamplers` define a consistent `batch_size`, but those that + do should override this method. + """ + return None diff --git a/combo/data/samplers/samplers.py b/combo/data/samplers/samplers.py index ee32a5e..3dcab19 100644 --- a/combo/data/samplers/samplers.py +++ b/combo/data/samplers/samplers.py @@ -1,11 +1,16 @@ +""" +Adapted from COMBO +Author: Mateusz Klimaszewski +""" + from typing import List import numpy as np -from combo.data.samplers import Sampler +from combo.data.samplers import BatchSampler -class TokenCountBatchSampler(Sampler): +class TokenCountBatchSampler(BatchSampler): def __init__(self, dataset, word_batch_size: int = 2500, shuffle_dataset: bool = True): self._index = 0 -- GitLab