Skip to content
Snippets Groups Projects
Commit ad097cf0 authored by Maja Jablonska's avatar Maja Jablonska
Browse files

Batch sizes as CLI parameters

parent 06c4376b
Branches
Tags
1 merge request!46Merge COMBO 3.0 into master
...@@ -48,14 +48,14 @@ flags.DEFINE_string(name="training_data_path", default="", help="Training data p ...@@ -48,14 +48,14 @@ flags.DEFINE_string(name="training_data_path", default="", help="Training data p
flags.DEFINE_alias(name="training_data", original_name="training_data_path") flags.DEFINE_alias(name="training_data", original_name="training_data_path")
flags.DEFINE_string(name="validation_data_path", default="", help="Validation data path(s)") flags.DEFINE_string(name="validation_data_path", default="", help="Validation data path(s)")
flags.DEFINE_alias(name="validation_data", original_name="validation_data_path") flags.DEFINE_alias(name="validation_data", original_name="validation_data_path")
flags.DEFINE_string(name="pretrained_tokens", default="",
help="Pretrained tokens embeddings path")
flags.DEFINE_integer(name="lemmatizer_embedding_dim", default=300, flags.DEFINE_integer(name="lemmatizer_embedding_dim", default=300,
help="Lemmatizer embeddings dim") help="Lemmatizer embeddings dim")
flags.DEFINE_integer(name="num_epochs", default=400, flags.DEFINE_integer(name="num_epochs", default=400,
help="Epochs num") help="Epochs num")
flags.DEFINE_integer(name="word_batch_size", default=2500, flags.DEFINE_integer(name="batch_size", default=256,
help="Minimum words in batch") help="Batch size")
flags.DEFINE_integer(name="batches_per_epoch", default=16,
help="Number of batches per epoch")
flags.DEFINE_string(name="pretrained_transformer_name", default="", flags.DEFINE_string(name="pretrained_transformer_name", default="",
help="Pretrained transformer model name (see transformers from HuggingFace library for list of " help="Pretrained transformer model name (see transformers from HuggingFace library for list of "
"available models) for transformers based embeddings.") "available models) for transformers based embeddings.")
...@@ -90,8 +90,6 @@ flags.DEFINE_string(name="input_file", default=None, ...@@ -90,8 +90,6 @@ flags.DEFINE_string(name="input_file", default=None,
help="File to predict path") help="File to predict path")
flags.DEFINE_boolean(name="conllu_format", default=True, flags.DEFINE_boolean(name="conllu_format", default=True,
help="Prediction based on conllu format (instead of raw text).") help="Prediction based on conllu format (instead of raw text).")
flags.DEFINE_integer(name="batch_size", default=1,
help="Prediction batch size.")
flags.DEFINE_boolean(name="silent", default=True, flags.DEFINE_boolean(name="silent", default=True,
help="Silent prediction to file (without printing to console).") help="Silent prediction to file (without printing to console).")
flags.DEFINE_boolean(name="finetuning", default=False, flags.DEFINE_boolean(name="finetuning", default=False,
...@@ -305,6 +303,8 @@ def _get_ext_vars(finetuning: bool = False) -> Dict: ...@@ -305,6 +303,8 @@ def _get_ext_vars(finetuning: bool = False) -> Dict:
}, },
"data_loader": { "data_loader": {
"data_path": (",".join(FLAGS.training_data_path if not finetuning else FLAGS.finetuning_training_data_path)), "data_path": (",".join(FLAGS.training_data_path if not finetuning else FLAGS.finetuning_training_data_path)),
"batch_size": FLAGS.batch_size,
"batches_per_epoch": FLAGS.batches_per_epoch,
"parameters": { "parameters": {
"reader": { "reader": {
"parameters": { "parameters": {
...@@ -323,6 +323,8 @@ def _get_ext_vars(finetuning: bool = False) -> Dict: ...@@ -323,6 +323,8 @@ def _get_ext_vars(finetuning: bool = False) -> Dict:
}, },
"validation_data_loader": { "validation_data_loader": {
"data_path": (",".join(FLAGS.validation_data_path if not finetuning else FLAGS.finetuning_validation_data_path)), "data_path": (",".join(FLAGS.validation_data_path if not finetuning else FLAGS.finetuning_validation_data_path)),
"batch_size": FLAGS.batch_size,
"batches_per_epoch": FLAGS.batches_per_epoch,
"parameters": { "parameters": {
"reader": { "reader": {
"parameters": { "parameters": {
...@@ -351,13 +353,7 @@ def _get_ext_vars(finetuning: bool = False) -> Dict: ...@@ -351,13 +353,7 @@ def _get_ext_vars(finetuning: bool = False) -> Dict:
} }
} }
} }
}, }
"vocabulary": {
"parameters": {
"pretrained_files": FLAGS.pretrained_tokens
}
},
"word_batch_size": int(FLAGS.word_batch_size),
} }
return to_override return to_override
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment