Skip to content
Snippets Groups Projects
Commit ad097cf0 authored by Maja Jablonska's avatar Maja Jablonska
Browse files

Batch sizes as CLI parameters

parent 06c4376b
Branches
1 merge request!46Merge COMBO 3.0 into master
......@@ -48,14 +48,14 @@ flags.DEFINE_string(name="training_data_path", default="", help="Training data p
flags.DEFINE_alias(name="training_data", original_name="training_data_path")
flags.DEFINE_string(name="validation_data_path", default="", help="Validation data path(s)")
flags.DEFINE_alias(name="validation_data", original_name="validation_data_path")
flags.DEFINE_string(name="pretrained_tokens", default="",
help="Pretrained tokens embeddings path")
flags.DEFINE_integer(name="lemmatizer_embedding_dim", default=300,
help="Lemmatizer embeddings dim")
flags.DEFINE_integer(name="num_epochs", default=400,
help="Epochs num")
flags.DEFINE_integer(name="word_batch_size", default=2500,
help="Minimum words in batch")
flags.DEFINE_integer(name="batch_size", default=256,
help="Batch size")
flags.DEFINE_integer(name="batches_per_epoch", default=16,
help="Number of batches per epoch")
flags.DEFINE_string(name="pretrained_transformer_name", default="",
help="Pretrained transformer model name (see transformers from HuggingFace library for list of "
"available models) for transformers based embeddings.")
......@@ -90,8 +90,6 @@ flags.DEFINE_string(name="input_file", default=None,
help="File to predict path")
flags.DEFINE_boolean(name="conllu_format", default=True,
help="Prediction based on conllu format (instead of raw text).")
flags.DEFINE_integer(name="batch_size", default=1,
help="Prediction batch size.")
flags.DEFINE_boolean(name="silent", default=True,
help="Silent prediction to file (without printing to console).")
flags.DEFINE_boolean(name="finetuning", default=False,
......@@ -305,6 +303,8 @@ def _get_ext_vars(finetuning: bool = False) -> Dict:
},
"data_loader": {
"data_path": (",".join(FLAGS.training_data_path if not finetuning else FLAGS.finetuning_training_data_path)),
"batch_size": FLAGS.batch_size,
"batches_per_epoch": FLAGS.batches_per_epoch,
"parameters": {
"reader": {
"parameters": {
......@@ -323,6 +323,8 @@ def _get_ext_vars(finetuning: bool = False) -> Dict:
},
"validation_data_loader": {
"data_path": (",".join(FLAGS.validation_data_path if not finetuning else FLAGS.finetuning_validation_data_path)),
"batch_size": FLAGS.batch_size,
"batches_per_epoch": FLAGS.batches_per_epoch,
"parameters": {
"reader": {
"parameters": {
......@@ -351,13 +353,7 @@ def _get_ext_vars(finetuning: bool = False) -> Dict:
}
}
}
},
"vocabulary": {
"parameters": {
"pretrained_files": FLAGS.pretrained_tokens
}
},
"word_batch_size": int(FLAGS.word_batch_size),
}
}
return to_override
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment