"""Implementation of punctuator service""" from winer import Winer from winer import allign_into_entities, allign_into_words from winer.io import create_document_from_clarin_json, create_entities_from_hf_outputs from winer.io import read_clarin_json, write_clarin_json import logging class WinerWorker: def __init__( self, models_location: str, ): logging.info("Loading models...") self._model = Winer(models_location) def process( self, input_path: str, task_options: dict, output_path: str ) -> None: # Read inputs documents = [create_document_from_clarin_json(read_clarin_json(input_path))] tok_inputs = [document.get_pretokenized_text() for document in documents] plain_inputs = [document.get_text() for document in documents] # Predicion encoded_inputs = self._model.tokenize(tok_inputs) predictions_per_token = self._model.predict(encoded_inputs) # Aggregation aggregations = [allign_into_words(tokens, labels) for tokens, labels in zip( [encoded_inputs.word_ids(index) for index in range(len(predictions_per_token))], predictions_per_token)] entities = allign_into_entities(tok_inputs, aggregations, inputs_sentences=plain_inputs) for idx in range(len(documents)): documents[idx].add_entites(create_entities_from_hf_outputs(entities[idx])) write_clarin_json(documents[0].as_clarin_json(), output_path)