diff --git a/src/lambo/learning/train.py b/src/lambo/learning/train.py index a2cfd2c77418a517a20635501ffe6906e027063d..dfab088181bf50077974cd13057d7dcd4db5c8fa 100644 --- a/src/lambo/learning/train.py +++ b/src/lambo/learning/train.py @@ -223,7 +223,7 @@ def pretrain(model, train_dataloader, test_dataloader, epochs, device='cpu'): optimizer = Adam(model.parameters(), lr=learning_rate) print("Pretraining") - test_loop_pretraining(test_dataloader, model) + test_loop_pretraining(test_dataloader, model, device) for t in range(epochs): print(f"Epoch {t + 1}\n-------------------------------") train_loop(train_dataloader, model, optimizer, device)