From d80a60c1296bc4147f210fe64b2c9fc97178020b Mon Sep 17 00:00:00 2001
From: Maja Jablonska <majajjablonska@gmail.com>
Date: Sun, 12 Nov 2023 18:08:49 +1100
Subject: [PATCH] Add a custom logger

---
 combo/main.py           | 64 ++++++++++++++++++++++++++++++-----------
 combo/utils/__init__.py |  3 +-
 combo/utils/logging.py  | 43 +++++++++++++++++++++++++++
 3 files changed, 93 insertions(+), 17 deletions(-)
 create mode 100644 combo/utils/logging.py

diff --git a/combo/main.py b/combo/main.py
index 2faaf18..5dee611 100755
--- a/combo/main.py
+++ b/combo/main.py
@@ -9,13 +9,14 @@ from absl import app, flags
 import pytorch_lightning as pl
 
 from combo.training.trainable_combo import TrainableCombo
-from combo.utils import checks
+from combo.utils import checks, ComboLogger
 
-from config import resolve
-from default_model import default_ud_dataset_reader, default_data_loader
-from modules.archival import load_archive, archive
-from predict import COMBO
+from combo.config import resolve
+from combo.default_model import default_ud_dataset_reader, default_data_loader
+from combo.modules.archival import load_archive, archive
+from combo.predict import COMBO
 
+logging.setLoggerClass(ComboLogger)
 logger = logging.getLogger(__name__)
 _FEATURES = ["token", "char", "upostag", "xpostag", "lemma", "feats"]
 _TARGETS = ["deprel", "feats", "head", "lemma", "upostag", "xpostag", "semrel", "sent", "deps"]
@@ -54,14 +55,14 @@ flags.DEFINE_string(name="serialization_dir", default=None,
                     help="Model serialization directory (default - system temp dir).")
 flags.DEFINE_boolean(name="tensorboard", default=False,
                      help="When provided model will log tensorboard metrics.")
