Skip to content
Snippets Groups Projects

Herbert configuration and AllenNLP 1.2.0 update.

Merged Mateusz Klimaszewski requested to merge herberta-config into master
Viewing commit 9d339859
Show latest version
3 files
+ 14
10
Compare changes
  • Side-by-side
  • Inline
Files
3
+ 6
6
@@ -54,12 +54,12 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
batch_callbacks: List[training.BatchCallback] = None,
epoch_callbacks: List[training.EpochCallback] = None, distributed: bool = False, local_rank: int = 0,
world_size: int = 1, num_gradient_accumulation_steps: int = 1,
opt_level: Optional[str] = None) -> None:
use_amp: bool = False) -> None:
super().__init__(model, optimizer, data_loader, patience, validation_metric, validation_data_loader, num_epochs,
serialization_dir, checkpointer, cuda_device, grad_norm, grad_clipping,
learning_rate_scheduler, momentum_scheduler, tensorboard_writer, moving_average,
batch_callbacks, epoch_callbacks, distributed, local_rank, world_size,
num_gradient_accumulation_steps, opt_level)
num_gradient_accumulation_steps, use_amp)
# TODO extract param to constructor (+ constructor method?)
self.validate_every_n = 5
@@ -125,7 +125,8 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
self.model,
val_loss,
val_reg_loss,
num_batches,
num_batches=num_batches,
batch_loss=None,
reset=True,
world_size=self._world_size,
cuda_device=self.cuda_device,
@@ -231,7 +232,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
world_size: int = 1,
num_gradient_accumulation_steps: int = 1,
opt_level: Optional[str] = None,
no_grad: List[str] = None,
use_amp: bool = False,
optimizer: common.Lazy[optimizers.Optimizer] = None,
learning_rate_scheduler: common.Lazy[learning_rate_schedulers.LearningRateScheduler] = None,
momentum_scheduler: common.Lazy[momentum_schedulers.MomentumScheduler] = None,
@@ -258,8 +259,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
distributed=distributed,
world_size=world_size,
num_gradient_accumulation_steps=num_gradient_accumulation_steps,
opt_level=opt_level,
no_grad=no_grad,
use_amp=use_amp,
optimizer=optimizer,
learning_rate_scheduler=learning_rate_scheduler,
momentum_scheduler=momentum_scheduler,