From a6cfe4df5d179e86232039d2ab91caba645786fe Mon Sep 17 00:00:00 2001 From: Jarema Radom <jaremaradom@gmail.com> Date: Wed, 23 Jun 2021 14:18:12 +0200 Subject: [PATCH 01/14] languagetools support added --- Dockerfile.worker => Dockerfile | 4 ++-- config.ini | 3 ++- docker-compose.yml | 17 +++++++++++++++++ worker.py | 10 ++++++++++ 4 files changed, 31 insertions(+), 3 deletions(-) rename Dockerfile.worker => Dockerfile (85%) create mode 100644 docker-compose.yml diff --git a/Dockerfile.worker b/Dockerfile similarity index 85% rename from Dockerfile.worker rename to Dockerfile index 1046391..f2ec8a5 100644 --- a/Dockerfile.worker +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM clarinpl/cuda-python:3.7 +FROM clarinpl/cuda-python:3.7 AS base RUN DEBIAN_FRONTEND=noninteractive apt-get update && apt-get install -y gcc python3-dev @@ -17,4 +17,4 @@ COPY entrypoint.sh entrypoint.sh COPY worker.py worker.py COPY config.ini config.ini -ENTRYPOINT ["bash", "entrypoint.sh"] \ No newline at end of file +ENTRYPOINT ["bash", "entrypoint.sh"] diff --git a/config.ini b/config.ini index de392b1..d71e5ba 100644 --- a/config.ini +++ b/config.ini @@ -16,4 +16,5 @@ local_log_level = INFO model_path = /home/worker/model/punctuator max_context_size = 256 overlap = 20 -device = cpu \ No newline at end of file +device = cpu +languagetool_port = 8010 \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..732706c --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,17 @@ +version: "3" + +services: + languagetool: + image: erikvl87/languagetool + container_name: languagetool + ports: + - 8010:8010 # Using default port from the image + environment: + - langtool_languageModel=/ngrams # OPTIONAL: Using ngrams data + - Java_Xms=512m # OPTIONAL: Setting a minimal Java heap size of 512 mib + - Java_Xmx=1g # OPTIONAL: Setting a maximum Java heap size of 1 Gib + volumes: + - /path/to/ngrams/data:/ngrams + punctuator: + build: . + container_name: punctuator diff --git a/worker.py b/worker.py index 7f2d6e7..2af9c4b 100644 --- a/worker.py +++ b/worker.py @@ -3,6 +3,7 @@ import configparser import json import string +import requests import nlp_ws from transformers import AutoModelForTokenClassification, AutoTokenizer @@ -21,6 +22,13 @@ def _preprocess_input(text: str): return text +def _post_process(text: str, url: str): + resp = requests.get(url, params={'language': 'pl-PL', 'text': text}) + for match in resp.json()['matches']: + if match['rule']['category']['id'] == 'PUNCTUATION': + if len(match['replacements']) > 0: + text = text.replace(text[match['offset']:match['offset']+match['length']], match['replacements'][0]['value']) + return text class Worker(nlp_ws.NLPWorker): def init(self): @@ -32,6 +40,7 @@ class Worker(nlp_ws.NLPWorker): self.overlap = int(self.config["overlap"]) self.device = self.config["device"] + self.languagetool_url = "http://languagetool:{}/v2/check".format(self.config["languagetool_port"]) model_path = self.config["model_path"] self.model = AutoModelForTokenClassification.from_pretrained( @@ -86,6 +95,7 @@ class Worker(nlp_ws.NLPWorker): tokens += tokenized["input_ids"][0, combine_mask].numpy().tolist() text_out = decode(tokens, labels, self.tokenizer) + text_out = _post_process(text_out, self.languagetool_url) with open(output_path, "w") as f: f.write(text_out) -- GitLab From 6c7d5b930f3141cec2874e4a00c000c6426b88e2 Mon Sep 17 00:00:00 2001 From: Jarema Radom <jaremaradom@gmail.com> Date: Thu, 24 Jun 2021 08:25:44 +0200 Subject: [PATCH 02/14] style fix --- worker.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/worker.py b/worker.py index 2af9c4b..d81f83d 100644 --- a/worker.py +++ b/worker.py @@ -27,7 +27,8 @@ def _post_process(text: str, url: str): for match in resp.json()['matches']: if match['rule']['category']['id'] == 'PUNCTUATION': if len(match['replacements']) > 0: - text = text.replace(text[match['offset']:match['offset']+match['length']], match['replacements'][0]['value']) + text = text.replace(text[match['offset']:match['offset']+match['length']], + match['replacements'][0]['value']) return text class Worker(nlp_ws.NLPWorker): @@ -40,7 +41,8 @@ class Worker(nlp_ws.NLPWorker): self.overlap = int(self.config["overlap"]) self.device = self.config["device"] - self.languagetool_url = "http://languagetool:{}/v2/check".format(self.config["languagetool_port"]) + self.languagetool_url = "http://languagetool:{}/v2/check".format( + self.config["languagetool_port"]) model_path = self.config["model_path"] self.model = AutoModelForTokenClassification.from_pretrained( -- GitLab From b095d168d14859f39b3ef806121b4b7f8ee17c34 Mon Sep 17 00:00:00 2001 From: Jarema Radom <jaremaradom@gmail.com> Date: Thu, 24 Jun 2021 08:33:28 +0200 Subject: [PATCH 03/14] Further work on style --- worker.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/worker.py b/worker.py index d81f83d..06e51f4 100644 --- a/worker.py +++ b/worker.py @@ -22,15 +22,17 @@ def _preprocess_input(text: str): return text + def _post_process(text: str, url: str): resp = requests.get(url, params={'language': 'pl-PL', 'text': text}) for match in resp.json()['matches']: if match['rule']['category']['id'] == 'PUNCTUATION': if len(match['replacements']) > 0: - text = text.replace(text[match['offset']:match['offset']+match['length']], - match['replacements'][0]['value']) + text = text.replace(text[match['offset']:match['offset'] + + match['length']],match['replacements'][0]['value']) return text + class Worker(nlp_ws.NLPWorker): def init(self): self.config = configparser.ConfigParser() -- GitLab From 2812d88b4b077dee0bc5451613e990835883dee9 Mon Sep 17 00:00:00 2001 From: Jarema Radom <jaremaradom@gmail.com> Date: Thu, 24 Jun 2021 08:38:22 +0200 Subject: [PATCH 04/14] More work on style --- worker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/worker.py b/worker.py index 06e51f4..8228cab 100644 --- a/worker.py +++ b/worker.py @@ -28,8 +28,8 @@ def _post_process(text: str, url: str): for match in resp.json()['matches']: if match['rule']['category']['id'] == 'PUNCTUATION': if len(match['replacements']) > 0: - text = text.replace(text[match['offset']:match['offset'] + - match['length']],match['replacements'][0]['value']) + text = text.replace(text[match['offset']:match['offset'] + + match['length']],match['replacements'][0]['value']) return text -- GitLab From f72fc1df50f5a7baf17bc212e757338ae421c3b6 Mon Sep 17 00:00:00 2001 From: Jarema Radom <jaremaradom@gmail.com> Date: Thu, 24 Jun 2021 11:42:35 +0200 Subject: [PATCH 05/14] Additional container removed --- Dockerfile => Dockerfile.worker | 16 ++++++++++++++++ docker-compose.yml | 17 ----------------- requirements.txt | 3 ++- worker.py | 24 +++++++++++++----------- 4 files changed, 31 insertions(+), 29 deletions(-) rename Dockerfile => Dockerfile.worker (55%) delete mode 100644 docker-compose.yml diff --git a/Dockerfile b/Dockerfile.worker similarity index 55% rename from Dockerfile rename to Dockerfile.worker index f2ec8a5..54cbf34 100644 --- a/Dockerfile +++ b/Dockerfile.worker @@ -10,6 +10,22 @@ WORKDIR /workspace RUN pip3 install --index-url https://pypi.clarin-pl.eu/simple/ nlp_ws==0.6 +# Install OpenJDK-8 +RUN apt-get update && \ + apt-get install -y openjdk-8-jdk && \ + apt-get install -y ant && \ + apt-get clean; + +# Fix certificate issues +RUN apt-get update && \ + apt-get install ca-certificates-java && \ + apt-get clean && \ + update-ca-certificates -f; + +# Setup JAVA_HOME -- useful for docker commandline +ENV JAVA_HOME /usr/lib/jvm/java-8-openjdk-amd64/ +RUN export JAVA_HOME + WORKDIR /home/worker COPY punctuator punctuator diff --git a/docker-compose.yml b/docker-compose.yml deleted file mode 100644 index 732706c..0000000 --- a/docker-compose.yml +++ /dev/null @@ -1,17 +0,0 @@ -version: "3" - -services: - languagetool: - image: erikvl87/languagetool - container_name: languagetool - ports: - - 8010:8010 # Using default port from the image - environment: - - langtool_languageModel=/ngrams # OPTIONAL: Using ngrams data - - Java_Xms=512m # OPTIONAL: Setting a minimal Java heap size of 512 mib - - Java_Xmx=1g # OPTIONAL: Setting a maximum Java heap size of 1 Gib - volumes: - - /path/to/ngrams/data:/ngrams - punctuator: - build: . - container_name: punctuator diff --git a/requirements.txt b/requirements.txt index 1de9709..9df4a0c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ numpy==1.19.4 transformers==4.3.2 -torch==1.7.1 \ No newline at end of file +torch==1.7.1 +language-tool-python==2.5.4 \ No newline at end of file diff --git a/worker.py b/worker.py index 8228cab..2fb52fd 100644 --- a/worker.py +++ b/worker.py @@ -7,6 +7,8 @@ import requests import nlp_ws from transformers import AutoModelForTokenClassification, AutoTokenizer +import language_tool_python + from punctuator.punctuator import ( combine_masks, @@ -23,14 +25,11 @@ def _preprocess_input(text: str): return text -def _post_process(text: str, url: str): - resp = requests.get(url, params={'language': 'pl-PL', 'text': text}) - for match in resp.json()['matches']: - if match['rule']['category']['id'] == 'PUNCTUATION': - if len(match['replacements']) > 0: - text = text.replace(text[match['offset']:match['offset'] - + match['length']],match['replacements'][0]['value']) - return text +def _post_process(text: str, tool): + is_punctuation_rule = lambda rule: rule.category != 'PUNCTUATION' and len(rule.replacements) + matches = tool.check(text) + matches = [rule for rule in matches if not is_punctuation_rule(rule)] + return language_tool_python.utils.correct(text, matches) class Worker(nlp_ws.NLPWorker): @@ -43,8 +42,11 @@ class Worker(nlp_ws.NLPWorker): self.overlap = int(self.config["overlap"]) self.device = self.config["device"] - self.languagetool_url = "http://languagetool:{}/v2/check".format( - self.config["languagetool_port"]) + self.tool = language_tool_python.LanguageTool('pl-PL') + + # + print(_post_process('Ile dałbym osiem dziewięc korzyk, dwa razy, kamera, dwa', self.tool)) + # model_path = self.config["model_path"] self.model = AutoModelForTokenClassification.from_pretrained( @@ -99,7 +101,7 @@ class Worker(nlp_ws.NLPWorker): tokens += tokenized["input_ids"][0, combine_mask].numpy().tolist() text_out = decode(tokens, labels, self.tokenizer) - text_out = _post_process(text_out, self.languagetool_url) + text_out = _post_process(text_out, self.tool) with open(output_path, "w") as f: f.write(text_out) -- GitLab From 93a97127212c011a9750c16c470cdbc16574dd25 Mon Sep 17 00:00:00 2001 From: Michal Pogoda <michalpogoda@hotmail.com> Date: Thu, 24 Jun 2021 14:53:32 +0200 Subject: [PATCH 06/14] Added LT download path configuration --- config.ini | 4 ++-- worker.py | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/config.ini b/config.ini index d71e5ba..c6cff3c 100644 --- a/config.ini +++ b/config.ini @@ -14,7 +14,7 @@ local_log_level = INFO [deployment] model_path = /home/worker/model/punctuator +languagetool_path = /home/worker/model/languagetool max_context_size = 256 overlap = 20 -device = cpu -languagetool_port = 8010 \ No newline at end of file +device = cpu \ No newline at end of file diff --git a/worker.py b/worker.py index 2fb52fd..42ff957 100644 --- a/worker.py +++ b/worker.py @@ -4,6 +4,7 @@ import configparser import json import string import requests +import os import nlp_ws from transformers import AutoModelForTokenClassification, AutoTokenizer @@ -42,6 +43,9 @@ class Worker(nlp_ws.NLPWorker): self.overlap = int(self.config["overlap"]) self.device = self.config["device"] + + self.languagetool_path = self.config["languagetool_path"] + os.environ["LTP_PATH"] = self.languagetool_path self.tool = language_tool_python.LanguageTool('pl-PL') # -- GitLab From 6154e873a25d79904452cf198b06e51d023637bb Mon Sep 17 00:00:00 2001 From: Michal Pogoda <michalpogoda@hotmail.com> Date: Thu, 24 Jun 2021 14:55:06 +0200 Subject: [PATCH 07/14] Updated readme --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 480c09c..8c04942 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,7 @@ A service that automatically adds punctuation to raw word-stream (eg. from speec [deployment] device = cpu ; Device on which inference will be made (eg. cpu, cuda:0 etc) model_path = /model/punctuator ; Path where the model will be placed +languagetool_path = /model/languagetool ; Path where languagetool server will be placed max_context_size = 256 ; Number of tokens that will be oonsidered in prediciton at once. Must be between in range 2*overlap+1 to 512 overlap = 20 ; The number of tokens from the environment that will be taken at inference for a text fragment ``` -- GitLab From f75b15f1a1499b80fbf5cf44cead1f7206c3804c Mon Sep 17 00:00:00 2001 From: Jarema Radom <jaremaradom@gmail.com> Date: Thu, 24 Jun 2021 15:06:50 +0200 Subject: [PATCH 08/14] Final version --- worker.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/worker.py b/worker.py index 42ff957..45147ea 100644 --- a/worker.py +++ b/worker.py @@ -3,7 +3,6 @@ import configparser import json import string -import requests import os import nlp_ws @@ -26,8 +25,11 @@ def _preprocess_input(text: str): return text +def is_punctuation_rule(rule): + lambda rule: rule.category != 'PUNCTUATION' and len(rule.replacements) + + def _post_process(text: str, tool): - is_punctuation_rule = lambda rule: rule.category != 'PUNCTUATION' and len(rule.replacements) matches = tool.check(text) matches = [rule for rule in matches if not is_punctuation_rule(rule)] return language_tool_python.utils.correct(text, matches) @@ -47,11 +49,7 @@ class Worker(nlp_ws.NLPWorker): self.languagetool_path = self.config["languagetool_path"] os.environ["LTP_PATH"] = self.languagetool_path self.tool = language_tool_python.LanguageTool('pl-PL') - - # - print(_post_process('Ile dałbym osiem dziewięc korzyk, dwa razy, kamera, dwa', self.tool)) - # - + model_path = self.config["model_path"] self.model = AutoModelForTokenClassification.from_pretrained( model_path -- GitLab From e637912bc3f809b0f3b4d5c78773205d059e5acd Mon Sep 17 00:00:00 2001 From: Jarema Radom <jaremaradom@gmail.com> Date: Thu, 24 Jun 2021 15:07:17 +0200 Subject: [PATCH 09/14] Final fix --- worker.py | 1 - 1 file changed, 1 deletion(-) diff --git a/worker.py b/worker.py index 45147ea..f72481f 100644 --- a/worker.py +++ b/worker.py @@ -49,7 +49,6 @@ class Worker(nlp_ws.NLPWorker): self.languagetool_path = self.config["languagetool_path"] os.environ["LTP_PATH"] = self.languagetool_path self.tool = language_tool_python.LanguageTool('pl-PL') - model_path = self.config["model_path"] self.model = AutoModelForTokenClassification.from_pretrained( model_path -- GitLab From 1ee3ad667a1482ac001ff9e0c82d7f4007ea3875 Mon Sep 17 00:00:00 2001 From: Jarema Radom <jaremaradom@gmail.com> Date: Fri, 9 Jul 2021 10:36:38 +0200 Subject: [PATCH 10/14] worker redone, new entries in config --- config.ini | 4 +++- punctuator/punctuator.py | 18 +++++++++--------- worker.py | 39 +++++++++++++++++++++++++++------------ 3 files changed, 39 insertions(+), 22 deletions(-) diff --git a/config.ini b/config.ini index c6cff3c..f0af353 100644 --- a/config.ini +++ b/config.ini @@ -13,7 +13,9 @@ port = 9981 local_log_level = INFO [deployment] -model_path = /home/worker/model/punctuator +model_path_pl = /home/worker/model/punctuator +model_path_en = /home/worker/model/punctuator_en +model_path_ru = /home/worker/model/punctuator_ru languagetool_path = /home/worker/model/languagetool max_context_size = 256 overlap = 20 diff --git a/punctuator/punctuator.py b/punctuator/punctuator.py index 45f8efa..02ee56b 100644 --- a/punctuator/punctuator.py +++ b/punctuator/punctuator.py @@ -17,7 +17,7 @@ def decode_labels(results, labels_map) -> List[str]: return labels_decoded -def decode(tokens, labels_decoded, tokenizer): +def decode(tokens, labels_decoded, tokenizer, bpe=False): """Applies predictions to text in order to get punctuated text representation Args: @@ -31,21 +31,21 @@ def decode(tokens, labels_decoded, tokenizer): text_recovered = [] word = [] word_end = "" - for label, token in zip(labels_decoded, tokens): - token_str = tokenizer.convert_ids_to_tokens([token])[0] - + if bpe: + token_str = tokenizer.decode(token) + else: + token_str = tokenizer.convert_ids_to_tokens([token])[0] if token_str == "[PAD]": break - if token_str.startswith("##"): word.append(token_str.replace("##", "")) else: if len(word) > 0: - word.append(word_end) + if not bpe or word_end != ' ': + word.append(word_end) text_recovered.append("".join(word)) word = [] - if label.startswith("__ALL_UPPER__"): # TODO: Make all uppercase word.append(token_str[0].upper() + token_str[1:]) @@ -57,9 +57,9 @@ def decode(tokens, labels_decoded, tokenizer): label = label.replace("__UPPER__", "") label = label.replace("__ALL_UPPER__", "") word_end = label - text_recovered.append("".join(word)) - + if word_end != '': + text_recovered += word_end return "".join(text_recovered) diff --git a/worker.py b/worker.py index f72481f..e2706ca 100644 --- a/worker.py +++ b/worker.py @@ -40,7 +40,7 @@ class Worker(nlp_ws.NLPWorker): self.config = configparser.ConfigParser() self.config.read("config.ini") self.config = self.config["deployment"] - + self.languagetool_map = {'model_path_pl': 'pl-PL', 'model_path_ru':'ru', 'model_path_en':'en-US'} self.max_context_size = int(self.config["max_context_size"]) self.overlap = int(self.config["overlap"]) @@ -48,20 +48,19 @@ class Worker(nlp_ws.NLPWorker): self.languagetool_path = self.config["languagetool_path"] os.environ["LTP_PATH"] = self.languagetool_path - self.tool = language_tool_python.LanguageTool('pl-PL') - model_path = self.config["model_path"] - self.model = AutoModelForTokenClassification.from_pretrained( - model_path - ).to(self.device) - self.tokenizer = AutoTokenizer.from_pretrained(model_path) + self.model_path_pl = self.config["model_path_pl"] + self.model_path_ru = self.config["model_path_ru"] + self.model_path_en = self.config["model_path_en"] + self.initialize_model(self.model_path_pl) - with open(f"{model_path}/classes.json", "r") as f: - mapping = json.load(f) - self.mapping = list(mapping.keys()) def process( self, input_path: str, task_options: dict, output_path: str ) -> None: + + if task_options['language'] != self.current_model: + self.initialize_model(task_options['language']) + with open(input_path, "r") as f: text = f.read() @@ -101,12 +100,28 @@ class Worker(nlp_ws.NLPWorker): ): tokens += tokenized["input_ids"][0, combine_mask].numpy().tolist() - text_out = decode(tokens, labels, self.tokenizer) + text_out = decode(tokens, labels, self.tokenizer, self.current_model != self.model_path_pl) text_out = _post_process(text_out, self.tool) - + if not text_out.endswith('.'): + text_out += '.' with open(output_path, "w") as f: f.write(text_out) + def initialize_model(self, model_path: str): + self.tool = language_tool_python.LanguageTool(self.languagetool_map[model_path]) + if self.model: + self.model.to('cpu') + self.model = AutoModelForTokenClassification.from_pretrained( + model_path + ).to(self.device) + self.tokenizer = AutoTokenizer.from_pretrained(model_path) + + with open(f"{model_path}/classes.json", "r") as f: + mapping = json.load(f) + self.mapping = list(mapping.keys()) + self.current_model = model_path + + if __name__ == "__main__": nlp_ws.NLPService.main(Worker) -- GitLab From 5955d8015432b7773adad92dd655a66aafba12f9 Mon Sep 17 00:00:00 2001 From: Jarema Radom <jaremaradom@gmail.com> Date: Sat, 10 Jul 2021 09:39:18 +0200 Subject: [PATCH 11/14] style fix --- punctuator/punctuator.py | 3 ++- worker.py | 11 ++++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/punctuator/punctuator.py b/punctuator/punctuator.py index 02ee56b..1bea097 100644 --- a/punctuator/punctuator.py +++ b/punctuator/punctuator.py @@ -7,7 +7,8 @@ def decode_labels(results, labels_map) -> List[str]: Args: results (List[int]): List of ids of labels - labels_map (List[str]): List of classnames in order matching list of ids + labels_map (List[str]): List of classnames in order matching list of + ids Returns: List[str]: List of classnames diff --git a/worker.py b/worker.py index e2706ca..239824a 100644 --- a/worker.py +++ b/worker.py @@ -40,7 +40,8 @@ class Worker(nlp_ws.NLPWorker): self.config = configparser.ConfigParser() self.config.read("config.ini") self.config = self.config["deployment"] - self.languagetool_map = {'model_path_pl': 'pl-PL', 'model_path_ru':'ru', 'model_path_en':'en-US'} + self.languagetool_map = {'model_path_pl': 'pl-PL', 'model_path_ru': + 'ru', 'model_path_en': 'en-US'} self.max_context_size = int(self.config["max_context_size"]) self.overlap = int(self.config["overlap"]) @@ -53,7 +54,6 @@ class Worker(nlp_ws.NLPWorker): self.model_path_en = self.config["model_path_en"] self.initialize_model(self.model_path_pl) - def process( self, input_path: str, task_options: dict, output_path: str ) -> None: @@ -100,16 +100,17 @@ class Worker(nlp_ws.NLPWorker): ): tokens += tokenized["input_ids"][0, combine_mask].numpy().tolist() - text_out = decode(tokens, labels, self.tokenizer, self.current_model != self.model_path_pl) + text_out = decode(tokens, labels, self.tokenizer, + self.current_model != self.model_path_pl) text_out = _post_process(text_out, self.tool) if not text_out.endswith('.'): text_out += '.' with open(output_path, "w") as f: f.write(text_out) - def initialize_model(self, model_path: str): - self.tool = language_tool_python.LanguageTool(self.languagetool_map[model_path]) + self.tool = language_tool_python.LanguageTool( + self.languagetool_map[model_path]) if self.model: self.model.to('cpu') self.model = AutoModelForTokenClassification.from_pretrained( -- GitLab From 645b78826b31c7a250d6c7aefe36d9100aa3ddb1 Mon Sep 17 00:00:00 2001 From: Jarema Radom <jaremaradom@gmail.com> Date: Sat, 10 Jul 2021 09:53:12 +0200 Subject: [PATCH 12/14] Further work on style --- worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/worker.py b/worker.py index 239824a..718d524 100644 --- a/worker.py +++ b/worker.py @@ -110,7 +110,7 @@ class Worker(nlp_ws.NLPWorker): def initialize_model(self, model_path: str): self.tool = language_tool_python.LanguageTool( - self.languagetool_map[model_path]) + self.languagetool_map[model_path]) if self.model: self.model.to('cpu') self.model = AutoModelForTokenClassification.from_pretrained( -- GitLab From 4dfbf7e10c8b6806f88d937b72fac675dd713676 Mon Sep 17 00:00:00 2001 From: Jarema Radom <jaremaradom@gmail.com> Date: Sat, 10 Jul 2021 10:00:55 +0200 Subject: [PATCH 13/14] style fix --- worker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/worker.py b/worker.py index 718d524..ff4b96f 100644 --- a/worker.py +++ b/worker.py @@ -109,8 +109,8 @@ class Worker(nlp_ws.NLPWorker): f.write(text_out) def initialize_model(self, model_path: str): - self.tool = language_tool_python.LanguageTool( - self.languagetool_map[model_path]) + self.tool = language_tool_python.LanguageTool(self.languagetool_map + [model_path]) if self.model: self.model.to('cpu') self.model = AutoModelForTokenClassification.from_pretrained( -- GitLab From eef85e3160142d4b1512b8439de9840a3a1b4aeb Mon Sep 17 00:00:00 2001 From: Jarema Radom <jaremaradom@gmail.com> Date: Mon, 12 Jul 2021 09:52:58 +0200 Subject: [PATCH 14/14] Different approach to device sharing --- worker.py | 59 ++++++++++++++++++++++++++++++++++++------------------- 1 file changed, 39 insertions(+), 20 deletions(-) diff --git a/worker.py b/worker.py index ff4b96f..4275942 100644 --- a/worker.py +++ b/worker.py @@ -52,14 +52,23 @@ class Worker(nlp_ws.NLPWorker): self.model_path_pl = self.config["model_path_pl"] self.model_path_ru = self.config["model_path_ru"] self.model_path_en = self.config["model_path_en"] - self.initialize_model(self.model_path_pl) + self.tool_pl, self.model_pl, self.tokenizer_pl, self.mapping_pl \ + = self.initialize_model(self.model_path_pl, self.device) + self.tool_en, self.model_en, self.tokenizer_en, self.mapping_en \ + = self.initialize_model(self.model_path_en, 'cpu') + self.tool_ru, self.model_ru, self.tokenizer_ru, self.mapping_ru \ + = self.initialize_model(self.model_path_ru, 'cpu') + self.current_model = self.model_path_pl def process( self, input_path: str, task_options: dict, output_path: str ) -> None: if task_options['language'] != self.current_model: - self.initialize_model(task_options['language']) + self.pass_device(task_options['language']) + self.current_model = task_options['language'] + tool, model, tokenizer, mapping = self.get_setup_for_language( + self.current_model) with open(input_path, "r") as f: text = f.read() @@ -67,7 +76,7 @@ class Worker(nlp_ws.NLPWorker): # Make sure that the text is lowercase & punctuationless text = _preprocess_input(text) - tokenized = self.tokenizer(text, return_tensors="pt") + tokenized = tokenizer(text, return_tensors="pt") num_tokens = len(tokenized["input_ids"][0]) @@ -76,7 +85,7 @@ class Worker(nlp_ws.NLPWorker): for inference_mask, mask_mask in zip( *inference_masks(num_tokens, self.max_context_size, self.overlap) ): - result = self.model( + result = model( input_ids=tokenized["input_ids"][:, inference_mask].to( self.device ), @@ -91,7 +100,7 @@ class Worker(nlp_ws.NLPWorker): .squeeze() .numpy()[mask_mask] ) - results.append(decode_labels(labels_ids, self.mapping)) + results.append(decode_labels(labels_ids, mapping)) labels = sum(results, []) tokens = [] @@ -100,28 +109,38 @@ class Worker(nlp_ws.NLPWorker): ): tokens += tokenized["input_ids"][0, combine_mask].numpy().tolist() - text_out = decode(tokens, labels, self.tokenizer, + text_out = decode(tokens, labels, tokenizer, self.current_model != self.model_path_pl) - text_out = _post_process(text_out, self.tool) - if not text_out.endswith('.'): - text_out += '.' + text_out = _post_process(text_out, tool) with open(output_path, "w") as f: f.write(text_out) - def initialize_model(self, model_path: str): - self.tool = language_tool_python.LanguageTool(self.languagetool_map - [model_path]) - if self.model: - self.model.to('cpu') - self.model = AutoModelForTokenClassification.from_pretrained( + def initialize_model(self, model_path: str, device: str): + tool = language_tool_python.LanguageTool(self.languagetool_map + [model_path]) + model = AutoModelForTokenClassification.from_pretrained( model_path - ).to(self.device) - self.tokenizer = AutoTokenizer.from_pretrained(model_path) - + ).to(device) + tokenizer = AutoTokenizer.from_pretrained(model_path) + mapping = {} with open(f"{model_path}/classes.json", "r") as f: mapping = json.load(f) - self.mapping = list(mapping.keys()) - self.current_model = model_path + mapping = list(mapping.keys()) + return tool, model, tokenizer, mapping + + def get_setup_for_language(self, language): + if language == 'model_path_ru': + return self.tool_ru, self.model_ru, self.tokenizer_ru, self.mapping_ru + elif language == 'model_path_en': + return self.tool_en, self.model_en, self.tokenizer_en, self.mapping_en + else: + return self.tool_pl, self.model_pl, self.tokenizer_pl, self.mapping_pl + + def pass_device(self, new_language): + _, current_model, _, _ = self.get_setup_for_language(self.current_model) + current_model.to('cpu') + _, current_model, _, _ = self.get_setup_for_language(new_language) + current_model.to(self.device) if __name__ == "__main__": -- GitLab