diff --git a/combo/config.graph.template.jsonnet b/combo/config.graph.template.jsonnet
index a4725606e3cbc1917c46966ee3bf38833de969bc..f3f979b20ce0817b3c4645175f10b8bfa7abe619 100644
--- a/combo/config.graph.template.jsonnet
+++ b/combo/config.graph.template.jsonnet
@@ -403,17 +403,15 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
             betas: [0.9, 0.9],
         },
         patience: 1, # it will  be overwriten by callback
-        epoch_callbacks: [
-            { type: "transfer_patience" },
+         callbacks: [
+            { type: "transfer_patience" }
+            { type: "track_epoch_callback" },
+            if use_tensorboard then
+            { type: "tensorboard", should_log_parameter_statistics: false},
         ],
         learning_rate_scheduler: {
             type: "combo_scheduler",
         },
-        tensorboard_writer: if use_tensorboard then {
-            should_log_learning_rate: false,
-            should_log_parameter_statistics: false,
-            summary_interval: 100,
-        },
         validation_metric: "+EM",
     }),
     random_seed: 8787,
diff --git a/combo/config.template.jsonnet b/combo/config.template.jsonnet
index 4e44f42bbac89d4bf852d31fee29d5d2ef4d4671..bfa39e0629282db8a0cf767a08e843ce45e1b14a 100644
--- a/combo/config.template.jsonnet
+++ b/combo/config.template.jsonnet
@@ -165,6 +165,7 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
     },
     # Data loader configuration
     data_loader: {
+        type: "multiprocess",
         batch_sampler: {
             type: "token_count",
             word_batch_size: word_batch_size,
@@ -366,6 +367,7 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
         type: "gradient_descent_validate_n",
         cuda_device: cuda_device,
         grad_clipping: 5.0,
+        enable_default_callbacks: {},
         num_epochs: num_epochs,
         optimizer: {
             type: "adam",
@@ -373,17 +375,15 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
             betas: [0.9, 0.9],
         },
         patience: 1, # it will  be overwriten by callback
-        epoch_callbacks: [
+        callbacks: [
             { type: "transfer_patience" },
+            { type: "track_epoch_callback" },
+            if use_tensorboard then
+            { type: "tensorboard", should_log_parameter_statistics: false},
         ],
         learning_rate_scheduler: {
             type: "combo_scheduler",
         },
-        tensorboard_writer: if use_tensorboard then {
-            should_log_learning_rate: false,
-            should_log_parameter_statistics: false,
-            summary_interval: 100,
-        },
         validation_metric: "+EM",
     }),
     random_seed: 8787,
diff --git a/combo/data/samplers/samplers.py b/combo/data/samplers/samplers.py
index 5db74d93d4fccc5de3fa72e033f152c86fb981bf..dcb83ee4725fe177e21b39e8a803783086fce5a5 100644
--- a/combo/data/samplers/samplers.py
+++ b/combo/data/samplers/samplers.py
@@ -1,23 +1,20 @@
-from typing import List
+from typing import List, Sequence, Iterable
 
 import numpy as np
 
-from allennlp import data as allen_data
+from allennlp import data as allen_data, data
 
 
 @allen_data.BatchSampler.register("token_count")
 class TokenCountBatchSampler(allen_data.BatchSampler):
 
-    def __init__(self, dataset, word_batch_size: int = 2500, shuffle_dataset: bool = True):
+    def __init__(self, word_batch_size: int = 2500, shuffle_dataset: bool = True):
         self._index = 0
-        self.shuffle_dataset = shuffle_dataset
-        self.batch_dataset = self._batchify(dataset, word_batch_size)
-        if shuffle_dataset:
-            self._shuffle()
-
-    @staticmethod
-    def _batchify(dataset, word_batch_size) -> List[List[int]]:
-        dataset = list(dataset)
+        self._word_batch_size = word_batch_size
+        self._shuffle = shuffle_dataset
+
+    def get_batch_indices(self, instances: Sequence[data.Instance]) -> Iterable[List[int]]:
+        dataset = list(instances)
         batches = []
         batch = []
         words_count = 0
@@ -26,29 +23,29 @@ class TokenCountBatchSampler(allen_data.BatchSampler):
         for idx in argsorted_lengths:
             words_count += lengths[idx]
             batch.append(idx)
-            if words_count > word_batch_size:
+            if words_count > self._word_batch_size:
                 batches.append(batch)
                 words_count = 0
                 batch = []
-        return batches
-
-    def __iter__(self):
-        return self
-
-    def __next__(self):
-        if self._index >= len(self.batch_dataset):
-            if self.shuffle_dataset:
-                self._index = 0
-                self._shuffle()
-            raise StopIteration()
 
-        batch = self.batch_dataset[self._index]
-        self._index += 1
-        return batch
+        if self._shuffle:
+            indices = np.random.permutation(range(len(batches)))
+            batches = np.array(batches)[indices].tolist()
 
-    def _shuffle(self):
-        indices = np.random.permutation(range(len(self.batch_dataset)))
-        self.batch_dataset = np.array(self.batch_dataset)[indices].tolist()
+        return batches
 
-    def __len__(self):
-        return len(self.batch_dataset)
+    def get_num_batches(self, instances: Sequence[data.Instance]) -> int:
+        dataset = list(instances)
+        batches = []
+        batch = []
+        words_count = 0
+        lengths = [len(instance.fields["sentence"].tokens) for instance in dataset]
+        argsorted_lengths = np.argsort(lengths)
+        for idx in argsorted_lengths:
+            words_count += lengths[idx]
+            batch.append(idx)
+            if words_count > self._word_batch_size:
+                batches.append(batch)
+                words_count = 0
+                batch = []
+        return len(batches)
diff --git a/combo/training/checkpointer.py b/combo/training/checkpointer.py
index bae4403f6a5cfa0bc3482f2670a221478cd2cae3..3109b27f79de6f3ce2286a5ce315786c11a6ff33 100644
--- a/combo/training/checkpointer.py
+++ b/combo/training/checkpointer.py
@@ -2,6 +2,7 @@ from typing import Union, Any, Dict, Tuple
 
 from allennlp import training
 from allennlp.training import trainer as allen_trainer
+from overrides import overrides
 
 
 @training.Checkpointer.register("finishing_only_checkpointer")
@@ -14,17 +15,15 @@ class FinishingTrainingCheckpointer(training.Checkpointer):
     def save_checkpoint(
             self,
             epoch: Union[int, str],
-            trainer: "allen_trainer.Trainer",
-            is_best_so_far: bool = False,
-            save_model_only: bool = False,
+            trainer: "allen_trainer.Trainer"
     ) -> None:
         if trainer._learning_rate_scheduler.decreases <= 1 or epoch == trainer._num_epochs - 1:
-            super().save_checkpoint(epoch, trainer, is_best_so_far)
+            super().save_checkpoint(trainer)
 
     def restore_checkpoint(self) -> Tuple[Dict[str, Any], Dict[str, Any]]:
         return {}, {}
 
     def maybe_save_checkpoint(
-            self, trainer: "allen_trainer.Trainer", epoch: int, batches_this_epoch: int
+            self, trainer: "allen_trainer.Trainer", num_epochs_completed: int, num_batches_in_epoch_completed: int,
     ) -> None:
         pass
diff --git a/combo/training/tensorboard_writer.py b/combo/training/tensorboard_writer.py
deleted file mode 100644
index d7ab2c5d1bbc1255d2505b833a2497606a99862e..0000000000000000000000000000000000000000
--- a/combo/training/tensorboard_writer.py
+++ /dev/null
@@ -1,94 +0,0 @@
-import logging
-from typing import Dict, Optional, List
-
-import torch
-from allennlp import models, common
-from allennlp.data import dataloader
-from allennlp.training import optimizers
-
-logger = logging.getLogger(__name__)
-
-
-class NullTensorboardWriter(common.FromParams):
-
-    def log_batch(
-            self,
-            model: models.Model,
-            optimizer: optimizers.Optimizer,
-            batch_grad_norm: Optional[float],
-            metrics: Dict[str, float],
-            batch_group: List[List[dataloader.TensorDict]],
-            param_updates: Optional[Dict[str, torch.Tensor]],
-    ) -> None:
-        pass
-
-    def reset_epoch(self) -> None:
-        pass
-
-    def should_log_this_batch(self) -> bool:
-        return False
-
-    def should_log_histograms_this_batch(self) -> bool:
-        return False
-
-    def add_train_scalar(self, name: str, value: float, timestep: int = None) -> None:
-        pass
-
-    def add_train_histogram(self, name: str, values: torch.Tensor) -> None:
-        pass
-
-    def add_validation_scalar(self, name: str, value: float, timestep: int = None) -> None:
-        pass
-
-    def log_parameter_and_gradient_statistics(self, model: models.Model, batch_grad_norm: float) -> None:
-        pass
-
-    def log_learning_rates(self, model: models.Model, optimizer: torch.optim.Optimizer):
-        pass
-
-    def log_histograms(self, model: models.Model) -> None:
-        pass
-
-    def log_gradient_updates(self, model: models.Model, param_updates: Dict[str, torch.Tensor]) -> None:
-        pass
-
-    def log_metrics(
-        self,
-        train_metrics: dict,
-        val_metrics: dict = None,
-        epoch: int = None,
-        log_to_console: bool = False,
-    ) -> None:
-        metric_names = set(train_metrics.keys())
-        if val_metrics is not None:
-            metric_names.update(val_metrics.keys())
-        val_metrics = val_metrics or {}
-
-        if log_to_console:
-            dual_message_template = "%s |  %8.3f  |  %8.3f"
-            no_val_message_template = "%s |  %8.3f  |  %8s"
-            no_train_message_template = "%s |  %8s  |  %8.3f"
-            header_template = "%s |  %-10s"
-            name_length = max(len(x) for x in metric_names)
-            logger.info(header_template, "Training".rjust(name_length + 13), "Validation")
-
-            for name in metric_names:
-                train_metric = train_metrics.get(name)
-                val_metric = val_metrics.get(name)
-                if val_metric is not None and train_metric is not None:
-                    logger.info(
-                        dual_message_template, name.ljust(name_length), train_metric, val_metric
-                    )
-                elif val_metric is not None:
-                    logger.info(no_train_message_template, name.ljust(name_length), "N/A", val_metric)
-                elif train_metric is not None:
-                    logger.info(no_val_message_template, name.ljust(name_length), train_metric, "N/A")
-
-    def enable_activation_logging(self, model: models.Model) -> None:
-        pass
-
-    def log_activation_histogram(self, outputs, log_prefix: str) -> None:
-        pass
-
-    def close(self) -> None:
-        pass
diff --git a/combo/training/trainer.py b/combo/training/trainer.py
index 3bee8fcaaf1c29189583a551a1100ac4f0215a65..e2876c02b67c4e59bfe6af5a106d12a4561a56bf 100644
--- a/combo/training/trainer.py
+++ b/combo/training/trainer.py
@@ -3,35 +3,39 @@ import logging
 import os
 import time
 import traceback
-from typing import Any, Dict, List, Optional, Union
+from typing import Any, Dict, List, Optional, Union, Tuple
 
 import torch
 import torch.distributed as dist
-import torch.optim as optim
 import torch.optim.lr_scheduler
 import torch.utils.data as data
-from allennlp import training, common
+from allennlp import training
 from allennlp.common import checks
 from allennlp.common import util as common_util
+from allennlp.common.file_utils import hardlink_or_copy
 from allennlp.models import model
-from allennlp.training import checkpointer, optimizers
+from allennlp.training import checkpointer
 from allennlp.training import learning_rate_schedulers
 from allennlp.training import momentum_schedulers
 from allennlp.training import moving_average
-from allennlp.training import tensorboard_writer as allen_tensorboard_writer
 from allennlp.training import util as training_util
+from allennlp.nn.parallel import DdpWrappedModel
 from overrides import overrides
 
-from combo.training import tensorboard_writer as combo_tensorboard_writer
-
 logger = logging.getLogger(__name__)
 
 
-@training.EpochCallback.register("transfer_patience")
-class TransferPatienceEpochCallback(training.EpochCallback):
+@training.TrainerCallback.register("transfer_patience")
+class TransferPatienceEpochCallback(training.TrainerCallback):
 
-    def __call__(self, trainer: "training.GradientDescentTrainer", metrics: Dict[str, Any], epoch: int,
-                 is_master: bool) -> None:
+    def on_epoch(
+        self,
+        trainer: "GradientDescentTrainer",
+        metrics: Dict[str, Any],
+        epoch: int,
+        is_primary: bool = True,
+        **kwargs,
+    ) -> None:
         if trainer._learning_rate_scheduler and trainer._learning_rate_scheduler.patience is not None:
             trainer._metric_tracker._patience = trainer._learning_rate_scheduler.patience
             trainer._metric_tracker._epochs_with_no_improvement = 0
@@ -42,35 +46,34 @@ class TransferPatienceEpochCallback(training.EpochCallback):
 @training.Trainer.register("gradient_descent_validate_n", constructor="from_partial_objects")
 class GradientDescentTrainer(training.GradientDescentTrainer):
 
-    def __init__(self, model: model.Model, optimizer: optim.Optimizer, data_loader: data.DataLoader,
-                 patience: Optional[int] = None, validation_metric: str = "-loss",
+    def __init__(self, model: model.Model, optimizer: torch.optim.Optimizer, data_loader: data.DataLoader,
+                 patience: Optional[int] = None, validation_metric: Union[str, List[str]] = "-loss",
                  validation_data_loader: data.DataLoader = None, num_epochs: int = 20,
                  serialization_dir: Optional[str] = None, checkpointer: checkpointer.Checkpointer = None,
-                 cuda_device: Optional[Union[int, torch.device]] = -1,
-                 grad_norm: Optional[float] = None, grad_clipping: Optional[float] = None,
+                 cuda_device: Optional[Union[int, torch.device]] = None, grad_norm: Optional[float] = None,
+                 grad_clipping: Optional[float] = None,
                  learning_rate_scheduler: Optional[learning_rate_schedulers.LearningRateScheduler] = None,
                  momentum_scheduler: Optional[momentum_schedulers.MomentumScheduler] = None,
-                 tensorboard_writer: allen_tensorboard_writer.TensorboardWriter = None,
                  moving_average: Optional[moving_average.MovingAverage] = None,
-                 batch_callbacks: List[training.BatchCallback] = None,
-                 epoch_callbacks: List[training.EpochCallback] = None,
-                 end_callbacks: List[training.EpochCallback] = None,
-                 trainer_callbacks: List[training.TrainerCallback] = None,
-                 distributed: bool = False, local_rank: int = 0,
-                 world_size: int = 1, num_gradient_accumulation_steps: int = 1,
-                 use_amp: bool = False) -> None:
+                 callbacks: List[training.TrainerCallback] = None, distributed: bool = False, local_rank: int = 0,
+                 world_size: int = 1, num_gradient_accumulation_steps: int = 1, use_amp: bool = False,
+                 enable_default_callbacks: bool = True, run_confidence_checks: bool = True,
+                 grad_scaling: bool = True,ddp_wrapped_model: Optional[DdpWrappedModel] = None) -> None:
+        # TODO extract param to constructor (+ constructor method?)
         super().__init__(model, optimizer, data_loader, patience, validation_metric, validation_data_loader, num_epochs,
                          serialization_dir, checkpointer, cuda_device, grad_norm, grad_clipping,
-                         learning_rate_scheduler, momentum_scheduler, tensorboard_writer, moving_average,
-                         batch_callbacks, epoch_callbacks, end_callbacks, trainer_callbacks, distributed, local_rank, world_size,
-                         num_gradient_accumulation_steps, use_amp)
-        # TODO extract param to constructor (+ constructor method?)
+                         learning_rate_scheduler, momentum_scheduler, moving_average, callbacks, distributed,
+                         local_rank, world_size, num_gradient_accumulation_steps, use_amp,enable_default_callbacks,
+                         run_confidence_checks,grad_scaling,ddp_wrapped_model)
         self.validate_every_n = 5
 
