From fa0c1555f16ff60f9f986b0c3780513ccee6aaa8 Mon Sep 17 00:00:00 2001 From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com> Date: Tue, 16 Jun 2020 11:40:35 +0200 Subject: [PATCH] Add cuda configuration from python. --- combo/predict.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/combo/predict.py b/combo/predict.py index 310f48f..1f48e22 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -225,7 +225,9 @@ class SemanticMultitaskPredictor(predictor.Predictor): return cls(model, dataset_reader, tokenizers.SpacyTokenizer()) @classmethod - def from_pretrained(cls, path: str, tokenizer=tokenizers.SpacyTokenizer(), batch_size: int = 500): + def from_pretrained(cls, path: str, tokenizer=tokenizers.SpacyTokenizer(), + batch_size: int = 500, + cuda_device: int = -1): util.import_module_and_submodules("combo.commands") util.import_module_and_submodules("combo.models") util.import_module_and_submodules("combo.training") @@ -240,7 +242,7 @@ class SemanticMultitaskPredictor(predictor.Predictor): logger.error(e) raise e - archive = models.load_archive(model_path) + archive = models.load_archive(model_path, cuda_device=cuda_device) model = archive.model dataset_reader = allen_data.DatasetReader.from_params( archive.config["dataset_reader"]) -- GitLab