diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index a674b422558321cbcf728371af710ed6ae3dc143..7c25d257e438ea1bf0d34b710b9bbe3d3601e2ce 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -6,7 +6,6 @@ cache: stages: - check_style - - testing - build - deploy @@ -18,11 +17,6 @@ pep8: script: - tox -v -e pep8 -unittest: - stage: testing - script: - - tox -v -e unittest - build_image: stage: build image: 'docker:18.09.7' diff --git a/entrypoint.py b/entrypoint.py index ed2b810ea1806d0227ca5eb572a628284a8a5480..923f19c8b1308662e247e53c55a9355b4c8f29d4 100755 --- a/entrypoint.py +++ b/entrypoint.py @@ -5,13 +5,18 @@ import configparser import sys parser = configparser.ConfigParser() -parser.read('config.ini') +parser.read("config.ini") -s3_endpoint = parser['deployment'].get('s3_endpoint', 'https://s3.clarin-pl.eu') -s3_location = parser["deployment"].get("models_s3_location", "s3://workers/punctuator/models_2_0") +s3_endpoint = parser["deployment"].get("s3_endpoint", "https://s3.clarin-pl.eu") +s3_location = parser["deployment"].get( + "models_s3_location", "s3://workers/punctuator/models_2_0" +) local_models_location = parser["deployment"].get("models_cache_dir", "/tmp/models") -cmd = f"aws --no-sign-request --endpoint-url \"{s3_endpoint}\" s3 sync --delete \"{s3_location}\" \"{local_models_location}\"" +cmd = ( + f'aws --no-sign-request --endpoint-url "{s3_endpoint}" s3 sync --delete' + f' "{s3_location}" "{local_models_location}"' +) run(cmd, shell=True) -run(["python", "worker.py"] + sys.argv[1:]) \ No newline at end of file +run(["python", "worker.py"] + sys.argv[1:]) diff --git a/src/language_tool.py b/src/language_tool.py index 607a91a92eb92a8a8b0b5cfe19ab2e1170e25c62..abb47a44ef6e4a6c977141b5f627b816ab6a756c 100644 --- a/src/language_tool.py +++ b/src/language_tool.py @@ -6,24 +6,23 @@ class LanguageToolFixer: def __init__(self, lt_cache: str): self.lt_cache = lt_cache os.environ["LTP_PATH"] = lt_cache - + self.tools = { - 'pl': language_tool_python.LanguageTool('pl-PL'), - 'en': language_tool_python.LanguageTool('en-US'), - 'ru': language_tool_python.LanguageTool('ru-RU'), + "pl": language_tool_python.LanguageTool("pl-PL"), + "en": language_tool_python.LanguageTool("en-US"), + "ru": language_tool_python.LanguageTool("ru-RU"), } - + def _post_process(self, text: str, tool): matches = tool.check(text) matches = [rule for rule in matches if not self._is_punctuation_rule(rule)] return language_tool_python.utils.correct(text, matches) def _is_punctuation_rule(self, rule): - lambda rule: rule.category != "PUNCTUATION" and len(rule.replacements) - + lambda rule: rule.category != "PUNCTUATION" and len(rule.replacements) + def fix_punctuation(self, text: str, lang: str) -> str: if lang in self.tools.keys(): return self._post_process(text, self.tools[lang]) else: return text - \ No newline at end of file diff --git a/src/punctuator_worker.py b/src/punctuator_worker.py index a0a1bd8a7ce48080e9c5f705581e45f400ae06ac..4c69777587b5cdff35268ccd0befdd415271ef3e 100644 --- a/src/punctuator_worker.py +++ b/src/punctuator_worker.py @@ -9,24 +9,27 @@ from punctuator import Punctuator from src.language_tool import LanguageToolFixer import logging -class PunctuatorWorker: - DEFAULT_MODEL = 'pl' - - def __init__(self, - models_location: str, - languagetool_location: Optional[str], - max_context_size: int = 256, - overlap: int = 20, - device: str = 'cpu', - ): - logging.info('Loading models...') +class PunctuatorWorker: + DEFAULT_MODEL = "pl" + + def __init__( + self, + models_location: str, + languagetool_location: Optional[str], + max_context_size: int = 256, + overlap: int = 20, + device: str = "cpu", + ): + + logging.info("Loading models...") self.models = { Path(language_model_dir).stem: Punctuator( - language_model_dir, max_context_size, overlap) + language_model_dir, max_context_size, overlap + ) for language_model_dir in glob(models_location + "/*") } - + if languagetool_location is not None: self.lt = LanguageToolFixer(languagetool_location) else: @@ -53,7 +56,7 @@ class PunctuatorWorker: f.write(punctuated_text) def _set_active_model(self, model_language): - self.models[self.active_model].to('cpu') + self.models[self.active_model].to("cpu") self.models[model_language].to(self.device) self.active_model = model_language diff --git a/tox.ini b/tox.ini index 5de4ec618b54b8d3d04492c1e2d7031974480430..29e0f9793f7cbc750d87774ceb06ff77173ecc0e 100644 --- a/tox.ini +++ b/tox.ini @@ -2,16 +2,10 @@ envlist = unittest,pep8 skipsdist = True -[testenv] -deps = -rrequirements.txt - pytest >= 6.0.1 - -[testenv:unittest] -commands = pytest - [flake8] exclude = - .tox, + venv, + .tox, .git, __pycache__, docs/source/conf.py, diff --git a/worker.py b/worker.py index f20e17430b732a55ab25904eb3ce9a16d8dc7513..33109baaee5e3a873efc7ce883a8984fe7880d97 100644 --- a/worker.py +++ b/worker.py @@ -13,19 +13,17 @@ class Worker(nlp_ws.NLPWorker): config = configparser.ConfigParser() config.read("config.ini") config = config["deployment"] - - models_cache_dir = config.get("models_cache_dir", '/home/worker/models') - languagetool_cache_dir = config.get("languagetool_cache_dir", '/home/worker/languagetool') - max_context_size = int(config.get("max_context_size", '256')) - overlap = int(config.get("overlap", '20')) - device = config.get("device", 'cpu') - + + models_cache_dir = config.get("models_cache_dir", "/home/worker/models") + languagetool_cache_dir = config.get( + "languagetool_cache_dir", "/home/worker/languagetool" + ) + max_context_size = int(config.get("max_context_size", "256")) + overlap = int(config.get("overlap", "20")) + device = config.get("device", "cpu") + self.punctuator = PunctuatorWorker( - models_cache_dir, - languagetool_cache_dir, - max_context_size, - overlap, - device + models_cache_dir, languagetool_cache_dir, max_context_size, overlap, device ) def process(self, input_path: str, task_options: dict, output_path: str) -> None: @@ -36,31 +34,23 @@ def perform_fast_test(): config = configparser.ConfigParser() config.read("config.ini") config = config["deployment"] - - models_cache_dir = config.get("models_cache_dir", '/home/worker/models') - languagetool_cache_dir = config.get("languagetool_cache_dir", '/home/worker/languagetool') - max_context_size = int(config.get("max_context_size", '256')) - overlap = int(config.get("overlap", '20')) - device = config.get("device", 'cpu') - - punctuator = PunctuatorWorker( - models_cache_dir, - languagetool_cache_dir, - max_context_size, - overlap, - device - ) - punctuator.process( - "/test/input/pl.txt", {"language": "pl"}, "/test/output/pl.txt" + models_cache_dir = config.get("models_cache_dir", "/home/worker/models") + languagetool_cache_dir = config.get( + "languagetool_cache_dir", "/home/worker/languagetool" ) - punctuator.process( - "/test/input/en.txt", {"language": "en"}, "/test/output/en.txt" - ) - punctuator.process( - "/test/input/ru.txt", {"language": "ru"}, "/test/output/ru.txt" + max_context_size = int(config.get("max_context_size", "256")) + overlap = int(config.get("overlap", "20")) + device = config.get("device", "cpu") + + punctuator = PunctuatorWorker( + models_cache_dir, languagetool_cache_dir, max_context_size, overlap, device ) + punctuator.process("/test/input/pl.txt", {"language": "pl"}, "/test/output/pl.txt") + punctuator.process("/test/input/en.txt", {"language": "en"}, "/test/output/en.txt") + punctuator.process("/test/input/ru.txt", {"language": "ru"}, "/test/output/ru.txt") + if __name__ == "__main__": parser = argparse.ArgumentParser()