From cd660089520eed40dd039930c0828a65221a517f Mon Sep 17 00:00:00 2001 From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com> Date: Mon, 19 Apr 2021 14:36:49 +0200 Subject: [PATCH] Enable BERT fine-tuning. --- combo/config.graph.template.jsonnet | 10 ++++++++-- combo/training/scheduler.py | 3 ++- scripts/train_iwpt21.py | 4 ++-- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/combo/config.graph.template.jsonnet b/combo/config.graph.template.jsonnet index c0c4696..eb64270 100644 --- a/combo/config.graph.template.jsonnet +++ b/combo/config.graph.template.jsonnet @@ -202,6 +202,7 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't }, token: if use_transformer then { type: "transformers_word_embeddings", + freeze_transformer: false, model_name: pretrained_transformer_name, projection_dim: projected_embedding_dim, tokenizer_kwargs: if std.startsWith(pretrained_transformer_name, "allegro/herbert") @@ -401,9 +402,14 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't grad_clipping: 5.0, num_epochs: num_epochs, optimizer: { - type: "adam", + type: "adamw", lr: learning_rate, - betas: [0.9, 0.9], + weight_decay: 0.0, + parameter_groups: [ + [ + ['_embedder'], { lr: 5e-5, weight_decay: 0.01, finetune: true, }, + ], + ], }, patience: 1, # it will be overwriten by callback epoch_callbacks: [ diff --git a/combo/training/scheduler.py b/combo/training/scheduler.py index 8752d73..8b21d15 100644 --- a/combo/training/scheduler.py +++ b/combo/training/scheduler.py @@ -7,7 +7,8 @@ from overrides import overrides class Scheduler(learning_rate_scheduler._PyTorchLearningRateSchedulerWrapper): def __init__(self, optimizer, patience: int = 6, decreases: int = 2, threshold: float = 1e-3): - super().__init__(lr_scheduler.LambdaLR(optimizer, lr_lambda=[self._lr_lambda])) + super().__init__(lr_scheduler.LambdaLR(optimizer, + lr_lambda=[self._lr_lambda] * len(optimizer.param_groups))) self.threshold = threshold self.decreases = decreases self.patience = patience diff --git a/scripts/train_iwpt21.py b/scripts/train_iwpt21.py index c6310ea..4c1f54a 100644 --- a/scripts/train_iwpt21.py +++ b/scripts/train_iwpt21.py @@ -114,8 +114,8 @@ def run(_): --pretrained_transformer_name {utils.LANG2TRANSFORMER[lang]} --serialization_dir {serialization_dir} --cuda_device {FLAGS.cuda_device} - --word_batch_size 2500 - --config_path {pathlib.Path.cwd() / 'config.graph.template.jsonnet'} + --word_batch_size 1000 + --config_path {pathlib.Path.cwd() / 'combo' / 'config.graph.template.jsonnet'} --notensorboard """ -- GitLab