diff --git a/combo/models/embeddings.py b/combo/models/embeddings.py
index 6ad25590e3f29bcde42266b8ee9cc720787b4388..d8e9d7a28a7fa36b60d108a30a3286026a327e51 100644
--- a/combo/models/embeddings.py
+++ b/combo/models/embeddings.py
@@ -107,18 +107,16 @@ class TransformersWordEmbedder(token_embedders.PretrainedTransformerMismatchedEm
 
     def __init__(self,
                  model_name: str,
-                 projection_dim: int,
+                 projection_dim: int = 0,
                  projection_activation: Optional[allen_nn.Activation] = lambda x: x,
                  projection_dropout_rate: Optional[float] = 0.0,
                  freeze_transformer: bool = True,
                  tokenizer_kwargs: Optional[Dict[str, Any]] = None,
                  transformer_kwargs: Optional[Dict[str, Any]] = None):
-        super().__init__(model_name, tokenizer_kwargs=tokenizer_kwargs, transformer_kwargs=transformer_kwargs)
-        self.freeze_transformer = freeze_transformer
-        if self.freeze_transformer:
-            self._matched_embedder.eval()
-            for param in self._matched_embedder.parameters():
-                param.requires_grad = False
+        super().__init__(model_name,
+                         train_parameters=not freeze_transformer,
+                         tokenizer_kwargs=tokenizer_kwargs,
+                         transformer_kwargs=transformer_kwargs)
         if projection_dim:
             self.projection_layer = base.Linear(in_features=super().get_output_dim(),
                                                 out_features=projection_dim,
@@ -148,20 +146,6 @@ class TransformersWordEmbedder(token_embedders.PretrainedTransformerMismatchedEm
     def get_output_dim(self):
         return self.output_dim
 
-    @overrides
-    def train(self, mode: bool):
-        if self.freeze_transformer:
-            self.projection_layer.train(mode)
-        else:
-            super().train(mode)
-
-    @overrides
-    def eval(self):
-        if self.freeze_transformer:
-            self.projection_layer.eval()
-        else:
-            super().eval()
-
 
 @token_embedders.TokenEmbedder.register("feats_embedding")
 class FeatsTokenEmbedder(token_embedders.Embedding):
diff --git a/combo/models/parser.py b/combo/models/parser.py
index dfb53ab8ded369b01eae4851dd1d7a9936c05bbe..511edffc2f8d17edbc3fd0702e6425a4ec645e4e 100644
--- a/combo/models/parser.py
+++ b/combo/models/parser.py
@@ -158,7 +158,7 @@ class DependencyRelationModel(base.Predictor):
             output["prediction"] = (relation_prediction.argmax(-1)[:, 1:], head_output["prediction"])
         else:
             # Mask root label whenever head is not 0.
-            relation_prediction_output = relation_prediction[:, 1:]
+            relation_prediction_output = relation_prediction[:, 1:].clone()
             mask = (head_output["prediction"] == 0)
             vocab_size = relation_prediction_output.size(-1)
             root_idx = torch.tensor([self.root_idx], device=device)
diff --git a/combo/predict.py b/combo/predict.py
index a5c99fd883b4ee40c5e9c76af44a7c7dbad85bdc..bd9f5d4637bf410bace029604afb915ed05311c2 100644
--- a/combo/predict.py
+++ b/combo/predict.py
@@ -24,7 +24,7 @@ class COMBO(predictor.Predictor):
                  model: models.Model,
                  dataset_reader: allen_data.DatasetReader,
                  tokenizer: allen_data.Tokenizer = tokenizers.WhitespaceTokenizer(),
-                 batch_size: int = 32,
+                 batch_size: int = 1024,
                  line_to_conllu: bool = True) -> None:
         super().__init__(model, dataset_reader)
         self.batch_size = batch_size
@@ -140,54 +140,56 @@ class COMBO(predictor.Predictor):
         tree = instance.fields["metadata"]["input"]
         field_names = instance.fields["metadata"]["field_names"]
         tree_tokens = [t for t in tree if isinstance(t["id"], int)]
-        for idx, token in enumerate(tree_tokens):
-            for field_name in field_names:
-                if field_name in predictions:
-                    if field_name in ["xpostag", "upostag", "semrel", "deprel"]:
-                        value = self.vocab.get_token_from_index(predictions[field_name][idx], field_name + "_labels")
-                        token[field_name] = value
-                    elif field_name in ["head"]:
-                        token[field_name] = int(predictions[field_name][idx])
-                    elif field_name == "deps":
-                        # Handled after every other decoding
-                        continue
-
-                    elif field_name in ["feats"]:
-                        slices = self._model.morphological_feat.slices
-                        features = []
-                        prediction = predictions[field_name][idx]
-                        for (cat, cat_indices), pred_idx in zip(slices.items(), prediction):
-                            if cat not in ["__PAD__", "_"]:
-                                value = self.vocab.get_token_from_index(cat_indices[pred_idx],
-                                                                        field_name + "_labels")
-                                # Exclude auxiliary values
-                                if "=None" not in value:
-                                    features.append(value)
-                        if len(features) == 0:
-                            field_value = "_"
-                        else:
-                            lowercase_features = [f.lower() for f in features]
-                            arg_indices = sorted(range(len(lowercase_features)), key=lowercase_features.__getitem__)
-                            field_value = "|".join(np.array(features)[arg_indices].tolist())
-
-                        token[field_name] = field_value
-                    elif field_name == "lemma":
-                        prediction = predictions[field_name][idx]
-                        word_chars = []
-                        for char_idx in prediction[1:-1]:
-                            pred_char = self.vocab.get_token_from_index(char_idx, "lemma_characters")
-
-                            if pred_char == "__END__":
-                                break
-                            elif pred_char == "__PAD__":
-                                continue
-                            elif "_" in pred_char:
-                                pred_char = "?"
-
-                            word_chars.append(pred_char)
-                        token[field_name] = "".join(word_chars)
+        for field_name in field_names:
+            if field_name not in predictions:
+                continue
+            field_predictions = predictions[field_name]
+            for idx, token in enumerate(tree_tokens):
+                if field_name in {"xpostag", "upostag", "semrel", "deprel"}:
+                    value = self.vocab.get_token_from_index(field_predictions[idx], field_name + "_labels")
+                    token[field_name] = value
+                elif field_name == "head":
+                    token[field_name] = int(field_predictions[idx])
+                elif field_name == "deps":
+                    # Handled after every other decoding
+                    continue
+
+                elif field_name == "feats":
+                    slices = self._model.morphological_feat.slices
+                    features = []
+                    prediction = field_predictions[idx]
+                    for (cat, cat_indices), pred_idx in zip(slices.items(), prediction):
+                        if cat not in ["__PAD__", "_"]:
+                            value = self.vocab.get_token_from_index(cat_indices[pred_idx],
+                                                                    field_name + "_labels")
+                            # Exclude auxiliary values
+                            if "=None" not in value:
+                                features.append(value)
+                    if len(features) == 0:
+                        field_value = "_"
                     else:
-                        raise NotImplementedError(f"Unknown field name {field_name}!")
+                        lowercase_features = [f.lower() for f in features]
+                        arg_indices = sorted(range(len(lowercase_features)), key=lowercase_features.__getitem__)
+                        field_value = "|".join(np.array(features)[arg_indices].tolist())
+
+                    token[field_name] = field_value
+                elif field_name == "lemma":
+                    prediction = field_predictions[idx]
+                    word_chars = []
+                    for char_idx in prediction[1:-1]:
+                        pred_char = self.vocab.get_token_from_index(char_idx, "lemma_characters")
+
+                        if pred_char == "__END__":
+                            break
+                        elif pred_char == "__PAD__":
+                            continue
+                        elif "_" in pred_char:
+                            pred_char = "?"
+
+                        word_chars.append(pred_char)
+                    token[field_name] = "".join(word_chars)
+                else:
+                    raise NotImplementedError(f"Unknown field name {field_name}!")
 
         if "enhanced_head" in predictions and predictions["enhanced_head"]:
             # TODO off-by-one hotfix, refactor
@@ -212,7 +214,7 @@ class COMBO(predictor.Predictor):
 
     @classmethod
     def from_pretrained(cls, path: str, tokenizer=tokenizers.SpacyTokenizer(),
-                        batch_size: int = 32,
+                        batch_size: int = 1024,
                         cuda_device: int = -1):
         util.import_module_and_submodules("combo.commands")
         util.import_module_and_submodules("combo.models")
diff --git a/combo/training/trainer.py b/combo/training/trainer.py
index aeb9f097369f515b3860d961f638870ec17f6786..26bd75f7fbe6917f144b820bbbb1c7e14c3c8e9d 100644
--- a/combo/training/trainer.py
+++ b/combo/training/trainer.py
@@ -230,22 +230,24 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
             patience: int = None,
             validation_metric: str = "-loss",
             num_epochs: int = 20,
-            cuda_device: int = -1,
+            cuda_device: Optional[Union[int, torch.device]] = -1,
             grad_norm: float = None,
             grad_clipping: float = None,
             distributed: bool = None,
             world_size: int = 1,
             num_gradient_accumulation_steps: int = 1,
-            opt_level: Optional[str] = None,
             use_amp: bool = False,
-            optimizer: common.Lazy[optimizers.Optimizer] = None,
+            no_grad: List[str] = None,
+            optimizer: common.Lazy[optimizers.Optimizer] = common.Lazy(optimizers.Optimizer.default),
             learning_rate_scheduler: common.Lazy[learning_rate_schedulers.LearningRateScheduler] = None,
             momentum_scheduler: common.Lazy[momentum_schedulers.MomentumScheduler] = None,
             tensorboard_writer: common.Lazy[allen_tensorboard_writer.TensorboardWriter] = None,
             moving_average: common.Lazy[moving_average.MovingAverage] = None,
-            checkpointer: common.Lazy[training.Checkpointer] = None,
+            checkpointer: common.Lazy[training.Checkpointer] = common.Lazy(training.Checkpointer),
             batch_callbacks: List[training.BatchCallback] = None,
             epoch_callbacks: List[training.EpochCallback] = None,
+            end_callbacks: List[training.EpochCallback] = None,
+            trainer_callbacks: List[training.TrainerCallback] = None,
     ) -> "training.Trainer":
         if tensorboard_writer is None:
             tensorboard_writer = common.Lazy(combo_tensorboard_writer.NullTensorboardWriter)
@@ -265,6 +267,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
             world_size=world_size,
             num_gradient_accumulation_steps=num_gradient_accumulation_steps,
             use_amp=use_amp,
+            no_grad=no_grad,
             optimizer=optimizer,
             learning_rate_scheduler=learning_rate_scheduler,
             momentum_scheduler=momentum_scheduler,
@@ -273,4 +276,6 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
             checkpointer=checkpointer,
             batch_callbacks=batch_callbacks,
             epoch_callbacks=epoch_callbacks,
+            end_callbacks=end_callbacks,
+            trainer_callbacks=trainer_callbacks,
         )
diff --git a/combo/utils/graph.py b/combo/utils/graph.py
index 651c14a7d79b7ea3c277b9466f5e050435a7a01b..3352625e6665ca1cd3196506ed5e50183fedfbb0 100644
--- a/combo/utils/graph.py
+++ b/combo/utils/graph.py
@@ -88,6 +88,7 @@ def _dfs(graph, start, end):
 
 
 def restore_collapse_edges(tree_tokens):
+    # https://gist.github.com/hankcs/776e7d95c19e5ff5da8469fe4e9ab050
     empty_tokens = []
     for token in tree_tokens:
         deps = token["deps"].split("|")