+
+
+
     @overrides
-    def _try_train(self) -> Dict[str, Any]:
+    def _try_train(self) -> Tuple[Dict[str, Any], int]:
         try:
-            epoch_counter = self._restore_checkpoint()
+            epoch_counter = self._maybe_restore_checkpoint()
         except RuntimeError:
             traceback.print_exc()
             raise checks.ConfigurationError(
@@ -84,7 +87,6 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
         logger.info("Beginning training.")
 
         val_metrics: Dict[str, float] = {}
-        this_epoch_val_metric: float = None
         metrics: Dict[str, Any] = {}
         epochs_trained = 0
         training_start_time = time.time()
@@ -93,15 +95,12 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
         for key, value in self._metric_tracker.best_epoch_metrics.items():
             metrics["best_validation_" + key] = value
 
-        for callback in self._epoch_callbacks:
-            callback(self, metrics={}, epoch=-1, is_master=self._master)
-
-        for epoch in range(epoch_counter, self._num_epochs):
+        for epoch in range(self._num_epochs):
             epoch_start_time = time.time()
             train_metrics = self._train_epoch(epoch)
 
-            if self._master and self._checkpointer is not None:
-                self._checkpointer.save_checkpoint(epoch, self, save_model_only=True)
+            if self._primary and self._checkpointer is not None:
+                self._checkpointer.save_checkpoint(epoch, self)
 
             # Wait for the master to finish saving the model checkpoint
             if self._distributed:
@@ -114,9 +113,9 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
                 elif key.startswith("worker_") and key.endswith("_memory_MB"):
                     metrics["peak_" + key] = max(metrics.get("peak_" + key, 0), value)
 
+            this_epoch_val_metric: float = 0.0
             if self._validation_data_loader is not None:
                 val_metrics = {}
-                this_epoch_val_metric = None
                 if epoch % self.validate_every_n == 0:
                     with torch.no_grad():
                         # We have a validation set, so compute all the metrics on it.
@@ -134,13 +133,13 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
                             batch_loss=None,
                             batch_reg_loss=None,
                             num_batches=num_batches,
-                            reset=True,
-                            world_size=self._world_size,
-                            cuda_device=self.cuda_device,
+                            reset=True#,
+                            #world_size=self._world_size,
+                            #cuda_device=self.cuda_device,
                         )
 
                         # Check validation metric for early stopping
