diff --git a/combo/main.py b/combo/main.py index 055ebeb13a5d6e11c3e4540b3e68e5067ff47902..572570bcff9376bcd817dbaf4ff1aa97f6fadba3 100755 --- a/combo/main.py +++ b/combo/main.py @@ -48,14 +48,14 @@ flags.DEFINE_string(name="training_data_path", default="", help="Training data p flags.DEFINE_alias(name="training_data", original_name="training_data_path") flags.DEFINE_string(name="validation_data_path", default="", help="Validation data path(s)") flags.DEFINE_alias(name="validation_data", original_name="validation_data_path") -flags.DEFINE_string(name="pretrained_tokens", default="", - help="Pretrained tokens embeddings path") flags.DEFINE_integer(name="lemmatizer_embedding_dim", default=300, help="Lemmatizer embeddings dim") flags.DEFINE_integer(name="num_epochs", default=400, help="Epochs num") -flags.DEFINE_integer(name="word_batch_size", default=2500, - help="Minimum words in batch") +flags.DEFINE_integer(name="batch_size", default=256, + help="Batch size") +flags.DEFINE_integer(name="batches_per_epoch", default=16, + help="Number of batches per epoch") flags.DEFINE_string(name="pretrained_transformer_name", default="", help="Pretrained transformer model name (see transformers from HuggingFace library for list of " "available models) for transformers based embeddings.") @@ -90,8 +90,6 @@ flags.DEFINE_string(name="input_file", default=None, help="File to predict path") flags.DEFINE_boolean(name="conllu_format", default=True, help="Prediction based on conllu format (instead of raw text).") -flags.DEFINE_integer(name="batch_size", default=1, - help="Prediction batch size.") flags.DEFINE_boolean(name="silent", default=True, help="Silent prediction to file (without printing to console).") flags.DEFINE_boolean(name="finetuning", default=False, @@ -305,6 +303,8 @@ def _get_ext_vars(finetuning: bool = False) -> Dict: }, "data_loader": { "data_path": (",".join(FLAGS.training_data_path if not finetuning else FLAGS.finetuning_training_data_path)), + "batch_size": FLAGS.batch_size, + "batches_per_epoch": FLAGS.batches_per_epoch, "parameters": { "reader": { "parameters": { @@ -323,6 +323,8 @@ def _get_ext_vars(finetuning: bool = False) -> Dict: }, "validation_data_loader": { "data_path": (",".join(FLAGS.validation_data_path if not finetuning else FLAGS.finetuning_validation_data_path)), + "batch_size": FLAGS.batch_size, + "batches_per_epoch": FLAGS.batches_per_epoch, "parameters": { "reader": { "parameters": { @@ -351,13 +353,7 @@ def _get_ext_vars(finetuning: bool = False) -> Dict: } } } - }, - "vocabulary": { - "parameters": { - "pretrained_files": FLAGS.pretrained_tokens - } - }, - "word_batch_size": int(FLAGS.word_batch_size), + } } return to_override