From b03c33b222157e9374c4b9fd6cc0c499f9ee3baa Mon Sep 17 00:00:00 2001 From: Maja Jablonska <majajjablonska@gmail.com> Date: Wed, 15 Nov 2023 21:10:50 +1100 Subject: [PATCH] Add a tensorboard logger --- combo/main.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/combo/main.py b/combo/main.py index 572570b..5174a74 100755 --- a/combo/main.py +++ b/combo/main.py @@ -8,6 +8,7 @@ from typing import Dict import torch from absl import app, flags import pytorch_lightning as pl +from pytorch_lightning.loggers import TensorBoardLogger from tqdm import tqdm from combo.training.trainable_combo import TrainableCombo @@ -67,6 +68,8 @@ flags.DEFINE_string(name="serialization_dir", default=None, help="Model serialization directory (default - system temp dir).") flags.DEFINE_boolean(name="tensorboard", default=False, help="When provided model will log tensorboard metrics.") +flags.DEFINE_string(name="tensorboard_name", default="combo", + help="Name of the model in TensorBoard logs.") flags.DEFINE_string(name="config_path", default=str(pathlib.Path(__file__).parent / "config.json"), help="Config file path.") @@ -208,10 +211,14 @@ def run(_): n_cuda_devices = "auto" if FLAGS.n_cuda_devices == -1 else FLAGS.n_cuda_devices + tensorboard_logger = TensorBoardLogger(os.path.join(serialization_dir, 'tensorboard_logs'), + name=FLAGS.tensorboard_name) if FLAGS.tensorboard else None + trainer = pl.Trainer(max_epochs=FLAGS.num_epochs, default_root_dir=serialization_dir, gradient_clip_val=5, - devices=n_cuda_devices) + devices=n_cuda_devices, + logger=tensorboard_logger) try: trainer.fit(model=nlp, train_dataloaders=train_data_loader, val_dataloaders=validation_data_loader) except Exception as e: -- GitLab