Skip to content
Snippets Groups Projects
Commit b03c33b2 authored by Maja Jablonska's avatar Maja Jablonska
Browse files

Add a tensorboard logger

parent ad097cf0
1 merge request!46Merge COMBO 3.0 into master
...@@ -8,6 +8,7 @@ from typing import Dict ...@@ -8,6 +8,7 @@ from typing import Dict
import torch import torch
from absl import app, flags from absl import app, flags
import pytorch_lightning as pl import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from tqdm import tqdm from tqdm import tqdm
from combo.training.trainable_combo import TrainableCombo from combo.training.trainable_combo import TrainableCombo
...@@ -67,6 +68,8 @@ flags.DEFINE_string(name="serialization_dir", default=None, ...@@ -67,6 +68,8 @@ flags.DEFINE_string(name="serialization_dir", default=None,
help="Model serialization directory (default - system temp dir).") help="Model serialization directory (default - system temp dir).")
flags.DEFINE_boolean(name="tensorboard", default=False, flags.DEFINE_boolean(name="tensorboard", default=False,
help="When provided model will log tensorboard metrics.") help="When provided model will log tensorboard metrics.")
flags.DEFINE_string(name="tensorboard_name", default="combo",
help="Name of the model in TensorBoard logs.")
flags.DEFINE_string(name="config_path", default=str(pathlib.Path(__file__).parent / "config.json"), flags.DEFINE_string(name="config_path", default=str(pathlib.Path(__file__).parent / "config.json"),
help="Config file path.") help="Config file path.")
...@@ -208,10 +211,14 @@ def run(_): ...@@ -208,10 +211,14 @@ def run(_):
n_cuda_devices = "auto" if FLAGS.n_cuda_devices == -1 else FLAGS.n_cuda_devices n_cuda_devices = "auto" if FLAGS.n_cuda_devices == -1 else FLAGS.n_cuda_devices
tensorboard_logger = TensorBoardLogger(os.path.join(serialization_dir, 'tensorboard_logs'),
name=FLAGS.tensorboard_name) if FLAGS.tensorboard else None
trainer = pl.Trainer(max_epochs=FLAGS.num_epochs, trainer = pl.Trainer(max_epochs=FLAGS.num_epochs,
default_root_dir=serialization_dir, default_root_dir=serialization_dir,
gradient_clip_val=5, gradient_clip_val=5,
devices=n_cuda_devices) devices=n_cuda_devices,
logger=tensorboard_logger)
try: try:
trainer.fit(model=nlp, train_dataloaders=train_data_loader, val_dataloaders=validation_data_loader) trainer.fit(model=nlp, train_dataloaders=train_data_loader, val_dataloaders=validation_data_loader)
except Exception as e: except Exception as e:
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment