diff --git a/combo/config.graph.template.jsonnet b/combo/config.graph.template.jsonnet
index c0c469674f4a74a6d46953e65d9715e9eabf0a2f..eb64270921e4d7041cd87385859aba0d833d7e1e 100644
--- a/combo/config.graph.template.jsonnet
+++ b/combo/config.graph.template.jsonnet
@@ -202,6 +202,7 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
                 },
                 token: if use_transformer then {
                     type: "transformers_word_embeddings",
+                    freeze_transformer: false,
                     model_name: pretrained_transformer_name,
                     projection_dim: projected_embedding_dim,
                     tokenizer_kwargs: if std.startsWith(pretrained_transformer_name, "allegro/herbert")
@@ -401,9 +402,14 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
         grad_clipping: 5.0,
         num_epochs: num_epochs,
         optimizer: {
-            type: "adam",
+            type: "adamw",
             lr: learning_rate,
-            betas: [0.9, 0.9],
+            weight_decay: 0.0,
+            parameter_groups: [
+                [
+                    ['_embedder'], { lr: 5e-5, weight_decay: 0.01, finetune: true, },
+                ],
+            ],
         },
         patience: 1, # it will  be overwriten by callback
         epoch_callbacks: [
diff --git a/combo/training/scheduler.py b/combo/training/scheduler.py
index 8752d739029f16289db94ee493d00bbd92158ce7..8b21d155d7d865a0dd8c00610f7185a241482348 100644
--- a/combo/training/scheduler.py
+++ b/combo/training/scheduler.py
@@ -7,7 +7,8 @@ from overrides import overrides
 class Scheduler(learning_rate_scheduler._PyTorchLearningRateSchedulerWrapper):
 
     def __init__(self, optimizer, patience: int = 6, decreases: int = 2, threshold: float = 1e-3):
-        super().__init__(lr_scheduler.LambdaLR(optimizer, lr_lambda=[self._lr_lambda]))
+        super().__init__(lr_scheduler.LambdaLR(optimizer,
+                                               lr_lambda=[self._lr_lambda] * len(optimizer.param_groups)))
         self.threshold = threshold
         self.decreases = decreases
         self.patience = patience
diff --git a/scripts/train_iwpt21.py b/scripts/train_iwpt21.py
index c6310eae7621c9745248e368bd0782cd053613b8..4c1f54a1a748ecf2eab2bcefc052b659fb49085c 100644
--- a/scripts/train_iwpt21.py
+++ b/scripts/train_iwpt21.py
@@ -114,8 +114,8 @@ def run(_):
         --pretrained_transformer_name {utils.LANG2TRANSFORMER[lang]}
         --serialization_dir {serialization_dir}
         --cuda_device {FLAGS.cuda_device}
-        --word_batch_size 2500
-        --config_path {pathlib.Path.cwd() / 'config.graph.template.jsonnet'}
+        --word_batch_size 1000
+        --config_path {pathlib.Path.cwd() / 'combo' / 'config.graph.template.jsonnet'}
         --notensorboard
         """