From b3fb984ee5bb01ba642d4199e40ab3ced7592f12 Mon Sep 17 00:00:00 2001 From: piotrmp <piotr.m.przybyla@gmail.com> Date: Thu, 24 Nov 2022 09:07:51 +0100 Subject: [PATCH] Added fallback for no pretrained model case. --- src/lambo/learning/train.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/lambo/learning/train.py b/src/lambo/learning/train.py index 203eafb..13884a1 100644 --- a/src/lambo/learning/train.py +++ b/src/lambo/learning/train.py @@ -149,7 +149,11 @@ def train_pretrained_and_save(language, treebank_path, save_path, pretrained_pat """ print("Loading pretrained model") pretrained_name = 'oscar_' + language - pretrained_model = torch.load(pretrained_path / (pretrained_name + '.pth')) + file_path = pretrained_path / (pretrained_name + '.pth') + if not file_path.exists(): + print("Pretrained model not found, falling back to training from scratch.") + return train_new_and_save('LAMBO-BILSTM', treebank_path, save_path, epochs, device) + pretrained_model = torch.load(file_path) dict = {} for line in open(pretrained_path / (pretrained_name + '.dict')): if line.strip() == '': -- GitLab