diff --git a/combo/training/tensorboard_writer.py b/combo/training/tensorboard_writer.py index 3796e66afa0d43bb61e7e427cb245955bb56330b..d7ab2c5d1bbc1255d2505b833a2497606a99862e 100644 --- a/combo/training/tensorboard_writer.py +++ b/combo/training/tensorboard_writer.py @@ -1,3 +1,4 @@ +import logging from typing import Dict, Optional, List import torch @@ -5,6 +6,8 @@ from allennlp import models, common from allennlp.data import dataloader from allennlp.training import optimizers +logger = logging.getLogger(__name__) + class NullTensorboardWriter(common.FromParams): @@ -50,13 +53,36 @@ class NullTensorboardWriter(common.FromParams): pass def log_metrics( - self, - train_metrics: dict, - val_metrics: dict = None, - epoch: int = None, - log_to_console: bool = False, + self, + train_metrics: dict, + val_metrics: dict = None, + epoch: int = None, + log_to_console: bool = False, ) -> None: - pass + metric_names = set(train_metrics.keys()) + if val_metrics is not None: + metric_names.update(val_metrics.keys()) + val_metrics = val_metrics or {} + + if log_to_console: + dual_message_template = "%s | %8.3f | %8.3f" + no_val_message_template = "%s | %8.3f | %8s" + no_train_message_template = "%s | %8s | %8.3f" + header_template = "%s | %-10s" + name_length = max(len(x) for x in metric_names) + logger.info(header_template, "Training".rjust(name_length + 13), "Validation") + + for name in metric_names: + train_metric = train_metrics.get(name) + val_metric = val_metrics.get(name) + if val_metric is not None and train_metric is not None: + logger.info( + dual_message_template, name.ljust(name_length), train_metric, val_metric + ) + elif val_metric is not None: + logger.info(no_train_message_template, name.ljust(name_length), "N/A", val_metric) + elif train_metric is not None: + logger.info(no_val_message_template, name.ljust(name_length), train_metric, "N/A") def enable_activation_logging(self, model: models.Model) -> None: pass