diff --git a/src/winer_worker.py b/src/winer_worker.py index 07932766510d1b7c3bd63c170c1f5c6ff170df9d..c0265e4b32e4ce57a56a097baf44e6b44e61400e 100644 --- a/src/winer_worker.py +++ b/src/winer_worker.py @@ -71,14 +71,14 @@ class WinerWorker(nlp_ws.NLPWorker): packed_plain_inputs = [] packed_tokenized_inputs = [] packed_sent_starts = [] - + current_plain_inputs = [] current_tokenized_inputs = [] current_sent_start = [] current_token_count = 0 for sentence, sentence_tokens, sent_start in zip( - plain_inputs, tokenized_inputs, sent_starts): + plain_inputs, tokenized_inputs, sent_starts): if current_token_count + len(sentence_tokens) <= max_tokens: current_plain_inputs.append(sentence) @@ -89,7 +89,7 @@ class WinerWorker(nlp_ws.NLPWorker): packed_plain_inputs.append(' '.join(current_plain_inputs)) packed_tokenized_inputs.append(current_tokenized_inputs) packed_sent_starts.append(current_sent_start[0]) - + # Reset for a new batch current_plain_inputs = [] current_tokenized_inputs = [] @@ -101,9 +101,8 @@ class WinerWorker(nlp_ws.NLPWorker): packed_plain_inputs.append(' '.join(current_plain_inputs)) packed_tokenized_inputs.append(current_tokenized_inputs) packed_sent_starts.append(current_sent_start[0]) - - return packed_plain_inputs, packed_tokenized_inputs, packed_sent_starts + return packed_plain_inputs, packed_tokenized_inputs, packed_sent_starts def process( self, @@ -136,7 +135,7 @@ class WinerWorker(nlp_ws.NLPWorker): packed_tokenized_inputs, packed_sent_starts ) = ( - pack_sentences_to_max_tokens( + self.pack_sentences_to_max_tokens( plain_inputs, tokenized_inputs, sent_starts, @@ -165,4 +164,4 @@ class WinerWorker(nlp_ws.NLPWorker): # Clean the memory gc.collect() - torch.cuda.empty_cache() + torch.cuda.empty_cache() \ No newline at end of file