diff --git a/combo/main.py b/combo/main.py index 9eb0a03469f7b9d344cd2797cca8fc8874b259e9..1e281069011b7b94992359d44d81d5d55c94fbc9 100755 --- a/combo/main.py +++ b/combo/main.py @@ -27,7 +27,6 @@ from combo.modules.model import Model from combo.utils import ConfigurationError from combo.utils.matrices import extract_combo_matrices - logging.setLoggerClass(ComboLogger) logger = logging.getLogger(__name__) _FEATURES = ["token", "char", "upostag", "xpostag", "lemma", "feats"] @@ -81,6 +80,8 @@ flags.DEFINE_string(name="config_path", default="", help="Config file path.") flags.DEFINE_boolean(name="save_matrices", default=True, help="Save relation distribution matrices.") +flags.DEFINE_list(name="datasets_for_vocabulary", default=["train"], + help="") # Finetune after training flags flags.DEFINE_string(name="finetuning_training_data_path", default="", @@ -115,10 +116,19 @@ def build_vocabulary_from_instances(training_data_loader: DataLoader, validation_data_loader: Optional[DataLoader], logging_prefix: str) -> Vocabulary: logger.info('Building a vocabulary from instances.', prefix=logging_prefix) - instances = chain(training_data_loader.iter_instances(), - validation_data_loader.iter_instances()) \ - if validation_data_loader \ - else training_data_loader.iter_instances() + if "train" in FLAGS.datasets_for_vocabulary and "valid" in FLAGS.datasets_for_vocabulary: + instances = chain(training_data_loader.iter_instances(), + validation_data_loader.iter_instances()) \ + if validation_data_loader \ + else training_data_loader.iter_instances() + elif "train" in FLAGS.datasets_for_vocabulary: + instances = training_data_loader.iter_instances() + elif "valid" in FLAGS.datasets_for_vocabulary: + instances = validation_data_loader.iter_instances() + else: + logger.error("train and valid are the only allowed values for --datasets_for_vocabulary!", + prefix=logging_prefix) + raise ValueError("train and valid are the only allowed values for --datasets_for_vocabulary!") vocabulary = Vocabulary.from_instances_extended( instances, non_padded_namespaces=['head_labels'], @@ -165,6 +175,7 @@ def get_defaults(dataset_reader: Optional[DatasetReader], return dataset_reader, training_data_loader, validation_data_loader, vocabulary + def _read_property_from_config(property_key: str, params: Dict[str, Any], logging_prefix: str, @@ -212,13 +223,19 @@ def read_vocabulary_from_config(params: Dict[str, Any], return vocabulary -def read_model_from_config(logging_prefix: str) -> Optional[Tuple[Model, DatasetReader, DataLoader, DataLoader, Vocabulary]]: +def read_model_from_config(logging_prefix: str) -> Optional[ + Tuple[Model, DatasetReader, DataLoader, DataLoader, Vocabulary]]: try: checks.file_exists(FLAGS.config_path) except ConfigurationError as e: handle_error(e, logging_prefix) return + if FLAGS.serialization_dir is None: + logger.error(f'--serialization_dir was not passed as an argument!') + print(f'--serialization_dir was not passed as an argument!') + return + with open(FLAGS.config_path, 'r') as f: params = json.load(f) @@ -235,12 +252,14 @@ def read_model_from_config(logging_prefix: str) -> Optional[Tuple[Model, Dataset dataset_reader = read_dataset_reader_from_config(params, logging_prefix, pass_down_parameters) training_data_loader = read_data_loader_from_config(params, logging_prefix, validation=False, pass_down_parameters=pass_down_parameters) - if (not FLAGS.validation_data_path or not FLAGS.finetuning_validation_data_path) and 'validation_data_loader' in params: + if ( + not FLAGS.validation_data_path or not FLAGS.finetuning_validation_data_path) and 'validation_data_loader' in params: logger.warning('Validation data loader is in parameters, but no validation data path was provided!') validation_data_loader = None else: validation_data_loader = read_data_loader_from_config(params, logging_prefix, - validation=True, pass_down_parameters=pass_down_parameters) + validation=True, + pass_down_parameters=pass_down_parameters) vocabulary = read_vocabulary_from_config(params, logging_prefix, pass_down_parameters) dataset_reader, training_data_loader, validation_data_loader, vocabulary = get_defaults( @@ -273,7 +292,8 @@ def run(_): if FLAGS.config_path: logger.info(f'Reading parameters from configuration path {FLAGS.config_path}', prefix=prefix) - model, dataset_reader, training_data_loader, validation_data_loader, vocabulary = read_model_from_config(prefix) + model, dataset_reader, training_data_loader, validation_data_loader, vocabulary = read_model_from_config( + prefix) else: dataset_reader, training_data_loader, validation_data_loader, vocabulary = get_defaults( dataset_reader,