diff --git a/README.md b/README.md index e8f9b06b25b944f8fd77838a1f8dea9170e1d206..ffead895613e0694001356eac13142709621cb0b 100644 --- a/README.md +++ b/README.md @@ -46,5 +46,23 @@ Eg. if you place your model named "production" at `punctuator/checkpoints/action python3 punctuate.py -a mixed -d /deploy/actions_mixed -i test_data/text.txt -m production -dv cuda:0 ``` +## Config +```ini +[deployment] +device = cpu ; Device on which inference will be made (eg. cpu, cuda:0 etc) +models_dir = deploy ; Relative path to directory, where models will be placed +models_enabled = actions_base,actions_mixed,actions_restricted ; which models are available. +``` + +## LPMN +``` +filedir(/users/michal.pogoda)|any2txt|punctuator_test +``` +or +``` +filedir(/users/michal.pogoda)|any2txt|punctuator_test({"model":"model_name"}) +``` +where model_name is one of models specified in models_enabled. If no model is provided or requested model is unavailable, actions_base will be used. + ## Mountpoints Directory where the model will be downloaded (~500Mb) needs to be mounted at /punctuator/deploy diff --git a/requirements.txt b/requirements.txt index 4a63c402fb8008dddeef0e44772538e6b9fa9e08..068c99e062e0f3457aed03cb4be1c6338c17f29e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ --index-url https://pypi.clarin-pl.eu/simple/ +--find-links https://download.pytorch.org/whl/torch_stable.html attrs==19.3.0 bokeh==2.1.1 certifi==2020.6.20 @@ -50,7 +51,7 @@ tblib==1.7.0 tokenizers==0.8.1rc1 toml==0.10.1 toolz==0.10.0 -torch==1.5.1 +torch==1.4.0+cu100 tornado==6.0.4 tqdm==4.48.2 transformers==3.0.2 diff --git a/src/models/actions_model_base.py b/src/models/actions_model_base.py index d503f088d01c8e1b2670b139c077fb66202ac792..339b9814b6e6b71c56d4a4622c6ef9cc291a0a6e 100644 --- a/src/models/actions_model_base.py +++ b/src/models/actions_model_base.py @@ -20,7 +20,13 @@ from src.pipelines.actions_based.processing import ( token_labels_to_word_labels, ) from src.pipelines.actions_based.utils import max_suppression -from src.utils import pickle_read, pickle_save, prepare_folder, yaml_serializable +from src.utils import ( + get_device, + pickle_read, + pickle_save, + prepare_folder, + yaml_serializable, +) @dataclass @@ -111,8 +117,10 @@ class ActionsModelBase(ActionsModel): def predict(self, text: str) -> str: text = text.strip() + device = get_device(self) + tokenizer = self.tokenizer() - tokens = tokenizer(text, return_tensors="pt")["input_ids"] + tokens = tokenizer(text, return_tensors="pt")["input_ids"].to(device) output = None index_start = 0 @@ -120,12 +128,10 @@ class ActionsModelBase(ActionsModel): index_end = min(index_start + self.runtime.chunksize, len(tokens[0])) tokens_chunk = tokens[:, index_start:index_end] + attention_mask = torch.ones_like(tokens_chunk).to(device) actions = ( - self.predict_raw(tokens_chunk, torch.ones_like(tokens_chunk)) - .detach() - .cpu() - .numpy() + self.predict_raw(tokens_chunk, attention_mask).detach().cpu().numpy() ) actions_suppresed = max_suppression(actions, self.runtime.threshold)[0] diff --git a/src/models/actions_model_mixed.py b/src/models/actions_model_mixed.py index e8f9a508cff61bdeac8e0c6a8679293b1a8414ad..e09c0fabe4ee895c22275e0741f4cf4dfea51433 100644 --- a/src/models/actions_model_mixed.py +++ b/src/models/actions_model_mixed.py @@ -16,7 +16,13 @@ from src.pipelines.actions_based.processing import ( recover_text, token_labels_to_word_labels, ) -from src.utils import pickle_read, pickle_save, prepare_folder, yaml_serializable +from src.utils import ( + get_device, + pickle_read, + pickle_save, + prepare_folder, + yaml_serializable, +) @dataclass @@ -83,7 +89,6 @@ class ActionsModelMixed(PunctuationModel): self._tokenizer = None self.num_labels = params.num_labels - self.device = "cpu" # Word embedder self.word_embedding = nn.Embedding(params.vocab_size, params.embedding_size) @@ -160,11 +165,6 @@ class ActionsModelMixed(PunctuationModel): return self.to_labels(z) - def to(self, device): - self.device = device - - super(ActionsModelMixed, self).to(device) - def tokenizer(self) -> BertTokenizerFast: if self._tokenizer is None: self._tokenizer = BertTokenizerFast.from_pretrained( @@ -173,12 +173,14 @@ class ActionsModelMixed(PunctuationModel): return self._tokenizer def predict(self, text: str) -> str: + # TODO: Optimize for speed + inputs = [action_vector(["upper_case"])] tokenizer = self.tokenizer() text_tokenized = tokenizer(text, return_tensors="pt") - target_device = self.device + target_device = get_device(self) max_cond_len = self.runtime.max_cond_len if max_cond_len is None: diff --git a/src/models/actions_model_restricted.py b/src/models/actions_model_restricted.py index 9239e667baa93c1ff91a954aa019775768ec4301..eb7f859120d7b1f694a59df3547a17551d2c0ded 100644 --- a/src/models/actions_model_restricted.py +++ b/src/models/actions_model_restricted.py @@ -19,7 +19,13 @@ from src.pipelines.actions_based.processing import ( token_labels_to_word_labels, ) from src.pipelines.actions_based.utils import max_suppression -from src.utils import pickle_read, pickle_save, prepare_folder, yaml_serializable +from src.utils import ( + get_device, + pickle_read, + pickle_save, + prepare_folder, + yaml_serializable, +) @dataclass @@ -131,10 +137,12 @@ class ActionsModelRestricted(ActionsModel): chunk_size = self.runtime.chunksize threshold = self.runtime.threshold + device = get_device(self) + text = text.strip() tokenizer = self.tokenizer() - tokens = tokenizer(text, return_tensors="pt")["input_ids"] + tokens = tokenizer(text, return_tensors="pt")["input_ids"].to(device) output = None index_start = 0 @@ -143,11 +151,10 @@ class ActionsModelRestricted(ActionsModel): tokens_chunk = tokens[:, index_start:index_end] + attention_mask = torch.ones_like(tokens_chunk).to(device) + actions = ( - self.predict_raw(tokens_chunk, torch.ones_like(tokens_chunk)) - .detach() - .cpu() - .numpy() + self.predict_raw(tokens_chunk, attention_mask).detach().cpu().numpy() ) actions_suppresed = max_suppression(actions, threshold)[0] diff --git a/src/models/model_factory.py b/src/models/model_factory.py index 3d4abcc4af843f1781aeba981fb6f16be2843565..5e4a9fc554fff5c1d284a0ae47b02b684f554ff8 100644 --- a/src/models/model_factory.py +++ b/src/models/model_factory.py @@ -1,7 +1,6 @@ +from src.models.actions_model_base import ActionsModelBase from src.models.actions_model_mixed import ActionsModelMixed from src.models.actions_model_restricted import ActionsModelRestricted -from src.models.actions_model_base import ActionsModelBase - MODELS_MAP = { "actions_base": ActionsModelBase, diff --git a/src/utils.py b/src/utils.py index 90a69f5f1ed13cd2f0a933a7ed79134d64ffa4fc..0bb292f6fa74e073ae5e8629171be2fbbd091c7f 100644 --- a/src/utils.py +++ b/src/utils.py @@ -604,3 +604,17 @@ def yaml_serializable(cls): setattr(cls, "load_yaml", load_yaml) return cls + + +def get_device(model: nn.Module) -> torch.device: + """Get device on which the module resides. Works only if all + parameters reside on single device. + + Args: + model (nn.Module): Module to check + + Returns: + torch.device: Device on which module's paraters exists + """ + + return next(model.parameters()).device diff --git a/worker.py b/worker.py index 92d4fe07922f91a719b5df306d5f3ab74b7c0dae..98e5a75385c110f3f236afe9f504b35c7a10616a 100755 --- a/worker.py +++ b/worker.py @@ -1,10 +1,12 @@ #!/usr/bin/python import configparser +import logging from src.models.model_factory import MODELS_MAP from typing import List import nlp_ws +import torch from src.utils import input_preprocess, output_preprocess @@ -16,10 +18,12 @@ class Worker(nlp_ws.NLPWorker): self.config = configparser.ConfigParser() self.config.read("config.ini") - self.device = self.config["deployment"]["device"] + self.device = torch.device(self.config["deployment"]["device"]) self.models_dir = self.config["deployment"]["models_dir"] self.models = {} + self._log = logging.getLogger(__name__) + models_enabled = self.config["deployment"]["models_enabled"] models_enabled = models_enabled.split(",") @@ -51,6 +55,9 @@ class Worker(nlp_ws.NLPWorker): with open(output_file, "w") as f: f.write(result) + if self.device.type != "cpu": + torch.cuda.empty_cache() + if __name__ == "__main__": nlp_ws.NLPService.main(Worker)