From 512bd5e2b59063634b690f17deba15199e014f81 Mon Sep 17 00:00:00 2001
From: pszenny <pszenny@e-science.pl>
Date: Fri, 14 Jan 2022 18:49:36 +0100
Subject: [PATCH] learning rate test fix

---
 combo/training/trainer.py | 44 +++++++++++++++++++--------------------
 1 file changed, 22 insertions(+), 22 deletions(-)

diff --git a/combo/training/trainer.py b/combo/training/trainer.py
index ae4c0d0..1b6cdd7 100644
--- a/combo/training/trainer.py
+++ b/combo/training/trainer.py
@@ -33,20 +33,26 @@ class TransferPatienceEpochCallback(training.EpochCallback):
     def __call__(self, trainer: "training.GradientDescentTrainer", metrics: Dict[str, Any], epoch: int,
                  is_master: bool) -> None:
 
-        #LR range test variables
         path_to_result_file = "/tmp/lustre_shared/lukasz/tmp/LR_range_3.txt"
-        end_lr = 0.001
-        start_lr = 0.0000001
-        num_lr_in_test = 25
-        lr_update_factor = (end_lr / start_lr) ** (1.0 / num_lr_in_test)
-        # # # # # # #
-
         param_group = trainer.optimizer.param_groups
 
         if epoch == 0:
             with open(path_to_result_file, "a") as file:
                 file.write("\n" + str(param_group[0]["lr"]) + ";" + str(param_group[1]["lr"]) + ";")
         else:
+
+            # LR range test variables
+            path_to_result_file = "/tmp/lustre_shared/lukasz/tmp/LR_range_3.txt"
+            end_lr = 0.001
+            start_lr = 0.0000001
+            num_lr_in_test = 25
+            epochs_to_change_lr = trainer.epochs_to_change_lr
+            lr_update_factor = (end_lr / start_lr) ** (1.0 / (num_lr_in_test - 1))
+            test_number = int(epoch / epochs_to_change_lr)
+            encoder_exponent = test_number % num_lr_in_test
+            rest_exponent = int(test_number / num_lr_in_test)
+            # # # # # # #
+
             with open(path_to_result_file, "a") as file:
                 file.write(str(metrics["training_loss"]) + ";" +
                            str(metrics["validation_loss"]) + ";" +
@@ -84,13 +90,11 @@ class TransferPatienceEpochCallback(training.EpochCallback):
                            )
 
                 # END CONDITIONS
-                if param_group[1]["lr"] >= end_lr and param_group[0]["lr"] >= end_lr:
+                if param_group[1]["lr"] > end_lr and param_group[0]["lr"] > end_lr:
                     raise Exception('End of LR test')
 
-                param_group[0]["lr"] = param_group[0]["lr"] * lr_update_factor
-                if param_group[0]["lr"] >= end_lr:
-                    param_group[0]["lr"] = start_lr
-                    param_group[1]["lr"] = param_group[1]["lr"] * lr_update_factor
+                param_group[0]["lr"] = start_lr * (lr_update_factor ** encoder_exponent)
+                param_group[1]["lr"] = start_lr * (lr_update_factor ** rest_exponent)
 
                 file.write("\n" + str(param_group[0]["lr"]) + ";" + str(param_group[1]["lr"]) + ";")
 
@@ -124,6 +128,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
                          num_gradient_accumulation_steps, use_amp)
         # TODO extract param to constructor (+ constructor method?)
         self.validate_every_n = 1
+        self.epochs_to_change_lr = 15
 
     @overrides
     def _try_train(self) -> Dict[str, Any]:
@@ -153,12 +158,6 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
 
         for epoch in range(epoch_counter, self._num_epochs):
 
-            epochs_to_change_lr = 15
-
-            # every epochs_to_change_lr epoch loads weights after 1 epoch
-            if (epoch - 1) % epochs_to_change_lr == 0 and epoch > 1:
-                self.model.load_state_dict(torch.load(os.path.join(self._serialization_dir, "initial.th")))
-
             epoch_start_time = time.time()
             train_metrics = self._train_epoch(epoch)
 
@@ -195,7 +194,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
             for key, value in val_metrics.items():
                 metrics["validation_" + key] = value
 
-            if self._metric_tracker.is_best_so_far():
+            if self._metric_tracker.is_best_so_far() or (epoch % self.epochs_to_change_lr == 1 and epoch > 1):
                 # Update all the best_ metrics.
                 # (Otherwise they just stay the same as they were.)
                 metrics["best_epoch"] = epoch
@@ -205,7 +204,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
                 self._metric_tracker.best_epoch_metrics = val_metrics
 
             for callback in self._epoch_callbacks:
-                if ((epoch-1) % epochs_to_change_lr == 0 and epoch > 1 ) or epoch == 0:
+                if epoch % self.epochs_to_change_lr == 0:
                     callback(self, metrics=metrics, epoch=epoch, is_master=self._master)
 
             # The Scheduler API is agnostic to whether your schedule requires a validation metric -
@@ -231,8 +230,9 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
             if epoch == 0:
                 torch.save(self.model.state_dict(), os.path.join(self._serialization_dir, "initial.th"))
 
-            if (epoch-1) % epochs_to_change_lr == 0 and epoch > 1:
-                self._metric_tracker.best_epoch_metrics = val_metrics
+            # every epochs_to_change_lr epoch loads weights after 1 epoch
+            if epoch % self.epochs_to_change_lr == 0 and epoch > 1:
+                self.model.load_state_dict(torch.load(os.path.join(self._serialization_dir, "initial.th")))
 
         for callback in self._end_callbacks:
             callback(self, metrics=metrics, epoch=epoch, is_master=self._master)
-- 
GitLab