"""Implementation of punctuator service""" from glob import glob from pathlib import Path from typing import Optional from punctuator import Punctuator from src.language_tool import LanguageToolFixer import logging class PunctuatorWorker: DEFAULT_MODEL = "pl" def __init__( self, models_location: str, languagetool_location: Optional[str], max_context_size: int = 256, overlap: int = 20, device: str = "cpu", ): logging.info("Loading models...") self.models = { Path(language_model_dir).stem: Punctuator( language_model_dir, max_context_size, overlap ) for language_model_dir in glob(models_location + "/*") } if languagetool_location is not None: self.lt = LanguageToolFixer(languagetool_location) else: self.lt = None self.device = device self.active_model = self.DEFAULT_MODEL def process(self, input_path: str, task_options: dict, output_path: str) -> None: language = task_options.get("language", self.DEFAULT_MODEL) if self.active_model != language: self._set_active_model(language) with open(input_path, "r") as f: text = f.read() punctuated_text = self.models[self.active_model].punctuate(text) if self.lt is not None: punctuated_text = self.lt.fix_punctuation(punctuated_text, language) with open(output_path, "w") as f: f.write(punctuated_text) def _set_active_model(self, model_language): self.models[self.active_model].to("cpu") self.models[model_language].to(self.device) self.active_model = model_language