Skip to content
Snippets Groups Projects
Commit f8d8a057 authored by Maja Jabłońska's avatar Maja Jabłońska
Browse files

Add TokenCountBatchSampler from COMBO

parent c698ce09
No related branches found
No related tags found
1 merge request!46Merge COMBO 3.0 into master
from .base_sampler import Sampler
from .batch_sampler import BatchSampler
from .samplers import TokenCountBatchSampler
class Sampler:
pass
"""
Re-implemented AllenNLP BatchSampler
https://github.com/allenai/allennlp/blob/main/allennlp/data/samplers/batch_sampler.py
"""
from typing import Sequence, Iterable, List, Optional
from torch import Tensor
class BatchSampler:
def get_batch_indices(self, instances: Sequence[Tensor]) -> Iterable[List[int]]:
raise NotImplementedError
def get_num_batches(self, instances: Sequence[Tensor]) -> int:
raise NotImplementedError
def get_batch_size(self) -> Optional[int]:
"""
Not all `BatchSamplers` define a consistent `batch_size`, but those that
do should override this method.
"""
return None
"""
Adapted from COMBO
Author: Mateusz Klimaszewski
"""
from typing import List
import numpy as np
from combo.data.samplers import Sampler
from combo.data.samplers import BatchSampler
class TokenCountBatchSampler(Sampler):
class TokenCountBatchSampler(BatchSampler):
def __init__(self, dataset, word_batch_size: int = 2500, shuffle_dataset: bool = True):
self._index = 0
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment