Skip to content
Snippets Groups Projects
Commit 512bd5e2 authored by Łukasz Pszenny's avatar Łukasz Pszenny
Browse files

learning rate test fix

parent 1a0f387d
No related merge requests found
Pipeline #4257 passed with stage
in 4 minutes and 55 seconds
...@@ -33,20 +33,26 @@ class TransferPatienceEpochCallback(training.EpochCallback): ...@@ -33,20 +33,26 @@ class TransferPatienceEpochCallback(training.EpochCallback):
def __call__(self, trainer: "training.GradientDescentTrainer", metrics: Dict[str, Any], epoch: int, def __call__(self, trainer: "training.GradientDescentTrainer", metrics: Dict[str, Any], epoch: int,
is_master: bool) -> None: is_master: bool) -> None:
#LR range test variables
path_to_result_file = "/tmp/lustre_shared/lukasz/tmp/LR_range_3.txt" path_to_result_file = "/tmp/lustre_shared/lukasz/tmp/LR_range_3.txt"
end_lr = 0.001
start_lr = 0.0000001
num_lr_in_test = 25
lr_update_factor = (end_lr / start_lr) ** (1.0 / num_lr_in_test)
# # # # # # #
param_group = trainer.optimizer.param_groups param_group = trainer.optimizer.param_groups
if epoch == 0: if epoch == 0:
with open(path_to_result_file, "a") as file: with open(path_to_result_file, "a") as file:
file.write("\n" + str(param_group[0]["lr"]) + ";" + str(param_group[1]["lr"]) + ";") file.write("\n" + str(param_group[0]["lr"]) + ";" + str(param_group[1]["lr"]) + ";")
else: else:
# LR range test variables
path_to_result_file = "/tmp/lustre_shared/lukasz/tmp/LR_range_3.txt"
end_lr = 0.001
start_lr = 0.0000001
num_lr_in_test = 25
epochs_to_change_lr = trainer.epochs_to_change_lr
lr_update_factor = (end_lr / start_lr) ** (1.0 / (num_lr_in_test - 1))
test_number = int(epoch / epochs_to_change_lr)
encoder_exponent = test_number % num_lr_in_test
rest_exponent = int(test_number / num_lr_in_test)
# # # # # # #
with open(path_to_result_file, "a") as file: with open(path_to_result_file, "a") as file:
file.write(str(metrics["training_loss"]) + ";" + file.write(str(metrics["training_loss"]) + ";" +
str(metrics["validation_loss"]) + ";" + str(metrics["validation_loss"]) + ";" +
...@@ -84,13 +90,11 @@ class TransferPatienceEpochCallback(training.EpochCallback): ...@@ -84,13 +90,11 @@ class TransferPatienceEpochCallback(training.EpochCallback):
) )
# END CONDITIONS # END CONDITIONS
if param_group[1]["lr"] >= end_lr and param_group[0]["lr"] >= end_lr: if param_group[1]["lr"] > end_lr and param_group[0]["lr"] > end_lr:
raise Exception('End of LR test') raise Exception('End of LR test')
param_group[0]["lr"] = param_group[0]["lr"] * lr_update_factor param_group[0]["lr"] = start_lr * (lr_update_factor ** encoder_exponent)
if param_group[0]["lr"] >= end_lr: param_group[1]["lr"] = start_lr * (lr_update_factor ** rest_exponent)
param_group[0]["lr"] = start_lr
param_group[1]["lr"] = param_group[1]["lr"] * lr_update_factor
file.write("\n" + str(param_group[0]["lr"]) + ";" + str(param_group[1]["lr"]) + ";") file.write("\n" + str(param_group[0]["lr"]) + ";" + str(param_group[1]["lr"]) + ";")
...@@ -124,6 +128,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer): ...@@ -124,6 +128,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
num_gradient_accumulation_steps, use_amp) num_gradient_accumulation_steps, use_amp)
# TODO extract param to constructor (+ constructor method?) # TODO extract param to constructor (+ constructor method?)
self.validate_every_n = 1 self.validate_every_n = 1
self.epochs_to_change_lr = 15
@overrides @overrides
def _try_train(self) -> Dict[str, Any]: def _try_train(self) -> Dict[str, Any]:
...@@ -153,12 +158,6 @@ class GradientDescentTrainer(training.GradientDescentTrainer): ...@@ -153,12 +158,6 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
for epoch in range(epoch_counter, self._num_epochs): for epoch in range(epoch_counter, self._num_epochs):
epochs_to_change_lr = 15
# every epochs_to_change_lr epoch loads weights after 1 epoch
if (epoch - 1) % epochs_to_change_lr == 0 and epoch > 1:
self.model.load_state_dict(torch.load(os.path.join(self._serialization_dir, "initial.th")))
epoch_start_time = time.time() epoch_start_time = time.time()
train_metrics = self._train_epoch(epoch) train_metrics = self._train_epoch(epoch)
...@@ -195,7 +194,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer): ...@@ -195,7 +194,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
for key, value in val_metrics.items(): for key, value in val_metrics.items():
metrics["validation_" + key] = value metrics["validation_" + key] = value
if self._metric_tracker.is_best_so_far(): if self._metric_tracker.is_best_so_far() or (epoch % self.epochs_to_change_lr == 1 and epoch > 1):
# Update all the best_ metrics. # Update all the best_ metrics.
# (Otherwise they just stay the same as they were.) # (Otherwise they just stay the same as they were.)
metrics["best_epoch"] = epoch metrics["best_epoch"] = epoch
...@@ -205,7 +204,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer): ...@@ -205,7 +204,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
self._metric_tracker.best_epoch_metrics = val_metrics self._metric_tracker.best_epoch_metrics = val_metrics
for callback in self._epoch_callbacks: for callback in self._epoch_callbacks:
if ((epoch-1) % epochs_to_change_lr == 0 and epoch > 1 ) or epoch == 0: if epoch % self.epochs_to_change_lr == 0:
callback(self, metrics=metrics, epoch=epoch, is_master=self._master) callback(self, metrics=metrics, epoch=epoch, is_master=self._master)
# The Scheduler API is agnostic to whether your schedule requires a validation metric - # The Scheduler API is agnostic to whether your schedule requires a validation metric -
...@@ -231,8 +230,9 @@ class GradientDescentTrainer(training.GradientDescentTrainer): ...@@ -231,8 +230,9 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
if epoch == 0: if epoch == 0:
torch.save(self.model.state_dict(), os.path.join(self._serialization_dir, "initial.th")) torch.save(self.model.state_dict(), os.path.join(self._serialization_dir, "initial.th"))
if (epoch-1) % epochs_to_change_lr == 0 and epoch > 1: # every epochs_to_change_lr epoch loads weights after 1 epoch
self._metric_tracker.best_epoch_metrics = val_metrics if epoch % self.epochs_to_change_lr == 0 and epoch > 1:
self.model.load_state_dict(torch.load(os.path.join(self._serialization_dir, "initial.th")))
for callback in self._end_callbacks: for callback in self._end_callbacks:
callback(self, metrics=metrics, epoch=epoch, is_master=self._master) callback(self, metrics=metrics, epoch=epoch, is_master=self._master)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment