Skip to content
Snippets Groups Projects

S3 synchronization and CI

5 files
+ 112
92
Compare changes
  • Side-by-side
  • Inline

Files

+ 50
44
"""Implementation of punctuator service"""
"""Implementation of punctuator service"""
import configparser
import json
import json
import string
import os
import os
 
import string
import nlp_ws
from transformers import AutoModelForTokenClassification, AutoTokenizer
import language_tool_python
import language_tool_python
 
from transformers import AutoModelForTokenClassification, AutoTokenizer
from punctuator.utils import (combine_masks, decode, decode_labels,
from punctuator.utils import (
inference_masks)
combine_masks,
decode,
decode_labels,
inference_masks,
)
def _preprocess_input(text: str):
def _preprocess_input(text: str):
@@ -26,7 +19,7 @@ def _preprocess_input(text: str):
@@ -26,7 +19,7 @@ def _preprocess_input(text: str):
def is_punctuation_rule(rule):
def is_punctuation_rule(rule):
lambda rule: rule.category != 'PUNCTUATION' and len(rule.replacements)
lambda rule: rule.category != "PUNCTUATION" and len(rule.replacements)
def _post_process(text: str, tool):
def _post_process(text: str, tool):
@@ -38,38 +31,58 @@ def _post_process(text: str, tool):
@@ -38,38 +31,58 @@ def _post_process(text: str, tool):
class Punctuator:
class Punctuator:
def __init__(self, config):
def __init__(self, config):
self.config = config
self.config = config
self.languagetool_map = {'model_path_pl': 'pl-PL', 'model_path_ru':
self.languagetool_map = {
'ru', 'model_path_en': 'en-US'}
"model_path_pl": "pl-PL",
 
"model_path_ru": "ru",
 
"model_path_en": "en-US",
 
}
self.max_context_size = int(self.config.get("max_context_size", 256))
self.max_context_size = int(self.config.get("max_context_size", 256))
self.overlap = int(self.config.get("overlap", 20))
self.overlap = int(self.config.get("overlap", 20))
self.device = self.config.get("device", "cpu")
self.device = self.config.get("device", "cpu")
self.languagetool_path = self.config.get("languagetool_path", "/home/worker/models/languagetool")
self.languagetool_path = self.config.get(
 
"languagetool_path", "/home/worker/models/languagetool"
 
)
os.environ["LTP_PATH"] = self.languagetool_path
os.environ["LTP_PATH"] = self.languagetool_path
self.model_path_pl = self.config.get("model_path_pl", "/home/worker/models/punctuator/pl")
self.model_path_pl = self.config.get(
self.model_path_ru = self.config.get("model_path_ru", "/home/worker/models/punctuator/en")
"model_path_pl", "/home/worker/models/punctuator/pl"
self.model_path_en = self.config.get("model_path_en", "/home/worker/models/punctuator/ru")
)
self.tool_pl, self.model_pl, self.tokenizer_pl, self.mapping_pl \
self.model_path_ru = self.config.get(
= self._initialize_model('pl-PL', self.model_path_pl, self.device)
"model_path_ru", "/home/worker/models/punctuator/en"
self.tool_en, self.model_en, self.tokenizer_en, self.mapping_en \
)
= self._initialize_model('en-US', self.model_path_en, 'cpu')
self.model_path_en = self.config.get(
self.tool_ru, self.model_ru, self.tokenizer_ru, self.mapping_ru \
"model_path_en", "/home/worker/models/punctuator/ru"
= self._initialize_model('ru', self.model_path_ru, 'cpu')
)
 
(
 
self.tool_pl,
 
self.model_pl,
 
self.tokenizer_pl,
 
self.mapping_pl,
 
) = self._initialize_model("pl-PL", self.model_path_pl, self.device)
 
(
 
self.tool_en,
 
self.model_en,
 
self.tokenizer_en,
 
self.mapping_en,
 
) = self._initialize_model("en-US", self.model_path_en, "cpu")
 
