From a6cfe4df5d179e86232039d2ab91caba645786fe Mon Sep 17 00:00:00 2001
From: Jarema Radom <jaremaradom@gmail.com>
Date: Wed, 23 Jun 2021 14:18:12 +0200
Subject: [PATCH 01/14] languagetools support added

---
 Dockerfile.worker => Dockerfile |  4 ++--
 config.ini                      |  3 ++-
 docker-compose.yml              | 17 +++++++++++++++++
 worker.py                       | 10 ++++++++++
 4 files changed, 31 insertions(+), 3 deletions(-)
 rename Dockerfile.worker => Dockerfile (85%)
 create mode 100644 docker-compose.yml

diff --git a/Dockerfile.worker b/Dockerfile
similarity index 85%
rename from Dockerfile.worker
rename to Dockerfile
index 1046391..f2ec8a5 100644
--- a/Dockerfile.worker
+++ b/Dockerfile
@@ -1,4 +1,4 @@
-FROM clarinpl/cuda-python:3.7
+FROM clarinpl/cuda-python:3.7 AS base
 
 RUN DEBIAN_FRONTEND=noninteractive apt-get update && apt-get install -y gcc python3-dev
 
@@ -17,4 +17,4 @@ COPY entrypoint.sh entrypoint.sh
 COPY worker.py worker.py
 COPY config.ini config.ini
 
-ENTRYPOINT ["bash", "entrypoint.sh"]
\ No newline at end of file
+ENTRYPOINT ["bash", "entrypoint.sh"]
diff --git a/config.ini b/config.ini
index de392b1..d71e5ba 100644
--- a/config.ini
+++ b/config.ini
@@ -16,4 +16,5 @@ local_log_level = INFO
 model_path = /home/worker/model/punctuator
 max_context_size = 256
 overlap = 20
-device = cpu
\ No newline at end of file
+device = cpu
+languagetool_port = 8010
\ No newline at end of file
diff --git a/docker-compose.yml b/docker-compose.yml
new file mode 100644
index 0000000..732706c
--- /dev/null
+++ b/docker-compose.yml
@@ -0,0 +1,17 @@
+version: "3"
+
+services:
+  languagetool:
+    image: erikvl87/languagetool
+    container_name: languagetool
+    ports:
+        - 8010:8010  # Using default port from the image
+    environment:
+        - langtool_languageModel=/ngrams  # OPTIONAL: Using ngrams data
+        - Java_Xms=512m  # OPTIONAL: Setting a minimal Java heap size of 512 mib
+        - Java_Xmx=1g  # OPTIONAL: Setting a maximum Java heap size of 1 Gib
+    volumes:
+        - /path/to/ngrams/data:/ngrams
+  punctuator:
+    build: .
+    container_name: punctuator
diff --git a/worker.py b/worker.py
index 7f2d6e7..2af9c4b 100644
--- a/worker.py
+++ b/worker.py
@@ -3,6 +3,7 @@
 import configparser
 import json
 import string
+import requests
 
 import nlp_ws
 from transformers import AutoModelForTokenClassification, AutoTokenizer
@@ -21,6 +22,13 @@ def _preprocess_input(text: str):
 
     return text
 
+def _post_process(text: str, url: str):
+    resp = requests.get(url, params={'language': 'pl-PL', 'text': text})
+    for match in resp.json()['matches']:
+        if match['rule']['category']['id'] == 'PUNCTUATION':
+            if len(match['replacements']) > 0:
+                text = text.replace(text[match['offset']:match['offset']+match['length']], match['replacements'][0]['value'])
+    return text
 
 class Worker(nlp_ws.NLPWorker):
     def init(self):
@@ -32,6 +40,7 @@ class Worker(nlp_ws.NLPWorker):
         self.overlap = int(self.config["overlap"])
 
         self.device = self.config["device"]
