diff --git a/src/lambo/learning/train.py b/src/lambo/learning/train.py index 13884a1ab59c3af08d5291c4a530a56d3473bd1f..d3f3a06ccfce1d047d5f8d86a3169d5a2d707377 100644 --- a/src/lambo/learning/train.py +++ b/src/lambo/learning/train.py @@ -153,7 +153,7 @@ def train_pretrained_and_save(language, treebank_path, save_path, pretrained_pat if not file_path.exists(): print("Pretrained model not found, falling back to training from scratch.") return train_new_and_save('LAMBO-BILSTM', treebank_path, save_path, epochs, device) - pretrained_model = torch.load(file_path) + pretrained_model = torch.load(file_path, map_location=torch.device('cpu')) dict = {} for line in open(pretrained_path / (pretrained_name + '.dict')): if line.strip() == '': diff --git a/src/lambo/segmenter/lambo.py b/src/lambo/segmenter/lambo.py index e2fbc538b5438324d0bf20726b77bfc5485ff8bd..a5e46d80fb95182f972d877d1ed4e903aacc79cd 100644 --- a/src/lambo/segmenter/lambo.py +++ b/src/lambo/segmenter/lambo.py @@ -38,7 +38,7 @@ class Lambo(): model_name = Lambo.getDefaultModel(provided_name) dict_path, model_path = download_model(model_name) dict = Lambo.read_dict(dict_path) - model = torch.load(model_path) + model = torch.load(model_path, map_location=torch.device('cpu')) return cls(model, dict) @staticmethod @@ -75,7 +75,7 @@ class Lambo(): :param model_name: model name :return: """ - model = torch.load(model_path / (model_name + '.pth')) + model = torch.load(model_path / (model_name + '.pth'), map_location=torch.device('cpu')) dict = Lambo.read_dict(model_path / (model_name + '.dict')) return cls(model, dict)