diff --git a/combo/config.graph.template.jsonnet b/combo/config.graph.template.jsonnet index 09e74639fa22765cc887bacf4453a6733d717aff..b88c436a54f04796204cc5c086b0657d7e2e313c 100644 --- a/combo/config.graph.template.jsonnet +++ b/combo/config.graph.template.jsonnet @@ -392,31 +392,7 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't ], }, }), - trainer: std.prune({ - checkpointer: { - type: "finishing_only_checkpointer", - }, - type: "gradient_descent_validate_n", - cuda_device: cuda_device, - grad_clipping: 5.0, - num_epochs: num_epochs, - optimizer: { - type: "adam", - lr: learning_rate, - betas: [0.9, 0.9], - }, - patience: 1, # it will be overwriten by callback - callbacks: [ - { type: "transfer_patience" } - { type: "track_epoch_callback" }, - if use_tensorboard then - { type: "tensorboard", should_log_parameter_statistics: false}, - ], - learning_rate_scheduler: { - type: "combo_scheduler", - }, - validation_metric: "+EM", - }), + trainer: shared_config.Trainer(cuda_device, num_epochs, learning_rate, use_tensorboard), random_seed: 8787, pytorch_seed: 8787, numpy_seed: 8787, diff --git a/combo/config.shared.libsonnet b/combo/config.shared.libsonnet new file mode 100644 index 0000000000000000000000000000000000000000..23d3f9a8717dc35bf3a33bd112d13de6daab13dc --- /dev/null +++ b/combo/config.shared.libsonnet @@ -0,0 +1,29 @@ +{ + local trainer(cuda_device, num_epochs, learning_rate, use_tensorboard) = + std.prune({ + checkpointer: { + type: "finishing_only_checkpointer", + }, + type: "gradient_descent_validate_n", + cuda_device: cuda_device, + grad_clipping: 5.0, + num_epochs: num_epochs, + optimizer: { + type: "adam", + lr: learning_rate, + betas: [0.9, 0.9], + }, + patience: 1, # it will be overwriten by callback + callbacks: [ + { type: "transfer_patience" } + { type: "track_epoch_callback" }, + if use_tensorboard then + { type: "tensorboard", should_log_parameter_statistics: false}, + ], + learning_rate_scheduler: { + type: "combo_scheduler", + }, + validation_metric: "+EM", + }), + Trainer: trainer +} \ No newline at end of file diff --git a/combo/config.template.jsonnet b/combo/config.template.jsonnet index 9049eb91157e9d5d1fdb257dd5e8a84d15169c62..2019d7b05bcdcd7ef8180663206e3367b04119de 100644 --- a/combo/config.template.jsonnet +++ b/combo/config.template.jsonnet @@ -1,3 +1,4 @@ +local shared_config = import "config.shared.libsonnet"; ######################################################################################## # BASIC configuration # ######################################################################################## @@ -359,31 +360,7 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't ], }, }), - trainer: std.prune({ - checkpointer: { - type: "finishing_only_checkpointer", - }, - type: "gradient_descent_validate_n", - cuda_device: cuda_device, - grad_clipping: 5.0, - num_epochs: num_epochs, - optimizer: { - type: "adam", - lr: learning_rate, - betas: [0.9, 0.9], - }, - patience: 1, # it will be overwriten by callback - callbacks: [ - { type: "transfer_patience" } - { type: "track_epoch_callback" }, - if use_tensorboard then - { type: "tensorboard", should_log_parameter_statistics: false}, - ], - learning_rate_scheduler: { - type: "combo_scheduler", - }, - validation_metric: "+EM", - }), + trainer: shared_config.Trainer(cuda_device, num_epochs, learning_rate, use_tensorboard), random_seed: 8787, pytorch_seed: 8787, numpy_seed: 8787,