Skip to content
Snippets Groups Projects
Commit c0835180 authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski
Browse files

Fix training loops and metrics.

parent b8c83784
No related branches found
No related tags found
2 merge requests!15Merge develop to master,!14Install and docs
This commit is part of merge request !15. Comments created here will be created in the context of that merge request.
...@@ -16,6 +16,7 @@ class FinishingTrainingCheckpointer(training.Checkpointer): ...@@ -16,6 +16,7 @@ class FinishingTrainingCheckpointer(training.Checkpointer):
epoch: Union[int, str], epoch: Union[int, str],
trainer: "allen_trainer.Trainer", trainer: "allen_trainer.Trainer",
is_best_so_far: bool = False, is_best_so_far: bool = False,
save_model_only: bool = False,
) -> None: ) -> None:
if trainer._learning_rate_scheduler.decreases <= 1 or epoch == trainer._num_epochs - 1: if trainer._learning_rate_scheduler.decreases <= 1 or epoch == trainer._num_epochs - 1:
super().save_checkpoint(epoch, trainer, is_best_so_far) super().save_checkpoint(epoch, trainer, is_best_so_far)
......
...@@ -84,7 +84,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer): ...@@ -84,7 +84,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
logger.info("Beginning training.") logger.info("Beginning training.")
val_metrics: Dict[str, float] = {} val_metrics: Dict[str, float] = {}
this_epoch_val_metric: float this_epoch_val_metric: float = None
metrics: Dict[str, Any] = {} metrics: Dict[str, Any] = {}
epochs_trained = 0 epochs_trained = 0
training_start_time = time.time() training_start_time = time.time()
...@@ -141,7 +141,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer): ...@@ -141,7 +141,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
# Check validation metric for early stopping # Check validation metric for early stopping
this_epoch_val_metric = val_metrics[self._validation_metric] this_epoch_val_metric = val_metrics[self._validation_metric]
self._metric_tracker.add_metric(this_epoch_val_metric) # self._metric_tracker.add_metric(this_epoch_val_metric)
train_metrics["patience"] = self._metric_tracker._patience train_metrics["patience"] = self._metric_tracker._patience
if self._metric_tracker.should_stop_early(): if self._metric_tracker.should_stop_early():
......
...@@ -241,10 +241,10 @@ class SemanticMetrics(metrics.Metric): ...@@ -241,10 +241,10 @@ class SemanticMetrics(metrics.Metric):
self.feats_score.correct_indices * self.feats_score.correct_indices *
self.lemma_score.correct_indices * self.lemma_score.correct_indices *
self.attachment_scores.correct_indices * self.attachment_scores.correct_indices *
enhanced_indices) enhanced_indices) * mask.flatten()
total, correct_indices = self.detach_tensors(total, correct_indices) total, correct_indices = self.detach_tensors(total, correct_indices.float().sum())
self.em_score = (correct_indices.float().sum() / total).item() self.em_score = (correct_indices / total).item()
def get_metric(self, reset: bool) -> Dict[str, float]: def get_metric(self, reset: bool) -> Dict[str, float]:
metrics_dict = { metrics_dict = {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment