From a106d5f067ac36fe553db2d5c0e70b76df08622e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martyna=20Wi=C4=85cek?= <martyna.wiacek@ipipan.waw.pl> Date: Fri, 9 Feb 2024 17:39:40 +0100 Subject: [PATCH] added separate validation batches per epoch --- combo/main.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/combo/main.py b/combo/main.py index ba135d4..79e9820 100755 --- a/combo/main.py +++ b/combo/main.py @@ -63,6 +63,8 @@ flags.DEFINE_integer(name="batch_size", default=256, help="Batch size") flags.DEFINE_integer(name="batches_per_epoch", default=16, help="Number of batches per epoch") +flags.DEFINE_integer(name="validation_batches_per_epoch", default=4, + help="Number of batches per epoch") flags.DEFINE_string(name="pretrained_transformer_name", default="bert-base-cased", help="Pretrained transformer model name (see transformers from HuggingFace library for list of " "available models) for transformers based embeddings.") @@ -171,7 +173,7 @@ def get_defaults(dataset_reader: Optional[DatasetReader], str(training_data_loader.data_path), prefix=prefix) if FLAGS.validation_data_path and not validation_data_loader: - validation_data_loader = default_data_loader(dataset_reader, validation_data_path, FLAGS.batch_size, FLAGS.batches_per_epoch) + validation_data_loader = default_data_loader(dataset_reader, validation_data_path, FLAGS.batch_size, FLAGS.validation_batches_per_epoch) elif FLAGS.validation_data_path and validation_data_loader: if validation_data_path: validation_data_loader.data_path = validation_data_path @@ -514,7 +516,7 @@ def _get_ext_vars(finetuning: bool = False) -> Dict: "parameters": { "data_path": FLAGS.validation_data_path if not finetuning else FLAGS.finetuning_validation_data_path, "batch_size": FLAGS.batch_size, - "batches_per_epoch": FLAGS.batches_per_epoch, + "batches_per_epoch": FLAGS.validation_batches_per_epoch, "reader": { "parameters": { "features": FLAGS.features, -- GitLab