From 512bd5e2b59063634b690f17deba15199e014f81 Mon Sep 17 00:00:00 2001 From: pszenny <pszenny@e-science.pl> Date: Fri, 14 Jan 2022 18:49:36 +0100 Subject: [PATCH] learning rate test fix --- combo/training/trainer.py | 44 +++++++++++++++++++-------------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/combo/training/trainer.py b/combo/training/trainer.py index ae4c0d0..1b6cdd7 100644 --- a/combo/training/trainer.py +++ b/combo/training/trainer.py @@ -33,20 +33,26 @@ class TransferPatienceEpochCallback(training.EpochCallback): def __call__(self, trainer: "training.GradientDescentTrainer", metrics: Dict[str, Any], epoch: int, is_master: bool) -> None: - #LR range test variables path_to_result_file = "/tmp/lustre_shared/lukasz/tmp/LR_range_3.txt" - end_lr = 0.001 - start_lr = 0.0000001 - num_lr_in_test = 25 - lr_update_factor = (end_lr / start_lr) ** (1.0 / num_lr_in_test) - # # # # # # # - param_group = trainer.optimizer.param_groups if epoch == 0: with open(path_to_result_file, "a") as file: file.write("\n" + str(param_group[0]["lr"]) + ";" + str(param_group[1]["lr"]) + ";") else: + + # LR range test variables + path_to_result_file = "/tmp/lustre_shared/lukasz/tmp/LR_range_3.txt" + end_lr = 0.001 + start_lr = 0.0000001 + num_lr_in_test = 25 + epochs_to_change_lr = trainer.epochs_to_change_lr + lr_update_factor = (end_lr / start_lr) ** (1.0 / (num_lr_in_test - 1)) + test_number = int(epoch / epochs_to_change_lr) + encoder_exponent = test_number % num_lr_in_test + rest_exponent = int(test_number / num_lr_in_test) + # # # # # # # + with open(path_to_result_file, "a") as file: file.write(str(metrics["training_loss"]) + ";" + str(metrics["validation_loss"]) + ";" + @@ -84,13 +90,11 @@ class TransferPatienceEpochCallback(training.EpochCallback): ) # END CONDITIONS - if param_group[1]["lr"] >= end_lr and param_group[0]["lr"] >= end_lr: + if param_group[1]["lr"] > end_lr and param_group[0]["lr"] > end_lr: raise Exception('End of LR test') - param_group[0]["lr"] = param_group[0]["lr"] * lr_update_factor - if param_group[0]["lr"] >= end_lr: - param_group[0]["lr"] = start_lr - param_group[1]["lr"] = param_group[1]["lr"] * lr_update_factor + param_group[0]["lr"] = start_lr * (lr_update_factor ** encoder_exponent) + param_group[1]["lr"] = start_lr * (lr_update_factor ** rest_exponent) file.write("\n" + str(param_group[0]["lr"]) + ";" + str(param_group[1]["lr"]) + ";") @@ -124,6 +128,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer): num_gradient_accumulation_steps, use_amp) # TODO extract param to constructor (+ constructor method?) self.validate_every_n = 1 + self.epochs_to_change_lr = 15 @overrides def _try_train(self) -> Dict[str, Any]: @@ -153,12 +158,6 @@ class GradientDescentTrainer(training.GradientDescentTrainer): for epoch in range(epoch_counter, self._num_epochs): - epochs_to_change_lr = 15 - - # every epochs_to_change_lr epoch loads weights after 1 epoch - if (epoch - 1) % epochs_to_change_lr == 0 and epoch > 1: - self.model.load_state_dict(torch.load(os.path.join(self._serialization_dir, "initial.th"))) - epoch_start_time = time.time() train_metrics = self._train_epoch(epoch) @@ -195,7 +194,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer): for key, value in val_metrics.items(): metrics["validation_" + key] = value - if self._metric_tracker.is_best_so_far(): + if self._metric_tracker.is_best_so_far() or (epoch % self.epochs_to_change_lr == 1 and epoch > 1): # Update all the best_ metrics. # (Otherwise they just stay the same as they were.) metrics["best_epoch"] = epoch @@ -205,7 +204,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer): self._metric_tracker.best_epoch_metrics = val_metrics for callback in self._epoch_callbacks: - if ((epoch-1) % epochs_to_change_lr == 0 and epoch > 1 ) or epoch == 0: + if epoch % self.epochs_to_change_lr == 0: callback(self, metrics=metrics, epoch=epoch, is_master=self._master) # The Scheduler API is agnostic to whether your schedule requires a validation metric - @@ -231,8 +230,9 @@ class GradientDescentTrainer(training.GradientDescentTrainer): if epoch == 0: torch.save(self.model.state_dict(), os.path.join(self._serialization_dir, "initial.th")) - if (epoch-1) % epochs_to_change_lr == 0 and epoch > 1: - self._metric_tracker.best_epoch_metrics = val_metrics + # every epochs_to_change_lr epoch loads weights after 1 epoch + if epoch % self.epochs_to_change_lr == 0 and epoch > 1: + self.model.load_state_dict(torch.load(os.path.join(self._serialization_dir, "initial.th"))) for callback in self._end_callbacks: callback(self, metrics=metrics, epoch=epoch, is_master=self._master) -- GitLab