diff --git a/src/winer_worker.py b/src/winer_worker.py index 77c56ed6673634003083ad526ee97f4839ca75cd..c62ddc660efb3b087cc087884b24c33f3597c709 100644 --- a/src/winer_worker.py +++ b/src/winer_worker.py @@ -49,6 +49,43 @@ class WinerWorker(nlp_ws.NLPWorker): worker will store result. :type output_path: str """ + + def pack_sentences_to_max_tokens(plain_inputs, tokenized_inputs, sent_starts, max_tokens=512): + packed_plain_inputs = [] + packed_tokenized_inputs = [] + packed_sent_starts = [] + + current_plain_inputs = [] + current_tokenized_inputs = [] + current_sent_start = [] + current_token_count = 0 + + for sentence, sentence_tokens, sent_start in zip(plain_inputs, tokenized_inputs, sent_starts): + + if current_token_count + len(sentence_tokens) <= max_tokens: + current_plain_inputs.append(sentence) + current_tokenized_inputs.extend(sentence_tokens) + current_sent_start.append(sent_start) + current_token_count += len(sentence_tokens) + else: + packed_plain_inputs.append(' '.join(current_plain_inputs)) + packed_tokenized_inputs.append(current_tokenized_inputs) + packed_sent_starts.append(current_sent_start[0]) + + # Reset for a new batch + current_plain_inputs = [] + current_tokenized_inputs = [] + current_sent_start = [] + current_token_count = 0 + + # the last batch + if current_plain_inputs: + packed_plain_inputs.append(' '.join(current_plain_inputs)) + packed_tokenized_inputs.append(current_tokenized_inputs) + packed_sent_starts.append(current_sent_start[0]) + + return packed_plain_inputs, packed_tokenized_inputs, packed_sent_starts + # Read inputs and open output F_ASCII = task_options.get("ensure_ascii", False) with clarin_json.open(output_path, "w", ensure_ascii=F_ASCII) as fout: @@ -58,15 +95,20 @@ class WinerWorker(nlp_ws.NLPWorker): plain_inputs, tokenized_inputs, sent_starts = \ get_sentences_from_document(document) + # Pack sentences up to max tokens + packed_plain_texts, packed_tokenized_inputs, packed_sent_starts = pack_sentences_to_max_tokens( + plain_inputs, tokenized_inputs, sent_starts, max_tokens=512 + ) + # Process data - aggregations = self._model.process(tokenized_inputs) + aggregations = nlp.process(packed_tokenized_inputs) # Aggregate results to entities entities = allign_into_entities( - tokenized_inputs, + packed_tokenized_inputs, aggregations, - inputs_sentences=plain_inputs, - sentences_range=sent_starts + inputs_sentences=packed_plain_texts, + sentences_range=packed_sent_starts ) document.set_spans( [Span(**entity)