From 51c310b13b7cb69a11dc05b8688987b85a5a2b25 Mon Sep 17 00:00:00 2001 From: piotrmp <piotr.m.przybyla@gmail.com> Date: Thu, 24 Nov 2022 15:28:20 +0100 Subject: [PATCH] Bug fix. --- src/lambo/learning/train.py | 2 +- src/lambo/segmenter/lambo.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lambo/learning/train.py b/src/lambo/learning/train.py index 13884a1..d3f3a06 100644 --- a/src/lambo/learning/train.py +++ b/src/lambo/learning/train.py @@ -153,7 +153,7 @@ def train_pretrained_and_save(language, treebank_path, save_path, pretrained_pat if not file_path.exists(): print("Pretrained model not found, falling back to training from scratch.") return train_new_and_save('LAMBO-BILSTM', treebank_path, save_path, epochs, device) - pretrained_model = torch.load(file_path) + pretrained_model = torch.load(file_path, map_location=torch.device('cpu')) dict = {} for line in open(pretrained_path / (pretrained_name + '.dict')): if line.strip() == '': diff --git a/src/lambo/segmenter/lambo.py b/src/lambo/segmenter/lambo.py index e2fbc53..a5e46d8 100644 --- a/src/lambo/segmenter/lambo.py +++ b/src/lambo/segmenter/lambo.py @@ -38,7 +38,7 @@ class Lambo(): model_name = Lambo.getDefaultModel(provided_name) dict_path, model_path = download_model(model_name) dict = Lambo.read_dict(dict_path) - model = torch.load(model_path) + model = torch.load(model_path, map_location=torch.device('cpu')) return cls(model, dict) @staticmethod @@ -75,7 +75,7 @@ class Lambo(): :param model_name: model name :return: """ - model = torch.load(model_path / (model_name + '.pth')) + model = torch.load(model_path / (model_name + '.pth'), map_location=torch.device('cpu')) dict = Lambo.read_dict(model_path / (model_name + '.dict')) return cls(model, dict) -- GitLab