From 11b2e8a3dd9a19d1ecaf5b741aba0c5a1d5166d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martyna=20Wi=C4=85cek?= <martyna.wiacek@ipipan.waw.pl> Date: Sun, 4 Feb 2024 17:44:14 +0100 Subject: [PATCH] fix main to take into consideration value of batches_per_epoch argument + removed batches_per_epoch from template since it is not used --- combo/config.template.json | 1 - combo/main.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/combo/config.template.json b/combo/config.template.json index 7baf654..93332c5 100644 --- a/combo/config.template.json +++ b/combo/config.template.json @@ -252,7 +252,6 @@ }, "batch_size": 1, "shuffle": true, - "batches_per_epoch": 64, "quiet": false } }, diff --git a/combo/main.py b/combo/main.py index 3884f05..256e75a 100755 --- a/combo/main.py +++ b/combo/main.py @@ -161,7 +161,7 @@ def get_defaults(dataset_reader: Optional[DatasetReader], ) if not training_data_loader: - training_data_loader = default_data_loader(dataset_reader, training_data_path, FLAGS.batch_size) + training_data_loader = default_data_loader(dataset_reader, training_data_path, FLAGS.batch_size, FLAGS.batches_per_epoch) else: if training_data_path: training_data_loader.data_path = training_data_path @@ -170,7 +170,7 @@ def get_defaults(dataset_reader: Optional[DatasetReader], str(training_data_loader.data_path), prefix=prefix) if FLAGS.validation_data_path and not validation_data_loader: - validation_data_loader = default_data_loader(dataset_reader, validation_data_path, FLAGS.batch_size) + validation_data_loader = default_data_loader(dataset_reader, validation_data_path, FLAGS.batch_size, FLAGS.batches_per_epoch) elif FLAGS.validation_data_path and validation_data_loader: if validation_data_path: validation_data_loader.data_path = validation_data_path -- GitLab