Skip to content
Snippets Groups Projects
Commit eef85e31 authored by Jarema Radom's avatar Jarema Radom
Browse files

Different approach to device sharing

parent 4dfbf7e1
No related branches found
No related tags found
1 merge request!14Support for russian and english models
Pipeline #3198 passed
......@@ -52,14 +52,23 @@ class Worker(nlp_ws.NLPWorker):
self.model_path_pl = self.config["model_path_pl"]
self.model_path_ru = self.config["model_path_ru"]
self.model_path_en = self.config["model_path_en"]
self.initialize_model(self.model_path_pl)
self.tool_pl, self.model_pl, self.tokenizer_pl, self.mapping_pl \
= self.initialize_model(self.model_path_pl, self.device)
self.tool_en, self.model_en, self.tokenizer_en, self.mapping_en \
= self.initialize_model(self.model_path_en, 'cpu')
self.tool_ru, self.model_ru, self.tokenizer_ru, self.mapping_ru \
= self.initialize_model(self.model_path_ru, 'cpu')
self.current_model = self.model_path_pl
def process(
self, input_path: str, task_options: dict, output_path: str
) -> None:
if task_options['language'] != self.current_model:
self.initialize_model(task_options['language'])
self.pass_device(task_options['language'])
self.current_model = task_options['language']
tool, model, tokenizer, mapping = self.get_setup_for_language(
self.current_model)
with open(input_path, "r") as f:
text = f.read()
......@@ -67,7 +76,7 @@ class Worker(nlp_ws.NLPWorker):
# Make sure that the text is lowercase & punctuationless
text = _preprocess_input(text)
tokenized = self.tokenizer(text, return_tensors="pt")
tokenized = tokenizer(text, return_tensors="pt")
num_tokens = len(tokenized["input_ids"][0])
......@@ -76,7 +85,7 @@ class Worker(nlp_ws.NLPWorker):
for inference_mask, mask_mask in zip(
*inference_masks(num_tokens, self.max_context_size, self.overlap)
):
result = self.model(
result = model(
input_ids=tokenized["input_ids"][:, inference_mask].to(
self.device
),
......@@ -91,7 +100,7 @@ class Worker(nlp_ws.NLPWorker):
.squeeze()
.numpy()[mask_mask]
)
results.append(decode_labels(labels_ids, self.mapping))
results.append(decode_labels(labels_ids, mapping))
labels = sum(results, [])
tokens = []
......@@ -100,28 +109,38 @@ class Worker(nlp_ws.NLPWorker):
):
tokens += tokenized["input_ids"][0, combine_mask].numpy().tolist()
text_out = decode(tokens, labels, self.tokenizer,
text_out = decode(tokens, labels, tokenizer,
self.current_model != self.model_path_pl)
text_out = _post_process(text_out, self.tool)
if not text_out.endswith('.'):
text_out += '.'
text_out = _post_process(text_out, tool)
with open(output_path, "w") as f:
f.write(text_out)
def initialize_model(self, model_path: str):
self.tool = language_tool_python.LanguageTool(self.languagetool_map
def initialize_model(self, model_path: str, device: str):
tool = language_tool_python.LanguageTool(self.languagetool_map
[model_path])
if self.model:
self.model.to('cpu')
self.model = AutoModelForTokenClassification.from_pretrained(
model = AutoModelForTokenClassification.from_pretrained(
model_path
).to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_path)
mapping = {}
with open(f"{model_path}/classes.json", "r") as f:
mapping = json.load(f)
self.mapping = list(mapping.keys())
self.current_model = model_path
mapping = list(mapping.keys())
return tool, model, tokenizer, mapping
def get_setup_for_language(self, language):
if language == 'model_path_ru':
return self.tool_ru, self.model_ru, self.tokenizer_ru, self.mapping_ru
elif language == 'model_path_en':
return self.tool_en, self.model_en, self.tokenizer_en, self.mapping_en
else:
return self.tool_pl, self.model_pl, self.tokenizer_pl, self.mapping_pl
def pass_device(self, new_language):
_, current_model, _, _ = self.get_setup_for_language(self.current_model)
current_model.to('cpu')
_, current_model, _, _ = self.get_setup_for_language(new_language)
current_model.to(self.device)
if __name__ == "__main__":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment