diff --git a/combo/config/from_parameters.py b/combo/config/from_parameters.py index c12d5aa8d8a797c0988da1b276f0f00934fd7603..5f5580b1f3da8bea2ccb137097df729af2966c79 100644 --- a/combo/config/from_parameters.py +++ b/combo/config/from_parameters.py @@ -142,6 +142,7 @@ class FromParameters: def _to_params(self, pass_down_parameter_names: List[str] = None) -> Dict[str, str]: parameters_to_serialize = self.constructed_args or {} pass_down_parameter_names = pass_down_parameter_names or [] + parameters_dict = {} for pn, param_value in parameters_to_serialize.items(): if pn in pass_down_parameter_names: @@ -151,8 +152,6 @@ class FromParameters: return parameters_dict def serialize(self, pass_down_parameter_names: List[str] = None) -> Dict[str, Any]: - pass_down_parameter_names = pass_down_parameter_names or [] - constructor_method = self.constructed_from if self.constructed_from else '__init__' if not getattr(self, constructor_method): raise ConfigurationError('Class ' + str(type(self)) + ' has no constructor method ' + constructor_method) diff --git a/combo/main.py b/combo/main.py index c815130f50df92f560760fd52c29e6b09eaab955..974b70d622b5fe4acf03dcdd77853c3dc28d3474 100755 --- a/combo/main.py +++ b/combo/main.py @@ -4,7 +4,7 @@ import os import pathlib import tempfile from itertools import chain -from typing import Dict, Optional, Any +from typing import Dict, Optional, Any, Tuple import torch from absl import app, flags @@ -24,6 +24,7 @@ from config import override_parameters from config.from_parameters import override_or_add_parameters from data import LamboTokenizer, Sentence, Vocabulary, DatasetReader from data.dataset_loaders import DataLoader +from modules.model import Model from utils import ConfigurationError logging.setLoggerClass(ComboLogger) @@ -127,6 +128,42 @@ def build_vocabulary_from_instances(training_data_loader: DataLoader, return vocabulary +def get_defaults(dataset_reader: Optional[DatasetReader], + training_data_loader: Optional[DataLoader], + validation_data_loader: Optional[DataLoader], + vocabulary: Optional[Vocabulary], + training_data_path: str, + validation_data_path: str, + prefix: str) -> Tuple[DatasetReader, DataLoader, DataLoader, Vocabulary]: + if not dataset_reader and (FLAGS.test_data_path + or not training_data_loader + or (FLAGS.validation_data_path and not validation_data_loader)): + # Dataset reader is required to read training data and/or for training (and validation) data loader + dataset_reader = default_ud_dataset_reader(FLAGS.pretrained_transformer_name) + + if not training_data_loader: + training_data_loader = default_data_loader(dataset_reader, training_data_path) + else: + if training_data_path: + training_data_loader.data_path = training_data_path + else: + logger.warning(f'No training data path provided - using the path from configuration: ' + + str(training_data_loader.data_path), prefix=prefix) + + if FLAGS.validation_data_path and not validation_data_loader: + validation_data_loader = default_data_loader(dataset_reader, validation_data_path) + else: + if validation_data_path: + validation_data_loader.data_path = validation_data_path + else: + logger.warning(f'No validation data path provided - using the path from configuration: ' + + str(validation_data_loader.data_path), prefix=prefix) + + if not vocabulary: + vocabulary = build_vocabulary_from_instances(training_data_loader, validation_data_loader, prefix) + + return dataset_reader, training_data_loader, validation_data_loader, vocabulary + def _read_property_from_config(property_key: str, params: Dict[str, Any], logging_prefix: str) -> Optional[Any]: @@ -167,7 +204,7 @@ def read_vocabulary_from_config(params: Dict[str, Any], return vocabulary -def read_model_from_config(logging_prefix: str): +def read_model_from_config(logging_prefix: str) -> Optional[Tuple[Model, DatasetReader, DataLoader, DataLoader, Vocabulary]]: try: checks.file_exists(FLAGS.config_path) except ConfigurationError as e: @@ -187,20 +224,29 @@ def read_model_from_config(logging_prefix: str): validation_data_loader = read_data_loader_from_config(params, logging_prefix, validation=True) vocabulary = read_vocabulary_from_config(params, logging_prefix) + dataset_reader, training_data_loader, validation_data_loader, vocabulary = get_defaults( + dataset_reader, + training_data_loader, + validation_data_loader, + vocabulary, + FLAGS.training_data_path if not FLAGS.finetuning else FLAGS.finetuning_training_data_path, + FLAGS.validation_data_path if not FLAGS.finetuning else FLAGS.finetuning_validation_data_path, + logging_prefix + ) + pass_down_parameters = {'vocabulary': vocabulary} if not FLAGS.use_pure_config: pass_down_parameters['model_name'] = FLAGS.pretrained_transformer_name logger.info('Resolving the model from parameters.', prefix=logging_prefix) - model = resolve(params['model'], - pass_down_parameters=pass_down_parameters) + model = resolve(params['model'], pass_down_parameters=pass_down_parameters) - return model, vocabulary, training_data_loader, validation_data_loader, dataset_reader + return model, dataset_reader, training_data_loader, validation_data_loader, vocabulary def run(_): if FLAGS.mode == 'train': - model, vocabulary, training_data_loader, validation_data_loader, dataset_reader = None, None, None, None, None + model, dataset_reader, training_data_loader, validation_data_loader, vocabulary = None, None, None, None, None if not FLAGS.finetuning: prefix = 'Training' @@ -208,7 +254,7 @@ def run(_): if FLAGS.config_path: logger.info(f'Reading parameters from configuration path {FLAGS.config_path}', prefix=prefix) - model, vocabulary, training_data_loader, validation_data_loader, dataset_reader = read_model_from_config(prefix) + model, dataset_reader, training_data_loader, validation_data_loader, vocabulary = read_model_from_config(prefix) if FLAGS.use_pure_config and model is None: logger.error('Error in configuration - model could not be read from parameters. ' + @@ -218,9 +264,6 @@ def run(_): serialization_dir = tempfile.mkdtemp(prefix='combo', dir=FLAGS.serialization_dir) - training_data_path = FLAGS.training_data_path - validation_data_path = FLAGS.validation_data_path - else: prefix = 'Finetuning' @@ -242,38 +285,15 @@ def run(_): serialization_dir = tempfile.mkdtemp(prefix='combo', suffix='-finetuning', dir=FLAGS.serialization_dir) - training_data_path = FLAGS.finetuning_training_data_path - validation_data_path = FLAGS.finetuning_validation_data_path - - if not dataset_reader and (FLAGS.test_data_path - or not training_data_loader - or (FLAGS.validation_data_path and not validation_data_loader)): - # Dataset reader is required to read training data and/or for training (and validation) data loader - dataset_reader = default_ud_dataset_reader(FLAGS.pretrained_transformer_name) - - if not training_data_loader: - training_data_loader = default_data_loader(dataset_reader, training_data_path) - else: - if training_data_path: - training_data_loader.data_path = training_data_path - else: - logger.warning(f'No training data path provided - using the path from configuration: ' + - str(training_data_loader.data_path), prefix=prefix) - - if FLAGS.validation_data_path and not validation_data_loader: - validation_data_loader = default_data_loader(dataset_reader, validation_data_path) - else: - if validation_data_path: - validation_data_loader.data_path = validation_data_path - else: - logger.warning(f'No validation data path provided - using the path from configuration: ' + - str(validation_data_loader.data_path), prefix=prefix) - - if not vocabulary: - vocabulary = build_vocabulary_from_instances(training_data_loader, validation_data_loader, prefix) - - if not model: - model = default_model(FLAGS.pretrained_transformer_name, vocabulary) + dataset_reader, training_data_loader, validation_data_loader, vocabulary = get_defaults( + dataset_reader, + training_data_loader, + validation_data_loader, + vocabulary, + FLAGS.finetuning_training_data_path, + FLAGS.finetuning_validation_data_path, + prefix + ) logger.info('Indexing training data loader', prefix=prefix) training_data_loader.index_with(model.vocab)