diff --git a/combo/config.graph.template.jsonnet b/combo/config.graph.template.jsonnet index c0c469674f4a74a6d46953e65d9715e9eabf0a2f..eb64270921e4d7041cd87385859aba0d833d7e1e 100644 --- a/combo/config.graph.template.jsonnet +++ b/combo/config.graph.template.jsonnet @@ -202,6 +202,7 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't }, token: if use_transformer then { type: "transformers_word_embeddings", + freeze_transformer: false, model_name: pretrained_transformer_name, projection_dim: projected_embedding_dim, tokenizer_kwargs: if std.startsWith(pretrained_transformer_name, "allegro/herbert") @@ -401,9 +402,14 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't grad_clipping: 5.0, num_epochs: num_epochs, optimizer: { - type: "adam", + type: "adamw", lr: learning_rate, - betas: [0.9, 0.9], + weight_decay: 0.0, + parameter_groups: [ + [ + ['_embedder'], { lr: 5e-5, weight_decay: 0.01, finetune: true, }, + ], + ], }, patience: 1, # it will be overwriten by callback epoch_callbacks: [ diff --git a/combo/training/scheduler.py b/combo/training/scheduler.py index 8752d739029f16289db94ee493d00bbd92158ce7..8b21d155d7d865a0dd8c00610f7185a241482348 100644 --- a/combo/training/scheduler.py +++ b/combo/training/scheduler.py @@ -7,7 +7,8 @@ from overrides import overrides class Scheduler(learning_rate_scheduler._PyTorchLearningRateSchedulerWrapper): def __init__(self, optimizer, patience: int = 6, decreases: int = 2, threshold: float = 1e-3): - super().__init__(lr_scheduler.LambdaLR(optimizer, lr_lambda=[self._lr_lambda])) + super().__init__(lr_scheduler.LambdaLR(optimizer, + lr_lambda=[self._lr_lambda] * len(optimizer.param_groups))) self.threshold = threshold self.decreases = decreases self.patience = patience diff --git a/scripts/train_iwpt21.py b/scripts/train_iwpt21.py index c6310eae7621c9745248e368bd0782cd053613b8..4c1f54a1a748ecf2eab2bcefc052b659fb49085c 100644 --- a/scripts/train_iwpt21.py +++ b/scripts/train_iwpt21.py @@ -114,8 +114,8 @@ def run(_): --pretrained_transformer_name {utils.LANG2TRANSFORMER[lang]} --serialization_dir {serialization_dir} --cuda_device {FLAGS.cuda_device} - --word_batch_size 2500 - --config_path {pathlib.Path.cwd() / 'config.graph.template.jsonnet'} + --word_batch_size 1000 + --config_path {pathlib.Path.cwd() / 'combo' / 'config.graph.template.jsonnet'} --notensorboard """