Select Git revision
FindLibXML2.cmake
winer_worker.py 5.93 KiB
"""Winer worker implementation."""
from winer import Winer, allign_into_entities, get_sentences_from_document
import clarin_json
from clarin_json.containers import Span
import gc
import torch
import nlp_ws
import logging
_log = logging.getLogger(__name__)
class WinerWorker(nlp_ws.NLPWorker):
"""Class implementing Winer worker."""
@classmethod
def static_init(cls, config):
"""Initialize process.
:param config: configuration from ini file
:type config: dict
"""
cls.models_location = config["tool"].get("models_cache_dir",
"/home/worker/models/base")
cls.batch_size = int(config["tool"].get("batch_size", "128"))
def __init__(self):
"""Constructor."""
_log.info("Loading models...")
self._model = Winer(self.models_location,
batch_size=self.batch_size)
self._model._tokenizer.model_max_length = 512
@staticmethod
def pack_sentences_to_max_tokens(
plain_inputs: list,
tokenized_inputs: list,
sent_starts: list,
max_tokens: int = 510
):
"""Pack sentences without exceeding a maximum token count.
This method takes in plain text sentences, their tokenized
versions, and sentence start indices,
and it creates batches of sentences such that the total
token count per batch does not exceed the
specified max_tokens limit.
Args:
plain_inputs (list): List of plain text sentences.
tokenized_inputs (list): List of tokenized versions
of the sentences.
sent_starts (list): List of sentence start indices.
max_tokens (int, optional): The maximum number of
tokens allowed per chunk. Defaults to 510.
Returns:
tuple:
- packed_plain_inputs (list): List of packed
plain text sentences, where each element
is a chunk of sentences.
- packed_tokenized_inputs (list): List of packed
tokenized inputs corresponding to the plain
text chunks.
- packed_sent_starts (list): List of sentence
start indices for each chunk.
"""
packed_plain_inputs = []
packed_tokenized_inputs = []
packed_sent_starts = []
current_plain_inputs = []
current_tokenized_inputs = []
current_sent_start = []
current_token_count = 0
for sentence, sentence_tokens, sent_start in zip(
plain_inputs, tokenized_inputs, sent_starts):
if current_token_count + len(sentence_tokens) <= max_tokens:
current_plain_inputs.append(sentence)
current_tokenized_inputs.extend(sentence_tokens)
current_sent_start.append(sent_start)
current_token_count += len(sentence_tokens)
else:
packed_plain_inputs.append(' '.join(current_plain_inputs))
packed_tokenized_inputs.append(current_tokenized_inputs)
packed_sent_starts.append(current_sent_start[0])
# Reset for a new batch
current_plain_inputs = []
current_tokenized_inputs = []
current_sent_start = []
current_token_count = 0
# the last batch
if current_plain_inputs:
packed_plain_inputs.append(' '.join(current_plain_inputs))
packed_tokenized_inputs.append(current_tokenized_inputs)
packed_sent_starts.append(current_sent_start[0])
return packed_plain_inputs, packed_tokenized_inputs, packed_sent_starts
def process(
self,
input_path: str,
task_options: dict,
output_path: str
) -> None:
"""Called for each request made to the worker.
:param input_path: Path to a file with text documents from which
the worker should read text.
:type input_path: str
:param task_options: no task options
:type task_options: dict
:param output_path: Path to directory where the
worker will store result.
:type output_path: str
"""
# Read inputs and open output
F_ASCII = task_options.get("ensure_ascii", False)
with clarin_json.open(output_path, "w", ensure_ascii=F_ASCII) as fout:
with clarin_json.open(input_path, "r") as fin:
for document in fin:
# Pre-process document
plain_inputs, tokenized_inputs, sent_starts = \
get_sentences_from_document(document)
(
packed_plain_texts,
packed_tokenized_inputs,
packed_sent_starts
) = (
self.pack_sentences_to_max_tokens(
plain_inputs,
tokenized_inputs,
sent_starts,
max_tokens=510
)
)
# Process data
aggregations = self._model.process(packed_tokenized_inputs)
# Aggregate results to entities
entities = allign_into_entities(
packed_tokenized_inputs,
aggregations,
inputs_sentences=packed_plain_texts,
sentences_range=packed_sent_starts
)
document.set_spans(
[Span(**entity)
for sent_entities in entities
for entity in sent_entities],
"ner")
# Write processed document
fout.write(document)
# Clean the memory
gc.collect()
torch.cuda.empty_cache()