From 6123b85e1c88931572a31ad28c013d30a3957e83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pogoda?= <mipo57@e-science.pl> Date: Tue, 30 Nov 2021 12:58:43 +0100 Subject: [PATCH] Style changes --- .gitlab-ci.yml | 6 ----- entrypoint.py | 15 +++++++---- src/language_tool.py | 15 +++++------ src/punctuator_worker.py | 31 ++++++++++++---------- tox.ini | 10 ++----- worker.py | 56 +++++++++++++++++----------------------- 6 files changed, 59 insertions(+), 74 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index a674b42..7c25d25 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -6,7 +6,6 @@ cache: stages: - check_style - - testing - build - deploy @@ -18,11 +17,6 @@ pep8: script: - tox -v -e pep8 -unittest: - stage: testing - script: - - tox -v -e unittest - build_image: stage: build image: 'docker:18.09.7' diff --git a/entrypoint.py b/entrypoint.py index ed2b810..923f19c 100755 --- a/entrypoint.py +++ b/entrypoint.py @@ -5,13 +5,18 @@ import configparser import sys parser = configparser.ConfigParser() -parser.read('config.ini') +parser.read("config.ini") -s3_endpoint = parser['deployment'].get('s3_endpoint', 'https://s3.clarin-pl.eu') -s3_location = parser["deployment"].get("models_s3_location", "s3://workers/punctuator/models_2_0") +s3_endpoint = parser["deployment"].get("s3_endpoint", "https://s3.clarin-pl.eu") +s3_location = parser["deployment"].get( + "models_s3_location", "s3://workers/punctuator/models_2_0" +) local_models_location = parser["deployment"].get("models_cache_dir", "/tmp/models") -cmd = f"aws --no-sign-request --endpoint-url \"{s3_endpoint}\" s3 sync --delete \"{s3_location}\" \"{local_models_location}\"" +cmd = ( + f'aws --no-sign-request --endpoint-url "{s3_endpoint}" s3 sync --delete' + f' "{s3_location}" "{local_models_location}"' +) run(cmd, shell=True) -run(["python", "worker.py"] + sys.argv[1:]) \ No newline at end of file +run(["python", "worker.py"] + sys.argv[1:]) diff --git a/src/language_tool.py b/src/language_tool.py index 607a91a..abb47a4 100644 --- a/src/language_tool.py +++ b/src/language_tool.py @@ -6,24 +6,23 @@ class LanguageToolFixer: def __init__(self, lt_cache: str): self.lt_cache = lt_cache os.environ["LTP_PATH"] = lt_cache - + self.tools = { - 'pl': language_tool_python.LanguageTool('pl-PL'), - 'en': language_tool_python.LanguageTool('en-US'), - 'ru': language_tool_python.LanguageTool('ru-RU'), + "pl": language_tool_python.LanguageTool("pl-PL"), + "en": language_tool_python.LanguageTool("en-US"), + "ru": language_tool_python.LanguageTool("ru-RU"), } - + def _post_process(self, text: str, tool): matches = tool.check(text) matches = [rule for rule in matches if not self._is_punctuation_rule(rule)] return language_tool_python.utils.correct(text, matches) def _is_punctuation_rule(self, rule): - lambda rule: rule.category != "PUNCTUATION" and len(rule.replacements) - + lambda rule: rule.category != "PUNCTUATION" and len(rule.replacements) + def fix_punctuation(self, text: str, lang: str) -> str: if lang in self.tools.keys(): return self._post_process(text, self.tools[lang]) else: return text - \ No newline at end of file diff --git a/src/punctuator_worker.py b/src/punctuator_worker.py index a0a1bd8..4c69777 100644 --- a/src/punctuator_worker.py +++ b/src/punctuator_worker.py @@ -9,24 +9,27 @@ from punctuator import Punctuator from src.language_tool import LanguageToolFixer import logging -class PunctuatorWorker: - DEFAULT_MODEL = 'pl' - - def __init__(self, - models_location: str, - languagetool_location: Optional[str], - max_context_size: int = 256, - overlap: int = 20, - device: str = 'cpu', - ): - logging.info('Loading models...') +class PunctuatorWorker: + DEFAULT_MODEL = "pl" + + def __init__( + self, + models_location: str, + languagetool_location: Optional[str], + max_context_size: int = 256, + overlap: int = 20, + device: str = "cpu", + ): + + logging.info("Loading models...") self.models = { Path(language_model_dir).stem: Punctuator( - language_model_dir, max_context_size, overlap) + language_model_dir, max_context_size, overlap + ) for language_model_dir in glob(models_location + "/*") } - + if languagetool_location is not None: self.lt = LanguageToolFixer(languagetool_location) else: @@ -53,7 +56,7 @@ class PunctuatorWorker: f.write(punctuated_text) def _set_active_model(self, model_language): - self.models[self.active_model].to('cpu') + self.models[self.active_model].to("cpu") self.models[model_language].to(self.device) self.active_model = model_language diff --git a/tox.ini b/tox.ini index 5de4ec6..29e0f97 100644 --- a/tox.ini +++ b/tox.ini @@ -2,16 +2,10 @@ envlist = unittest,pep8 skipsdist = True -[testenv] -deps = -rrequirements.txt - pytest >= 6.0.1 - -[testenv:unittest] -commands = pytest - [flake8] exclude = - .tox, + venv, + .tox, .git, __pycache__, docs/source/conf.py, diff --git a/worker.py b/worker.py index f20e174..33109ba 100644 --- a/worker.py +++ b/worker.py @@ -13,19 +13,17 @@ class Worker(nlp_ws.NLPWorker): config = configparser.ConfigParser() config.read("config.ini") config = config["deployment"] - - models_cache_dir = config.get("models_cache_dir", '/home/worker/models') - languagetool_cache_dir = config.get("languagetool_cache_dir", '/home/worker/languagetool') - max_context_size = int(config.get("max_context_size", '256')) - overlap = int(config.get("overlap", '20')) - device = config.get("device", 'cpu') - + + models_cache_dir = config.get("models_cache_dir", "/home/worker/models") + languagetool_cache_dir = config.get( + "languagetool_cache_dir", "/home/worker/languagetool" + ) + max_context_size = int(config.get("max_context_size", "256")) + overlap = int(config.get("overlap", "20")) + device = config.get("device", "cpu") + self.punctuator = PunctuatorWorker( - models_cache_dir, - languagetool_cache_dir, - max_context_size, - overlap, - device + models_cache_dir, languagetool_cache_dir, max_context_size, overlap, device ) def process(self, input_path: str, task_options: dict, output_path: str) -> None: @@ -36,31 +34,23 @@ def perform_fast_test(): config = configparser.ConfigParser() config.read("config.ini") config = config["deployment"] - - models_cache_dir = config.get("models_cache_dir", '/home/worker/models') - languagetool_cache_dir = config.get("languagetool_cache_dir", '/home/worker/languagetool') - max_context_size = int(config.get("max_context_size", '256')) - overlap = int(config.get("overlap", '20')) - device = config.get("device", 'cpu') - - punctuator = PunctuatorWorker( - models_cache_dir, - languagetool_cache_dir, - max_context_size, - overlap, - device - ) - punctuator.process( - "/test/input/pl.txt", {"language": "pl"}, "/test/output/pl.txt" + models_cache_dir = config.get("models_cache_dir", "/home/worker/models") + languagetool_cache_dir = config.get( + "languagetool_cache_dir", "/home/worker/languagetool" ) - punctuator.process( - "/test/input/en.txt", {"language": "en"}, "/test/output/en.txt" - ) - punctuator.process( - "/test/input/ru.txt", {"language": "ru"}, "/test/output/ru.txt" + max_context_size = int(config.get("max_context_size", "256")) + overlap = int(config.get("overlap", "20")) + device = config.get("device", "cpu") + + punctuator = PunctuatorWorker( + models_cache_dir, languagetool_cache_dir, max_context_size, overlap, device ) + punctuator.process("/test/input/pl.txt", {"language": "pl"}, "/test/output/pl.txt") + punctuator.process("/test/input/en.txt", {"language": "en"}, "/test/output/en.txt") + punctuator.process("/test/input/ru.txt", {"language": "ru"}, "/test/output/ru.txt") + if __name__ == "__main__": parser = argparse.ArgumentParser() -- GitLab