Skip to content
Snippets Groups Projects
Commit 90c6fbee authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski
Browse files

Add logging metrics to console.

parent 2f459607
No related branches found
No related tags found
No related merge requests found
import logging
from typing import Dict, Optional, List from typing import Dict, Optional, List
import torch import torch
...@@ -5,6 +6,8 @@ from allennlp import models, common ...@@ -5,6 +6,8 @@ from allennlp import models, common
from allennlp.data import dataloader from allennlp.data import dataloader
from allennlp.training import optimizers from allennlp.training import optimizers
logger = logging.getLogger(__name__)
class NullTensorboardWriter(common.FromParams): class NullTensorboardWriter(common.FromParams):
...@@ -56,7 +59,30 @@ class NullTensorboardWriter(common.FromParams): ...@@ -56,7 +59,30 @@ class NullTensorboardWriter(common.FromParams):
epoch: int = None, epoch: int = None,
log_to_console: bool = False, log_to_console: bool = False,
) -> None: ) -> None:
pass metric_names = set(train_metrics.keys())
if val_metrics is not None:
metric_names.update(val_metrics.keys())
val_metrics = val_metrics or {}
if log_to_console:
dual_message_template = "%s | %8.3f | %8.3f"
no_val_message_template = "%s | %8.3f | %8s"
no_train_message_template = "%s | %8s | %8.3f"
header_template = "%s | %-10s"
name_length = max(len(x) for x in metric_names)
logger.info(header_template, "Training".rjust(name_length + 13), "Validation")
for name in metric_names:
train_metric = train_metrics.get(name)
val_metric = val_metrics.get(name)
if val_metric is not None and train_metric is not None:
logger.info(
dual_message_template, name.ljust(name_length), train_metric, val_metric
)
elif val_metric is not None:
logger.info(no_train_message_template, name.ljust(name_length), "N/A", val_metric)
elif train_metric is not None:
logger.info(no_val_message_template, name.ljust(name_length), train_metric, "N/A")
def enable_activation_logging(self, model: models.Model) -> None: def enable_activation_logging(self, model: models.Model) -> None:
pass pass
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment