From b3fb984ee5bb01ba642d4199e40ab3ced7592f12 Mon Sep 17 00:00:00 2001
From: piotrmp <piotr.m.przybyla@gmail.com>
Date: Thu, 24 Nov 2022 09:07:51 +0100
Subject: [PATCH] Added fallback for no pretrained model case.

---
 src/lambo/learning/train.py | 6 +++++-
 1 file changed, 5 insertions(+), 1 deletion(-)

diff --git a/src/lambo/learning/train.py b/src/lambo/learning/train.py
index 203eafb..13884a1 100644
--- a/src/lambo/learning/train.py
+++ b/src/lambo/learning/train.py
@@ -149,7 +149,11 @@ def train_pretrained_and_save(language, treebank_path, save_path, pretrained_pat
     """
     print("Loading pretrained model")
     pretrained_name = 'oscar_' + language
-    pretrained_model = torch.load(pretrained_path / (pretrained_name + '.pth'))
+    file_path = pretrained_path / (pretrained_name + '.pth')
+    if not file_path.exists():
+        print("Pretrained model not found, falling back to training from scratch.")
+        return train_new_and_save('LAMBO-BILSTM', treebank_path, save_path, epochs, device)
+    pretrained_model = torch.load(file_path)
     dict = {}
     for line in open(pretrained_path / (pretrained_name + '.dict')):
         if line.strip() == '':
-- 
GitLab