Skip to content
Snippets Groups Projects
winer_worker.py 1.01 KiB
Newer Older
Wiktor Walentynowicz's avatar
Wiktor Walentynowicz committed
"""Implementation of punctuator service"""

from winer.datafiles import read_clarin_json, write_clarin_json
from winer.document import create_document_from_clarin_json, \
    create_entities_from_hf_outputs
from winer.winer import Winer

import logging


class WinerWorker:
    DEFAULT_MODEL = "dummy"

    def __init__(
        self,
        models_location: str,
    ):

        logging.info("Loading models...")
        self.active_model = Winer(f'{models_location}/{self.DEFAULT_MODEL}')

    def process(
        self,
        input_path: str,
        task_options: dict,
        output_path: str
    ) -> None:
        documents = [create_document_from_clarin_json(read_clarin_json(input_path))]
        outputs = self.active_model.predict(
            [document.get_pretokenized_text() for document in documents]
        )

        for idx in range(len(documents)):
            documents[idx].add_entites(create_entities_from_hf_outputs(outputs[idx]))

        write_clarin_json(documents[0].as_clarin_json(), output_path)