diff --git a/src/winer_worker.py b/src/winer_worker.py index 47d522a3e46f594b20db5533b575bcf01fd78a1e..07932766510d1b7c3bd63c170c1f5c6ff170df9d 100644 --- a/src/winer_worker.py +++ b/src/winer_worker.py @@ -33,25 +33,40 @@ class WinerWorker(nlp_ws.NLPWorker): self._model._tokenizer.model_max_length = 512 @staticmethod - def pack_sentences_to_max_tokens(plain_inputs, tokenized_inputs, sent_starts, max_tokens=512): + def pack_sentences_to_max_tokens( + plain_inputs, + tokenized_inputs, + sent_starts, + max_tokens=512 + ): """ - Pack sentences into chunks, ensuring that the token count per chunk does not exceed a given maximum. + Pack sentences into chunks, ensuring that the token count + per chunk does not exceed a given maximum. - This method takes in plain text sentences, their tokenized versions, and sentence start indices, - and it creates batches of sentences such that the total token count per batch does not exceed the + This method takes in plain text sentences, their tokenized + versions, and sentence start indices, + and it creates batches of sentences such that the total + token count per batch does not exceed the specified max_tokens limit. Args: plain_inputs (list): List of plain text sentences. - tokenized_inputs (list): List of tokenized versions of the sentences. + tokenized_inputs (list): List of tokenized versions + of the sentences. sent_starts (list): List of sentence start indices. - max_tokens (int, optional): The maximum number of tokens allowed per chunk. Defaults to 512. + max_tokens (int, optional): The maximum number of + tokens allowed per chunk. Defaults to 512. Returns: tuple: - - packed_plain_inputs (list): List of packed plain text sentences, where each element is a chunk of sentences. - - packed_tokenized_inputs (list): List of packed tokenized inputs corresponding to the plain text chunks. - - packed_sent_starts (list): List of sentence start indices for each chunk. + - packed_plain_inputs (list): List of packed + plain text sentences, where each element + is a chunk of sentences. + - packed_tokenized_inputs (list): List of packed + tokenized inputs corresponding to the plain + text chunks. + - packed_sent_starts (list): List of sentence + start indices for each chunk. """ packed_plain_inputs = [] packed_tokenized_inputs = [] @@ -62,7 +77,8 @@ class WinerWorker(nlp_ws.NLPWorker): current_sent_start = [] current_token_count = 0 - for sentence, sentence_tokens, sent_start in zip(plain_inputs, tokenized_inputs, sent_starts): + for sentence, sentence_tokens, sent_start in zip( + plain_inputs, tokenized_inputs, sent_starts): if current_token_count + len(sentence_tokens) <= max_tokens: current_plain_inputs.append(sentence) @@ -115,8 +131,17 @@ class WinerWorker(nlp_ws.NLPWorker): plain_inputs, tokenized_inputs, sent_starts = \ get_sentences_from_document(document) - packed_plain_texts, packed_tokenized_inputs, packed_sent_starts = pack_sentences_to_max_tokens( - plain_inputs, tokenized_inputs, sent_starts, max_tokens=510 + ( + packed_plain_texts, + packed_tokenized_inputs, + packed_sent_starts + ) = ( + pack_sentences_to_max_tokens( + plain_inputs, + tokenized_inputs, + sent_starts, + max_tokens=510 + ) ) # Process data