diff --git a/combo/config/from_parameters.py b/combo/config/from_parameters.py index da9dafbdb352fb24d429f2d9464aacc8bb0fdbbc..c12d5aa8d8a797c0988da1b276f0f00934fd7603 100644 --- a/combo/config/from_parameters.py +++ b/combo/config/from_parameters.py @@ -221,3 +221,12 @@ def override_parameters(parameters: Dict[str, Any], override_values: Dict[str, A overriden_parameters[ko] = vo return unflatten_dictionary(overriden_parameters) + + +def override_or_add_parameters(parameters: Dict[str, Any], override_values: Dict[str, Any]) -> Dict[str, Any]: + overriden_parameters = flatten_dictionary(parameters) + override_values = flatten_dictionary(override_values) + for ko, vo in override_values.items(): + overriden_parameters[ko] = vo + + return unflatten_dictionary(overriden_parameters) diff --git a/combo/data/tokenizers/pretrained_transformer_tokenizer.py b/combo/data/tokenizers/pretrained_transformer_tokenizer.py index cb12d2a050b6e0748a8b10fd068d7492d33ae107..606b5ddcf7734961f7877fa84867ec5f71c44771 100644 --- a/combo/data/tokenizers/pretrained_transformer_tokenizer.py +++ b/combo/data/tokenizers/pretrained_transformer_tokenizer.py @@ -19,7 +19,7 @@ from combo.utils import sanitize_wordpiece logger = logging.getLogger(__name__) -@Registry.register( 'pretrained_transformer_tokenizer') +@Registry.register('pretrained_transformer_tokenizer') class PretrainedTransformerTokenizer(Tokenizer): """ A `PretrainedTransformerTokenizer` uses a model from HuggingFace's diff --git a/combo/default_model.py b/combo/default_model.py index 96223570aec1fbfffd92b2150f5f9459b54b27ca..b3fe5dbf630332844753101333094a5e2cb064cf 100644 --- a/combo/default_model.py +++ b/combo/default_model.py @@ -42,7 +42,7 @@ def default_character_indexer(namespace=None, ) -def default_ud_dataset_reader() -> UniversalDependenciesDatasetReader: +def default_ud_dataset_reader(pretrained_transformer_name: str) -> UniversalDependenciesDatasetReader: return UniversalDependenciesDatasetReader( features=["token", "char"], lemma_indexers={ @@ -53,7 +53,7 @@ def default_ud_dataset_reader() -> UniversalDependenciesDatasetReader: "char": default_character_indexer(), "feats": TokenFeatsIndexer(), "lemma": default_character_indexer(), - "token": PretrainedTransformerFixedMismatchedIndexer("bert-base-cased"), + "token": PretrainedTransformerFixedMismatchedIndexer(pretrained_transformer_name), "upostag": SingleIdTokenIndexer( feature_name="pos_", namespace="upostag" @@ -89,7 +89,7 @@ def default_vocabulary(data_loader: DataLoader) -> Vocabulary: ) -def default_model(vocabulary: Vocabulary) -> ComboModel: +def default_model(pretrained_transformer_name: str, vocabulary: Vocabulary) -> ComboModel: return ComboModel( vocabulary=vocabulary, dependency_relation=DependencyRelationModel( @@ -188,7 +188,7 @@ def default_model(vocabulary: Vocabulary) -> ComboModel: ), embedding_dim=64 ), - "token": TransformersWordEmbedder("allegro/herbert-base-cased", projection_dim=100) + "token": TransformersWordEmbedder(pretrained_transformer_name, projection_dim=100) } ), upos_tagger=FeedForwardPredictor.from_vocab( diff --git a/combo/main.py b/combo/main.py index 1ad700c4d8382cf79ff0ea8562565e6c621ed673..9736646f6468d9c2293bada9e5397c0828fc9a79 100755 --- a/combo/main.py +++ b/combo/main.py @@ -3,7 +3,8 @@ import logging import os import pathlib import tempfile -from typing import Dict +from itertools import chain +from typing import Dict, Optional, Any import torch from absl import app, flags @@ -15,12 +16,14 @@ from combo.training.trainable_combo import TrainableCombo from combo.utils import checks, ComboLogger from combo.config import resolve -from combo.default_model import default_ud_dataset_reader, default_data_loader +from combo.default_model import default_ud_dataset_reader, default_data_loader, default_model from combo.modules.archival import load_archive, archive from combo.predict import COMBO from combo.data import api from config import override_parameters -from data import LamboTokenizer, Sentence +from config.from_parameters import override_or_add_parameters +from data import LamboTokenizer, Sentence, Vocabulary, DatasetReader +from data.dataset_loaders import DataLoader from utils import ConfigurationError logging.setLoggerClass(ComboLogger) @@ -29,10 +32,9 @@ _FEATURES = ["token", "char", "upostag", "xpostag", "lemma", "feats"] _TARGETS = ["deprel", "feats", "head", "lemma", "upostag", "xpostag", "semrel", "sent", "deps"] -def handle_error(error: Exception): +def handle_error(error: Exception, prefix: str): msg = getattr(error, 'message', str(error)) - logger.error(msg) - print(f'Error: {msg}') + logger.error(msg, prefix) FLAGS = flags.FLAGS @@ -50,6 +52,8 @@ flags.DEFINE_string(name="training_data_path", default="", help="Training data p flags.DEFINE_alias(name="training_data", original_name="training_data_path") flags.DEFINE_string(name="validation_data_path", default="", help="Validation data path(s)") flags.DEFINE_alias(name="validation_data", original_name="validation_data_path") +flags.DEFINE_string(name="vocabulary_path", default=None, + help="Stored vocabulary files. If not provided in training mode, vocabulary is built from training files") flags.DEFINE_integer(name="lemmatizer_embedding_dim", default=300, help="Lemmatizer embeddings dim") flags.DEFINE_integer(name="num_epochs", default=400, @@ -71,7 +75,7 @@ flags.DEFINE_boolean(name="tensorboard", default=False, help="When provided model will log tensorboard metrics.") flags.DEFINE_string(name="tensorboard_name", default="combo", help="Name of the model in TensorBoard logs.") -flags.DEFINE_string(name="config_path", default=str(pathlib.Path(__file__).parent / "config.json"), +flags.DEFINE_string(name="config_path", default="", help="Config file path.") # Finetune after training flags @@ -104,111 +108,180 @@ flags.DEFINE_enum(name="predictor_name", default="combo-lambo", enum_values=["combo", "combo-spacy", "combo-lambo"], help="Use predictor with whitespace, spacy or lambo (recommended) tokenizer.") + +def build_vocabulary_from_instances(training_data_loader: DataLoader, + validation_data_loader: Optional[DataLoader], + logging_prefix: str) -> Vocabulary: + logger.info('Building a vocabulary from instances.', prefix=logging_prefix) + instances = chain(training_data_loader.iter_instances(), + validation_data_loader.iter_instances()) \ + if validation_data_loader \ + else training_data_loader.iter_instances() + vocabulary = Vocabulary.from_instances_extended( + instances, + non_padded_namespaces=['head_labels'], + only_include_pretrained_words=False, + oov_token='_', + padding_token='__PAD__' + ) + return vocabulary + + +def _read_property_from_config(property_key: str, + params: Dict[str, Any], + logging_prefix: str) -> Optional[Any]: + property = None + if property_key in params: + logger.info(f'Reading {property_key.replace("_", " ")} from parameters.', prefix=logging_prefix) + try: + property = resolve(params[property_key]) + except Exception as e: + handle_error(e, logging_prefix) + return property + + +def read_dataset_reader_from_config(params: Dict[str, Any], logging_prefix: str) -> Optional[DataLoader]: + return _read_property_from_config('dataset_reader', params, logging_prefix) + + +def read_data_loader_from_config(params: Dict[str, Any], + logging_prefix: str, + validation: bool = False) -> Optional[DataLoader]: + key = 'validation_data_loader' if validation else 'data_loader' + return _read_property_from_config(key, params, logging_prefix) + + +def read_vocabulary_from_config(params: Dict[str, Any], + logging_prefix: str) -> Optional[Vocabulary]: + vocabulary = None + if "vocabulary" in params: + logger.info('Reading vocabulary from saved directory.', prefix=logging_prefix) + if 'directory' in params['vocabulary']['parameters']: + params['vocabulary']['parameters']['directory'] = os.path.join('/'.join(FLAGS.config_path.split('/')[:-1]), + params['vocabulary']['parameters'][ + 'directory']) + try: + vocabulary = resolve(params['vocabulary']) + except Exception as e: + handle_error(e, logging_prefix) + return vocabulary + + +def read_model_from_config(logging_prefix: str): + try: + checks.file_exists(FLAGS.config_path) + except ConfigurationError as e: + handle_error(e, logging_prefix) + return + + with open(FLAGS.config_path, 'r') as f: + params = json.load(f) + + if not FLAGS.use_pure_config: + params = override_parameters(params, _get_ext_vars(finetuning=False)) + if 'feats' not in FLAGS.targets and 'morphological_feat' in params['model']['parameters']: + del params['model']['parameters']['morphological_feat'] + + dataset_reader = read_dataset_reader_from_config(params, logging_prefix) + training_data_loader = read_data_loader_from_config(params, logging_prefix, validation=False) + validation_data_loader = read_data_loader_from_config(params, logging_prefix, validation=True) + vocabulary = read_vocabulary_from_config(params, logging_prefix) + + pass_down_parameters = {'vocabulary': vocabulary} + if not FLAGS.use_pure_config: + pass_down_parameters['model_name'] = FLAGS.pretrained_transformer_name + + logger.info('Resolving the model from parameters.', prefix=logging_prefix) + model = resolve(params['model'], + pass_down_parameters=pass_down_parameters) + + return model, vocabulary, training_data_loader, validation_data_loader, dataset_reader + + def run(_): if FLAGS.mode == 'train': + model, vocabulary, training_data_loader, validation_data_loader, dataset_reader = None, None, None, None, None + if not FLAGS.finetuning: prefix = 'Training' logger.info('Setting up the model for training', prefix=prefix) - try: - checks.file_exists(FLAGS.config_path) - except ConfigurationError as e: - handle_error(e) - return - logger.info(f'Reading parameters from configuration path {FLAGS.config_path}', prefix=prefix) - with open(FLAGS.config_path, 'r') as f: - params = json.load(f) - params = override_parameters(params, _get_ext_vars(True)) + if FLAGS.config_path: + logger.info(f'Reading parameters from configuration path {FLAGS.config_path}', prefix=prefix) + model, vocabulary, training_data_loader, validation_data_loader, dataset_reader = read_model_from_config(prefix) - if 'feats' not in FLAGS.features: - del params['model']['parameters']['morphological_feat'] + if FLAGS.use_pure_config and model is None: + logger.error('Error in configuration - model could not be read from parameters. ' + + 'Correct the configuration or use --nopure_config ' + + 'to use default models.') + return serialization_dir = tempfile.mkdtemp(prefix='combo', dir=FLAGS.serialization_dir) - params['vocabulary']['parameters']['directory'] = os.path.join('/'.join(FLAGS.config_path.split('/')[:-1]), - params['vocabulary']['parameters'][ - 'directory']) + training_data_path = FLAGS.training_data_path + validation_data_path = FLAGS.validation_data_path - try: - vocabulary = resolve(params['vocabulary']) - except Exception as e: - handle_error(e) - return + else: + prefix = 'Finetuning' try: - model = resolve(override_parameters(params['model'], _get_ext_vars(False)), - pass_down_parameters={'vocabulary': vocabulary}) - except Exception as e: - handle_error(e) + checks.file_exists(FLAGS.finetuning_training_data_path) + if FLAGS.finetuning_validation_data_path: + checks.file_exists(FLAGS.finetuning_validation_data_path) + except ConfigurationError as e: + handle_error(e, prefix) + + logger.info('Loading the model for finetuning', prefix=prefix) + model, _, training_data_loader, validation_data_loader, dataset_reader = load_archive(FLAGS.model_path) + + if model is None: + logger.error(f'Model could not be loaded from archive {str(FLAGS.model_path)}. Exiting', prefix=prefix) return - dataset_reader = None + vocabulary = model.vocab + + serialization_dir = tempfile.mkdtemp(prefix='combo', suffix='-finetuning', dir=FLAGS.serialization_dir) + + training_data_path = FLAGS.finetuning_training_data_path + validation_data_path = FLAGS.finetuning_validation_data_path - if 'data_loader' in params: - logger.info(f'Resolving the training data loader from parameters', prefix=prefix) - try: - train_data_loader = resolve(params['data_loader']) - except Exception as e: - handle_error(e) - return + if not dataset_reader and (FLAGS.test_data_path + or not training_data_loader + or (FLAGS.validation_data_path and not validation_data_loader)): + # Dataset reader is required to read training data and/or for training (and validation) data loader + dataset_reader = default_ud_dataset_reader(FLAGS.pretrained_transformer_name) + + if not training_data_loader: + training_data_loader = default_data_loader(dataset_reader, training_data_path) + else: + if training_data_path: + training_data_loader.data_path = training_data_path else: - checks.file_exists(FLAGS.training_data_path) - logger.info(f'Using a default UD data loader with training data path {FLAGS.training_data_path}', - prefix=prefix) - try: - train_data_loader = default_data_loader(default_ud_dataset_reader(), - FLAGS.training_data_path) - except Exception as e: - handle_error(e) - return - - logger.info('Indexing training data loader') - train_data_loader.index_with(model.vocab) - - validation_data_loader = None - - if 'validation_data_loader' in params: - logger.info(f'Resolving the validation data loader from parameters', prefix=prefix) - validation_data_loader = resolve(params['validation_data_loader']) - logger.info('Indexing validation data loader', prefix=prefix) - validation_data_loader.index_with(model.vocab) - elif FLAGS.validation_data_path: - checks.file_exists(FLAGS.validation_data_path) - logger.info(f'Using a default UD data loader with validation data path {FLAGS.training_data_path}', - prefix=prefix) - validation_data_loader = default_data_loader(default_ud_dataset_reader(), - FLAGS.validation_data_path) - logger.info('Indexing validation data loader', prefix=prefix) - validation_data_loader.index_with(model.vocab) + logger.warning(f'No training data path provided - using the path from configuration: ' + + str(training_data_loader.data_path), prefix=prefix) + if FLAGS.validation_data_path and not validation_data_loader: + validation_data_loader = default_data_loader(dataset_reader, validation_data_path) else: - prefix = 'Finetuning' - logger.info('Loading the model for finetuning', prefix=prefix) - model, _, train_data_loader, validation_data_loader, dataset_reader = load_archive(FLAGS.model_path) + if validation_data_path: + validation_data_loader.data_path = validation_data_path + else: + logger.warning(f'No validation data path provided - using the path from configuration: ' + + str(validation_data_loader.data_path), prefix=prefix) - serialization_dir = tempfile.mkdtemp(prefix='combo', suffix='-finetuning', dir=FLAGS.serialization_dir) + if not vocabulary: + vocabulary = build_vocabulary_from_instances(training_data_loader, validation_data_loader, prefix) - if not train_data_loader: - checks.file_exists(FLAGS.finetuning_training_data_path) - logger.info( - f'Using a default UD data loader with training data path {FLAGS.finetuning_training_data_path}', - prefix=prefix) - train_data_loader = default_data_loader(default_ud_dataset_reader(), - FLAGS.finetuning_training_data_path) - if not validation_data_loader and FLAGS.finetuning_validation_data_path: - checks.file_exists(FLAGS.finetuning_validation_data_path) - logger.info( - f'Using a default UD data loader with validation data path {FLAGS.finetuning_validation_data_path}', - prefix=prefix) - validation_data_loader = default_data_loader(default_ud_dataset_reader(), - FLAGS.finetuning_validation_data_path) - logger.info("Indexing train loader", prefix=prefix) - train_data_loader.index_with(model.vocab) - logger.info("Indexing validation loader", prefix=prefix) + if not model: + model = default_model(FLAGS.pretrained_transformer_name, vocabulary) + + logger.info('Indexing training data loader', prefix=prefix) + training_data_loader.index_with(model.vocab) + logger.info('Indexing validation data loader', prefix=prefix) validation_data_loader.index_with(model.vocab) - logger.info("Indexed", prefix=prefix) - nlp = TrainableCombo(model, torch.optim.Adam, + nlp = TrainableCombo(model, + torch.optim.Adam, optimizer_kwargs={'betas': [0.9, 0.9], 'lr': 0.002}, validation_metrics=['EM']) @@ -222,31 +295,12 @@ def run(_): gradient_clip_val=5, devices=n_cuda_devices, logger=tensorboard_logger) - try: - trainer.fit(model=nlp, train_dataloaders=train_data_loader, val_dataloaders=validation_data_loader) - except Exception as e: - handle_error(e) - return + trainer.fit(model=nlp, train_dataloaders=training_data_loader, val_dataloaders=validation_data_loader) logger.info(f'Archiving the model in {serialization_dir}', prefix=prefix) - archive(model, serialization_dir, train_data_loader, validation_data_loader, dataset_reader) + archive(model, serialization_dir, training_data_loader, validation_data_loader, dataset_reader) logger.info(f"Model stored in: {serialization_dir}", prefix=prefix) - if FLAGS.test_data_path and FLAGS.output_file: - checks.file_exists(FLAGS.test_data_path) - if not dataset_reader: - logger.info( - "No dataset reader in the configuration or archive file - using a default UD dataset reader", - prefix=prefix) - dataset_reader = default_ud_dataset_reader() - logger.info("Predicting test examples", prefix=prefix) - test_trees = dataset_reader.read(FLAGS.test_data_path) - predictor = COMBO(model, dataset_reader) - with open(FLAGS.output_file, "w") as file: - for tree in tqdm(test_trees): - file.writelines(api.sentence2conllu(predictor.predict_instance(tree), - keep_semrel=dataset_reader.use_sem).serialize()) - elif FLAGS.mode == 'predict': prefix = 'Predicting' logger.info('Loading the model', prefix=prefix) @@ -321,10 +375,10 @@ def _get_ext_vars(finetuning: bool = False) -> Dict: } }, "data_loader": { - "data_path": (",".join(FLAGS.training_data_path if not finetuning else FLAGS.finetuning_training_data_path)), - "batch_size": FLAGS.batch_size, - "batches_per_epoch": FLAGS.batches_per_epoch, "parameters": { + "data_path": FLAGS.training_data_path if not finetuning else FLAGS.finetuning_training_data_path, + "batch_size": FLAGS.batch_size, + "batches_per_epoch": FLAGS.batches_per_epoch, "reader": { "parameters": { "features": FLAGS.features, @@ -341,10 +395,10 @@ def _get_ext_vars(finetuning: bool = False) -> Dict: } }, "validation_data_loader": { - "data_path": (",".join(FLAGS.validation_data_path if not finetuning else FLAGS.finetuning_validation_data_path)), - "batch_size": FLAGS.batch_size, - "batches_per_epoch": FLAGS.batches_per_epoch, "parameters": { + "data_path": FLAGS.validation_data_path if not finetuning else FLAGS.finetuning_validation_data_path, + "batch_size": FLAGS.batch_size, + "batches_per_epoch": FLAGS.batches_per_epoch, "reader": { "parameters": { "features": FLAGS.features, @@ -375,6 +429,16 @@ def _get_ext_vars(finetuning: bool = False) -> Dict: } } + if FLAGS.vocabulary_path: + to_override["vocabulary"] = { + "type": "from_files_vocabulary", + "parameters": { + "directory": FLAGS.vocabulary_path, + "padding_token": "__PAD__", + "oov_token": "_" + } + } + return to_override diff --git a/combo/predict.py b/combo/predict.py index 6743d8f9d68488f06171700deb624b3015ddbcff..c4218f58e5f2cb3bd88433fa934d0420f9027310 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -30,7 +30,7 @@ class COMBO(PredictorModule): def __init__(self, model: Model, dataset_reader: DatasetReader, - tokenizer: data.Tokenizer = tokenizers.LamboTokenizer(), + tokenizer: data.Tokenizer, batch_size: int = 1024, line_to_conllu: bool = True) -> None: super().__init__(model, dataset_reader)