From 11b5840073a827351f19050c8f70a634c9e77be1 Mon Sep 17 00:00:00 2001 From: piotrmp <piotr.m.przybyla@gmail.com> Date: Wed, 23 Nov 2022 10:13:10 +0100 Subject: [PATCH] Bug fix. --- src/lambo/learning/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lambo/learning/train.py b/src/lambo/learning/train.py index a2cfd2c..dfab088 100644 --- a/src/lambo/learning/train.py +++ b/src/lambo/learning/train.py @@ -223,7 +223,7 @@ def pretrain(model, train_dataloader, test_dataloader, epochs, device='cpu'): optimizer = Adam(model.parameters(), lr=learning_rate) print("Pretraining") - test_loop_pretraining(test_dataloader, model) + test_loop_pretraining(test_dataloader, model, device) for t in range(epochs): print(f"Epoch {t + 1}\n-------------------------------") train_loop(train_dataloader, model, optimizer, device) -- GitLab