From c0835180c3092ba52b19ab8b98a488468915f0a0 Mon Sep 17 00:00:00 2001
From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com>
Date: Wed, 6 Jan 2021 11:12:38 +0100
Subject: [PATCH] Fix training loops and metrics.

---
 combo/training/checkpointer.py | 1 +
 combo/training/trainer.py      | 4 ++--
 combo/utils/metrics.py         | 6 +++---
 3 files changed, 6 insertions(+), 5 deletions(-)

diff --git a/combo/training/checkpointer.py b/combo/training/checkpointer.py
index c148ed6..bae4403 100644
--- a/combo/training/checkpointer.py
+++ b/combo/training/checkpointer.py
@@ -16,6 +16,7 @@ class FinishingTrainingCheckpointer(training.Checkpointer):
             epoch: Union[int, str],
             trainer: "allen_trainer.Trainer",
             is_best_so_far: bool = False,
+            save_model_only: bool = False,
     ) -> None:
         if trainer._learning_rate_scheduler.decreases <= 1 or epoch == trainer._num_epochs - 1:
             super().save_checkpoint(epoch, trainer, is_best_so_far)
diff --git a/combo/training/trainer.py b/combo/training/trainer.py
index f74873e..3bee8fc 100644
--- a/combo/training/trainer.py
+++ b/combo/training/trainer.py
@@ -84,7 +84,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
         logger.info("Beginning training.")
 
         val_metrics: Dict[str, float] = {}
-        this_epoch_val_metric: float
+        this_epoch_val_metric: float = None
         metrics: Dict[str, Any] = {}
         epochs_trained = 0
         training_start_time = time.time()
@@ -141,7 +141,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
 
                         # Check validation metric for early stopping
                         this_epoch_val_metric = val_metrics[self._validation_metric]
-                        self._metric_tracker.add_metric(this_epoch_val_metric)
+                        # self._metric_tracker.add_metric(this_epoch_val_metric)
 
                 train_metrics["patience"] = self._metric_tracker._patience
                 if self._metric_tracker.should_stop_early():
diff --git a/combo/utils/metrics.py b/combo/utils/metrics.py
index 682e885..1a17540 100644
--- a/combo/utils/metrics.py
+++ b/combo/utils/metrics.py
@@ -241,10 +241,10 @@ class SemanticMetrics(metrics.Metric):
                            self.feats_score.correct_indices *
                            self.lemma_score.correct_indices *
                            self.attachment_scores.correct_indices *
-                           enhanced_indices)
+                           enhanced_indices) * mask.flatten()
 
-        total, correct_indices = self.detach_tensors(total, correct_indices)
-        self.em_score = (correct_indices.float().sum() / total).item()
+        total, correct_indices = self.detach_tensors(total, correct_indices.float().sum())
+        self.em_score = (correct_indices / total).item()
 
     def get_metric(self, reset: bool) -> Dict[str, float]:
         metrics_dict = {
-- 
GitLab