"""Implementation of punctuator service"""

from winer import Winer
from winer import allign_into_entities, allign_into_words
from winer.io import create_document_from_clarin_json, create_entities_from_hf_outputs
from winer.io import read_clarin_json, write_clarin_json

import logging


class WinerWorker:

    def __init__(
        self,
        models_location: str,
    ):

        logging.info("Loading models...")
        self._model = Winer(models_location)

    def process(
        self,
        input_path: str,
        task_options: dict,
        output_path: str
    ) -> None:
        # Read inputs
        documents = [create_document_from_clarin_json(read_clarin_json(input_path))]
        tok_inputs = [document.get_pretokenized_text() for document in documents]
        plain_inputs = [document.get_text() for document in documents]

        # Predicion
        encoded_inputs = self._model.tokenize(tok_inputs)
        predictions_per_token = self._model.predict(encoded_inputs)

        # Aggregation
        aggregations = [allign_into_words(tokens, labels)
                        for tokens, labels in zip(
                            [encoded_inputs.word_ids(index)
                             for index in range(len(predictions_per_token))],
                            predictions_per_token)]
        entities = allign_into_entities(tok_inputs, aggregations,
                                        inputs_sentences=plain_inputs)

        for idx in range(len(documents)):
            documents[idx].add_entites(create_entities_from_hf_outputs(entities[idx]))

        write_clarin_json(documents[0].as_clarin_json(), output_path)