Skip to content
Snippets Groups Projects
Commit b03b8e20 authored by Marek Maziarz's avatar Marek Maziarz
Browse files

added sentence packing to the predict function, added pack_sentences_to_max_tokens method

parent 0bfcbd18
No related merge requests found
Pipeline #19502 failed with stages
in 36 seconds
...@@ -49,6 +49,43 @@ class WinerWorker(nlp_ws.NLPWorker): ...@@ -49,6 +49,43 @@ class WinerWorker(nlp_ws.NLPWorker):
worker will store result. worker will store result.
:type output_path: str :type output_path: str
""" """
def pack_sentences_to_max_tokens(plain_inputs, tokenized_inputs, sent_starts, max_tokens=512):
packed_plain_inputs = []
packed_tokenized_inputs = []
packed_sent_starts = []
current_plain_inputs = []
current_tokenized_inputs = []
current_sent_start = []
current_token_count = 0
for sentence, sentence_tokens, sent_start in zip(plain_inputs, tokenized_inputs, sent_starts):
if current_token_count + len(sentence_tokens) <= max_tokens:
current_plain_inputs.append(sentence)
current_tokenized_inputs.extend(sentence_tokens)
current_sent_start.append(sent_start)
current_token_count += len(sentence_tokens)
else:
packed_plain_inputs.append(' '.join(current_plain_inputs))
packed_tokenized_inputs.append(current_tokenized_inputs)
packed_sent_starts.append(current_sent_start[0])
# Reset for a new batch
current_plain_inputs = []
current_tokenized_inputs = []
current_sent_start = []
current_token_count = 0
# the last batch
if current_plain_inputs:
packed_plain_inputs.append(' '.join(current_plain_inputs))
packed_tokenized_inputs.append(current_tokenized_inputs)
packed_sent_starts.append(current_sent_start[0])
return packed_plain_inputs, packed_tokenized_inputs, packed_sent_starts
# Read inputs and open output # Read inputs and open output
F_ASCII = task_options.get("ensure_ascii", False) F_ASCII = task_options.get("ensure_ascii", False)
with clarin_json.open(output_path, "w", ensure_ascii=F_ASCII) as fout: with clarin_json.open(output_path, "w", ensure_ascii=F_ASCII) as fout:
...@@ -58,15 +95,20 @@ class WinerWorker(nlp_ws.NLPWorker): ...@@ -58,15 +95,20 @@ class WinerWorker(nlp_ws.NLPWorker):
plain_inputs, tokenized_inputs, sent_starts = \ plain_inputs, tokenized_inputs, sent_starts = \
get_sentences_from_document(document) get_sentences_from_document(document)
# Pack sentences up to max tokens
packed_plain_texts, packed_tokenized_inputs, packed_sent_starts = pack_sentences_to_max_tokens(
plain_inputs, tokenized_inputs, sent_starts, max_tokens=512
)
# Process data # Process data
aggregations = self._model.process(tokenized_inputs) aggregations = nlp.process(packed_tokenized_inputs)
# Aggregate results to entities # Aggregate results to entities
entities = allign_into_entities( entities = allign_into_entities(
tokenized_inputs, packed_tokenized_inputs,
aggregations, aggregations,
inputs_sentences=plain_inputs, inputs_sentences=packed_plain_texts,
sentences_range=sent_starts sentences_range=packed_sent_starts
) )
document.set_spans( document.set_spans(
[Span(**entity) [Span(**entity)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment