diff --git a/combo/data/samplers/__init__.py b/combo/data/samplers/__init__.py index 1db58a9b3221ba0442c0d0b3321be208ab2c9daa..7e90dca36a0bf3ca8e0c7cbf939ca0f4305d03e8 100644 --- a/combo/data/samplers/__init__.py +++ b/combo/data/samplers/__init__.py @@ -1,2 +1,2 @@ -from .base_sampler import Sampler +from .batch_sampler import BatchSampler from .samplers import TokenCountBatchSampler diff --git a/combo/data/samplers/base_sampler.py b/combo/data/samplers/base_sampler.py deleted file mode 100644 index e570a36d5a38abc945e93a133a5b91de376b0489..0000000000000000000000000000000000000000 --- a/combo/data/samplers/base_sampler.py +++ /dev/null @@ -1,2 +0,0 @@ -class Sampler: - pass diff --git a/combo/data/samplers/batch_sampler.py b/combo/data/samplers/batch_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..c6c67eee08e16cd3259e1247e411c01929c5705e --- /dev/null +++ b/combo/data/samplers/batch_sampler.py @@ -0,0 +1,21 @@ +""" +Re-implemented AllenNLP BatchSampler +https://github.com/allenai/allennlp/blob/main/allennlp/data/samplers/batch_sampler.py +""" +from typing import Sequence, Iterable, List, Optional +from torch import Tensor + + +class BatchSampler: + def get_batch_indices(self, instances: Sequence[Tensor]) -> Iterable[List[int]]: + raise NotImplementedError + + def get_num_batches(self, instances: Sequence[Tensor]) -> int: + raise NotImplementedError + + def get_batch_size(self) -> Optional[int]: + """ + Not all `BatchSamplers` define a consistent `batch_size`, but those that + do should override this method. + """ + return None diff --git a/combo/data/samplers/samplers.py b/combo/data/samplers/samplers.py index ee32a5e54ef9a58e047182594294103c501003cf..3dcab19fc6621cb25c7d34321f0b4feeab20612c 100644 --- a/combo/data/samplers/samplers.py +++ b/combo/data/samplers/samplers.py @@ -1,11 +1,16 @@ +""" +Adapted from COMBO +Author: Mateusz Klimaszewski +""" + from typing import List import numpy as np -from combo.data.samplers import Sampler +from combo.data.samplers import BatchSampler -class TokenCountBatchSampler(Sampler): +class TokenCountBatchSampler(BatchSampler): def __init__(self, dataset, word_batch_size: int = 2500, shuffle_dataset: bool = True): self._index = 0