Skip to content
Snippets Groups Projects
Commit 90c6fbee authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski
Browse files

Add logging metrics to console.

parent 2f459607
Branches
Tags
No related merge requests found
import logging
from typing import Dict, Optional, List
import torch
......@@ -5,6 +6,8 @@ from allennlp import models, common
from allennlp.data import dataloader
from allennlp.training import optimizers
logger = logging.getLogger(__name__)
class NullTensorboardWriter(common.FromParams):
......@@ -50,13 +53,36 @@ class NullTensorboardWriter(common.FromParams):
pass
def log_metrics(
self,
train_metrics: dict,
val_metrics: dict = None,
epoch: int = None,
log_to_console: bool = False,
self,
train_metrics: dict,
val_metrics: dict = None,
epoch: int = None,
log_to_console: bool = False,
) -> None:
pass
metric_names = set(train_metrics.keys())
if val_metrics is not None:
metric_names.update(val_metrics.keys())
val_metrics = val_metrics or {}
if log_to_console:
dual_message_template = "%s | %8.3f | %8.3f"
no_val_message_template = "%s | %8.3f | %8s"
no_train_message_template = "%s | %8s | %8.3f"
header_template = "%s | %-10s"
name_length = max(len(x) for x in metric_names)
logger.info(header_template, "Training".rjust(name_length + 13), "Validation")
for name in metric_names:
train_metric = train_metrics.get(name)
val_metric = val_metrics.get(name)
if val_metric is not None and train_metric is not None:
logger.info(
dual_message_template, name.ljust(name_length), train_metric, val_metric
)
elif val_metric is not None:
logger.info(no_train_message_template, name.ljust(name_length), "N/A", val_metric)
elif train_metric is not None:
logger.info(no_val_message_template, name.ljust(name_length), train_metric, "N/A")
def enable_activation_logging(self, model: models.Model) -> None:
pass
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment