diff --git a/combo/main.py b/combo/main.py index 2faaf1882501d642018cabfe0c9780af26eca8d6..5dee611b547d1a6e562d273921d904bf417a5652 100755 --- a/combo/main.py +++ b/combo/main.py @@ -9,13 +9,14 @@ from absl import app, flags import pytorch_lightning as pl from combo.training.trainable_combo import TrainableCombo -from combo.utils import checks +from combo.utils import checks, ComboLogger -from config import resolve -from default_model import default_ud_dataset_reader, default_data_loader -from modules.archival import load_archive, archive -from predict import COMBO +from combo.config import resolve +from combo.default_model import default_ud_dataset_reader, default_data_loader +from combo.modules.archival import load_archive, archive +from combo.predict import COMBO +logging.setLoggerClass(ComboLogger) logger = logging.getLogger(__name__) _FEATURES = ["token", "char", "upostag", "xpostag", "lemma", "feats"] _TARGETS = ["deprel", "feats", "head", "lemma", "upostag", "xpostag", "semrel", "sent", "deps"] @@ -54,14 +55,14 @@ flags.DEFINE_string(name="serialization_dir", default=None, help="Model serialization directory (default - system temp dir).") flags.DEFINE_boolean(name="tensorboard", default=False, help="When provided model will log tensorboard metrics.") +flags.DEFINE_string(name="config_path", default=str(pathlib.Path(__file__).parent / "config.json"), + help="Config file path.") # Finetune after training flags flags.DEFINE_string(name="finetuning_training_data_path", default="", - help="Training data path(s)") + help="Training data path(s)") flags.DEFINE_string(name="finetuning_validation_data_path", default="", - help="Validation data path(s)") -flags.DEFINE_string(name="config_path", default=str(pathlib.Path(__file__).parent / "params.json"), - help="Config file path.") + help="Validation data path(s)") # Test after training flags flags.DEFINE_string(name="test_path", default=None, @@ -99,51 +100,81 @@ def get_predictor() -> COMBO: def run(_): if FLAGS.mode == 'train': if not FLAGS.finetuning: + prefix = 'Training' + logger.info('Setting up the model for training', prefix=prefix) checks.file_exists(FLAGS.config_path) + + logger.info(f'Reading parameters from configuration path {FLAGS.config_path}', prefix=prefix) with open(FLAGS.config_path, 'r') as f: params = json.load(f) params = {**params, **_get_ext_vars()} serialization_dir = tempfile.mkdtemp(prefix='combo', dir=FLAGS.serialization_dir) - model = resolve(params['model']) + try: + vocabulary = resolve(params['vocabulary']) + except KeyError: + logger.error('No vocabulary in config.json!') + return + + model = resolve(params['model'], pass_down_parameters={'vocabulary': vocabulary}) + dataset_reader = None if 'data_loader' in params: + logger.info(f'Resolving the training data loader from parameters', prefix=prefix) train_data_loader = resolve(params['data_loader']) else: checks.file_exists(FLAGS.training_data_path) + logger.info(f'Using a default UD data loader with training data path {FLAGS.training_data_path}', + prefix=prefix) train_data_loader = default_data_loader(default_ud_dataset_reader(), FLAGS.training_data_path) + + logger.info('Indexing training data loader') train_data_loader.index_with(model.vocab) validation_data_loader = None + if 'validation_data_loader' in params: + logger.info(f'Resolving the validation data loader from parameters', prefix=prefix) validation_data_loader = resolve(params['validation_data_loader']) + logger.info('Indexing validation data loader', prefix=prefix) validation_data_loader.index_with(model.vocab) elif FLAGS.validation_data_path: checks.file_exists(FLAGS.validation_data_path) + logger.info(f'Using a default UD data loader with validation data path {FLAGS.training_data_path}', + prefix=prefix) validation_data_loader = default_data_loader(default_ud_dataset_reader(), FLAGS.validation_data_path) + logger.info('Indexing validation data loader', prefix=prefix) validation_data_loader.index_with(model.vocab) else: - model, train_data_loader, validation_data_loader = load_archive(FLAGS.model_path) + prefix = 'Finetuning' + logger.info('Loading the model for finetuning', prefix=prefix) + model, _, train_data_loader, validation_data_loader, dataset_reader = load_archive(FLAGS.model_path) serialization_dir = tempfile.mkdtemp(prefix='combo', suffix='-finetuning', dir=FLAGS.serialization_dir) if not train_data_loader: checks.file_exists(FLAGS.finetuning_training_data_path) + logger.info( + f'Using a default UD data loader with training data path {FLAGS.finetuning_training_data_path}', + prefix=prefix) train_data_loader = default_data_loader(default_ud_dataset_reader(), FLAGS.finetuning_training_data_path) if not validation_data_loader and FLAGS.finetuning_validation_data_path: checks.file_exists(FLAGS.finetuning_validation_data_path) + logger.info( + f'Using a default UD data loader with validation data path {FLAGS.finetuning_validation_data_path}', + prefix=prefix) validation_data_loader = default_data_loader(default_ud_dataset_reader(), FLAGS.finetuning_validation_data_path) - print("Indexing train loader") + logger.info("Indexing train loader", prefix=prefix) train_data_loader.index_with(model.vocab) - print("Indexing validation loader") + logger.info("Indexing validation loader", prefix=prefix) validation_data_loader.index_with(model.vocab) - print("Indexed") + logger.info("Indexed", prefix=prefix) nlp = TrainableCombo(model, torch.optim.Adam, optimizer_kwargs={'betas': [0.9, 0.9], 'lr': 0.002}, @@ -153,8 +184,9 @@ def run(_): gradient_clip_val=5) trainer.fit(model=nlp, train_dataloaders=train_data_loader, val_dataloaders=validation_data_loader) - archive(model, serialization_dir) - logger.info(f"Training model stored in: {serialization_dir}") + logger.info(f'Archiving the fine-tuned model in {serialization_dir}', prefix=prefix) + archive(model, serialization_dir, train_data_loader, validation_data_loader, dataset_reader) + logger.info(f"Training model stored in: {serialization_dir}", prefix=prefix) elif FLAGS.mode == 'predict': predictor = get_predictor() diff --git a/combo/utils/__init__.py b/combo/utils/__init__.py index b36a4d9c69c70253fa56c6e7fce558a75463e0c5..450322600814973fff3589ae347ab7fb973a48b9 100644 --- a/combo/utils/__init__.py +++ b/combo/utils/__init__.py @@ -1,4 +1,5 @@ from .checks import * from .sequence import * from .exceptions import * -from .typing import * \ No newline at end of file +from .typing import * +from .logging import ComboLogger diff --git a/combo/utils/logging.py b/combo/utils/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..cd04b6a3ebff83682b92fffac9a02add90d2debc --- /dev/null +++ b/combo/utils/logging.py @@ -0,0 +1,43 @@ +import logging +from overrides import overrides +from datetime import datetime + + +class ComboLogger(logging.Logger): + def __init__(self, name: str, prefix: str = None, display_date: bool = True): + super().__init__(name) + self.__prefix = prefix or '' + self.__display_date = display_date + + @overrides(check_signature=False) + def log(self, level: int, msg: str, prefix: str = None): + prefix = prefix or self.__prefix + super().log(level, '[{date} UTC {prefix}] {msg}'.format( + date=datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S'), + prefix=prefix, + msg=msg + )) + + @overrides(check_signature=False) + def debug(self, msg: str, prefix: str = None): + self.log(level=logging.DEBUG, msg=msg, prefix=prefix) + + @overrides(check_signature=False) + def info(self, msg: str, prefix: str = None): + self.log(level=logging.INFO, msg=msg, prefix=prefix) + + @overrides(check_signature=False) + def warn(self, msg: str, prefix: str = None): + self.log(level=logging.WARN, msg=msg, prefix=prefix) + + @overrides(check_signature=False) + def error(self, msg: str, prefix: str = None): + self.log(level=logging.ERROR, msg=msg, prefix=prefix) + + @overrides(check_signature=False) + def fatal(self, msg: str, prefix: str = None): + self.log(level=logging.FATAL, msg=msg, prefix=prefix) + + @overrides(check_signature=False) + def critical(self, msg: str, prefix: str = None): + self.log(level=logging.CRITICAL, msg=msg, prefix=prefix)