diff --git a/combo/combo_model.py b/combo/combo_model.py index 2747ad2dd904e40f218dcb2bb316d5cfbbecaf04..b87eb1c0719fdd1e8fc4e1e4695559deb81644cb 100644 --- a/combo/combo_model.py +++ b/combo/combo_model.py @@ -32,7 +32,7 @@ class ComboModel(Model, FromParameters): @classmethod def pass_down_parameter_names(cls) -> List[str]: - return ['vocabulary', 'serialization_dir'] + return ['vocabulary', 'serialization_dir', 'model_name'] @register_arguments def __init__(self, @@ -67,6 +67,10 @@ class ComboModel(Model, FromParameters): self.scores = metrics.SemanticMetrics() self._partial_losses = None + self.model_name = getattr( + getattr(self.text_field_embedder, '_token_embedders', {}).get('token'), 'constructed_args' + ).get('model_name', '') + def forward_on_tensors(self, sentence: Dict[str, Dict[str, torch.Tensor]], metadata: List[Dict[str, Any]], diff --git a/combo/data/dataset_loaders/simple_data_loader.py b/combo/data/dataset_loaders/simple_data_loader.py index 0289275dde9da1f9e96948a0742becd2c0c6713c..a98dae156ade83281298c9a70e34301e9a28a7d1 100644 --- a/combo/data/dataset_loaders/simple_data_loader.py +++ b/combo/data/dataset_loaders/simple_data_loader.py @@ -12,7 +12,6 @@ import torch from combo.common.util import lazy_groups_of from combo.common.tqdm import Tqdm from combo.config import Registry -from combo.config.from_parameters import register_arguments from combo.data.dataset_loaders.data_collator import DefaultDataCollator from combo.data.dataset_readers import DatasetReader from combo.data.instance import Instance @@ -28,7 +27,6 @@ class SimpleDataLoader(DataLoader): """ A very simple `DataLoader` that is mostly used for testing. """ - @register_arguments def __init__( self, instances: List[Instance], @@ -90,7 +88,6 @@ class SimpleDataLoader(DataLoader): self.cuda_device = device @classmethod - @register_arguments def from_dataset_reader( cls, reader: DatasetReader, @@ -106,5 +103,13 @@ class SimpleDataLoader(DataLoader): instance_iter = Tqdm.tqdm(instance_iter, desc="loading instances") instances = list(instance_iter) new_obj = cls(instances, batch_size, shuffle=shuffle, batches_per_epoch=batches_per_epoch, collate_fn=collate_fn) + new_obj.constructed_args = { + 'reader': reader, + 'data_path': data_path, + 'batch_size': batch_size, + 'shuffle': shuffle, + 'batches_per_epoch': batches_per_epoch, + 'quiet': quiet + } new_obj.constructed_from = 'from_dataset_reader' return new_obj diff --git a/combo/data/dataset_readers/dataset_reader.py b/combo/data/dataset_readers/dataset_reader.py index b7bc935432c0767590ecb512dd0d964238546e53..fbd11595ac9b2f5f27b8da913686d90ec321475d 100644 --- a/combo/data/dataset_readers/dataset_reader.py +++ b/combo/data/dataset_readers/dataset_reader.py @@ -11,7 +11,7 @@ from torch.utils.data import IterableDataset from combo.config import FromParameters, Registry from combo.config.from_parameters import register_arguments -from combo.data import SpacyTokenizer, SingleIdTokenIndexer +from combo.data import LamboTokenizer, SingleIdTokenIndexer from combo.data.instance import Instance from combo.data.token_indexers import TokenIndexer from combo.data.tokenizers import Tokenizer @@ -35,7 +35,7 @@ class DatasetReader(IterableDataset, FromParameters): token_indexers: Optional[Dict[str, TokenIndexer]] = None) -> None: super(DatasetReader).__init__() # self.__file_path = None - self.__tokenizer = tokenizer or SpacyTokenizer() + self.__tokenizer = tokenizer or LamboTokenizer() self.__token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()} @property diff --git a/combo/data/dataset_readers/universal_dependencies_dataset_reader.py b/combo/data/dataset_readers/universal_dependencies_dataset_reader.py index eb336ae9131a03a0403f10070f966149ffd80ddd..5b726f3641f14cbcd19153b9dbd049415f3f3754 100644 --- a/combo/data/dataset_readers/universal_dependencies_dataset_reader.py +++ b/combo/data/dataset_readers/universal_dependencies_dataset_reader.py @@ -5,7 +5,7 @@ Author: Mateusz Klimaszewski import copy import pathlib from abc import ABC -from typing import List, Any, Dict, Iterable, Tuple +from typing import List, Any, Dict, Iterable, Tuple, Optional import conllu import torch @@ -24,6 +24,7 @@ from combo.data.fields.text_field import TextField from combo.data.token_indexers import TokenIndexer from combo.data.vocabulary import get_slices_if_not_provided from combo.utils import checks, pad_sequence_to_length +from combo.data import Tokenizer def parser_field_to_dataset_reader_field(field: str): @@ -40,6 +41,7 @@ class UniversalDependenciesDatasetReader(DatasetReader, ABC): @register_arguments def __init__( self, + tokenizer: Optional[Tokenizer] = None, token_indexers: Dict[str, TokenIndexer] = None, lemma_indexers: Dict[str, TokenIndexer] = None, features: List[str] = None, @@ -47,7 +49,7 @@ class UniversalDependenciesDatasetReader(DatasetReader, ABC): use_sem: bool = False, **kwargs, ) -> None: - super().__init__(token_indexers=token_indexers) + super().__init__(tokenizer=tokenizer, token_indexers=token_indexers) if features is None: features = ["token", "char"] if targets is None: diff --git a/combo/data/token_indexers/pretrained_transformer_indexer.py b/combo/data/token_indexers/pretrained_transformer_indexer.py index 3fc60e64d3d19dcab6d99c1c2106986625c8b6d5..a23bbeed03080f7e93e00d4dbaf1e731399cc0cc 100644 --- a/combo/data/token_indexers/pretrained_transformer_indexer.py +++ b/combo/data/token_indexers/pretrained_transformer_indexer.py @@ -282,7 +282,7 @@ class PretrainedTransformerIndexer(TokenIndexer): def __eq__(self, other): if isinstance(other, PretrainedTransformerIndexer): for key in self.__dict__: - if key == "_tokenizer": + if key == "tokenizer": # This is a reference to a function in the huggingface code, which we can't # really modify to make this clean. So we special-case it. continue diff --git a/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py b/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py index 161c71d2b1cf81f63aeb8c8e20d4ac82b16f6631..f096dbc3f203a9147db0c8274ce0947bd467973f 100644 --- a/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py +++ b/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py @@ -126,7 +126,7 @@ class PretrainedTransformerMismatchedIndexer(TokenIndexer): def __eq__(self, other): if isinstance(other, PretrainedTransformerMismatchedIndexer): for key in self.__dict__: - if key == "_tokenizer": + if key == "tokenizer": # This is a reference to a function in the huggingface code, which we can't # really modify to make this clean. So we special-case it. continue diff --git a/combo/main.py b/combo/main.py index 4faa916624c1fde4fb48110fadf12ff9292c6cea..441f36f0f36621d60773bacdad1ff9865d061d75 100755 --- a/combo/main.py +++ b/combo/main.py @@ -16,13 +16,12 @@ from combo.training.trainable_combo import TrainableCombo from combo.utils import checks, ComboLogger from combo.config import resolve -from combo.default_model import default_ud_dataset_reader, default_data_loader, default_model +from combo.default_model import default_ud_dataset_reader, default_data_loader from combo.modules.archival import load_archive, archive from combo.predict import COMBO from combo.data import api from config import override_parameters -from config.from_parameters import override_or_add_parameters -from data import LamboTokenizer, Sentence, Vocabulary, DatasetReader +from data import LamboTokenizer, Vocabulary, DatasetReader from data.dataset_loaders import DataLoader from modules.model import Model from utils import ConfigurationError @@ -169,31 +168,38 @@ def get_defaults(dataset_reader: Optional[DatasetReader], def _read_property_from_config(property_key: str, params: Dict[str, Any], - logging_prefix: str) -> Optional[Any]: + logging_prefix: str, + pass_down_parameters: Dict[str, Any] = None) -> Optional[Any]: property = None + pass_down_parameters = pass_down_parameters or {} if property_key in params: logger.info(f'Reading {property_key.replace("_", " ")} from parameters.', prefix=logging_prefix) try: - property = resolve(params[property_key]) + property = resolve(params[property_key], pass_down_parameters=pass_down_parameters) except Exception as e: handle_error(e, logging_prefix) return property -def read_dataset_reader_from_config(params: Dict[str, Any], logging_prefix: str) -> Optional[DataLoader]: - return _read_property_from_config('dataset_reader', params, logging_prefix) +def read_dataset_reader_from_config(params: Dict[str, Any], + logging_prefix: str, + pass_down_parameters: Dict[str, Any] = None) -> Optional[DataLoader]: + return _read_property_from_config('dataset_reader', params, logging_prefix, pass_down_parameters) def read_data_loader_from_config(params: Dict[str, Any], logging_prefix: str, - validation: bool = False) -> Optional[DataLoader]: + validation: bool = False, + pass_down_parameters: Dict[str, Any] = None) -> Optional[DataLoader]: key = 'validation_data_loader' if validation else 'data_loader' - return _read_property_from_config(key, params, logging_prefix) + return _read_property_from_config(key, params, logging_prefix, pass_down_parameters) def read_vocabulary_from_config(params: Dict[str, Any], - logging_prefix: str) -> Optional[Vocabulary]: + logging_prefix: str, + pass_down_parameters: Dict[str, Any] = None) -> Optional[Vocabulary]: vocabulary = None + pass_down_parameters = pass_down_parameters or {} if "vocabulary" in params: logger.info('Reading vocabulary from saved directory.', prefix=logging_prefix) if 'directory' in params['vocabulary']['parameters']: @@ -201,7 +207,7 @@ def read_vocabulary_from_config(params: Dict[str, Any], params['vocabulary']['parameters'][ 'directory']) try: - vocabulary = resolve(params['vocabulary']) + vocabulary = resolve(params['vocabulary'], pass_down_parameters) except Exception as e: handle_error(e, logging_prefix) return vocabulary @@ -222,10 +228,21 @@ def read_model_from_config(logging_prefix: str) -> Optional[Tuple[Model, Dataset if 'feats' not in FLAGS.targets and 'morphological_feat' in params['model']['parameters']: del params['model']['parameters']['morphological_feat'] - dataset_reader = read_dataset_reader_from_config(params, logging_prefix) - training_data_loader = read_data_loader_from_config(params, logging_prefix, validation=False) - validation_data_loader = read_data_loader_from_config(params, logging_prefix, validation=True) - vocabulary = read_vocabulary_from_config(params, logging_prefix) + pass_down_parameters = {} + model_name = params.get("model_name") + if model_name: + pass_down_parameters = {"model_name": model_name} + + dataset_reader = read_dataset_reader_from_config(params, logging_prefix, pass_down_parameters) + training_data_loader = read_data_loader_from_config(params, logging_prefix, + validation=False, pass_down_parameters=pass_down_parameters) + if (not FLAGS.validation_data_path or not FLAGS.finetuning_validation_data_path) and 'validation_data_loader' in params: + logger.warning('Validation data loader is in parameters, but no validation data path was provided!') + validation_data_loader = None + else: + validation_data_loader = read_data_loader_from_config(params, logging_prefix, + validation=True, pass_down_parameters=pass_down_parameters) + vocabulary = read_vocabulary_from_config(params, logging_prefix, pass_down_parameters) dataset_reader, training_data_loader, validation_data_loader, vocabulary = get_defaults( dataset_reader, @@ -265,6 +282,7 @@ def run(_): 'to use default models.') return + pathlib.Path(FLAGS.serialization_dir).mkdir(parents=True, exist_ok=True) serialization_dir = tempfile.mkdtemp(prefix='combo', dir=FLAGS.serialization_dir) else: @@ -286,6 +304,7 @@ def run(_): vocabulary = model.vocab + pathlib.Path(FLAGS.serialization_dir).mkdir(parents=True, exist_ok=True) serialization_dir = tempfile.mkdtemp(prefix='combo', suffix='-finetuning', dir=FLAGS.serialization_dir) dataset_reader, training_data_loader, validation_data_loader, vocabulary = get_defaults( @@ -352,6 +371,8 @@ def run(_): except ConfigurationError as e: handle_error(e, prefix) + pathlib.Path(FLAGS.output_file).mkdir(parents=True, exist_ok=True) + logger.info("Predicting examples from file", prefix=prefix) predictions = [] @@ -394,6 +415,7 @@ def _get_ext_vars(finetuning: bool = False) -> Dict: return {} to_override = { + "model_name": FLAGS.pretrained_transformer_name, "model": { "parameters": { "lemmatizer": { @@ -401,19 +423,6 @@ def _get_ext_vars(finetuning: bool = False) -> Dict: "embedding_dim": FLAGS.lemmatizer_embedding_dim } }, - "text_field_embedder": { - "parameters": { - "token_embedders": { - "parameters": { - "token": { - "parameters": { - "model_name": FLAGS.pretrained_transformer_name - } - } - } - } - } - }, "serialization_dir": FLAGS.serialization_dir } }, @@ -425,14 +434,7 @@ def _get_ext_vars(finetuning: bool = False) -> Dict: "reader": { "parameters": { "features": FLAGS.features, - "targets": FLAGS.targets, - "token_indexers": { - "token": { - "parameters": { - "model_name": FLAGS.pretrained_transformer_name - } - } - } + "targets": FLAGS.targets } } } @@ -445,14 +447,7 @@ def _get_ext_vars(finetuning: bool = False) -> Dict: "reader": { "parameters": { "features": FLAGS.features, - "targets": FLAGS.targets, - "token_indexers": { - "token": { - "parameters": { - "model_name": FLAGS.pretrained_transformer_name - } - } - } + "targets": FLAGS.targets } } } @@ -460,14 +455,7 @@ def _get_ext_vars(finetuning: bool = False) -> Dict: "dataset_reader": { "parameters": { "features": FLAGS.features, - "targets": FLAGS.targets, - "token_indexers": { - "token": { - "parameters": { - "model_name": FLAGS.pretrained_transformer_name - } - } - } + "targets": FLAGS.targets } } } diff --git a/combo/modules/archival.py b/combo/modules/archival.py index aae88ddc69b8f0c37fc06c49f03659fe0f096d8e..3b39f3d7e3d31b91b76a6a3a3d0501a471e133bf 100644 --- a/combo/modules/archival.py +++ b/combo/modules/archival.py @@ -48,7 +48,7 @@ def archive(model: Model, 'padding_token': model.vocab._padding_token, 'oov_token': model.vocab._oov_token } - }, 'model': model.serialize(pass_down_parameter_names=['vocabulary', 'optimizer', 'scheduler'])} + }, 'model': model.serialize(pass_down_parameter_names=['vocabulary', 'optimizer', 'scheduler', 'model_name'])} if data_loader: parameters['data_loader'] = data_loader.serialize() @@ -65,6 +65,9 @@ def archive(model: Model, if model.scheduler: parameters['training']['scheduler'] = model.scheduler.serialize(pass_down_parameter_names=['optimizer']) + if model.model_name: + parameters['model_name'] = model.model_name + with (TemporaryDirectory(os.path.join('tmp')) as t, BytesIO() as out_stream, tarfile.open(os.path.join(serialization_dir, 'model.tar.gz'), 'w|gz') as tar_file): @@ -93,13 +96,16 @@ def load_archive(url_or_filename: Union[PathLike, str], config = json.load(f) data_loader, validation_data_loader, dataset_reader = None, None, None + pass_down_parameters = {} + if config.get("model_name"): + pass_down_parameters = {"model_name": config.get("model_name")} if 'data_loader' in config: - data_loader = resolve(config['data_loader']) + data_loader = resolve(config['data_loader'], pass_down_parameters=pass_down_parameters) if 'validation_data_loader' in config: - validation_data_loader = resolve(config['validation_data_loader']) + validation_data_loader = resolve(config['validation_data_loader'], pass_down_parameters=pass_down_parameters) if 'dataset_reader' in config: - dataset_reader = resolve(config['dataset_reader']) + dataset_reader = resolve(config['dataset_reader'], pass_down_parameters=pass_down_parameters) return Archive(model=model, config=config, diff --git a/combo/modules/model.py b/combo/modules/model.py index 846545942b71d9af7c7e2ad3e678dcfe2e3bf0bb..64bb5667992ea4e6f99f2911c9b685f201aad693 100644 --- a/combo/modules/model.py +++ b/combo/modules/model.py @@ -362,7 +362,7 @@ class Model(Module, pl.LightningModule, FromParameters): # stored in our model. We don't need any pretrained weight file or initializers anymore, # and we don't want the code to look for it, so we remove it from the parameters here. remove_keys_from_params(model_params) - model = resolve(model_params) + model = resolve(model_params, pass_down_parameters={'model_name': config.get("model_name", "")}) # Force model to cpu or gpu, as appropriate, to make sure that the embeddings are # in sync with the weights diff --git a/combo/predict.py b/combo/predict.py index e30889d33eaa59956d08ea73a888b351e00f8103..4450d2f33a8881156cb60ccac2bde96903dd477b 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -11,7 +11,7 @@ from combo import data from combo.common import util from combo.config import Registry from combo.config.from_parameters import register_arguments -from combo.data import tokenizers, Instance, conllu2sentence, tokens2conllu, sentence2conllu +from combo.data import Instance, conllu2sentence, tokens2conllu, sentence2conllu from combo.data.dataset_loaders.dataset_loader import TensorDict from combo.data.dataset_readers.dataset_reader import DatasetReader from combo.data.instance import JsonDict @@ -30,7 +30,6 @@ class COMBO(PredictorModule): def __init__(self, model: Model, dataset_reader: DatasetReader, - tokenizer: data.Tokenizer, batch_size: int = 1024, line_to_conllu: bool = True) -> None: super().__init__(model, dataset_reader) @@ -38,7 +37,7 @@ class COMBO(PredictorModule): self.vocab = model.vocab self.dataset_reader.generate_labels = False self.dataset_reader.lazy = True - self._tokenizer = tokenizer + self.tokenizer = dataset_reader.tokenizer self.without_sentence_embedding = False self.line_to_conllu = line_to_conllu @@ -136,7 +135,7 @@ class COMBO(PredictorModule): def _json_to_instance(self, json_dict) -> Instance: sentence = json_dict["sentence"] if isinstance(sentence, str): - tokens = [t.text for t in self._tokenizer.tokenize(json_dict["sentence"])] + tokens = [t.text for t in self.tokenizer.tokenize(json_dict["sentence"])] elif isinstance(sentence, list): tokens = sentence else: @@ -257,16 +256,9 @@ class COMBO(PredictorModule): return tree, predictions["sentence_embedding"], embeddings, \ deprel_tree_distribution, deprel_label_distribution - - @classmethod - def with_spacy_tokenizer(cls, model: Model, - dataset_reader: DatasetReader): - return cls(model, dataset_reader, tokenizers.SpacyTokenizer()) - @classmethod def from_pretrained(cls, path: str, - tokenizer=tokenizers.SpacyTokenizer(), batch_size: int = 1024, cuda_device: int = -1): if os.path.exists(path): @@ -281,5 +273,5 @@ class COMBO(PredictorModule): archive = load_archive(model_path, cuda_device=cuda_device) model = archive.model - dataset_reader = archive.dataset_reader or default_ud_dataset_reader() - return cls(model, dataset_reader, tokenizer, batch_size) + dataset_reader = archive.dataset_reader or default_ud_dataset_reader(archive.config.get("model_name")) + return cls(model, dataset_reader, batch_size) diff --git a/tests/config/test_archive.py b/tests/config/test_archive.py index f8bae7a29c9687c6ff78a3dfff5d2638aac87ffc..36aebab52ec164de962f2d3155c82c6ee1dd3b1a 100644 --- a/tests/config/test_archive.py +++ b/tests/config/test_archive.py @@ -18,7 +18,7 @@ def _test_vocabulary() -> Vocabulary: class ArchivalTest(unittest.TestCase): def test_serialize_model(self): vocabulary = _test_vocabulary() - model = default_model(vocabulary) + model = default_model('bert-base-cased', vocabulary) t = '.' with TemporaryDirectory(TEMP_FILE_PATH) as t: archive(model, t)