From fa0c1555f16ff60f9f986b0c3780513ccee6aaa8 Mon Sep 17 00:00:00 2001
From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com>
Date: Tue, 16 Jun 2020 11:40:35 +0200
Subject: [PATCH] Add cuda configuration from python.

---
 combo/predict.py | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/combo/predict.py b/combo/predict.py
index 310f48f..1f48e22 100644
--- a/combo/predict.py
+++ b/combo/predict.py
@@ -225,7 +225,9 @@ class SemanticMultitaskPredictor(predictor.Predictor):
         return cls(model, dataset_reader, tokenizers.SpacyTokenizer())
 
     @classmethod
-    def from_pretrained(cls, path: str, tokenizer=tokenizers.SpacyTokenizer(), batch_size: int = 500):
+    def from_pretrained(cls, path: str, tokenizer=tokenizers.SpacyTokenizer(),
+                        batch_size: int = 500,
+                        cuda_device: int = -1):
         util.import_module_and_submodules("combo.commands")
         util.import_module_and_submodules("combo.models")
         util.import_module_and_submodules("combo.training")
@@ -240,7 +242,7 @@ class SemanticMultitaskPredictor(predictor.Predictor):
                 logger.error(e)
                 raise e
 
-        archive = models.load_archive(model_path)
+        archive = models.load_archive(model_path, cuda_device=cuda_device)
         model = archive.model
         dataset_reader = allen_data.DatasetReader.from_params(
             archive.config["dataset_reader"])
-- 
GitLab