+flags.DEFINE_string(name="config_path", default=str(pathlib.Path(__file__).parent / "config.json"),
+                    help="Config file path.")
 
 # Finetune after training flags
 flags.DEFINE_string(name="finetuning_training_data_path", default="",
-                  help="Training data path(s)")
+                    help="Training data path(s)")
 flags.DEFINE_string(name="finetuning_validation_data_path", default="",
-                  help="Validation data path(s)")
-flags.DEFINE_string(name="config_path", default=str(pathlib.Path(__file__).parent / "params.json"),
-                    help="Config file path.")
+                    help="Validation data path(s)")
 
 # Test after training flags
 flags.DEFINE_string(name="test_path", default=None,
@@ -99,51 +100,81 @@ def get_predictor() -> COMBO:
 def run(_):
     if FLAGS.mode == 'train':
         if not FLAGS.finetuning:
+            prefix = 'Training'
+            logger.info('Setting up the model for training', prefix=prefix)
             checks.file_exists(FLAGS.config_path)
+
+            logger.info(f'Reading parameters from configuration path {FLAGS.config_path}', prefix=prefix)
             with open(FLAGS.config_path, 'r') as f:
                 params = json.load(f)
             params = {**params, **_get_ext_vars()}
 
             serialization_dir = tempfile.mkdtemp(prefix='combo', dir=FLAGS.serialization_dir)
 
-            model = resolve(params['model'])
+            try:
+                vocabulary = resolve(params['vocabulary'])
+            except KeyError:
+                logger.error('No vocabulary in config.json!')
+                return
+
+            model = resolve(params['model'], pass_down_parameters={'vocabulary': vocabulary})
+            dataset_reader = None
 
             if 'data_loader' in params:
+                logger.info(f'Resolving the training data loader from parameters', prefix=prefix)
                 train_data_loader = resolve(params['data_loader'])
             else:
                 checks.file_exists(FLAGS.training_data_path)
+                logger.info(f'Using a default UD data loader with training data path {FLAGS.training_data_path}',
+                            prefix=prefix)
                 train_data_loader = default_data_loader(default_ud_dataset_reader(),
                                                         FLAGS.training_data_path)
+
+            logger.info('Indexing training data loader')
             train_data_loader.index_with(model.vocab)
 
             validation_data_loader = None
+
             if 'validation_data_loader' in params:
+                logger.info(f'Resolving the validation data loader from parameters', prefix=prefix)
                 validation_data_loader = resolve(params['validation_data_loader'])
+                logger.info('Indexing validation data loader', prefix=prefix)
                 validation_data_loader.index_with(model.vocab)
             elif FLAGS.validation_data_path:
                 checks.file_exists(FLAGS.validation_data_path)
+                logger.info(f'Using a default UD data loader with validation data path {FLAGS.training_data_path}',
+                            prefix=prefix)
                 validation_data_loader = default_data_loader(default_ud_dataset_reader(),
                                                              FLAGS.validation_data_path)
+                logger.info('Indexing validation data loader', prefix=prefix)
                 validation_data_loader.index_with(model.vocab)
 
         else:
-            model, train_data_loader, validation_data_loader = load_archive(FLAGS.model_path)
+            prefix = 'Finetuning'
+            logger.info('Loading the model for finetuning', prefix=prefix)
+            model, _, train_data_loader, validation_data_loader, dataset_reader = load_archive(FLAGS.model_path)
 
             serialization_dir = tempfile.mkdtemp(prefix='combo', suffix='-finetuning', dir=FLAGS.serialization_dir)
 
             if not train_data_loader:
                 checks.file_exists(FLAGS.finetuning_training_data_path)
+                logger.info(
+                    f'Using a default UD data loader with training data path {FLAGS.finetuning_training_data_path}',
+                    prefix=prefix)
                 train_data_loader = default_data_loader(default_ud_dataset_reader(),
                                                         FLAGS.finetuning_training_data_path)
             if not validation_data_loader and FLAGS.finetuning_validation_data_path:
                 checks.file_exists(FLAGS.finetuning_validation_data_path)
+                logger.info(
+                    f'Using a default UD data loader with validation data path {FLAGS.finetuning_validation_data_path}',
+                    prefix=prefix)
                 validation_data_loader = default_data_loader(default_ud_dataset_reader(),
                                                              FLAGS.finetuning_validation_data_path)
-        print("Indexing train loader")
+        logger.info("Indexing train loader", prefix=prefix)
         train_data_loader.index_with(model.vocab)
-        print("Indexing validation loader")
+        logger.info("Indexing validation loader", prefix=prefix)
         validation_data_loader.index_with(model.vocab)
-        print("Indexed")
+        logger.info("Indexed", prefix=prefix)
 
         nlp = TrainableCombo(model, torch.optim.Adam,
                              optimizer_kwargs={'betas': [0.9, 0.9], 'lr': 0.002},
@@ -153,8 +184,9 @@ def run(_):
                              gradient_clip_val=5)
         trainer.fit(model=nlp, train_dataloaders=train_data_loader, val_dataloaders=validation_data_loader)
 
-        archive(model, serialization_dir)
-        logger.info(f"Training model stored in: {serialization_dir}")
+        logger.info(f'Archiving the fine-tuned model in {serialization_dir}', prefix=prefix)
+        archive(model, serialization_dir, train_data_loader, validation_data_loader, dataset_reader)
+        logger.info(f"Training model stored in: {serialization_dir}", prefix=prefix)
 
     elif FLAGS.mode == 'predict':
         predictor = get_predictor()
diff --git a/combo/utils/__init__.py b/combo/utils/__init__.py
index b36a4d9..4503226 100644
--- a/combo/utils/__init__.py
+++ b/combo/utils/__init__.py
@@ -1,4 +1,5 @@
 from .checks import *
 from .sequence import *
 from .exceptions import *
-from .typing import *
\ No newline at end of file
+from .typing import *
+from .logging import ComboLogger
diff --git a/combo/utils/logging.py b/combo/utils/logging.py
new file mode 100644
index 0000000..cd04b6a
--- /dev/null
+++ b/combo/utils/logging.py
@@ -0,0 +1,43 @@
+import logging
+from overrides import overrides
+from datetime import datetime
+
+
+class ComboLogger(logging.Logger):
+    def __init__(self, name: str, prefix: str = None, display_date: bool = True):
+        super().__init__(name)
+        self.__prefix = prefix or ''
+        self.__display_date = display_date
+
+    @overrides(check_signature=False)
+    def log(self, level: int, msg: str, prefix: str = None):
+        prefix = prefix or self.__prefix
+        super().log(level, '[{date} UTC {prefix}] {msg}'.format(
+            date=datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S'),
+            prefix=prefix,
+            msg=msg
+        ))
+
+    @overrides(check_signature=False)
+    def debug(self, msg: str, prefix: str = None):
+        self.log(level=logging.DEBUG, msg=msg, prefix=prefix)
+
+    @overrides(check_signature=False)
+    def info(self, msg: str, prefix: str = None):
+        self.log(level=logging.INFO, msg=msg, prefix=prefix)
+
+    @overrides(check_signature=False)
+    def warn(self, msg: str, prefix: str = None):
+        self.log(level=logging.WARN, msg=msg, prefix=prefix)
+
+    @overrides(check_signature=False)
+    def error(self, msg: str, prefix: str = None):
+        self.log(level=logging.ERROR, msg=msg, prefix=prefix)
+
+    @overrides(check_signature=False)
+    def fatal(self, msg: str, prefix: str = None):
+        self.log(level=logging.FATAL, msg=msg, prefix=prefix)
+
+    @overrides(check_signature=False)
+    def critical(self, msg: str, prefix: str = None):
+        self.log(level=logging.CRITICAL, msg=msg, prefix=prefix)
-- 
GitLab