(
 
self.tool_ru,
 
self.model_ru,
 
self.tokenizer_ru,
 
self.mapping_ru,
 
) = self._initialize_model("ru", self.model_path_ru, "cpu")
self.current_model = self.model_path_pl
self.current_model = self.model_path_pl
def process(
def process(self, input_path: str, task_options: dict, output_path: str) -> None:
self, input_path: str, task_options: dict, output_path: str
) -> None:
language = task_options.get("language", "pl")
language = task_options.get("language", "pl")
if language == 'en':
if language == "en":
bpe = True
bpe = True
else:
else:
bpe = False
bpe = False
tool, model, tokenizer, mapping = self._get_setup_for_language(
tool, model, tokenizer, mapping = self._get_setup_for_language(language)
language)
with open(input_path, "r") as f:
with open(input_path, "r") as f:
text = f.read()
text = f.read()
@@ -87,18 +100,13 @@ class Punctuator:
@@ -87,18 +100,13 @@ class Punctuator:
*inference_masks(num_tokens, self.max_context_size, self.overlap)
*inference_masks(num_tokens, self.max_context_size, self.overlap)
):
):
result = model(
result = model(
input_ids=tokenized["input_ids"][:, inference_mask].to(
input_ids=tokenized["input_ids"][:, inference_mask].to(self.device),
 
attention_mask=tokenized["attention_mask"][:, inference_mask].to(
self.device
self.device
),
),
attention_mask=tokenized["attention_mask"][:, inference_mask]
.to(self.device),
)
)
labels_ids = (
labels_ids = (
result.logits.detach()
result.logits.detach().cpu().argmax(dim=-1).squeeze().numpy()[mask_mask]
.cpu()
.argmax(dim=-1)
.squeeze()
.numpy()[mask_mask]
)
)
results.append(decode_labels(labels_ids, mapping))
results.append(decode_labels(labels_ids, mapping))
labels = sum(results, [])
labels = sum(results, [])
@@ -116,9 +124,7 @@ class Punctuator:
@@ -116,9 +124,7 @@ class Punctuator:
def _initialize_model(self, lang, model_path: str, device: str):
def _initialize_model(self, lang, model_path: str, device: str):
tool = language_tool_python.LanguageTool(lang)
tool = language_tool_python.LanguageTool(lang)
model = AutoModelForTokenClassification.from_pretrained(
model = AutoModelForTokenClassification.from_pretrained(model_path).to(device)
model_path
).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
mapping = {}
mapping = {}
with open(f"{self.model_path_pl}/classes.json", "r") as f:
with open(f"{self.model_path_pl}/classes.json", "r") as f:
@@ -127,15 +133,15 @@ class Punctuator:
@@ -127,15 +133,15 @@ class Punctuator:
return tool, model, tokenizer, mapping
return tool, model, tokenizer, mapping
def _get_setup_for_language(self, language):
def _get_setup_for_language(self, language):
if language == 'ru':
if language == "ru":
return self.tool_ru, self.model_ru, self.tokenizer_ru, self.mapping_ru
return self.tool_ru, self.model_ru, self.tokenizer_ru, self.mapping_ru
elif language == 'en':
elif language == "en":
return self.tool_en, self.model_en, self.tokenizer_en, self.mapping_en
return self.tool_en, self.model_en, self.tokenizer_en, self.mapping_en
else:
else:
return self.tool_pl, self.model_pl, self.tokenizer_pl, self.mapping_pl
return self.tool_pl, self.model_pl, self.tokenizer_pl, self.mapping_pl
def _pass_device(self, new_language):
def _pass_device(self, new_language):
_, current_model, _, _ = self._get_setup_for_language(self.current_model)
_, current_model, _, _ = self._get_setup_for_language(self.current_model)
current_model.to('cpu')
current_model.to("cpu")
_, current_model, _, _ = self._get_setup_for_language(new_language)
_, current_model, _, _ = self._get_setup_for_language(new_language)
current_model.to(self.device)
current_model.to(self.device)
Loading