Skip to content
Snippets Groups Projects
Commit a106d5f0 authored by Martyna Wiącek's avatar Martyna Wiącek
Browse files

added separate validation batches per epoch

parent aff8d3c7
Branches
Tags
2 merge requests!49Multiword fix transformer encoder,!47Fixed multiword prediction + bug that made the code write empty predictions
......@@ -63,6 +63,8 @@ flags.DEFINE_integer(name="batch_size", default=256,
help="Batch size")
flags.DEFINE_integer(name="batches_per_epoch", default=16,
help="Number of batches per epoch")
flags.DEFINE_integer(name="validation_batches_per_epoch", default=4,
help="Number of batches per epoch")
flags.DEFINE_string(name="pretrained_transformer_name", default="bert-base-cased",
help="Pretrained transformer model name (see transformers from HuggingFace library for list of "
"available models) for transformers based embeddings.")
......@@ -171,7 +173,7 @@ def get_defaults(dataset_reader: Optional[DatasetReader],
str(training_data_loader.data_path), prefix=prefix)
if FLAGS.validation_data_path and not validation_data_loader:
validation_data_loader = default_data_loader(dataset_reader, validation_data_path, FLAGS.batch_size, FLAGS.batches_per_epoch)
validation_data_loader = default_data_loader(dataset_reader, validation_data_path, FLAGS.batch_size, FLAGS.validation_batches_per_epoch)
elif FLAGS.validation_data_path and validation_data_loader:
if validation_data_path:
validation_data_loader.data_path = validation_data_path
......@@ -514,7 +516,7 @@ def _get_ext_vars(finetuning: bool = False) -> Dict:
"parameters": {
"data_path": FLAGS.validation_data_path if not finetuning else FLAGS.finetuning_validation_data_path,
"batch_size": FLAGS.batch_size,
"batches_per_epoch": FLAGS.batches_per_epoch,
"batches_per_epoch": FLAGS.validation_batches_per_epoch,
"reader": {
"parameters": {
"features": FLAGS.features,
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment