Skip to content
Snippets Groups Projects
Commit 000157b0 authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski
Browse files

Remove old trainer constructor, use default instead.

parent f3f26738
Branches
No related merge requests found
Pipeline #2881 passed with stage
in 3 minutes and 6 seconds
......@@ -208,61 +208,3 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
self.model.load_state_dict(best_model_state)
return metrics, epoch
# @classmethod
# def from_partial_objects(
# cls,
# model: model.Model,
# serialization_dir: str,
# data_loader: data.DataLoader,
# validation_data_loader: data.DataLoader = None,
# local_rank: int = 0,
# patience: int = None,
# validation_metric: Union[str, List[str]] = "-loss",
# num_epochs: int = 20,
# cuda_device: Optional[Union[int, torch.device]] = None,
# grad_norm: float = None,
# grad_clipping: float = None,
# distributed: bool = False,
# world_size: int = 1,
# num_gradient_accumulation_steps: int = 1,
# use_amp: bool = False,
# no_grad: List[str] = None,
# optimizer: common.Lazy[optimizers.Optimizer] = common.Lazy(optimizers.Optimizer.default),
# learning_rate_scheduler: common.Lazy[learning_rate_schedulers.LearningRateScheduler] = None,
# momentum_scheduler: common.Lazy[momentum_schedulers.MomentumScheduler] = None,
# moving_average: common.Lazy[moving_average.MovingAverage] = None,
# checkpointer: common.Lazy[checkpointer.Checkpointer] = common.Lazy(checkpointer.Checkpointer),
# callbacks: List[common.Lazy[training.TrainerCallback]] = None,
# enable_default_callbacks: bool = True,
# ) -> "training.Trainer":
# if tensorboard_writer is None:
# tensorboard_writer = common.Lazy(combo_tensorboard_writer.NullTensorboardWriter)
# return super().from_partial_objects(
# model=model,
# serialization_dir=serialization_dir,
# data_loader=data_loader,
# validation_data_loader=validation_data_loader,
# local_rank=local_rank,
# patience=patience,
# validation_metric=validation_metric,
# num_epochs=num_epochs,
# cuda_device=cuda_device,
# grad_norm=grad_norm,
# grad_clipping=grad_clipping,
# distributed=distributed,
# world_size=world_size,
# num_gradient_accumulation_steps=num_gradient_accumulation_steps,
# use_amp=use_amp,
# no_grad=no_grad,
# optimizer=optimizer,
# learning_rate_scheduler=learning_rate_scheduler,
# momentum_scheduler=momentum_scheduler,
# tensorboard_writer=tensorboard_writer,
# moving_average=moving_average,
# checkpointer=checkpointer,
# batch_callbacks=batch_callbacks,
# epoch_callbacks=epoch_callbacks,
# end_callbacks=end_callbacks,
# trainer_callbacks=trainer_callbacks,
# )
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment