Skip to content
Snippets Groups Projects

S3 synchronization and CI

4 files
+ 105
73
Compare changes
  • Side-by-side
  • Inline

Files

+ 46
33
@@ -26,7 +26,7 @@ def _preprocess_input(text: str):
def is_punctuation_rule(rule):
lambda rule: rule.category != 'PUNCTUATION' and len(rule.replacements)
lambda rule: rule.category != "PUNCTUATION" and len(rule.replacements)
def _post_process(text: str, tool):
@@ -38,38 +38,58 @@ def _post_process(text: str, tool):
class Punctuator:
def __init__(self, config):
self.config = config
self.languagetool_map = {'model_path_pl': 'pl-PL', 'model_path_ru':
'ru', 'model_path_en': 'en-US'}
self.languagetool_map = {
"model_path_pl": "pl-PL",
"model_path_ru": "ru",
"model_path_en": "en-US",
}
self.max_context_size = int(self.config.get("max_context_size", 256))
self.overlap = int(self.config.get("overlap", 20))
self.device = self.config.get("device", "cpu")
self.languagetool_path = self.config.get("languagetool_path", "/home/worker/models/languagetool")
self.languagetool_path = self.config.get(
"languagetool_path", "/home/worker/models/languagetool"
)
os.environ["LTP_PATH"] = self.languagetool_path
self.model_path_pl = self.config.get("model_path_pl", "/home/worker/models/punctuator/pl")
self.model_path_ru = self.config.get("model_path_ru", "/home/worker/models/punctuator/en")
self.model_path_en = self.config.get("model_path_en", "/home/worker/models/punctuator/ru")
self.tool_pl, self.model_pl, self.tokenizer_pl, self.mapping_pl \
= self._initialize_model('pl-PL', self.model_path_pl, self.device)
self.tool_en, self.model_en, self.tokenizer_en, self.mapping_en \
= self._initialize_model('en-US', self.model_path_en, 'cpu')
self.tool_ru, self.model_ru, self.tokenizer_ru, self.mapping_ru \
= self._initialize_model('ru', self.model_path_ru, 'cpu')
self.model_path_pl = self.config.get(
"model_path_pl", "/home/worker/models/punctuator/pl"
)
self.model_path_ru = self.config.get(
"model_path_ru", "/home/worker/models/punctuator/en"
)
self.model_path_en = self.config.get(
"model_path_en", "/home/worker/models/punctuator/ru"
)
(
self.tool_pl,
self.model_pl,
self.tokenizer_pl,
self.mapping_pl,
) = self._initialize_model("pl-PL", self.model_path_pl, self.device)
(
self.tool_en,
self.model_en,
self.tokenizer_en,
self.mapping_en,
) = self._initialize_model("en-US", self.model_path_en, "cpu")
(
self.tool_ru,
self.model_ru,
self.tokenizer_ru,
self.mapping_ru,
) = self._initialize_model("ru", self.model_path_ru, "cpu")
self.current_model = self.model_path_pl
def process(
self, input_path: str, task_options: dict, output_path: str
) -> None:
def process(self, input_path: str, task_options: dict, output_path: str) -> None:
language = task_options.get("language", "pl")
if language == 'en':
if language == "en":
bpe = True
else:
bpe = False
tool, model, tokenizer, mapping = self._get_setup_for_language(
language)
tool, model, tokenizer, mapping = self._get_setup_for_language(language)
with open(input_path, "r") as f:
text = f.read()
@@ -87,18 +107,13 @@ class Punctuator:
*inference_masks(num_tokens, self.max_context_size, self.overlap)
):
result = model(
input_ids=tokenized["input_ids"][:, inference_mask].to(
input_ids=tokenized["input_ids"][:, inference_mask].to(self.device),
attention_mask=tokenized["attention_mask"][:, inference_mask].to(
self.device
),
attention_mask=tokenized["attention_mask"][:, inference_mask]
.to(self.device),
)
labels_ids = (
result.logits.detach()
.cpu()
.argmax(dim=-1)
.squeeze()
.numpy()[mask_mask]
result.logits.detach().cpu().argmax(dim=-1).squeeze().numpy()[mask_mask]
)
results.append(decode_labels(labels_ids, mapping))
labels = sum(results, [])
@@ -116,9 +131,7 @@ class Punctuator:
def _initialize_model(self, lang, model_path: str, device: str):
tool = language_tool_python.LanguageTool(lang)
model = AutoModelForTokenClassification.from_pretrained(
model_path
).to(device)
model = AutoModelForTokenClassification.from_pretrained(model_path).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_path)
mapping = {}
with open(f"{self.model_path_pl}/classes.json", "r") as f:
@@ -127,15 +140,15 @@ class Punctuator:
return tool, model, tokenizer, mapping
def _get_setup_for_language(self, language):
if language == 'ru':
if language == "ru":
return self.tool_ru, self.model_ru, self.tokenizer_ru, self.mapping_ru
elif language == 'en':
elif language == "en":
return self.tool_en, self.model_en, self.tokenizer_en, self.mapping_en
else:
return self.tool_pl, self.model_pl, self.tokenizer_pl, self.mapping_pl
def _pass_device(self, new_language):
_, current_model, _, _ = self._get_setup_for_language(self.current_model)
current_model.to('cpu')
current_model.to("cpu")
_, current_model, _, _ = self._get_setup_for_language(new_language)
current_model.to(self.device)
Loading