Skip to content
Snippets Groups Projects
Commit ad097cf0 authored by Maja Jablonska's avatar Maja Jablonska
Browse files

Batch sizes as CLI parameters

parent 06c4376b
No related branches found
No related tags found
1 merge request!46Merge COMBO 3.0 into master
...@@ -48,14 +48,14 @@ flags.DEFINE_string(name="training_data_path", default="", help="Training data p ...@@ -48,14 +48,14 @@ flags.DEFINE_string(name="training_data_path", default="", help="Training data p
flags.DEFINE_alias(name="training_data", original_name="training_data_path") flags.DEFINE_alias(name="training_data", original_name="training_data_path")
flags.DEFINE_string(name="validation_data_path", default="", help="Validation data path(s)") flags.DEFINE_string(name="validation_data_path", default="", help="Validation data path(s)")
flags.DEFINE_alias(name="validation_data", original_name="validation_data_path") flags.DEFINE_alias(name="validation_data", original_name="validation_data_path")
flags.DEFINE_string(name="pretrained_tokens", default="",
help="Pretrained tokens embeddings path")
flags.DEFINE_integer(name="lemmatizer_embedding_dim", default=300, flags.DEFINE_integer(name="lemmatizer_embedding_dim", default=300,
help="Lemmatizer embeddings dim") help="Lemmatizer embeddings dim")
flags.DEFINE_integer(name="num_epochs", default=400, flags.DEFINE_integer(name="num_epochs", default=400,
help="Epochs num") help="Epochs num")
flags.DEFINE_integer(name="word_batch_size", default=2500, flags.DEFINE_integer(name="batch_size", default=256,
help="Minimum words in batch") help="Batch size")
flags.DEFINE_integer(name="batches_per_epoch", default=16,
help="Number of batches per epoch")
flags.DEFINE_string(name="pretrained_transformer_name", default="", flags.DEFINE_string(name="pretrained_transformer_name", default="",
help="Pretrained transformer model name (see transformers from HuggingFace library for list of " help="Pretrained transformer model name (see transformers from HuggingFace library for list of "
"available models) for transformers based embeddings.") "available models) for transformers based embeddings.")
...@@ -90,8 +90,6 @@ flags.DEFINE_string(name="input_file", default=None, ...@@ -90,8 +90,6 @@ flags.DEFINE_string(name="input_file", default=None,
help="File to predict path") help="File to predict path")
flags.DEFINE_boolean(name="conllu_format", default=True, flags.DEFINE_boolean(name="conllu_format", default=True,
help="Prediction based on conllu format (instead of raw text).") help="Prediction based on conllu format (instead of raw text).")
flags.DEFINE_integer(name="batch_size", default=1,
help="Prediction batch size.")
flags.DEFINE_boolean(name="silent", default=True, flags.DEFINE_boolean(name="silent", default=True,
help="Silent prediction to file (without printing to console).") help="Silent prediction to file (without printing to console).")
flags.DEFINE_boolean(name="finetuning", default=False, flags.DEFINE_boolean(name="finetuning", default=False,
...@@ -305,6 +303,8 @@ def _get_ext_vars(finetuning: bool = False) -> Dict: ...@@ -305,6 +303,8 @@ def _get_ext_vars(finetuning: bool = False) -> Dict:
}, },
"data_loader": { "data_loader": {
"data_path": (",".join(FLAGS.training_data_path if not finetuning else FLAGS.finetuning_training_data_path)), "data_path": (",".join(FLAGS.training_data_path if not finetuning else FLAGS.finetuning_training_data_path)),
"batch_size": FLAGS.batch_size,
"batches_per_epoch": FLAGS.batches_per_epoch,
"parameters": { "parameters": {
"reader": { "reader": {
"parameters": { "parameters": {
...@@ -323,6 +323,8 @@ def _get_ext_vars(finetuning: bool = False) -> Dict: ...@@ -323,6 +323,8 @@ def _get_ext_vars(finetuning: bool = False) -> Dict:
}, },
"validation_data_loader": { "validation_data_loader": {
"data_path": (",".join(FLAGS.validation_data_path if not finetuning else FLAGS.finetuning_validation_data_path)), "data_path": (",".join(FLAGS.validation_data_path if not finetuning else FLAGS.finetuning_validation_data_path)),
"batch_size": FLAGS.batch_size,
"batches_per_epoch": FLAGS.batches_per_epoch,
"parameters": { "parameters": {
"reader": { "reader": {
"parameters": { "parameters": {
...@@ -351,13 +353,7 @@ def _get_ext_vars(finetuning: bool = False) -> Dict: ...@@ -351,13 +353,7 @@ def _get_ext_vars(finetuning: bool = False) -> Dict:
} }
} }
} }
},
"vocabulary": {
"parameters": {
"pretrained_files": FLAGS.pretrained_tokens
} }
},
"word_batch_size": int(FLAGS.word_batch_size),
} }
return to_override return to_override
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment