diff --git a/combo/predict.py b/combo/predict.py index 310f48fd5f8c78804344926b655df29b3a3d630b..1f48e22df3ced5f4cbed70e169ce21fb7814c6e2 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -225,7 +225,9 @@ class SemanticMultitaskPredictor(predictor.Predictor): return cls(model, dataset_reader, tokenizers.SpacyTokenizer()) @classmethod - def from_pretrained(cls, path: str, tokenizer=tokenizers.SpacyTokenizer(), batch_size: int = 500): + def from_pretrained(cls, path: str, tokenizer=tokenizers.SpacyTokenizer(), + batch_size: int = 500, + cuda_device: int = -1): util.import_module_and_submodules("combo.commands") util.import_module_and_submodules("combo.models") util.import_module_and_submodules("combo.training") @@ -240,7 +242,7 @@ class SemanticMultitaskPredictor(predictor.Predictor): logger.error(e) raise e - archive = models.load_archive(model_path) + archive = models.load_archive(model_path, cuda_device=cuda_device) model = archive.model dataset_reader = allen_data.DatasetReader.from_params( archive.config["dataset_reader"])