Skip to content
Snippets Groups Projects
Commit f8d8a057 authored by Maja Jabłońska's avatar Maja Jabłońska
Browse files

Add TokenCountBatchSampler from COMBO

parent c698ce09
Branches
Tags
1 merge request!46Merge COMBO 3.0 into master
from .base_sampler import Sampler
from .batch_sampler import BatchSampler
from .samplers import TokenCountBatchSampler
class Sampler:
pass
"""
Re-implemented AllenNLP BatchSampler
https://github.com/allenai/allennlp/blob/main/allennlp/data/samplers/batch_sampler.py
"""
from typing import Sequence, Iterable, List, Optional
from torch import Tensor
class BatchSampler:
def get_batch_indices(self, instances: Sequence[Tensor]) -> Iterable[List[int]]:
raise NotImplementedError
def get_num_batches(self, instances: Sequence[Tensor]) -> int:
raise NotImplementedError
def get_batch_size(self) -> Optional[int]:
"""
Not all `BatchSamplers` define a consistent `batch_size`, but those that
do should override this method.
"""
return None
"""
Adapted from COMBO
Author: Mateusz Klimaszewski
"""
from typing import List
import numpy as np
from combo.data.samplers import Sampler
from combo.data.samplers import BatchSampler
class TokenCountBatchSampler(Sampler):
class TokenCountBatchSampler(BatchSampler):
def __init__(self, dataset, word_batch_size: int = 2500, shuffle_dataset: bool = True):
self._index = 0
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment