Skip to content
Snippets Groups Projects
Commit 51c310b1 authored by piotrmp's avatar piotrmp
Browse files

Bug fix.

parent 1b098707
1 merge request!1Migration to UD 2.11
......@@ -153,7 +153,7 @@ def train_pretrained_and_save(language, treebank_path, save_path, pretrained_pat
if not file_path.exists():
print("Pretrained model not found, falling back to training from scratch.")
return train_new_and_save('LAMBO-BILSTM', treebank_path, save_path, epochs, device)
pretrained_model = torch.load(file_path)
pretrained_model = torch.load(file_path, map_location=torch.device('cpu'))
dict = {}
for line in open(pretrained_path / (pretrained_name + '.dict')):
if line.strip() == '':
......
......@@ -38,7 +38,7 @@ class Lambo():
model_name = Lambo.getDefaultModel(provided_name)
dict_path, model_path = download_model(model_name)
dict = Lambo.read_dict(dict_path)
model = torch.load(model_path)
model = torch.load(model_path, map_location=torch.device('cpu'))
return cls(model, dict)
@staticmethod
......@@ -75,7 +75,7 @@ class Lambo():
:param model_name: model name
:return:
"""
model = torch.load(model_path / (model_name + '.pth'))
model = torch.load(model_path / (model_name + '.pth'), map_location=torch.device('cpu'))
dict = Lambo.read_dict(model_path / (model_name + '.dict'))
return cls(model, dict)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment