Skip to content
Snippets Groups Projects
Commit d9dd119c authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski
Browse files

Fix configuration options and reduce amount of created checkpoints.

parent f3d37c16
Branches
Tags
No related merge requests found
......@@ -20,30 +20,34 @@ logger = logging.getLogger(__name__)
FLAGS = flags.FLAGS
flags.DEFINE_enum(name='mode', default=None, enum_values=['train', 'predict'],
help="Specify COMBO mode: train or precit")
help='Specify COMBO mode: train or predict')
# Common flags
flags.DEFINE_integer(name='cuda_device', default=-1,
help="Cuda device id (default -1 cpu)")
# Training flags
flags.DEFINE_string(name='training_data_path', default="./tests/fixtures/example.conllu",
help='Training data path')
flags.DEFINE_string(name='validation_data_path', default='',
help='Validation data path')
flags.DEFINE_list(name='training_data_path', default="./tests/fixtures/example.conllu",
help='Training data path')
flags.DEFINE_list(name='validation_data_path', default='',
help='Validation data path')
flags.DEFINE_string(name='pretrained_tokens', default='',
help='Pretrained tokens embeddings path')
flags.DEFINE_integer(name='embedding_dim', default=300,
help='Embeddings dim')
flags.DEFINE_integer(name='num_epochs', default=400,
help='Epochs num')
flags.DEFINE_integer(name='word_batch_size', default=2500,
help='Minimum words in batch')
flags.DEFINE_string(name='pretrained_transformer_name', default='',
help='Pretrained transformer model name (see transformers from HuggingFace library for list of'
'available models) for transformers based embeddings.')
flags.DEFINE_multi_enum(name='features', default=['token', 'char'],
enum_values=['token', 'char', 'upostag', 'xpostag', 'lemma'],
help='Features used to train model (required \'token\' and \'char\')')
help='Features used to train model (required `token` and `char`)')
flags.DEFINE_multi_enum(name='targets', default=['deprel', 'feats', 'head', 'lemma', 'upostag', 'xpostag'],
enum_values=['deprel', 'feats', 'head', 'lemma', 'upostag', 'xpostag', 'semrel', 'sent'],
help='Targets of the model (required \'deprel\' and \'head\')')
help='Targets of the model (required `deprel` and `head`)')
flags.DEFINE_string(name='serialization_dir', default=None,
help='Model serialization directory (default - system temp dir).')
......@@ -56,7 +60,7 @@ flags.DEFINE_string(name='config_path', default='config.template.jsonnet',
help='Config file path.')
# Test after training flags
flags.DEFINE_string(name='result', default='result.conll',
flags.DEFINE_string(name='result', default='result.conllu',
help='Test result path file')
flags.DEFINE_string(name='test_path', default=None,
help='Test path file.')
......@@ -124,7 +128,7 @@ def run(_):
if FLAGS.test_path and FLAGS.result:
checks.file_exists(FLAGS.test_path)
params = common.Params.from_file(FLAGS.config_path)['dataset_reader']
params = common.Params.from_file(FLAGS.config_path, ext_vars=_get_ext_vars())['dataset_reader']
params.pop('type')
dataset_reader = dataset.UniversalDependenciesDatasetReader.from_params(params)
predictor = predict.SemanticMultitaskPredictor(
......@@ -167,21 +171,28 @@ def _get_ext_vars(finetuning: bool = False) -> Dict:
return {}
else:
return {
'training_data_path': FLAGS.training_data_path if not finetuning else FLAGS.finetuning_training_data_path,
'training_data_path': (
':'.join(FLAGS.training_data_path) if not finetuning else FLAGS.finetuning_training_data_path),
'validation_data_path': (
FLAGS.validation_data_path if not finetuning else FLAGS.finetuning_validation_data_path),
':'.join(FLAGS.validation_data_path) if not finetuning else FLAGS.finetuning_validation_data_path),
'pretrained_tokens': FLAGS.pretrained_tokens,
'pretrained_transformer_name': FLAGS.pretrained_transformer_name,
'features': ' '.join(FLAGS.features),
'targets': ' '.join(FLAGS.targets),
'type': 'finetuning' if finetuning else 'default',
'embedding_dim': str(FLAGS.embedding_dim),
'cuda_device': str(FLAGS.cuda_device),
'num_epochs': str(FLAGS.num_epochs),
'word_batch_size': str(FLAGS.word_batch_size)
}
def main():
"""Parse flags."""
flags.mark_flag_as_required('mode')
flags.register_validator(
'mode',
lambda value: value is not None,
message='Flag --mode must be set with either `predict` or `train` value')
app.run(run)
......
"""Training tools."""
from .checkpointer import FinishingTrainingCheckpointer
from .scheduler import Scheduler
from .trainer import GradientDescentTrainer
from typing import Union, Any, Dict, Tuple
from allennlp import training
from allennlp.training import trainer as allen_trainer
@training.Checkpointer.register('finishing_only_checkpointer')
class FinishingTrainingCheckpointer(training.Checkpointer):
"""Checkpointer disables restoring interrupted training and saves only weights
when this is last epoch / learning rate is on the last lr decrease.
Remove checkpointer configuration from config template to get regular, on best score, saving."""
def save_checkpoint(
self,
epoch: Union[int, str],
trainer: "allen_trainer.Trainer",
is_best_so_far: bool = False,
) -> None:
if trainer._learning_rate_scheduler.decreases <= 1 or epoch == trainer._num_epochs - 1:
super().save_checkpoint(epoch, trainer, is_best_so_far)
def restore_checkpoint(self) -> Tuple[Dict[str, Any], Dict[str, Any]]:
return {}, {}
def maybe_save_checkpoint(
self, trainer: "allen_trainer.Trainer", epoch: int, batches_this_epoch: int
) -> None:
pass
......@@ -15,17 +15,17 @@ local pretrained_transformer_name = if std.length(std.extVar("pretrained_transfo
# Learning rate value, float
local learning_rate = 0.002;
# Number of epochs, int
local num_epochs = 1;
local num_epochs = std.parseInt(std.extVar("num_epochs"));
# Cuda device id, -1 for cpu, int
local cuda_device = -1;
local cuda_device = std.parseInt(std.extVar("cuda_device"));
# Minimum number of words in batch, int
local word_batch_size = 1;
local word_batch_size = std.parseInt(std.extVar("word_batch_size"));
# Features used as input, list of str
# Choice "upostag", "xpostag", "lemma"
# Required "token", "char"
local features = std.split(std.extVar("features"), " ");
# Targets of the model, list of str
# Choice "feats", "lemma", "upostag", "xpostag", "semrel"
# Choice "feats", "lemma", "upostag", "xpostag", "semrel". "sent"
# Required "deprel", "head"
local targets = std.split(std.extVar("targets"), " ");
# Path for tensorboard metrics, str
......@@ -343,6 +343,9 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
},
}),
trainer: {
checkpointer: {
type: "finishing_only_checkpointer",
},
type: "gradient_descent_validate_n",
cuda_device: cuda_device,
grad_clipping: 5.0,
......
......@@ -34,13 +34,13 @@ class TrainingEndToEndTest(unittest.TestCase):
'pretrained_tokens': os.path.join(self.FIXTURES_ROOT, 'example.vec'),
'pretrained_transformer_name': '',
'embedding_dim': '300',
'cuda_device': '-1',
'num_epochs': '1',
'word_batch_size': '1',
}
params = Params.from_file(os.path.join(self.PROJECT_ROOT, 'config.template.jsonnet'),
ext_vars=ext_vars)
params['trainer']['tensorboard_writer']['serialization_dir'] = os.path.join(self.TEST_DIR, 'metrics')
params['trainer']['num_epochs'] = 1
params['data_loader']['batch_sampler']['word_batch_size'] = 1
# when
model = train.train_model(params, serialization_dir=self.TEST_DIR)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment