"""Implementation of punctuator service"""

from glob import glob
from pathlib import Path
from typing import Optional

from punctuator import Punctuator

from src.language_tool import LanguageToolFixer
import logging


class PunctuatorWorker:
    DEFAULT_MODEL = "pl"

    def __init__(
        self,
        models_location: str,
        languagetool_location: Optional[str],
        max_context_size: int = 256,
        overlap: int = 20,
        device: str = "cpu",
    ):

        logging.info("Loading models...")
        self.models = {
            Path(language_model_dir).stem: Punctuator(
                language_model_dir, max_context_size, overlap
            )
            for language_model_dir in glob(models_location + "/*")
        }

        if languagetool_location is not None:
            self.lt = LanguageToolFixer(languagetool_location)
        else:
            self.lt = None

        self.device = device
        self.active_model = self.DEFAULT_MODEL

    def process(self, input_path: str, task_options: dict, output_path: str) -> None:
        language = task_options.get("language", self.DEFAULT_MODEL)

        if self.active_model != language:
            self._set_active_model(language)

        with open(input_path, "r") as f:
            text = f.read()

        punctuated_text = self.models[self.active_model].punctuate(text)

        if self.lt is not None:
            punctuated_text = self.lt.fix_punctuation(punctuated_text, language)

        with open(output_path, "w") as f:
            f.write(punctuated_text)

    def _set_active_model(self, model_language):
        self.models[self.active_model].to("cpu")
        self.models[model_language].to(self.device)

        self.active_model = model_language