diff --git a/.dockerignore b/.dockerignore index 36d93c891774467abefd4294f401bb9a7c04ccca..f644eaf9914b3d105651814fcf8164d8fa6debb4 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,3 +1,4 @@ -models -example_texts -.tox \ No newline at end of file +/data +/example_texts +/.tox +/venv \ No newline at end of file diff --git a/.dvc/.gitignore b/.dvc/.gitignore deleted file mode 100644 index 5ecbd4c37f40676b2d408f98f25b9e7c2a191d05..0000000000000000000000000000000000000000 --- a/.dvc/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -/config.local -/tmp diff --git a/.dvc/config b/.dvc/config deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/.dvc/plots/confusion.json b/.dvc/plots/confusion.json deleted file mode 100644 index af1b48d031a4c19d1b737243fc5e589e83d9fc2e..0000000000000000000000000000000000000000 --- a/.dvc/plots/confusion.json +++ /dev/null @@ -1,107 +0,0 @@ -{ - "$schema": "https://vega.github.io/schema/vega-lite/v4.json", - "data": { - "values": "<DVC_METRIC_DATA>" - }, - "title": "<DVC_METRIC_TITLE>", - "facet": { - "field": "rev", - "type": "nominal" - }, - "spec": { - "transform": [ - { - "aggregate": [ - { - "op": "count", - "as": "xy_count" - } - ], - "groupby": [ - "<DVC_METRIC_Y>", - "<DVC_METRIC_X>" - ] - }, - { - "impute": "xy_count", - "groupby": [ - "rev", - "<DVC_METRIC_Y>" - ], - "key": "<DVC_METRIC_X>", - "value": 0 - }, - { - "impute": "xy_count", - "groupby": [ - "rev", - "<DVC_METRIC_X>" - ], - "key": "<DVC_METRIC_Y>", - "value": 0 - }, - { - "joinaggregate": [ - { - "op": "max", - "field": "xy_count", - "as": "max_count" - } - ], - "groupby": [] - }, - { - "calculate": "datum.xy_count / datum.max_count", - "as": "percent_of_max" - } - ], - "encoding": { - "x": { - "field": "<DVC_METRIC_X>", - "type": "nominal", - "sort": "ascending", - "title": "<DVC_METRIC_X_LABEL>" - }, - "y": { - "field": "<DVC_METRIC_Y>", - "type": "nominal", - "sort": "ascending", - "title": "<DVC_METRIC_Y_LABEL>" - } - }, - "layer": [ - { - "mark": "rect", - "width": 300, - "height": 300, - "encoding": { - "color": { - "field": "xy_count", - "type": "quantitative", - "title": "", - "scale": { - "domainMin": 0, - "nice": true - } - } - } - }, - { - "mark": "text", - "encoding": { - "text": { - "field": "xy_count", - "type": "quantitative" - }, - "color": { - "condition": { - "test": "datum.percent_of_max > 0.5", - "value": "white" - }, - "value": "black" - } - } - } - ] - } -} diff --git a/.dvc/plots/confusion_normalized.json b/.dvc/plots/confusion_normalized.json deleted file mode 100644 index 1d38849f48f72b4a30e1f53eabcb519dd96f919a..0000000000000000000000000000000000000000 --- a/.dvc/plots/confusion_normalized.json +++ /dev/null @@ -1,112 +0,0 @@ -{ - "$schema": "https://vega.github.io/schema/vega-lite/v4.json", - "data": { - "values": "<DVC_METRIC_DATA>" - }, - "title": "<DVC_METRIC_TITLE>", - "facet": { - "field": "rev", - "type": "nominal" - }, - "spec": { - "transform": [ - { - "aggregate": [ - { - "op": "count", - "as": "xy_count" - } - ], - "groupby": [ - "<DVC_METRIC_Y>", - "<DVC_METRIC_X>" - ] - }, - { - "impute": "xy_count", - "groupby": [ - "rev", - "<DVC_METRIC_Y>" - ], - "key": "<DVC_METRIC_X>", - "value": 0 - }, - { - "impute": "xy_count", - "groupby": [ - "rev", - "<DVC_METRIC_X>" - ], - "key": "<DVC_METRIC_Y>", - "value": 0 - }, - { - "joinaggregate": [ - { - "op": "sum", - "field": "xy_count", - "as": "sum_y" - } - ], - "groupby": [ - "<DVC_METRIC_Y>" - ] - }, - { - "calculate": "datum.xy_count / datum.sum_y", - "as": "percent_of_y" - } - ], - "encoding": { - "x": { - "field": "<DVC_METRIC_X>", - "type": "nominal", - "sort": "ascending", - "title": "<DVC_METRIC_X_LABEL>" - }, - "y": { - "field": "<DVC_METRIC_Y>", - "type": "nominal", - "sort": "ascending", - "title": "<DVC_METRIC_Y_LABEL>" - } - }, - "layer": [ - { - "mark": "rect", - "width": 300, - "height": 300, - "encoding": { - "color": { - "field": "percent_of_y", - "type": "quantitative", - "title": "", - "scale": { - "domain": [ - 0, - 1 - ] - } - } - } - }, - { - "mark": "text", - "encoding": { - "text": { - "field": "percent_of_y", - "type": "quantitative", - "format": ".2f" - }, - "color": { - "condition": { - "test": "datum.percent_of_y > 0.5", - "value": "white" - }, - "value": "black" - } - } - } - ] - } -} diff --git a/.dvc/plots/default.json b/.dvc/plots/default.json deleted file mode 100644 index 9cf71ce0a2355e1706433dfbb6376c25f34e1b0b..0000000000000000000000000000000000000000 --- a/.dvc/plots/default.json +++ /dev/null @@ -1,31 +0,0 @@ -{ - "$schema": "https://vega.github.io/schema/vega-lite/v4.json", - "data": { - "values": "<DVC_METRIC_DATA>" - }, - "title": "<DVC_METRIC_TITLE>", - "width": 300, - "height": 300, - "mark": { - "type": "line" - }, - "encoding": { - "x": { - "field": "<DVC_METRIC_X>", - "type": "quantitative", - "title": "<DVC_METRIC_X_LABEL>" - }, - "y": { - "field": "<DVC_METRIC_Y>", - "type": "quantitative", - "title": "<DVC_METRIC_Y_LABEL>", - "scale": { - "zero": false - } - }, - "color": { - "field": "rev", - "type": "nominal" - } - } -} diff --git a/.dvc/plots/linear.json b/.dvc/plots/linear.json deleted file mode 100644 index 65549f9e01df5fe4cc7bb3c6fe90268e1656955d..0000000000000000000000000000000000000000 --- a/.dvc/plots/linear.json +++ /dev/null @@ -1,116 +0,0 @@ -{ - "$schema": "https://vega.github.io/schema/vega-lite/v4.json", - "data": { - "values": "<DVC_METRIC_DATA>" - }, - "title": "<DVC_METRIC_TITLE>", - "width": 300, - "height": 300, - "layer": [ - { - "encoding": { - "x": { - "field": "<DVC_METRIC_X>", - "type": "quantitative", - "title": "<DVC_METRIC_X_LABEL>" - }, - "y": { - "field": "<DVC_METRIC_Y>", - "type": "quantitative", - "title": "<DVC_METRIC_Y_LABEL>", - "scale": { - "zero": false - } - }, - "color": { - "field": "rev", - "type": "nominal" - } - }, - "layer": [ - { - "mark": "line" - }, - { - "selection": { - "label": { - "type": "single", - "nearest": true, - "on": "mouseover", - "encodings": [ - "x" - ], - "empty": "none", - "clear": "mouseout" - } - }, - "mark": "point", - "encoding": { - "opacity": { - "condition": { - "selection": "label", - "value": 1 - }, - "value": 0 - } - } - } - ] - }, - { - "transform": [ - { - "filter": { - "selection": "label" - } - } - ], - "layer": [ - { - "mark": { - "type": "rule", - "color": "gray" - }, - "encoding": { - "x": { - "field": "<DVC_METRIC_X>", - "type": "quantitative" - } - } - }, - { - "encoding": { - "text": { - "type": "quantitative", - "field": "<DVC_METRIC_Y>" - }, - "x": { - "field": "<DVC_METRIC_X>", - "type": "quantitative" - }, - "y": { - "field": "<DVC_METRIC_Y>", - "type": "quantitative" - } - }, - "layer": [ - { - "mark": { - "type": "text", - "align": "left", - "dx": 5, - "dy": -5 - }, - "encoding": { - "color": { - "type": "nominal", - "field": "rev" - } - } - } - ] - } - ] - } - ] -} diff --git a/.dvc/plots/scatter.json b/.dvc/plots/scatter.json deleted file mode 100644 index 9af9304c6470b376793a82f8196b514bbb4a6bc9..0000000000000000000000000000000000000000 --- a/.dvc/plots/scatter.json +++ /dev/null @@ -1,104 +0,0 @@ -{ - "$schema": "https://vega.github.io/schema/vega-lite/v4.json", - "data": { - "values": "<DVC_METRIC_DATA>" - }, - "title": "<DVC_METRIC_TITLE>", - "width": 300, - "height": 300, - "layer": [ - { - "encoding": { - "x": { - "field": "<DVC_METRIC_X>", - "type": "quantitative", - "title": "<DVC_METRIC_X_LABEL>" - }, - "y": { - "field": "<DVC_METRIC_Y>", - "type": "quantitative", - "title": "<DVC_METRIC_Y_LABEL>", - "scale": { - "zero": false - } - }, - "color": { - "field": "rev", - "type": "nominal" - } - }, - "layer": [ - { - "mark": "point" - }, - { - "selection": { - "label": { - "type": "single", - "nearest": true, - "on": "mouseover", - "encodings": [ - "x" - ], - "empty": "none", - "clear": "mouseout" - } - }, - "mark": "point", - "encoding": { - "opacity": { - "condition": { - "selection": "label", - "value": 1 - }, - "value": 0 - } - } - } - ] - }, - { - "transform": [ - { - "filter": { - "selection": "label" - } - } - ], - "layer": [ - { - "encoding": { - "text": { - "type": "quantitative", - "field": "<DVC_METRIC_Y>" - }, - "x": { - "field": "<DVC_METRIC_X>", - "type": "quantitative" - }, - "y": { - "field": "<DVC_METRIC_Y>", - "type": "quantitative" - } - }, - "layer": [ - { - "mark": { - "type": "text", - "align": "left", - "dx": 5, - "dy": -5 - }, - "encoding": { - "color": { - "type": "nominal", - "field": "rev" - } - } - } - ] - } - ] - } - ] -} diff --git a/.dvc/plots/smooth.json b/.dvc/plots/smooth.json deleted file mode 100644 index d497ce75e9e5375733781bd3c3b8b936b9bdec0b..0000000000000000000000000000000000000000 --- a/.dvc/plots/smooth.json +++ /dev/null @@ -1,39 +0,0 @@ -{ - "$schema": "https://vega.github.io/schema/vega-lite/v4.json", - "data": { - "values": "<DVC_METRIC_DATA>" - }, - "title": "<DVC_METRIC_TITLE>", - "mark": { - "type": "line" - }, - "encoding": { - "x": { - "field": "<DVC_METRIC_X>", - "type": "quantitative", - "title": "<DVC_METRIC_X_LABEL>" - }, - "y": { - "field": "<DVC_METRIC_Y>", - "type": "quantitative", - "title": "<DVC_METRIC_Y_LABEL>", - "scale": { - "zero": false - } - }, - "color": { - "field": "rev", - "type": "nominal" - } - }, - "transform": [ - { - "loess": "<DVC_METRIC_Y>", - "on": "<DVC_METRIC_X>", - "groupby": [ - "rev" - ], - "bandwidth": 0.3 - } - ] -} diff --git a/.gitignore b/.gitignore index 060439b3c3df8cfcce4571a0d19dcc64781e3158..0d050971bacb0b483a0ff3a5879b06c739181261 100644 --- a/.gitignore +++ b/.gitignore @@ -1,11 +1,5 @@ -/samba -/.pytest_cache -/.tox -/.vscode -/.env -/model -/config.test.ini -/wandb __pycache__ -/notebook.ipynb -/en +.tox +/data +/venv +/test.ipynb \ No newline at end of file diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index a674b422558321cbcf728371af710ed6ae3dc143..7c25d257e438ea1bf0d34b710b9bbe3d3601e2ce 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -6,7 +6,6 @@ cache: stages: - check_style - - testing - build - deploy @@ -18,11 +17,6 @@ pep8: script: - tox -v -e pep8 -unittest: - stage: testing - script: - - tox -v -e unittest - build_image: stage: build image: 'docker:18.09.7' diff --git a/Dockerfile b/Dockerfile index abaecf31c75a9a2cb4fc897bb88084f3940153e6..8d6ea09f823eb005e1f999a14b9dde4d613ac1cf 100644 --- a/Dockerfile +++ b/Dockerfile @@ -8,8 +8,6 @@ RUN pip install -r requirements.txt && rm requirements.txt RUN mkdir /workspace WORKDIR /workspace -RUN pip3 install --index-url https://pypi.clarin-pl.eu/simple/ nlp_ws==0.6 - # Install OpenJDK-8 RUN apt-get update && \ apt-get install -y openjdk-8-jdk && \ @@ -28,9 +26,9 @@ RUN export JAVA_HOME WORKDIR /home/worker -COPY punctuator punctuator -COPY entrypoint.sh entrypoint.sh +COPY src src +COPY entrypoint.py entrypoint.py COPY worker.py worker.py COPY config.ini config.ini -ENTRYPOINT ["bash", "entrypoint.sh"] +ENTRYPOINT ["python", "entrypoint.py"] diff --git a/README.md b/README.md index 28c0174b95c1fd1a7be61ba0b6703574fc16f57b..8cded2bf1554f5d98980a2947a6edbe903e8347c 100644 --- a/README.md +++ b/README.md @@ -42,8 +42,9 @@ docker build . -t punctuator ```bash docker run -it \ - -e PUNCTUATOR_TEST=TRUE \ - -v $(pwd)/example_texts:/samba \ - -v $(pwd)/models:/home/worker/models/punctuator \ + -v $(pwd)/example_texts:/test \ + -v $(pwd)/data/models:/home/worker/models \ + -v $(pwd)/data/lt:/home/worker/languagetool \ punctuator + --test ``` \ No newline at end of file diff --git a/config.ini b/config.ini index afe9243e0f8e7d9d39507e46e3cb13a9bab67fec..e00cefeb5db0f566b67638fc2f904a42b1f8c282 100644 --- a/config.ini +++ b/config.ini @@ -7,17 +7,17 @@ rabbit_password = $RABBIT_PASSWORD queue_prefix = nlp_ [tool] -workers_number = 1 +workers_number=1 [logging] -port = 9981 -local_log_level = INFO +port=9981 +local_log_level=INFO [deployment] -models_path_pl = /home/worker/models/punctuator/pl -model_path_en = /home/worker/models/punctuator/en -model_path_ru = /home/worker/models/punctuator/ru -languagetool_path = /home/worker/models/languagetool -max_context_size = 256 -overlap = 20 -device = cpu \ No newline at end of file +s3_endpoint = https://s3.clarin-pl.eu +models_s3_location=s3://workers/punctuator/models_2_0 +models_cache_dir=/home/worker/models +languagetool_cache_dir=/home/worker/languagetool +max_context_size=256 +overlap=20 +device=cpu \ No newline at end of file diff --git a/entrypoint.py b/entrypoint.py new file mode 100755 index 0000000000000000000000000000000000000000..923f19c8b1308662e247e53c55a9355b4c8f29d4 --- /dev/null +++ b/entrypoint.py @@ -0,0 +1,22 @@ +#!/usr/bin/python3 +from subprocess import run +import configparser + +import sys + +parser = configparser.ConfigParser() +parser.read("config.ini") + +s3_endpoint = parser["deployment"].get("s3_endpoint", "https://s3.clarin-pl.eu") +s3_location = parser["deployment"].get( + "models_s3_location", "s3://workers/punctuator/models_2_0" +) +local_models_location = parser["deployment"].get("models_cache_dir", "/tmp/models") + +cmd = ( + f'aws --no-sign-request --endpoint-url "{s3_endpoint}" s3 sync --delete' + f' "{s3_location}" "{local_models_location}"' +) +run(cmd, shell=True) + +run(["python", "worker.py"] + sys.argv[1:]) diff --git a/entrypoint.sh b/entrypoint.sh deleted file mode 100644 index 91b2c18274341e03ff169546798d73b46a0db7c6..0000000000000000000000000000000000000000 --- a/entrypoint.sh +++ /dev/null @@ -1,10 +0,0 @@ -#!/bin/bash - -aws --no-sign-request --endpoint-url "https://s3.clarin-pl.eu" s3 sync s3://workers/punctuator/models /home/worker/models/punctuator - -if [[ "$PUNCTUATOR_TEST" == "TRUE" ]] -then - python worker.py --test -else - python worker.py -fi \ No newline at end of file diff --git a/models/.gitignore b/models/.gitignore deleted file mode 100644 index bba1fb94e3defa4bb8e6fcf080c6242c86d942b7..0000000000000000000000000000000000000000 --- a/models/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -/en -/ru -/pl diff --git a/models/en.dvc b/models/en.dvc deleted file mode 100644 index ecb5bd46ae9fd7cb594ae862277ebf1d3a5c70af..0000000000000000000000000000000000000000 --- a/models/en.dvc +++ /dev/null @@ -1,12 +0,0 @@ -md5: 73f04ea7e7335101e783317fbe5189c7 -frozen: true -deps: -- path: data/models/en - repo: - url: git@gitlab.clarin-pl.eu:grupa-wieszcz/punctuator/models.git - rev_lock: 7d36e8cb5d372008ff7c99158eb91184203ec7f6 -outs: -- md5: e407614911d85095f0d3286b890b7b3a.dir - size: 1990793672 - nfiles: 24 - path: en diff --git a/models/pl.dvc b/models/pl.dvc deleted file mode 100644 index 83533ded6a57ba57fcb098bdb90203c4aa9813fc..0000000000000000000000000000000000000000 --- a/models/pl.dvc +++ /dev/null @@ -1,5 +0,0 @@ -outs: -- md5: da34f3b71f6e56a526ae1f0191300be5.dir - size: 527171457 - nfiles: 6 - path: pl diff --git a/models/ru.dvc b/models/ru.dvc deleted file mode 100644 index a720271de27b24c8a36a7b76b183842f2282e637..0000000000000000000000000000000000000000 --- a/models/ru.dvc +++ /dev/null @@ -1,12 +0,0 @@ -md5: 1f8156c987aaaa71f777fc23bce52d35 -frozen: true -deps: -- path: data/models/ru - repo: - url: git@gitlab.clarin-pl.eu:grupa-wieszcz/punctuator/models.git - rev_lock: 7d36e8cb5d372008ff7c99158eb91184203ec7f6 -outs: -- md5: db59bd054b718b0925e20fe16a12b473.dir - size: 2845417138 - nfiles: 22 - path: ru diff --git a/punctuator/punctuator.py b/punctuator/punctuator.py deleted file mode 100644 index a5fb967e987d2c227678253e5d8d932b28ac951f..0000000000000000000000000000000000000000 --- a/punctuator/punctuator.py +++ /dev/null @@ -1,147 +0,0 @@ -"""Implementation of punctuator service""" - -import json -import os -import string - -import language_tool_python -from transformers import AutoModelForTokenClassification, AutoTokenizer - -from punctuator.utils import (combine_masks, decode, decode_labels, - inference_masks) - - -def _preprocess_input(text: str): - text = text.translate(str.maketrans("", "", string.punctuation)) - text = text.lower() - - return text - - -def is_punctuation_rule(rule): - lambda rule: rule.category != "PUNCTUATION" and len(rule.replacements) - - -def _post_process(text: str, tool): - matches = tool.check(text) - matches = [rule for rule in matches if not is_punctuation_rule(rule)] - return language_tool_python.utils.correct(text, matches) - - -class Punctuator: - def __init__(self, config): - self.config = config - self.languagetool_map = { - "model_path_pl": "pl-PL", - "model_path_ru": "ru", - "model_path_en": "en-US", - } - self.max_context_size = int(self.config.get("max_context_size", 256)) - self.overlap = int(self.config.get("overlap", 20)) - - self.device = self.config.get("device", "cpu") - - self.languagetool_path = self.config.get( - "languagetool_path", "/home/worker/models/languagetool" - ) - os.environ["LTP_PATH"] = self.languagetool_path - self.model_path_pl = self.config.get( - "model_path_pl", "/home/worker/models/punctuator/pl" - ) - self.model_path_ru = self.config.get( - "model_path_ru", "/home/worker/models/punctuator/en" - ) - self.model_path_en = self.config.get( - "model_path_en", "/home/worker/models/punctuator/ru" - ) - ( - self.tool_pl, - self.model_pl, - self.tokenizer_pl, - self.mapping_pl, - ) = self._initialize_model("pl-PL", self.model_path_pl, self.device) - ( - self.tool_en, - self.model_en, - self.tokenizer_en, - self.mapping_en, - ) = self._initialize_model("en-US", self.model_path_en, "cpu") - ( - self.tool_ru, - self.model_ru, - self.tokenizer_ru, - self.mapping_ru, - ) = self._initialize_model("ru", self.model_path_ru, "cpu") - self.current_model = self.model_path_pl - - def process(self, input_path: str, task_options: dict, output_path: str) -> None: - language = task_options.get("language", "pl") - - if language == "en": - bpe = True - else: - bpe = False - - tool, model, tokenizer, mapping = self._get_setup_for_language(language) - - with open(input_path, "r") as f: - text = f.read() - - # Make sure that the text is lowercase & punctuationless - text = _preprocess_input(text) - - tokenized = tokenizer(text, return_tensors="pt") - - num_tokens = len(tokenized["input_ids"][0]) - - # TODO: Consider adding batching support - results = [] - for inference_mask, mask_mask in zip( - *inference_masks(num_tokens, self.max_context_size, self.overlap) - ): - result = model( - input_ids=tokenized["input_ids"][:, inference_mask].to(self.device), - attention_mask=tokenized["attention_mask"][:, inference_mask].to( - self.device - ), - ) - labels_ids = ( - result.logits.detach().cpu().argmax(dim=-1).squeeze().numpy()[mask_mask] - ) - results.append(decode_labels(labels_ids, mapping)) - labels = sum(results, []) - - tokens = [] - for combine_mask in combine_masks( - num_tokens, self.max_context_size, self.overlap - ): - tokens += tokenized["input_ids"][0, combine_mask].numpy().tolist() - - text_out = decode(tokens, labels, tokenizer, bpe) - text_out = _post_process(text_out, tool) - with open(output_path, "w") as f: - f.write(text_out) - - def _initialize_model(self, lang, model_path: str, device: str): - tool = language_tool_python.LanguageTool(lang) - model = AutoModelForTokenClassification.from_pretrained(model_path).to(device) - tokenizer = AutoTokenizer.from_pretrained(model_path) - mapping = {} - with open(f"{self.model_path_pl}/classes.json", "r") as f: - mapping = json.load(f) - mapping = list(mapping.keys()) - return tool, model, tokenizer, mapping - - def _get_setup_for_language(self, language): - if language == "ru": - return self.tool_ru, self.model_ru, self.tokenizer_ru, self.mapping_ru - elif language == "en": - return self.tool_en, self.model_en, self.tokenizer_en, self.mapping_en - else: - return self.tool_pl, self.model_pl, self.tokenizer_pl, self.mapping_pl - - def _pass_device(self, new_language): - _, current_model, _, _ = self._get_setup_for_language(self.current_model) - current_model.to("cpu") - _, current_model, _, _ = self._get_setup_for_language(new_language) - current_model.to(self.device) diff --git a/punctuator/utils.py b/punctuator/utils.py deleted file mode 100644 index 3d39316ec0eb044d3457fb9d4e97afc078ac9ee6..0000000000000000000000000000000000000000 --- a/punctuator/utils.py +++ /dev/null @@ -1,175 +0,0 @@ -from typing import List, Tuple -import numpy as np - - -def decode_labels(results, labels_map) -> List[str]: - """Converts labes from ids to text representations - - Args: - results (List[int]): List of ids of labels - labels_map (List[str]): List of classnames in order matching list of - ids - - Returns: - List[str]: List of classnames - """ - labels_decoded = list(map(lambda x: labels_map[x], results)) - - return labels_decoded - - -def decode(tokens, labels_decoded, tokenizer, bpe=False): - """Applies predictions to text in order to get punctuated text representation - - Args: - tokens (List[int]): List of token-ids - labels_decoded (List[str]): Per-token classnames - tokenizer: Huggingface tokenizer - - Returns: - str: Text with punctuation & casing applied - """ - text_recovered = [] - word = [] - word_end = "" - for label, token in zip(labels_decoded, tokens): - if bpe: - token_str = tokenizer.decode(token) - if token_str.startswith(" "): - token_str = token_str[1:] - else: - token_str = tokenizer.convert_ids_to_tokens([token])[0] - if token_str == "[PAD]": - break - if token_str.startswith("##"): - word.append(token_str.replace("##", "")) - else: - if len(word) > 0: - word.append(word_end) - text_recovered.append("".join(word)) - word = [] - if label.startswith("__ALL_UPPER__"): - # TODO: Make all uppercase - word.append(token_str[0].upper() + token_str[1:]) - elif label.startswith("__UPPER__"): - word.append(token_str[0].upper() + token_str[1:]) - else: - word.append(token_str) - - label = label.replace("__UPPER__", "") - label = label.replace("__ALL_UPPER__", "") - word_end = label - text_recovered.append("".join(word)) - if word_end != "": - text_recovered += word_end - return "".join(text_recovered) - - -def inference_masks( - num_tokens: int, max_len: int, overlap: int -) -> Tuple[List[List[bool]], List[List[bool]]]: - """Splits text that is to long for predicting. The function provide list - of masks for each prediction chunk - - Args: - num_tokens (int): Number of tokens, including CLS & SEP - max_len (int): Prediction window (must be less than 512) - overlap (int): Ammout of overlapping between chunking windows - - Returns: - Tuple[List[List[bool]], List[List[bool]]]: Masks for tokens provided - for inference & for result of inference - """ - if max_len >= num_tokens: - return ( - [[True] * num_tokens], - [[False] + [True] * (num_tokens - 2) + [False]], - ) - - # Account for CLS & SEP tokens - real_max_len = max_len - 2 - real_num_tokens = num_tokens - 2 - - step_size = real_max_len - 2 * overlap - - masks = [] - entries = [] - for start_id in range(0, real_num_tokens, step_size): - stop = False - if start_id == 0: - entry = ( - [True] - + [True] * real_max_len - + [False] * (real_num_tokens - real_max_len) - + [True] - ) - mask = [False] + [True] * (real_max_len - overlap) + [False] * (overlap + 1) - elif start_id + real_max_len >= real_num_tokens: - offset_start = real_num_tokens - real_max_len - entry = [True] + [False] * (offset_start) + [True] * real_max_len + [True] - mask = ( - [False] * (overlap + 1 + (start_id - offset_start)) - + [True] * (real_max_len - overlap - (start_id - offset_start)) - + [False] - ) - stop = True - else: - entry = ( - [True] - + [False] * start_id - + [True] * real_max_len - + [False] * (real_num_tokens - (start_id + real_max_len)) - + [True] - ) - mask = ( - [False] * (overlap + 1) - + [True] * (real_max_len - 2 * overlap) - + [False] * (overlap + 1) - ) - - masks.append(mask) - entries.append(entry) - - if stop: - break - - return entries, masks - - -def combine_masks(num_tokens: int, max_len: int, overlap: int) -> List[List[bool]]: - """Provides mask which tokens to take for each prediction. It makes sure - that each token is only taken once & scored by best chunk. - - Args: - num_tokens (int): Number of tokens, including CLS & SEP - max_len (int): Prediction window (must be less than 512) - overlap (int): Ammout of overlapping between chunking windows - - Returns: - List[List[bool]]: Token mask - """ - if max_len >= num_tokens: - return np.array([[False] + [True] * (num_tokens - 2) + [False]]) - - step_size = max_len - 2 - overlap - - entries = [] - for start in range(0, num_tokens - 2, step_size): - stop = False - - if start + max_len - 2 - overlap < num_tokens - 2: - entry = [False] + [False] * (start) + [True] * (max_len - 2 - overlap) - entry += [False] * (num_tokens - 2 - (start + max_len - 2 - overlap)) - entry += [False] - else: - entry = [False] + [False] * (start) - entry += [True] * (num_tokens - 2 - start) - entry += [False] - stop = True - - entries.append(entry) - - if stop: - break - - return entries diff --git a/requirements.txt b/requirements.txt index 3e07ff76b67026b919c81881430cceb0eca1af3f..934f534810c1abe32eab84eff3d465ad20dad1ec 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,7 @@ numpy==1.19.4 transformers==4.3.2 torch==1.7.1 language-tool-python==2.5.4 -awscli==1.20.11 \ No newline at end of file +awscli==1.20.11 +--index-url https://pypi.clarin-pl.eu/simple/ +nlp_ws +punctuator==2.0.3 \ No newline at end of file diff --git a/punctuator/__init__.py b/src/__init__.py similarity index 100% rename from punctuator/__init__.py rename to src/__init__.py diff --git a/src/language_tool.py b/src/language_tool.py new file mode 100644 index 0000000000000000000000000000000000000000..abb47a44ef6e4a6c977141b5f627b816ab6a756c --- /dev/null +++ b/src/language_tool.py @@ -0,0 +1,28 @@ +import os +import language_tool_python + + +class LanguageToolFixer: + def __init__(self, lt_cache: str): + self.lt_cache = lt_cache + os.environ["LTP_PATH"] = lt_cache + + self.tools = { + "pl": language_tool_python.LanguageTool("pl-PL"), + "en": language_tool_python.LanguageTool("en-US"), + "ru": language_tool_python.LanguageTool("ru-RU"), + } + + def _post_process(self, text: str, tool): + matches = tool.check(text) + matches = [rule for rule in matches if not self._is_punctuation_rule(rule)] + return language_tool_python.utils.correct(text, matches) + + def _is_punctuation_rule(self, rule): + lambda rule: rule.category != "PUNCTUATION" and len(rule.replacements) + + def fix_punctuation(self, text: str, lang: str) -> str: + if lang in self.tools.keys(): + return self._post_process(text, self.tools[lang]) + else: + return text diff --git a/src/punctuator_worker.py b/src/punctuator_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..4c69777587b5cdff35268ccd0befdd415271ef3e --- /dev/null +++ b/src/punctuator_worker.py @@ -0,0 +1,62 @@ +"""Implementation of punctuator service""" + +from glob import glob +from pathlib import Path +from typing import Optional + +from punctuator import Punctuator + +from src.language_tool import LanguageToolFixer +import logging + + +class PunctuatorWorker: + DEFAULT_MODEL = "pl" + + def __init__( + self, + models_location: str, + languagetool_location: Optional[str], + max_context_size: int = 256, + overlap: int = 20, + device: str = "cpu", + ): + + logging.info("Loading models...") + self.models = { + Path(language_model_dir).stem: Punctuator( + language_model_dir, max_context_size, overlap + ) + for language_model_dir in glob(models_location + "/*") + } + + if languagetool_location is not None: + self.lt = LanguageToolFixer(languagetool_location) + else: + self.lt = None + + self.device = device + self.active_model = self.DEFAULT_MODEL + + def process(self, input_path: str, task_options: dict, output_path: str) -> None: + language = task_options.get("language", self.DEFAULT_MODEL) + + if self.active_model != language: + self._set_active_model(language) + + with open(input_path, "r") as f: + text = f.read() + + punctuated_text = self.models[self.active_model].punctuate(text) + + if self.lt is not None: + punctuated_text = self.lt.fix_punctuation(punctuated_text, language) + + with open(output_path, "w") as f: + f.write(punctuated_text) + + def _set_active_model(self, model_language): + self.models[self.active_model].to("cpu") + self.models[model_language].to(self.device) + + self.active_model = model_language diff --git a/sync_to_s3.sh b/sync_to_s3.sh deleted file mode 100755 index 666b79d1ec7b27581fefe576126f136ef4362cbe..0000000000000000000000000000000000000000 --- a/sync_to_s3.sh +++ /dev/null @@ -1,5 +0,0 @@ -#!/bin/bash - -aws --endpoint-url "https://s3.clarin-pl.eu" s3 sync models/pl s3://workers/punctuator/models/pl -aws --endpoint-url "https://s3.clarin-pl.eu" s3 sync models/en s3://workers/punctuator/models/en -aws --endpoint-url "https://s3.clarin-pl.eu" s3 sync models/ru s3://workers/punctuator/models/ru \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/tests/test_chunking.py b/tests/test_chunking.py deleted file mode 100644 index 751722304d022648a8c372da3c907ad2d770236b..0000000000000000000000000000000000000000 --- a/tests/test_chunking.py +++ /dev/null @@ -1,113 +0,0 @@ -import numpy as np -from punctuator.punctuator import combine_masks, inference_masks - - -def test_inference_mask(): - T = True - F = False - - result, mask = inference_masks(11, 8, 2) - - assert np.all( - result - == np.array( - [ - [T, T, T, T, T, T, T, F, F, F, T], - [T, F, F, T, T, T, T, T, T, F, T], - [T, F, F, F, T, T, T, T, T, T, T], - ] - ) - ) - assert np.all( - mask - == np.array( - [ - [F, T, T, T, T, F, F, F], - [F, F, F, T, T, F, F, F], - [F, F, F, F, T, T, T, F], - ] - ) - ) - - result, mask = inference_masks(10, 8, 2) - assert np.all( - result - == np.array( - [ - [T, T, T, T, T, T, T, F, F, T], - [T, F, F, T, T, T, T, T, T, T], - ] - ) - ) - assert np.all( - mask - == np.array( - [ - [F, T, T, T, T, F, F, F], - [F, F, F, T, T, T, T, F], - ] - ) - ) - - result, mask = inference_masks(5, 8, 2) - assert np.all( - result - == np.array( - [ - [T, T, T, T, T], - ] - ) - ) - assert np.all(mask == np.array([[F, T, T, T, F]])) - - result, mask = inference_masks(10, 9, 3) - assert np.all( - result - == np.array( - [ - [T, T, T, T, T, T, T, T, F, T], - [T, F, T, T, T, T, T, T, T, T], - ] - ) - ) - assert np.all( - mask == np.array([[F, T, T, T, T, F, F, F, F], [F, F, F, F, T, T, T, T, F]]) - ) - - -def test_combine_mask(): - T = True - F = False - - result = combine_masks(11, 8, 2) - assert np.all( - result - == np.array( - [ - [F, T, T, T, T, F, F, F, F, F, F], - [F, F, F, F, F, T, T, T, T, F, F], - [F, F, F, F, F, F, F, F, F, T, F], - ] - ) - ) - - result = combine_masks(10, 8, 2) - assert np.all( - result - == np.array( - [ - [F, T, T, T, T, F, F, F, F, F], - [F, F, F, F, F, T, T, T, T, F], - ] - ) - ) - - result = combine_masks(5, 8, 2) - assert np.all( - result - == np.array( - [ - [F, T, T, T, F], - ] - ) - ) diff --git a/tox.ini b/tox.ini index 5de4ec618b54b8d3d04492c1e2d7031974480430..29e0f9793f7cbc750d87774ceb06ff77173ecc0e 100644 --- a/tox.ini +++ b/tox.ini @@ -2,16 +2,10 @@ envlist = unittest,pep8 skipsdist = True -[testenv] -deps = -rrequirements.txt - pytest >= 6.0.1 - -[testenv:unittest] -commands = pytest - [flake8] exclude = - .tox, + venv, + .tox, .git, __pycache__, docs/source/conf.py, diff --git a/worker.py b/worker.py index 6bb8c0c63ed286e1c6734ff8d2186b4caee2b20c..33109baaee5e3a873efc7ce883a8984fe7880d97 100644 --- a/worker.py +++ b/worker.py @@ -5,7 +5,7 @@ import configparser import nlp_ws -from punctuator.punctuator import Punctuator +from src.punctuator_worker import PunctuatorWorker class Worker(nlp_ws.NLPWorker): @@ -14,28 +14,43 @@ class Worker(nlp_ws.NLPWorker): config.read("config.ini") config = config["deployment"] - self.punctuator = Punctuator(config) + models_cache_dir = config.get("models_cache_dir", "/home/worker/models") + languagetool_cache_dir = config.get( + "languagetool_cache_dir", "/home/worker/languagetool" + ) + max_context_size = int(config.get("max_context_size", "256")) + overlap = int(config.get("overlap", "20")) + device = config.get("device", "cpu") + + self.punctuator = PunctuatorWorker( + models_cache_dir, languagetool_cache_dir, max_context_size, overlap, device + ) def process(self, input_path: str, task_options: dict, output_path: str) -> None: self.punctuator.process(input_path, task_options, output_path) -def perform_test(): +def perform_fast_test(): config = configparser.ConfigParser() config.read("config.ini") config = config["deployment"] - punctuator = Punctuator(config) - punctuator.process( - "/samba/input/pl.txt", {"language": "pl"}, "/samba/output/pl.txt" - ) - punctuator.process( - "/samba/input/en.txt", {"language": "en"}, "/samba/output/en.txt" + models_cache_dir = config.get("models_cache_dir", "/home/worker/models") + languagetool_cache_dir = config.get( + "languagetool_cache_dir", "/home/worker/languagetool" ) - punctuator.process( - "/samba/input/ru.txt", {"language": "ru"}, "/samba/output/ru.txt" + max_context_size = int(config.get("max_context_size", "256")) + overlap = int(config.get("overlap", "20")) + device = config.get("device", "cpu") + + punctuator = PunctuatorWorker( + models_cache_dir, languagetool_cache_dir, max_context_size, overlap, device ) + punctuator.process("/test/input/pl.txt", {"language": "pl"}, "/test/output/pl.txt") + punctuator.process("/test/input/en.txt", {"language": "en"}, "/test/output/en.txt") + punctuator.process("/test/input/ru.txt", {"language": "ru"}, "/test/output/ru.txt") + if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -43,6 +58,6 @@ if __name__ == "__main__": args = parser.parse_args() if args.test: - perform_test() + perform_fast_test() else: nlp_ws.NLPService.main(Worker)