+        self.languagetool_url = "http://languagetool:{}/v2/check".format(self.config["languagetool_port"])
 
         model_path = self.config["model_path"]
         self.model = AutoModelForTokenClassification.from_pretrained(
@@ -86,6 +95,7 @@ class Worker(nlp_ws.NLPWorker):
             tokens += tokenized["input_ids"][0, combine_mask].numpy().tolist()
 
         text_out = decode(tokens, labels, self.tokenizer)
+        text_out = _post_process(text_out, self.languagetool_url)
 
         with open(output_path, "w") as f:
             f.write(text_out)
-- 
GitLab


From 6c7d5b930f3141cec2874e4a00c000c6426b88e2 Mon Sep 17 00:00:00 2001
From: Jarema Radom <jaremaradom@gmail.com>
Date: Thu, 24 Jun 2021 08:25:44 +0200
Subject: [PATCH 02/14] style fix

---
 worker.py | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/worker.py b/worker.py
index 2af9c4b..d81f83d 100644
--- a/worker.py
+++ b/worker.py
@@ -27,7 +27,8 @@ def _post_process(text: str, url: str):
     for match in resp.json()['matches']:
         if match['rule']['category']['id'] == 'PUNCTUATION':
             if len(match['replacements']) > 0:
-                text = text.replace(text[match['offset']:match['offset']+match['length']], match['replacements'][0]['value'])
+                text = text.replace(text[match['offset']:match['offset']+match['length']],
+                 match['replacements'][0]['value'])
     return text
 
 class Worker(nlp_ws.NLPWorker):
@@ -40,7 +41,8 @@ class Worker(nlp_ws.NLPWorker):
         self.overlap = int(self.config["overlap"])
 
         self.device = self.config["device"]
-        self.languagetool_url = "http://languagetool:{}/v2/check".format(self.config["languagetool_port"])
+        self.languagetool_url = "http://languagetool:{}/v2/check".format(
+            self.config["languagetool_port"])
 
         model_path = self.config["model_path"]
         self.model = AutoModelForTokenClassification.from_pretrained(
-- 
GitLab


From b095d168d14859f39b3ef806121b4b7f8ee17c34 Mon Sep 17 00:00:00 2001
From: Jarema Radom <jaremaradom@gmail.com>
Date: Thu, 24 Jun 2021 08:33:28 +0200
Subject: [PATCH 03/14] Further work on style

---
 worker.py | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/worker.py b/worker.py
index d81f83d..06e51f4 100644
--- a/worker.py
+++ b/worker.py
@@ -22,15 +22,17 @@ def _preprocess_input(text: str):
 
     return text
 
+
 def _post_process(text: str, url: str):
     resp = requests.get(url, params={'language': 'pl-PL', 'text': text})
     for match in resp.json()['matches']:
         if match['rule']['category']['id'] == 'PUNCTUATION':
             if len(match['replacements']) > 0:
-                text = text.replace(text[match['offset']:match['offset']+match['length']],
-                 match['replacements'][0]['value'])
+                text = text.replace(text[match['offset']:match['offset'] +
+                                    match['length']],match['replacements'][0]['value'])
     return text
 
+
 class Worker(nlp_ws.NLPWorker):
     def init(self):
         self.config = configparser.ConfigParser()
-- 
GitLab


From 2812d88b4b077dee0bc5451613e990835883dee9 Mon Sep 17 00:00:00 2001
From: Jarema Radom <jaremaradom@gmail.com>
Date: Thu, 24 Jun 2021 08:38:22 +0200
Subject: [PATCH 04/14] More work on style

---
 worker.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/worker.py b/worker.py
index 06e51f4..8228cab 100644
--- a/worker.py
+++ b/worker.py
@@ -28,8 +28,8 @@ def _post_process(text: str, url: str):
     for match in resp.json()['matches']:
         if match['rule']['category']['id'] == 'PUNCTUATION':
             if len(match['replacements']) > 0:
-                text = text.replace(text[match['offset']:match['offset'] +
-                                    match['length']],match['replacements'][0]['value'])
+                text = text.replace(text[match['offset']:match['offset'] 
+                                    + match['length']],match['replacements'][0]['value'])
     return text
 
 
-- 
GitLab


From f72fc1df50f5a7baf17bc212e757338ae421c3b6 Mon Sep 17 00:00:00 2001
From: Jarema Radom <jaremaradom@gmail.com>
Date: Thu, 24 Jun 2021 11:42:35 +0200
Subject: [PATCH 05/14] Additional container removed

---
 Dockerfile => Dockerfile.worker | 16 ++++++++++++++++
 docker-compose.yml              | 17 -----------------
 requirements.txt                |  3 ++-
 worker.py                       | 24 +++++++++++++-----------
 4 files changed, 31 insertions(+), 29 deletions(-)
 rename Dockerfile => Dockerfile.worker (55%)
 delete mode 100644 docker-compose.yml

diff --git a/Dockerfile b/Dockerfile.worker
similarity index 55%
rename from Dockerfile
rename to Dockerfile.worker
index f2ec8a5..54cbf34 100644
--- a/Dockerfile
+++ b/Dockerfile.worker
@@ -10,6 +10,22 @@ WORKDIR /workspace
 
 RUN pip3 install --index-url https://pypi.clarin-pl.eu/simple/ nlp_ws==0.6
 
+# Install OpenJDK-8
+RUN apt-get update && \
+    apt-get install -y openjdk-8-jdk && \
+    apt-get install -y ant && \
+    apt-get clean;
+
+# Fix certificate issues
+RUN apt-get update && \
+    apt-get install ca-certificates-java && \
+    apt-get clean && \
+    update-ca-certificates -f;
+
+# Setup JAVA_HOME -- useful for docker commandline
+ENV JAVA_HOME /usr/lib/jvm/java-8-openjdk-amd64/
+RUN export JAVA_HOME
+
 WORKDIR /home/worker
 
 COPY punctuator punctuator
diff --git a/docker-compose.yml b/docker-compose.yml
deleted file mode 100644
index 732706c..0000000
--- a/docker-compose.yml
+++ /dev/null
@@ -1,17 +0,0 @@
-version: "3"
-
-services:
-  languagetool:
-    image: erikvl87/languagetool
-    container_name: languagetool
-    ports:
-        - 8010:8010  # Using default port from the image
-    environment:
-        - langtool_languageModel=/ngrams  # OPTIONAL: Using ngrams data
-        - Java_Xms=512m  # OPTIONAL: Setting a minimal Java heap size of 512 mib
-        - Java_Xmx=1g  # OPTIONAL: Setting a maximum Java heap size of 1 Gib
-    volumes:
-        - /path/to/ngrams/data:/ngrams
-  punctuator:
-    build: .
-    container_name: punctuator
diff --git a/requirements.txt b/requirements.txt
index 1de9709..9df4a0c 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,3 +1,4 @@
 numpy==1.19.4
 transformers==4.3.2
-torch==1.7.1
\ No newline at end of file
+torch==1.7.1
+language-tool-python==2.5.4
\ No newline at end of file
diff --git a/worker.py b/worker.py
index 8228cab..2fb52fd 100644
--- a/worker.py
+++ b/worker.py
@@ -7,6 +7,8 @@ import requests
 
 import nlp_ws
 from transformers import AutoModelForTokenClassification, AutoTokenizer
+import language_tool_python
+
 
 from punctuator.punctuator import (
     combine_masks,
@@ -23,14 +25,11 @@ def _preprocess_input(text: str):
     return text
 
 
-def _post_process(text: str, url: str):
-    resp = requests.get(url, params={'language': 'pl-PL', 'text': text})
-    for match in resp.json()['matches']:
-        if match['rule']['category']['id'] == 'PUNCTUATION':
-            if len(match['replacements']) > 0:
-                text = text.replace(text[match['offset']:match['offset'] 
-                                    + match['length']],match['replacements'][0]['value'])
-    return text
+def _post_process(text: str, tool):
+    is_punctuation_rule = lambda rule: rule.category != 'PUNCTUATION' and len(rule.replacements) 
+    matches = tool.check(text)
+    matches = [rule for rule in matches if not is_punctuation_rule(rule)]
+    return language_tool_python.utils.correct(text, matches)
 
 
 class Worker(nlp_ws.NLPWorker):
@@ -43,8 +42,11 @@ class Worker(nlp_ws.NLPWorker):
         self.overlap = int(self.config["overlap"])
 
         self.device = self.config["device"]
-        self.languagetool_url = "http://languagetool:{}/v2/check".format(
-            self.config["languagetool_port"])
+        self.tool = language_tool_python.LanguageTool('pl-PL')
+
+        #
+        print(_post_process('Ile dałbym osiem dziewięc korzyk, dwa razy, kamera, dwa', self.tool))
+        #
 
         model_path = self.config["model_path"]
         self.model = AutoModelForTokenClassification.from_pretrained(
@@ -99,7 +101,7 @@ class Worker(nlp_ws.NLPWorker):
             tokens += tokenized["input_ids"][0, combine_mask].numpy().tolist()
 
         text_out = decode(tokens, labels, self.tokenizer)
-        text_out = _post_process(text_out, self.languagetool_url)
+        text_out = _post_process(text_out, self.tool)
 
         with open(output_path, "w") as f:
             f.write(text_out)
-- 
GitLab


From 93a97127212c011a9750c16c470cdbc16574dd25 Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Thu, 24 Jun 2021 14:53:32 +0200
Subject: [PATCH 06/14] Added LT download path configuration

---
 config.ini | 4 ++--
 worker.py  | 4 ++++
 2 files changed, 6 insertions(+), 2 deletions(-)

diff --git a/config.ini b/config.ini
index d71e5ba..c6cff3c 100644
--- a/config.ini
+++ b/config.ini
@@ -14,7 +14,7 @@ local_log_level = INFO
 
 [deployment]
 model_path = /home/worker/model/punctuator
+languagetool_path = /home/worker/model/languagetool
 max_context_size = 256
 overlap = 20
-device = cpu
-languagetool_port = 8010
\ No newline at end of file
+device = cpu
\ No newline at end of file
diff --git a/worker.py b/worker.py
index 2fb52fd..42ff957 100644
--- a/worker.py
+++ b/worker.py
@@ -4,6 +4,7 @@ import configparser
 import json
 import string
 import requests
+import os
 
 import nlp_ws
 from transformers import AutoModelForTokenClassification, AutoTokenizer
@@ -42,6 +43,9 @@ class Worker(nlp_ws.NLPWorker):
         self.overlap = int(self.config["overlap"])
 
         self.device = self.config["device"]
+
+        self.languagetool_path = self.config["languagetool_path"]
+        os.environ["LTP_PATH"] = self.languagetool_path
         self.tool = language_tool_python.LanguageTool('pl-PL')
 
         #
-- 
GitLab


From 6154e873a25d79904452cf198b06e51d023637bb Mon Sep 17 00:00:00 2001
From: Michal Pogoda <michalpogoda@hotmail.com>
Date: Thu, 24 Jun 2021 14:55:06 +0200
Subject: [PATCH 07/14] Updated readme

---
 README.md | 1 +
 1 file changed, 1 insertion(+)

diff --git a/README.md b/README.md
index 480c09c..8c04942 100644
--- a/README.md
+++ b/README.md
@@ -12,6 +12,7 @@ A service that automatically adds punctuation to raw word-stream (eg. from speec
 [deployment]
 device = cpu ; Device on which inference will be made (eg. cpu, cuda:0 etc)
 model_path = /model/punctuator ; Path where the model will be placed
+languagetool_path = /model/languagetool ; Path where languagetool server will be placed
 max_context_size = 256 ; Number of tokens that will be oonsidered in prediciton at once. Must be between in range 2*overlap+1 to 512
 overlap = 20 ; The number of tokens from the environment that will be taken at inference for a text fragment
 ```
-- 
GitLab


From f75b15f1a1499b80fbf5cf44cead1f7206c3804c Mon Sep 17 00:00:00 2001
From: Jarema Radom <jaremaradom@gmail.com>
Date: Thu, 24 Jun 2021 15:06:50 +0200
Subject: [PATCH 08/14] Final version

---
 worker.py | 12 +++++-------
 1 file changed, 5 insertions(+), 7 deletions(-)

diff --git a/worker.py b/worker.py
index 42ff957..45147ea 100644
--- a/worker.py
+++ b/worker.py
@@ -3,7 +3,6 @@
 import configparser
 import json
 import string
-import requests
 import os
 
 import nlp_ws
@@ -26,8 +25,11 @@ def _preprocess_input(text: str):
     return text
 
 
+def is_punctuation_rule(rule):
+    lambda rule: rule.category != 'PUNCTUATION' and len(rule.replacements)
+
+
 def _post_process(text: str, tool):
-    is_punctuation_rule = lambda rule: rule.category != 'PUNCTUATION' and len(rule.replacements) 
     matches = tool.check(text)
     matches = [rule for rule in matches if not is_punctuation_rule(rule)]
     return language_tool_python.utils.correct(text, matches)
@@ -47,11 +49,7 @@ class Worker(nlp_ws.NLPWorker):
         self.languagetool_path = self.config["languagetool_path"]
         os.environ["LTP_PATH"] = self.languagetool_path
         self.tool = language_tool_python.LanguageTool('pl-PL')
-
-        #
-        print(_post_process('Ile dałbym osiem dziewięc korzyk, dwa razy, kamera, dwa', self.tool))
-        #
-
+        
         model_path = self.config["model_path"]
         self.model = AutoModelForTokenClassification.from_pretrained(
             model_path
-- 
GitLab


From e637912bc3f809b0f3b4d5c78773205d059e5acd Mon Sep 17 00:00:00 2001
From: Jarema Radom <jaremaradom@gmail.com>
Date: Thu, 24 Jun 2021 15:07:17 +0200
Subject: [PATCH 09/14] Final fix

---
 worker.py | 1 -
 1 file changed, 1 deletion(-)

diff --git a/worker.py b/worker.py
index 45147ea..f72481f 100644
--- a/worker.py
+++ b/worker.py
@@ -49,7 +49,6 @@ class Worker(nlp_ws.NLPWorker):
         self.languagetool_path = self.config["languagetool_path"]
         os.environ["LTP_PATH"] = self.languagetool_path
         self.tool = language_tool_python.LanguageTool('pl-PL')
-        
         model_path = self.config["model_path"]
         self.model = AutoModelForTokenClassification.from_pretrained(
             model_path
-- 
GitLab


From 1ee3ad667a1482ac001ff9e0c82d7f4007ea3875 Mon Sep 17 00:00:00 2001
From: Jarema Radom <jaremaradom@gmail.com>
Date: Fri, 9 Jul 2021 10:36:38 +0200
Subject: [PATCH 10/14] worker redone, new entries in config

---
 config.ini               |  4 +++-
 punctuator/punctuator.py | 18 +++++++++---------
 worker.py                | 39 +++++++++++++++++++++++++++------------
 3 files changed, 39 insertions(+), 22 deletions(-)

diff --git a/config.ini b/config.ini
index c6cff3c..f0af353 100644
--- a/config.ini
+++ b/config.ini
@@ -13,7 +13,9 @@ port = 9981
 local_log_level = INFO
 
 [deployment]
-model_path = /home/worker/model/punctuator
+model_path_pl = /home/worker/model/punctuator
+model_path_en = /home/worker/model/punctuator_en
+model_path_ru = /home/worker/model/punctuator_ru
 languagetool_path = /home/worker/model/languagetool
 max_context_size = 256
 overlap = 20
diff --git a/punctuator/punctuator.py b/punctuator/punctuator.py
index 45f8efa..02ee56b 100644
--- a/punctuator/punctuator.py
+++ b/punctuator/punctuator.py
@@ -17,7 +17,7 @@ def decode_labels(results, labels_map) -> List[str]:
     return labels_decoded
 
 
-def decode(tokens, labels_decoded, tokenizer):
+def decode(tokens, labels_decoded, tokenizer, bpe=False):
     """Applies predictions to text in order to get punctuated text representation
 
     Args:
@@ -31,21 +31,21 @@ def decode(tokens, labels_decoded, tokenizer):
     text_recovered = []
     word = []
     word_end = ""
-
     for label, token in zip(labels_decoded, tokens):
-        token_str = tokenizer.convert_ids_to_tokens([token])[0]
-
+        if bpe:
+            token_str = tokenizer.decode(token)
+        else:
+            token_str = tokenizer.convert_ids_to_tokens([token])[0]
         if token_str == "[PAD]":
             break
-
         if token_str.startswith("##"):
             word.append(token_str.replace("##", ""))
         else:
             if len(word) > 0:
-                word.append(word_end)
+                if not bpe or word_end != ' ':
+                    word.append(word_end)
                 text_recovered.append("".join(word))
                 word = []
-
             if label.startswith("__ALL_UPPER__"):
                 # TODO: Make all uppercase
                 word.append(token_str[0].upper() + token_str[1:])
@@ -57,9 +57,9 @@ def decode(tokens, labels_decoded, tokenizer):
             label = label.replace("__UPPER__", "")
             label = label.replace("__ALL_UPPER__", "")
             word_end = label
-
     text_recovered.append("".join(word))
-
+    if word_end != '':
+        text_recovered += word_end
     return "".join(text_recovered)
 
 
diff --git a/worker.py b/worker.py
index f72481f..e2706ca 100644
--- a/worker.py
+++ b/worker.py
@@ -40,7 +40,7 @@ class Worker(nlp_ws.NLPWorker):
         self.config = configparser.ConfigParser()
         self.config.read("config.ini")
         self.config = self.config["deployment"]
-
+        self.languagetool_map = {'model_path_pl': 'pl-PL', 'model_path_ru':'ru', 'model_path_en':'en-US'}
         self.max_context_size = int(self.config["max_context_size"])
         self.overlap = int(self.config["overlap"])
 
@@ -48,20 +48,19 @@ class Worker(nlp_ws.NLPWorker):
 
         self.languagetool_path = self.config["languagetool_path"]
         os.environ["LTP_PATH"] = self.languagetool_path
-        self.tool = language_tool_python.LanguageTool('pl-PL')
-        model_path = self.config["model_path"]
-        self.model = AutoModelForTokenClassification.from_pretrained(
-            model_path
-        ).to(self.device)
-        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
+        self.model_path_pl = self.config["model_path_pl"]
+        self.model_path_ru = self.config["model_path_ru"]
+        self.model_path_en = self.config["model_path_en"]
+        self.initialize_model(self.model_path_pl)
 
-        with open(f"{model_path}/classes.json", "r") as f:
-            mapping = json.load(f)
-            self.mapping = list(mapping.keys())
 
     def process(
         self, input_path: str, task_options: dict, output_path: str
     ) -> None:
+
+        if task_options['language'] != self.current_model:
+            self.initialize_model(task_options['language'])
+
         with open(input_path, "r") as f:
             text = f.read()
 
@@ -101,12 +100,28 @@ class Worker(nlp_ws.NLPWorker):
         ):
             tokens += tokenized["input_ids"][0, combine_mask].numpy().tolist()
 
-        text_out = decode(tokens, labels, self.tokenizer)
+        text_out = decode(tokens, labels, self.tokenizer, self.current_model != self.model_path_pl)
         text_out = _post_process(text_out, self.tool)
-
+        if not text_out.endswith('.'):
+            text_out += '.'
         with open(output_path, "w") as f:
             f.write(text_out)
 
 
+    def initialize_model(self, model_path: str):
+        self.tool = language_tool_python.LanguageTool(self.languagetool_map[model_path])
+        if self.model:
+            self.model.to('cpu')
+        self.model = AutoModelForTokenClassification.from_pretrained(
+            model_path
+        ).to(self.device)
+        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
+
+        with open(f"{model_path}/classes.json", "r") as f:
+            mapping = json.load(f)
+            self.mapping = list(mapping.keys())
+        self.current_model = model_path
+
+
 if __name__ == "__main__":
     nlp_ws.NLPService.main(Worker)
-- 
GitLab


From 5955d8015432b7773adad92dd655a66aafba12f9 Mon Sep 17 00:00:00 2001
From: Jarema Radom <jaremaradom@gmail.com>
Date: Sat, 10 Jul 2021 09:39:18 +0200
Subject: [PATCH 11/14] style fix

---
 punctuator/punctuator.py |  3 ++-
 worker.py                | 11 ++++++-----
 2 files changed, 8 insertions(+), 6 deletions(-)

diff --git a/punctuator/punctuator.py b/punctuator/punctuator.py
index 02ee56b..1bea097 100644
--- a/punctuator/punctuator.py
+++ b/punctuator/punctuator.py
@@ -7,7 +7,8 @@ def decode_labels(results, labels_map) -> List[str]:
 
     Args:
         results (List[int]): List of ids of labels
-        labels_map (List[str]): List of classnames in order matching list of ids
+        labels_map (List[str]): List of classnames in order matching list of
+        ids
 
     Returns:
         List[str]: List of classnames
diff --git a/worker.py b/worker.py
index e2706ca..239824a 100644
--- a/worker.py
+++ b/worker.py
@@ -40,7 +40,8 @@ class Worker(nlp_ws.NLPWorker):
         self.config = configparser.ConfigParser()
         self.config.read("config.ini")
         self.config = self.config["deployment"]
-        self.languagetool_map = {'model_path_pl': 'pl-PL', 'model_path_ru':'ru', 'model_path_en':'en-US'}
+        self.languagetool_map = {'model_path_pl': 'pl-PL', 'model_path_ru':
+                                 'ru', 'model_path_en': 'en-US'}
         self.max_context_size = int(self.config["max_context_size"])
         self.overlap = int(self.config["overlap"])
 
@@ -53,7 +54,6 @@ class Worker(nlp_ws.NLPWorker):
         self.model_path_en = self.config["model_path_en"]
         self.initialize_model(self.model_path_pl)
 
-
     def process(
         self, input_path: str, task_options: dict, output_path: str
     ) -> None:
@@ -100,16 +100,17 @@ class Worker(nlp_ws.NLPWorker):
         ):
             tokens += tokenized["input_ids"][0, combine_mask].numpy().tolist()
 
-        text_out = decode(tokens, labels, self.tokenizer, self.current_model != self.model_path_pl)
+        text_out = decode(tokens, labels, self.tokenizer,
+                          self.current_model != self.model_path_pl)
         text_out = _post_process(text_out, self.tool)
         if not text_out.endswith('.'):
             text_out += '.'
         with open(output_path, "w") as f:
             f.write(text_out)
 
-
     def initialize_model(self, model_path: str):
-        self.tool = language_tool_python.LanguageTool(self.languagetool_map[model_path])
+        self.tool = language_tool_python.LanguageTool(
+                    self.languagetool_map[model_path])
         if self.model:
             self.model.to('cpu')
         self.model = AutoModelForTokenClassification.from_pretrained(
-- 
GitLab


From 645b78826b31c7a250d6c7aefe36d9100aa3ddb1 Mon Sep 17 00:00:00 2001
From: Jarema Radom <jaremaradom@gmail.com>
Date: Sat, 10 Jul 2021 09:53:12 +0200
Subject: [PATCH 12/14] Further work on style

---
 worker.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/worker.py b/worker.py
index 239824a..718d524 100644
--- a/worker.py
+++ b/worker.py
@@ -110,7 +110,7 @@ class Worker(nlp_ws.NLPWorker):
 
     def initialize_model(self, model_path: str):
         self.tool = language_tool_python.LanguageTool(
-                    self.languagetool_map[model_path])
+                  self.languagetool_map[model_path])
         if self.model:
             self.model.to('cpu')
         self.model = AutoModelForTokenClassification.from_pretrained(
-- 
GitLab


From 4dfbf7e10c8b6806f88d937b72fac675dd713676 Mon Sep 17 00:00:00 2001
From: Jarema Radom <jaremaradom@gmail.com>
Date: Sat, 10 Jul 2021 10:00:55 +0200
Subject: [PATCH 13/14] style fix

---
 worker.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/worker.py b/worker.py
index 718d524..ff4b96f 100644
--- a/worker.py
+++ b/worker.py
@@ -109,8 +109,8 @@ class Worker(nlp_ws.NLPWorker):
             f.write(text_out)
 
     def initialize_model(self, model_path: str):
-        self.tool = language_tool_python.LanguageTool(
-                  self.languagetool_map[model_path])
+        self.tool = language_tool_python.LanguageTool(self.languagetool_map
+                                                      [model_path])
         if self.model:
             self.model.to('cpu')
         self.model = AutoModelForTokenClassification.from_pretrained(
-- 
GitLab


From eef85e3160142d4b1512b8439de9840a3a1b4aeb Mon Sep 17 00:00:00 2001
From: Jarema Radom <jaremaradom@gmail.com>
Date: Mon, 12 Jul 2021 09:52:58 +0200
Subject: [PATCH 14/14] Different approach to device sharing

---
 worker.py | 59 ++++++++++++++++++++++++++++++++++++-------------------
 1 file changed, 39 insertions(+), 20 deletions(-)

diff --git a/worker.py b/worker.py
index ff4b96f..4275942 100644
--- a/worker.py
+++ b/worker.py
@@ -52,14 +52,23 @@ class Worker(nlp_ws.NLPWorker):
         self.model_path_pl = self.config["model_path_pl"]
         self.model_path_ru = self.config["model_path_ru"]
         self.model_path_en = self.config["model_path_en"]
-        self.initialize_model(self.model_path_pl)
+        self.tool_pl, self.model_pl, self.tokenizer_pl, self.mapping_pl \
+            = self.initialize_model(self.model_path_pl, self.device)
+        self.tool_en, self.model_en, self.tokenizer_en, self.mapping_en \
+            = self.initialize_model(self.model_path_en, 'cpu')
+        self.tool_ru, self.model_ru, self.tokenizer_ru, self.mapping_ru \
+            = self.initialize_model(self.model_path_ru, 'cpu')
+        self.current_model = self.model_path_pl
 
     def process(
         self, input_path: str, task_options: dict, output_path: str
     ) -> None:
 
         if task_options['language'] != self.current_model:
-            self.initialize_model(task_options['language'])
+            self.pass_device(task_options['language'])
+            self.current_model = task_options['language']
+        tool, model, tokenizer, mapping = self.get_setup_for_language(
+            self.current_model)
 
         with open(input_path, "r") as f:
             text = f.read()
@@ -67,7 +76,7 @@ class Worker(nlp_ws.NLPWorker):
         # Make sure that the text is lowercase & punctuationless
         text = _preprocess_input(text)
 
-        tokenized = self.tokenizer(text, return_tensors="pt")
+        tokenized = tokenizer(text, return_tensors="pt")
 
         num_tokens = len(tokenized["input_ids"][0])
 
@@ -76,7 +85,7 @@ class Worker(nlp_ws.NLPWorker):
         for inference_mask, mask_mask in zip(
             *inference_masks(num_tokens, self.max_context_size, self.overlap)
         ):
-            result = self.model(
+            result = model(
                 input_ids=tokenized["input_ids"][:, inference_mask].to(
                     self.device
                 ),
@@ -91,7 +100,7 @@ class Worker(nlp_ws.NLPWorker):
                 .squeeze()
                 .numpy()[mask_mask]
             )
-            results.append(decode_labels(labels_ids, self.mapping))
+            results.append(decode_labels(labels_ids, mapping))
         labels = sum(results, [])
 
         tokens = []
@@ -100,28 +109,38 @@ class Worker(nlp_ws.NLPWorker):
         ):
             tokens += tokenized["input_ids"][0, combine_mask].numpy().tolist()
 
-        text_out = decode(tokens, labels, self.tokenizer,
+        text_out = decode(tokens, labels, tokenizer,
                           self.current_model != self.model_path_pl)
-        text_out = _post_process(text_out, self.tool)
-        if not text_out.endswith('.'):
-            text_out += '.'
+        text_out = _post_process(text_out, tool)
         with open(output_path, "w") as f:
             f.write(text_out)
 
-    def initialize_model(self, model_path: str):
-        self.tool = language_tool_python.LanguageTool(self.languagetool_map
-                                                      [model_path])
-        if self.model:
-            self.model.to('cpu')
-        self.model = AutoModelForTokenClassification.from_pretrained(
+    def initialize_model(self, model_path: str, device: str):
+        tool = language_tool_python.LanguageTool(self.languagetool_map
+                                                 [model_path])
+        model = AutoModelForTokenClassification.from_pretrained(
             model_path
-        ).to(self.device)
-        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
-
+        ).to(device)
+        tokenizer = AutoTokenizer.from_pretrained(model_path)
+        mapping = {}
         with open(f"{model_path}/classes.json", "r") as f:
             mapping = json.load(f)
-            self.mapping = list(mapping.keys())
-        self.current_model = model_path
+            mapping = list(mapping.keys())
+        return tool, model, tokenizer, mapping
+
+    def get_setup_for_language(self, language):
+        if language == 'model_path_ru':
+            return self.tool_ru, self.model_ru, self.tokenizer_ru, self.mapping_ru
+        elif language == 'model_path_en':
+            return self.tool_en, self.model_en, self.tokenizer_en, self.mapping_en
+        else:
+            return self.tool_pl, self.model_pl, self.tokenizer_pl, self.mapping_pl
+
+    def pass_device(self, new_language):
+        _, current_model, _, _ = self.get_setup_for_language(self.current_model)
+        current_model.to('cpu')
+        _, current_model, _, _ = self.get_setup_for_language(new_language)
+        current_model.to(self.device)
 
 
 if __name__ == "__main__":
-- 
GitLab