diff --git a/combo/main.py b/combo/main.py index 572570bcff9376bcd817dbaf4ff1aa97f6fadba3..5174a7491069ec8f83bc78b201554600c6139210 100755 --- a/combo/main.py +++ b/combo/main.py @@ -8,6 +8,7 @@ from typing import Dict import torch from absl import app, flags import pytorch_lightning as pl +from pytorch_lightning.loggers import TensorBoardLogger from tqdm import tqdm from combo.training.trainable_combo import TrainableCombo @@ -67,6 +68,8 @@ flags.DEFINE_string(name="serialization_dir", default=None, help="Model serialization directory (default - system temp dir).") flags.DEFINE_boolean(name="tensorboard", default=False, help="When provided model will log tensorboard metrics.") +flags.DEFINE_string(name="tensorboard_name", default="combo", + help="Name of the model in TensorBoard logs.") flags.DEFINE_string(name="config_path", default=str(pathlib.Path(__file__).parent / "config.json"), help="Config file path.") @@ -208,10 +211,14 @@ def run(_): n_cuda_devices = "auto" if FLAGS.n_cuda_devices == -1 else FLAGS.n_cuda_devices + tensorboard_logger = TensorBoardLogger(os.path.join(serialization_dir, 'tensorboard_logs'), + name=FLAGS.tensorboard_name) if FLAGS.tensorboard else None + trainer = pl.Trainer(max_epochs=FLAGS.num_epochs, default_root_dir=serialization_dir, gradient_clip_val=5, - devices=n_cuda_devices) + devices=n_cuda_devices, + logger=tensorboard_logger) try: trainer.fit(model=nlp, train_dataloaders=train_data_loader, val_dataloaders=validation_data_loader) except Exception as e: