diff --git a/src/lambo/learning/train.py b/src/lambo/learning/train.py index ce95d4d3cce36f8d5b6e9a7d41cb908b64d67a1d..fdbc80557569cce9603bd55d99a596d3fb33a1a0 100644 --- a/src/lambo/learning/train.py +++ b/src/lambo/learning/train.py @@ -168,7 +168,7 @@ def train_pretrained_and_save(language, treebank_path, save_path, pretrained_pat print("Pretrained model not found, falling back to training from scratch.") return train_new_and_save('LAMBO-BILSTM', treebank_path, save_path, epochs, device) pretrained_model = torch.load(file_path, map_location=torch.device('cpu')) - dict = Lambo.read_dict() + dict = Lambo.read_dict(pretrained_path / (pretrained_name + '.dict')) print("Reading data.") train_doc, dev_doc, test_doc = read_treebank(treebank_path, True)