diff --git a/combo/training/checkpointer.py b/combo/training/checkpointer.py index c148ed664d4a14676bc516888d212da240bef3ea..bae4403f6a5cfa0bc3482f2670a221478cd2cae3 100644 --- a/combo/training/checkpointer.py +++ b/combo/training/checkpointer.py @@ -16,6 +16,7 @@ class FinishingTrainingCheckpointer(training.Checkpointer): epoch: Union[int, str], trainer: "allen_trainer.Trainer", is_best_so_far: bool = False, + save_model_only: bool = False, ) -> None: if trainer._learning_rate_scheduler.decreases <= 1 or epoch == trainer._num_epochs - 1: super().save_checkpoint(epoch, trainer, is_best_so_far) diff --git a/combo/training/trainer.py b/combo/training/trainer.py index f74873ed1f57c6218898d6aad79c7cf776d7bf8e..3bee8fcaaf1c29189583a551a1100ac4f0215a65 100644 --- a/combo/training/trainer.py +++ b/combo/training/trainer.py @@ -84,7 +84,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer): logger.info("Beginning training.") val_metrics: Dict[str, float] = {} - this_epoch_val_metric: float + this_epoch_val_metric: float = None metrics: Dict[str, Any] = {} epochs_trained = 0 training_start_time = time.time() @@ -141,7 +141,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer): # Check validation metric for early stopping this_epoch_val_metric = val_metrics[self._validation_metric] - self._metric_tracker.add_metric(this_epoch_val_metric) + # self._metric_tracker.add_metric(this_epoch_val_metric) train_metrics["patience"] = self._metric_tracker._patience if self._metric_tracker.should_stop_early(): diff --git a/combo/utils/metrics.py b/combo/utils/metrics.py index 682e8859264a3414bf86d3a1e408ce5b3588a6f3..1a175402ec00b011221f2f7c1c76d7496b68281c 100644 --- a/combo/utils/metrics.py +++ b/combo/utils/metrics.py @@ -241,10 +241,10 @@ class SemanticMetrics(metrics.Metric): self.feats_score.correct_indices * self.lemma_score.correct_indices * self.attachment_scores.correct_indices * - enhanced_indices) + enhanced_indices) * mask.flatten() - total, correct_indices = self.detach_tensors(total, correct_indices) - self.em_score = (correct_indices.float().sum() / total).item() + total, correct_indices = self.detach_tensors(total, correct_indices.float().sum()) + self.em_score = (correct_indices / total).item() def get_metric(self, reset: bool) -> Dict[str, float]: metrics_dict = {