Skip to content
Snippets Groups Projects
Commit 2b72300b authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski
Browse files

Start to extract shared config for DP and EDP.

parent 04414d22
Branches
No related merge requests found
Pipeline #2874 passed with stage
in 4 minutes and 16 seconds
......@@ -392,31 +392,7 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
],
},
}),
trainer: std.prune({
checkpointer: {
type: "finishing_only_checkpointer",
},
type: "gradient_descent_validate_n",
cuda_device: cuda_device,
grad_clipping: 5.0,
num_epochs: num_epochs,
optimizer: {
type: "adam",
lr: learning_rate,
betas: [0.9, 0.9],
},
patience: 1, # it will be overwriten by callback
callbacks: [
{ type: "transfer_patience" }
{ type: "track_epoch_callback" },
if use_tensorboard then
{ type: "tensorboard", should_log_parameter_statistics: false},
],
learning_rate_scheduler: {
type: "combo_scheduler",
},
validation_metric: "+EM",
}),
trainer: shared_config.Trainer(cuda_device, num_epochs, learning_rate, use_tensorboard),
random_seed: 8787,
pytorch_seed: 8787,
numpy_seed: 8787,
......
{
local trainer(cuda_device, num_epochs, learning_rate, use_tensorboard) =
std.prune({
checkpointer: {
type: "finishing_only_checkpointer",
},
type: "gradient_descent_validate_n",
cuda_device: cuda_device,
grad_clipping: 5.0,
num_epochs: num_epochs,
optimizer: {
type: "adam",
lr: learning_rate,
betas: [0.9, 0.9],
},
patience: 1, # it will be overwriten by callback
callbacks: [
{ type: "transfer_patience" }
{ type: "track_epoch_callback" },
if use_tensorboard then
{ type: "tensorboard", should_log_parameter_statistics: false},
],
learning_rate_scheduler: {
type: "combo_scheduler",
},
validation_metric: "+EM",
}),
Trainer: trainer
}
\ No newline at end of file
local shared_config = import "config.shared.libsonnet";
########################################################################################
# BASIC configuration #
########################################################################################
......@@ -359,31 +360,7 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
],
},
}),
trainer: std.prune({
checkpointer: {
type: "finishing_only_checkpointer",
},
type: "gradient_descent_validate_n",
cuda_device: cuda_device,
grad_clipping: 5.0,
num_epochs: num_epochs,
optimizer: {
type: "adam",
lr: learning_rate,
betas: [0.9, 0.9],
},
patience: 1, # it will be overwriten by callback
callbacks: [
{ type: "transfer_patience" }
{ type: "track_epoch_callback" },
if use_tensorboard then
{ type: "tensorboard", should_log_parameter_statistics: false},
],
learning_rate_scheduler: {
type: "combo_scheduler",
},
validation_metric: "+EM",
}),
trainer: shared_config.Trainer(cuda_device, num_epochs, learning_rate, use_tensorboard),
random_seed: 8787,
pytorch_seed: 8787,
numpy_seed: 8787,
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment