Skip to content
Snippets Groups Projects
Commit c0835180 authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski
Browse files

Fix training loops and metrics.

parent b8c83784
Branches
No related tags found
2 merge requests!15Merge develop to master,!14Install and docs
This commit is part of merge request !15. Comments created here will be created in the context of that merge request.
...@@ -16,6 +16,7 @@ class FinishingTrainingCheckpointer(training.Checkpointer): ...@@ -16,6 +16,7 @@ class FinishingTrainingCheckpointer(training.Checkpointer):
epoch: Union[int, str], epoch: Union[int, str],
trainer: "allen_trainer.Trainer", trainer: "allen_trainer.Trainer",
is_best_so_far: bool = False, is_best_so_far: bool = False,
save_model_only: bool = False,
) -> None: ) -> None:
if trainer._learning_rate_scheduler.decreases <= 1 or epoch == trainer._num_epochs - 1: if trainer._learning_rate_scheduler.decreases <= 1 or epoch == trainer._num_epochs - 1:
super().save_checkpoint(epoch, trainer, is_best_so_far) super().save_checkpoint(epoch, trainer, is_best_so_far)
......
...@@ -84,7 +84,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer): ...@@ -84,7 +84,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
logger.info("Beginning training.") logger.info("Beginning training.")
val_metrics: Dict[str, float] = {} val_metrics: Dict[str, float] = {}
this_epoch_val_metric: float this_epoch_val_metric: float = None
metrics: Dict[str, Any] = {} metrics: Dict[str, Any] = {}
epochs_trained = 0 epochs_trained = 0
training_start_time = time.time() training_start_time = time.time()
...@@ -141,7 +141,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer): ...@@ -141,7 +141,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
# Check validation metric for early stopping # Check validation metric for early stopping
this_epoch_val_metric = val_metrics[self._validation_metric] this_epoch_val_metric = val_metrics[self._validation_metric]
self._metric_tracker.add_metric(this_epoch_val_metric) # self._metric_tracker.add_metric(this_epoch_val_metric)
train_metrics["patience"] = self._metric_tracker._patience train_metrics["patience"] = self._metric_tracker._patience
if self._metric_tracker.should_stop_early(): if self._metric_tracker.should_stop_early():
......
...@@ -241,10 +241,10 @@ class SemanticMetrics(metrics.Metric): ...@@ -241,10 +241,10 @@ class SemanticMetrics(metrics.Metric):
self.feats_score.correct_indices * self.feats_score.correct_indices *
self.lemma_score.correct_indices * self.lemma_score.correct_indices *
self.attachment_scores.correct_indices * self.attachment_scores.correct_indices *
enhanced_indices) enhanced_indices) * mask.flatten()
total, correct_indices = self.detach_tensors(total, correct_indices) total, correct_indices = self.detach_tensors(total, correct_indices.float().sum())
self.em_score = (correct_indices.float().sum() / total).item() self.em_score = (correct_indices / total).item()
def get_metric(self, reset: bool) -> Dict[str, float]: def get_metric(self, reset: bool) -> Dict[str, float]:
metrics_dict = { metrics_dict = {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment