Skip to content
Snippets Groups Projects
Commit 44b886c3 authored by Marek Maziarz's avatar Marek Maziarz
Browse files

doctring added to pack_sentences_to_max_tokens method

parent ec05cf26
Branches
No related tags found
No related merge requests found
Pipeline #19505 failed
...@@ -34,6 +34,25 @@ class WinerWorker(nlp_ws.NLPWorker): ...@@ -34,6 +34,25 @@ class WinerWorker(nlp_ws.NLPWorker):
@staticmethod @staticmethod
def pack_sentences_to_max_tokens(plain_inputs, tokenized_inputs, sent_starts, max_tokens=512): def pack_sentences_to_max_tokens(plain_inputs, tokenized_inputs, sent_starts, max_tokens=512):
"""
Pack sentences into chunks, ensuring that the token count per chunk does not exceed a given maximum.
This method takes in plain text sentences, their tokenized versions, and sentence start indices,
and it creates batches of sentences such that the total token count per batch does not exceed the
specified max_tokens limit.
Args:
plain_inputs (list): List of plain text sentences.
tokenized_inputs (list): List of tokenized versions of the sentences.
sent_starts (list): List of sentence start indices.
max_tokens (int, optional): The maximum number of tokens allowed per chunk. Defaults to 512.
Returns:
tuple:
- packed_plain_inputs (list): List of packed plain text sentences, where each element is a chunk of sentences.
- packed_tokenized_inputs (list): List of packed tokenized inputs corresponding to the plain text chunks.
- packed_sent_starts (list): List of sentence start indices for each chunk.
"""
packed_plain_inputs = [] packed_plain_inputs = []
packed_tokenized_inputs = [] packed_tokenized_inputs = []
packed_sent_starts = [] packed_sent_starts = []
...@@ -87,7 +106,6 @@ class WinerWorker(nlp_ws.NLPWorker): ...@@ -87,7 +106,6 @@ class WinerWorker(nlp_ws.NLPWorker):
worker will store result. worker will store result.
:type output_path: str :type output_path: str
""" """
# Read inputs and open output # Read inputs and open output
F_ASCII = task_options.get("ensure_ascii", False) F_ASCII = task_options.get("ensure_ascii", False)
with clarin_json.open(output_path, "w", ensure_ascii=F_ASCII) as fout: with clarin_json.open(output_path, "w", ensure_ascii=F_ASCII) as fout:
...@@ -98,7 +116,7 @@ class WinerWorker(nlp_ws.NLPWorker): ...@@ -98,7 +116,7 @@ class WinerWorker(nlp_ws.NLPWorker):
get_sentences_from_document(document) get_sentences_from_document(document)
packed_plain_texts, packed_tokenized_inputs, packed_sent_starts = pack_sentences_to_max_tokens( packed_plain_texts, packed_tokenized_inputs, packed_sent_starts = pack_sentences_to_max_tokens(
plain_inputs, tokenized_inputs, sent_starts, max_tokens=512 plain_inputs, tokenized_inputs, sent_starts, max_tokens=510
) )
# Process data # Process data
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment