From 11b2e8a3dd9a19d1ecaf5b741aba0c5a1d5166d3 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Martyna=20Wi=C4=85cek?= <martyna.wiacek@ipipan.waw.pl>
Date: Sun, 4 Feb 2024 17:44:14 +0100
Subject: [PATCH] fix main to take into consideration value of
 batches_per_epoch argument + removed batches_per_epoch from template since it
 is not used

---
 combo/config.template.json | 1 -
 combo/main.py              | 4 ++--
 2 files changed, 2 insertions(+), 3 deletions(-)

diff --git a/combo/config.template.json b/combo/config.template.json
index 7baf654..93332c5 100644
--- a/combo/config.template.json
+++ b/combo/config.template.json
@@ -252,7 +252,6 @@
       },
       "batch_size": 1,
       "shuffle": true,
-      "batches_per_epoch": 64,
       "quiet": false
     }
   },
diff --git a/combo/main.py b/combo/main.py
index 3884f05..256e75a 100755
--- a/combo/main.py
+++ b/combo/main.py
@@ -161,7 +161,7 @@ def get_defaults(dataset_reader: Optional[DatasetReader],
                                                    )
 
     if not training_data_loader:
-        training_data_loader = default_data_loader(dataset_reader, training_data_path, FLAGS.batch_size)
+        training_data_loader = default_data_loader(dataset_reader, training_data_path, FLAGS.batch_size, FLAGS.batches_per_epoch)
     else:
         if training_data_path:
             training_data_loader.data_path = training_data_path
@@ -170,7 +170,7 @@ def get_defaults(dataset_reader: Optional[DatasetReader],
                            str(training_data_loader.data_path), prefix=prefix)
 
     if FLAGS.validation_data_path and not validation_data_loader:
-        validation_data_loader = default_data_loader(dataset_reader, validation_data_path, FLAGS.batch_size)
+        validation_data_loader = default_data_loader(dataset_reader, validation_data_path, FLAGS.batch_size, FLAGS.batches_per_epoch)
     elif FLAGS.validation_data_path and validation_data_loader:
         if validation_data_path:
             validation_data_loader.data_path = validation_data_path
-- 
GitLab