diff --git a/src/winer_worker.py b/src/winer_worker.py index 7616b65c58a6d3afeb29013bd0c5f3a4a04748ca..141cffcb1ce86444ca31fe3e8c549398cefad400 100644 --- a/src/winer_worker.py +++ b/src/winer_worker.py @@ -39,32 +39,31 @@ class WinerWorker(nlp_ws.NLPWorker): sent_starts: list, max_tokens: int = 510 ): - """ - Pack sentences without exceeding a maximum token count. + """Pack sentences without exceeding a maximum token count. - This method takes in plain text sentences, their tokenized + This method takes in plain text sentences, their tokenized versions, and sentence start indices, - and it creates batches of sentences such that the total + and it creates batches of sentences such that the total token count per batch does not exceed the specified max_tokens limit. Args: plain_inputs (list): List of plain text sentences. - tokenized_inputs (list): List of tokenized versions + tokenized_inputs (list): List of tokenized versions of the sentences. sent_starts (list): List of sentence start indices. - max_tokens (int, optional): The maximum number of + max_tokens (int, optional): The maximum number of tokens allowed per chunk. Defaults to 510. Returns: - tuple: - - packed_plain_inputs (list): List of packed - plain text sentences, where each element + tuple: + - packed_plain_inputs (list): List of packed + plain text sentences, where each element is a chunk of sentences. - - packed_tokenized_inputs (list): List of packed - tokenized inputs corresponding to the plain + - packed_tokenized_inputs (list): List of packed + tokenized inputs corresponding to the plain text chunks. - - packed_sent_starts (list): List of sentence + - packed_sent_starts (list): List of sentence start indices for each chunk. """ @@ -164,4 +163,4 @@ class WinerWorker(nlp_ws.NLPWorker): # Clean the memory gc.collect() - torch.cuda.empty_cache() \ No newline at end of file + torch.cuda.empty_cache()