diff --git a/src/lambo/learning/train.py b/src/lambo/learning/train.py index 203eafb8dfe8b4bdf6c1cc868241133b6d977b30..13884a1ab59c3af08d5291c4a530a56d3473bd1f 100644 --- a/src/lambo/learning/train.py +++ b/src/lambo/learning/train.py @@ -149,7 +149,11 @@ def train_pretrained_and_save(language, treebank_path, save_path, pretrained_pat """ print("Loading pretrained model") pretrained_name = 'oscar_' + language - pretrained_model = torch.load(pretrained_path / (pretrained_name + '.pth')) + file_path = pretrained_path / (pretrained_name + '.pth') + if not file_path.exists(): + print("Pretrained model not found, falling back to training from scratch.") + return train_new_and_save('LAMBO-BILSTM', treebank_path, save_path, epochs, device) + pretrained_model = torch.load(file_path) dict = {} for line in open(pretrained_path / (pretrained_name + '.dict')): if line.strip() == '':