From c0835180c3092ba52b19ab8b98a488468915f0a0 Mon Sep 17 00:00:00 2001 From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com> Date: Wed, 6 Jan 2021 11:12:38 +0100 Subject: [PATCH] Fix training loops and metrics. --- combo/training/checkpointer.py | 1 + combo/training/trainer.py | 4 ++-- combo/utils/metrics.py | 6 +++--- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/combo/training/checkpointer.py b/combo/training/checkpointer.py index c148ed6..bae4403 100644 --- a/combo/training/checkpointer.py +++ b/combo/training/checkpointer.py @@ -16,6 +16,7 @@ class FinishingTrainingCheckpointer(training.Checkpointer): epoch: Union[int, str], trainer: "allen_trainer.Trainer", is_best_so_far: bool = False, + save_model_only: bool = False, ) -> None: if trainer._learning_rate_scheduler.decreases <= 1 or epoch == trainer._num_epochs - 1: super().save_checkpoint(epoch, trainer, is_best_so_far) diff --git a/combo/training/trainer.py b/combo/training/trainer.py index f74873e..3bee8fc 100644 --- a/combo/training/trainer.py +++ b/combo/training/trainer.py @@ -84,7 +84,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer): logger.info("Beginning training.") val_metrics: Dict[str, float] = {} - this_epoch_val_metric: float + this_epoch_val_metric: float = None metrics: Dict[str, Any] = {} epochs_trained = 0 training_start_time = time.time() @@ -141,7 +141,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer): # Check validation metric for early stopping this_epoch_val_metric = val_metrics[self._validation_metric] - self._metric_tracker.add_metric(this_epoch_val_metric) + # self._metric_tracker.add_metric(this_epoch_val_metric) train_metrics["patience"] = self._metric_tracker._patience if self._metric_tracker.should_stop_early(): diff --git a/combo/utils/metrics.py b/combo/utils/metrics.py index 682e885..1a17540 100644 --- a/combo/utils/metrics.py +++ b/combo/utils/metrics.py @@ -241,10 +241,10 @@ class SemanticMetrics(metrics.Metric): self.feats_score.correct_indices * self.lemma_score.correct_indices * self.attachment_scores.correct_indices * - enhanced_indices) + enhanced_indices) * mask.flatten() - total, correct_indices = self.detach_tensors(total, correct_indices) - self.em_score = (correct_indices.float().sum() / total).item() + total, correct_indices = self.detach_tensors(total, correct_indices.float().sum()) + self.em_score = (correct_indices / total).item() def get_metric(self, reset: bool) -> Dict[str, float]: metrics_dict = { -- GitLab