Newer
Older
"""Implementation of punctuator service"""
from winer.datafiles import read_clarin_json, write_clarin_json
from winer.document import create_document_from_clarin_json, \
create_entities_from_hf_outputs
from winer.winer import Winer
import logging
class WinerWorker:
DEFAULT_MODEL = "dummy"
def __init__(
self,
models_location: str,
):
logging.info("Loading models...")
self.active_model = Winer(f'{models_location}/{self.DEFAULT_MODEL}')
def process(
self,
input_path: str,
task_options: dict,
output_path: str
) -> None:
documents = [create_document_from_clarin_json(read_clarin_json(input_path))]
outputs = self.active_model.predict(
[document.get_pretokenized_text() for document in documents]
)
for idx in range(len(documents)):
documents[idx].add_entites(create_entities_from_hf_outputs(outputs[idx]))
write_clarin_json(documents[0].as_clarin_json(), output_path)