-                        this_epoch_val_metric = val_metrics[self._validation_metric]
+                        this_epoch_val_metric = self._metric_tracker.combined_score(val_metrics)
                         # self._metric_tracker.add_metric(this_epoch_val_metric)
 
                 train_metrics["patience"] = self._metric_tracker._patience
@@ -148,11 +147,6 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
                     logger.info("Ran out of patience.  Stopping training.")
                     break
 
-            if self._master:
-                self._tensorboard.log_metrics(
-                    train_metrics, val_metrics=val_metrics, log_to_console=True, epoch=epoch + 1
-                )  # +1 because tensorboard doesn't like 0
-
             # Create overall metrics dict
             training_elapsed_time = time.time() - training_start_time
             metrics["training_duration"] = str(datetime.timedelta(seconds=training_elapsed_time))
@@ -174,9 +168,10 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
 
                 self._metric_tracker.best_epoch_metrics = val_metrics
 
-            if self._serialization_dir and self._master:
+            if self._serialization_dir and self._primary:
                 common_util.dump_metrics(
-                    os.path.join(self._serialization_dir, f"metrics_epoch_{epoch}.json"), metrics
+                    os.path.join(self._serialization_dir, f"metrics_epoch_{epoch}.json"),
+                    metrics,
                 )
 
             # The Scheduler API is agnostic to whether your schedule requires a validation metric -
