From 51c310b13b7cb69a11dc05b8688987b85a5a2b25 Mon Sep 17 00:00:00 2001
From: piotrmp <piotr.m.przybyla@gmail.com>
Date: Thu, 24 Nov 2022 15:28:20 +0100
Subject: [PATCH] Bug fix.

---
 src/lambo/learning/train.py  | 2 +-
 src/lambo/segmenter/lambo.py | 4 ++--
 2 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/src/lambo/learning/train.py b/src/lambo/learning/train.py
index 13884a1..d3f3a06 100644
--- a/src/lambo/learning/train.py
+++ b/src/lambo/learning/train.py
@@ -153,7 +153,7 @@ def train_pretrained_and_save(language, treebank_path, save_path, pretrained_pat
     if not file_path.exists():
         print("Pretrained model not found, falling back to training from scratch.")
         return train_new_and_save('LAMBO-BILSTM', treebank_path, save_path, epochs, device)
-    pretrained_model = torch.load(file_path)
+    pretrained_model = torch.load(file_path, map_location=torch.device('cpu'))
     dict = {}
     for line in open(pretrained_path / (pretrained_name + '.dict')):
         if line.strip() == '':
diff --git a/src/lambo/segmenter/lambo.py b/src/lambo/segmenter/lambo.py
index e2fbc53..a5e46d8 100644
--- a/src/lambo/segmenter/lambo.py
+++ b/src/lambo/segmenter/lambo.py
@@ -38,7 +38,7 @@ class Lambo():
             model_name = Lambo.getDefaultModel(provided_name)
         dict_path, model_path = download_model(model_name)
         dict = Lambo.read_dict(dict_path)
-        model = torch.load(model_path)
+        model = torch.load(model_path, map_location=torch.device('cpu'))
         return cls(model, dict)
     
     @staticmethod
@@ -75,7 +75,7 @@ class Lambo():
         :param model_name: model name
         :return:
         """
-        model = torch.load(model_path / (model_name + '.pth'))
+        model = torch.load(model_path / (model_name + '.pth'), map_location=torch.device('cpu'))
         dict = Lambo.read_dict(model_path / (model_name + '.dict'))
         return cls(model, dict)
     
-- 
GitLab