diff --git a/combo/models/embeddings.py b/combo/models/embeddings.py
index 6ad25590e3f29bcde42266b8ee9cc720787b4388..49f1ab9d6fa114532927897e93148a793038fdf4 100644
--- a/combo/models/embeddings.py
+++ b/combo/models/embeddings.py
@@ -113,12 +113,10 @@ class TransformersWordEmbedder(token_embedders.PretrainedTransformerMismatchedEm
                  freeze_transformer: bool = True,
                  tokenizer_kwargs: Optional[Dict[str, Any]] = None,
                  transformer_kwargs: Optional[Dict[str, Any]] = None):
-        super().__init__(model_name, tokenizer_kwargs=tokenizer_kwargs, transformer_kwargs=transformer_kwargs)
-        self.freeze_transformer = freeze_transformer
-        if self.freeze_transformer:
-            self._matched_embedder.eval()
-            for param in self._matched_embedder.parameters():
-                param.requires_grad = False
+        super().__init__(model_name,
+                         train_parameters=not freeze_transformer,
+                         tokenizer_kwargs=tokenizer_kwargs,
+                         transformer_kwargs=transformer_kwargs)
         if projection_dim:
             self.projection_layer = base.Linear(in_features=super().get_output_dim(),
                                                 out_features=projection_dim,
@@ -148,20 +146,6 @@ class TransformersWordEmbedder(token_embedders.PretrainedTransformerMismatchedEm
     def get_output_dim(self):
         return self.output_dim
 
-    @overrides
-    def train(self, mode: bool):
-        if self.freeze_transformer:
-            self.projection_layer.train(mode)
-        else:
-            super().train(mode)
-
-    @overrides
-    def eval(self):
-        if self.freeze_transformer:
-            self.projection_layer.eval()
-        else:
-            super().eval()
-
 
 @token_embedders.TokenEmbedder.register("feats_embedding")
 class FeatsTokenEmbedder(token_embedders.Embedding):
diff --git a/combo/models/parser.py b/combo/models/parser.py
index dfb53ab8ded369b01eae4851dd1d7a9936c05bbe..511edffc2f8d17edbc3fd0702e6425a4ec645e4e 100644
--- a/combo/models/parser.py
+++ b/combo/models/parser.py
@@ -158,7 +158,7 @@ class DependencyRelationModel(base.Predictor):
             output["prediction"] = (relation_prediction.argmax(-1)[:, 1:], head_output["prediction"])
         else:
             # Mask root label whenever head is not 0.
-            relation_prediction_output = relation_prediction[:, 1:]
+            relation_prediction_output = relation_prediction[:, 1:].clone()
             mask = (head_output["prediction"] == 0)
             vocab_size = relation_prediction_output.size(-1)
             root_idx = torch.tensor([self.root_idx], device=device)
diff --git a/combo/training/trainer.py b/combo/training/trainer.py
index aeb9f097369f515b3860d961f638870ec17f6786..26bd75f7fbe6917f144b820bbbb1c7e14c3c8e9d 100644
--- a/combo/training/trainer.py
+++ b/combo/training/trainer.py
@@ -230,22 +230,24 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
             patience: int = None,
             validation_metric: str = "-loss",
             num_epochs: int = 20,
-            cuda_device: int = -1,
+            cuda_device: Optional[Union[int, torch.device]] = -1,
             grad_norm: float = None,
             grad_clipping: float = None,
             distributed: bool = None,
             world_size: int = 1,
             num_gradient_accumulation_steps: int = 1,
-            opt_level: Optional[str] = None,
             use_amp: bool = False,
-            optimizer: common.Lazy[optimizers.Optimizer] = None,
+            no_grad: List[str] = None,
+            optimizer: common.Lazy[optimizers.Optimizer] = common.Lazy(optimizers.Optimizer.default),
             learning_rate_scheduler: common.Lazy[learning_rate_schedulers.LearningRateScheduler] = None,
             momentum_scheduler: common.Lazy[momentum_schedulers.MomentumScheduler] = None,
             tensorboard_writer: common.Lazy[allen_tensorboard_writer.TensorboardWriter] = None,
             moving_average: common.Lazy[moving_average.MovingAverage] = None,
-            checkpointer: common.Lazy[training.Checkpointer] = None,
+            checkpointer: common.Lazy[training.Checkpointer] = common.Lazy(training.Checkpointer),
             batch_callbacks: List[training.BatchCallback] = None,
             epoch_callbacks: List[training.EpochCallback] = None,
+            end_callbacks: List[training.EpochCallback] = None,
+            trainer_callbacks: List[training.TrainerCallback] = None,
     ) -> "training.Trainer":
         if tensorboard_writer is None:
             tensorboard_writer = common.Lazy(combo_tensorboard_writer.NullTensorboardWriter)
@@ -265,6 +267,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
             world_size=world_size,
             num_gradient_accumulation_steps=num_gradient_accumulation_steps,
             use_amp=use_amp,
+            no_grad=no_grad,
             optimizer=optimizer,
             learning_rate_scheduler=learning_rate_scheduler,
             momentum_scheduler=momentum_scheduler,
@@ -273,4 +276,6 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
             checkpointer=checkpointer,
             batch_callbacks=batch_callbacks,
             epoch_callbacks=epoch_callbacks,
+            end_callbacks=end_callbacks,
+            trainer_callbacks=trainer_callbacks,
         )