diff --git a/combo/training/trainer.py b/combo/training/trainer.py index 4bffa5f8a8c8e36077fb5d6ce996c1a0c7ab6fc7..1628337b4173bedbf161f6a75a7675281acc391c 100644 --- a/combo/training/trainer.py +++ b/combo/training/trainer.py @@ -208,61 +208,3 @@ class GradientDescentTrainer(training.GradientDescentTrainer): self.model.load_state_dict(best_model_state) return metrics, epoch - - # @classmethod - # def from_partial_objects( - # cls, - # model: model.Model, - # serialization_dir: str, - # data_loader: data.DataLoader, - # validation_data_loader: data.DataLoader = None, - # local_rank: int = 0, - # patience: int = None, - # validation_metric: Union[str, List[str]] = "-loss", - # num_epochs: int = 20, - # cuda_device: Optional[Union[int, torch.device]] = None, - # grad_norm: float = None, - # grad_clipping: float = None, - # distributed: bool = False, - # world_size: int = 1, - # num_gradient_accumulation_steps: int = 1, - # use_amp: bool = False, - # no_grad: List[str] = None, - # optimizer: common.Lazy[optimizers.Optimizer] = common.Lazy(optimizers.Optimizer.default), - # learning_rate_scheduler: common.Lazy[learning_rate_schedulers.LearningRateScheduler] = None, - # momentum_scheduler: common.Lazy[momentum_schedulers.MomentumScheduler] = None, - # moving_average: common.Lazy[moving_average.MovingAverage] = None, - # checkpointer: common.Lazy[checkpointer.Checkpointer] = common.Lazy(checkpointer.Checkpointer), - # callbacks: List[common.Lazy[training.TrainerCallback]] = None, - # enable_default_callbacks: bool = True, - # ) -> "training.Trainer": - # if tensorboard_writer is None: - # tensorboard_writer = common.Lazy(combo_tensorboard_writer.NullTensorboardWriter) - # return super().from_partial_objects( - # model=model, - # serialization_dir=serialization_dir, - # data_loader=data_loader, - # validation_data_loader=validation_data_loader, - # local_rank=local_rank, - # patience=patience, - # validation_metric=validation_metric, - # num_epochs=num_epochs, - # cuda_device=cuda_device, - # grad_norm=grad_norm, - # grad_clipping=grad_clipping, - # distributed=distributed, - # world_size=world_size, - # num_gradient_accumulation_steps=num_gradient_accumulation_steps, - # use_amp=use_amp, - # no_grad=no_grad, - # optimizer=optimizer, - # learning_rate_scheduler=learning_rate_scheduler, - # momentum_scheduler=momentum_scheduler, - # tensorboard_writer=tensorboard_writer, - # moving_average=moving_average, - # checkpointer=checkpointer, - # batch_callbacks=batch_callbacks, - # epoch_callbacks=epoch_callbacks, - # end_callbacks=end_callbacks, - # trainer_callbacks=trainer_callbacks, - # )