@@ -186,100 +181,64 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
             if self._momentum_scheduler:
                 self._momentum_scheduler.step(this_epoch_val_metric)
 
-            if self._master and self._checkpointer is not None:
+            if self._primary and self._checkpointer is not None:
                 self._checkpointer.save_checkpoint(
-                    epoch, self, is_best_so_far=self._metric_tracker.is_best_so_far()
-                )
+                    epoch, self )
 
             # Wait for the master to finish saving the checkpoint
             if self._distributed:
                 dist.barrier()
 
-            for callback in self._epoch_callbacks:
-                callback(self, metrics=metrics, epoch=epoch, is_master=self._master)
+            if (
+                self._should_validate_this_epoch
+                and self._serialization_dir
+                and self._metric_tracker.is_best_so_far()
+            ):
+                if self._ddp_wrapped_model is not None and self._ddp_wrapped_model.is_sharded:
+                    # Each worker saves its own shard for now (we combine the shards later).
+                    self._best_model_filename = os.path.join(
+                        self._serialization_dir, f"best_w{self._rank}.th"
+                    )
+                else:
+                    self._best_model_filename = os.path.join(self._serialization_dir, "best.th")
+
+            # Wait for the primary process to finish saving the best
+            if self._distributed:
+                dist.barrier()
+
+            for callback in self._callbacks:
+                callback.on_epoch(self, metrics=metrics, epoch=epoch, is_primary=self._primary)
 
             epoch_elapsed_time = time.time() - epoch_start_time
             logger.info("Epoch duration: %s", datetime.timedelta(seconds=epoch_elapsed_time))
 
             if epoch < self._num_epochs - 1:
