diff --git a/combo/predict.py b/combo/predict.py
index 310f48fd5f8c78804344926b655df29b3a3d630b..1f48e22df3ced5f4cbed70e169ce21fb7814c6e2 100644
--- a/combo/predict.py
+++ b/combo/predict.py
@@ -225,7 +225,9 @@ class SemanticMultitaskPredictor(predictor.Predictor):
         return cls(model, dataset_reader, tokenizers.SpacyTokenizer())
 
     @classmethod
-    def from_pretrained(cls, path: str, tokenizer=tokenizers.SpacyTokenizer(), batch_size: int = 500):
+    def from_pretrained(cls, path: str, tokenizer=tokenizers.SpacyTokenizer(),
+                        batch_size: int = 500,
+                        cuda_device: int = -1):
         util.import_module_and_submodules("combo.commands")
         util.import_module_and_submodules("combo.models")
         util.import_module_and_submodules("combo.training")
@@ -240,7 +242,7 @@ class SemanticMultitaskPredictor(predictor.Predictor):
                 logger.error(e)
                 raise e
 
-        archive = models.load_archive(model_path)
+        archive = models.load_archive(model_path, cuda_device=cuda_device)
         model = archive.model
         dataset_reader = allen_data.DatasetReader.from_params(
             archive.config["dataset_reader"])