From 11b5840073a827351f19050c8f70a634c9e77be1 Mon Sep 17 00:00:00 2001
From: piotrmp <piotr.m.przybyla@gmail.com>
Date: Wed, 23 Nov 2022 10:13:10 +0100
Subject: [PATCH] Bug fix.

---
 src/lambo/learning/train.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/lambo/learning/train.py b/src/lambo/learning/train.py
index a2cfd2c..dfab088 100644
--- a/src/lambo/learning/train.py
+++ b/src/lambo/learning/train.py
@@ -223,7 +223,7 @@ def pretrain(model, train_dataloader, test_dataloader, epochs, device='cpu'):
     optimizer = Adam(model.parameters(), lr=learning_rate)
     
     print("Pretraining")
-    test_loop_pretraining(test_dataloader, model)
+    test_loop_pretraining(test_dataloader, model, device)
     for t in range(epochs):
         print(f"Epoch {t + 1}\n-------------------------------")
         train_loop(train_dataloader, model, optimizer, device)
-- 
GitLab