-                training_elapsed_time = time.time() - training_start_time
-                estimated_time_remaining = training_elapsed_time * (
-                        (self._num_epochs - epoch_counter) / float(epoch - epoch_counter + 1) - 1
+                time_per_epoch = training_elapsed_time / (
+                        (epoch + 1) - self._start_after_epochs_completed
                 )
+                estimated_time_remaining = (
+                    time_per_epoch * self._num_epochs
+                ) - training_elapsed_time
                 formatted_time = str(datetime.timedelta(seconds=int(estimated_time_remaining)))
                 logger.info("Estimated training time remaining: %s", formatted_time)
 
             epochs_trained += 1
+        else:
+            epoch = self._num_epochs - 1
 
-        for callback in self._end_callbacks:
-            callback(self, metrics=metrics, epoch=epoch, is_master=self._master)
+        if self._metric_tracker.is_best_so_far():
+            logger.info(
+                "Best validation performance so far. Copying weights to '%s/best.th'.",
+                self._serialization_dir,
+            )
+            model_state, training_states = self.get_checkpoint_state()
+            torch.save(model_state, os.path.join(self._serialization_dir, "best.th"))
 
         # Load the best model state before returning
-        best_model_state = (
-            None if self._checkpointer is None else self._checkpointer.best_model_state()
-        )
-        if best_model_state:
-            self.model.load_state_dict(best_model_state)
-
-        return metrics
-
-    @classmethod
-    def from_partial_objects(
-            cls,
-            model: model.Model,
-            serialization_dir: str,
-            data_loader: data.DataLoader,
-            validation_data_loader: data.DataLoader = None,
-            local_rank: int = 0,
-            patience: int = None,
-            validation_metric: str = "-loss",
-            num_epochs: int = 20,
-            cuda_device: Optional[Union[int, torch.device]] = -1,
-            grad_norm: float = None,
-            grad_clipping: float = None,
-            distributed: bool = None,
-            world_size: int = 1,
-            num_gradient_accumulation_steps: int = 1,
-            use_amp: bool = False,
-            no_grad: List[str] = None,
-            optimizer: common.Lazy[optimizers.Optimizer] = common.Lazy(optimizers.Optimizer.default),
-            learning_rate_scheduler: common.Lazy[learning_rate_schedulers.LearningRateScheduler] = None,
-            momentum_scheduler: common.Lazy[momentum_schedulers.MomentumScheduler] = None,
-            tensorboard_writer: common.Lazy[allen_tensorboard_writer.TensorboardWriter] = None,
-            moving_average: common.Lazy[moving_average.MovingAverage] = None,
-            checkpointer: common.Lazy[training.Checkpointer] = common.Lazy(training.Checkpointer),
-            batch_callbacks: List[training.BatchCallback] = None,
-            epoch_callbacks: List[training.EpochCallback] = None,
-            end_callbacks: List[training.EpochCallback] = None,
-            trainer_callbacks: List[training.TrainerCallback] = None,
-    ) -> "training.Trainer":
-        if tensorboard_writer is None:
-            tensorboard_writer = common.Lazy(combo_tensorboard_writer.NullTensorboardWriter)
-        return super().from_partial_objects(
-            model=model,
-            serialization_dir=serialization_dir,
-            data_loader=data_loader,
-            validation_data_loader=validation_data_loader,
-            local_rank=local_rank,
-            patience=patience,
-            validation_metric=validation_metric,
-            num_epochs=num_epochs,
-            cuda_device=cuda_device,
-            grad_norm=grad_norm,
-            grad_clipping=grad_clipping,
-            distributed=distributed,
-            world_size=world_size,
-            num_gradient_accumulation_steps=num_gradient_accumulation_steps,
-            use_amp=use_amp,
-            no_grad=no_grad,
-            optimizer=optimizer,
-            learning_rate_scheduler=learning_rate_scheduler,
-            momentum_scheduler=momentum_scheduler,
-            tensorboard_writer=tensorboard_writer,
-            moving_average=moving_average,
-            checkpointer=checkpointer,
-            batch_callbacks=batch_callbacks,
-            epoch_callbacks=epoch_callbacks,
-            end_callbacks=end_callbacks,
-            trainer_callbacks=trainer_callbacks,
-        )
+        if self._best_model_filename is None or self._metric_tracker.is_best_so_far():
+            self._finalize_model()
+        else:
+            # The model we're loading here has already been finalized.
+            self._load_model_state(self._best_model_filename)
+
+        return metrics, epoch
diff --git a/setup.py b/setup.py
index f4a82e59b63cbed2c1de8bebefe3e99d51db9f55..ef28081d5d2ec526a2882c05a4f239e3c7355954 100644
--- a/setup.py
+++ b/setup.py
@@ -3,26 +3,18 @@ from setuptools import find_packages, setup
 
 REQUIREMENTS = [
     'absl-py==0.9.0',
-    'allennlp==1.3.0',
+    'allennlp==2.9.0',
     'conllu==2.3.2',
     'dataclasses;python_version<"3.7"',
     'jsonnet==0.15.0',
-    'filelock==3.0;python_version>="3.9"',
-    'numpy==1.19.4;python_version<"3.9"',
-    'numpy==1.22.0;python_version>="3.9"',
     'overrides==3.1.0',
     'requests==2.23.0',
     'sentencepiece==0.1.83;python_version<"3.8"',
     'sentencepiece==0.1.85;python_version>="3.8" and python_version<"3.9"',
     'sentencepiece==0.1.94;python_version>="3.9"',
-    'scipy<1.6.0;python_version<"3.7"',  # SciPy 1.6.0 works for 3.7+
-    'scipy==1.6.0;python_version>="3.7"',
     'spacy==2.3.2',
     'scikit-learn<=0.23.2;python_version<"3.9"',
     'scikit-learn==0.23.2;python_version>="3.9"',
-    'torch==1.7.1',
-    'tqdm==4.43.0',
-    'transformers==4.0.1',
     'urllib3==1.25.11',
 ]
 
@@ -37,9 +29,7 @@ setup(
     url='https://gitlab.clarin-pl.eu/syntactic-tools/combo',
     keywords="nlp natural-language-processing dependency-parsing",
     setup_requires=['pytest-runner',
-    		    'pytest-pylint',
-    		    'numpy==1.22.0;python_version>="3.9"',
-    		    'scipy==1.6.0;python_version>="3.7"'],
+    		        'pytest-pylint'],
     tests_require=['pytest', 'pylint'],
     python_requires='>=3.6',
     package_data={'combo': ['config.graph.template.jsonnet', 'config.template.jsonnet']},