diff --git a/combo/modules/archival.py b/combo/modules/archival.py index 56a293e7cd54f944d25c5396e77da1bb6cd38d6c..496753f711f60fce9555d724fe72c7e49c7a8435 100644 --- a/combo/modules/archival.py +++ b/combo/modules/archival.py @@ -1,4 +1,6 @@ import os +import shutil +import tempfile from os import PathLike from pathlib import Path from typing import Any, Dict, Union, NamedTuple, Optional @@ -14,11 +16,19 @@ from combo.config import resolve from combo.data.dataset_loaders import DataLoader from combo.data.dataset_readers import DatasetReader from combo.modules.model import Model -from combo.utils import ConfigurationError +from contextlib import contextmanager + +import logging +from combo.utils import ComboLogger + +logging.setLoggerClass(ComboLogger) +logger = logging.getLogger(__name__) + CACHE_ROOT = Path(os.getenv("COMBO_CACHE_ROOT", Path.home() / ".combo")) CACHE_DIRECTORY = str(CACHE_ROOT / "cache") +PREFIX = 'Loading archive' class Archive(NamedTuple): model: Model @@ -82,36 +92,73 @@ def archive(model: Model, return serialization_dir +@contextmanager +def extracted_archive(resolved_archive_file, cleanup=True): + tempdir = None + try: + tempdir = tempfile.mkdtemp(dir=CACHE_DIRECTORY) + with tarfile.open(resolved_archive_file) as archive: + subdir_and_files = [ + tarinfo for tarinfo in archive.getmembers() + if (any([tarinfo.name.endswith(f) for f in ['config.json', 'weights.th']]) + or 'vocabulary' in tarinfo.name) + ] + for f in subdir_and_files: + if 'vocabulary' in f.name and not f.name.endswith('vocabulary'): + f.name = os.path.join('vocabulary', os.path.basename(f.name)) + else: + f.name = os.path.basename(f.name) + archive.extractall(path=tempdir, members=subdir_and_files) + yield tempdir + finally: + if tempdir is not None and cleanup: + shutil.rmtree(tempdir, ignore_errors=True) + + def load_archive(url_or_filename: Union[PathLike, str], cache_dir: Union[PathLike, str] = None, cuda_device: int = -1) -> Archive: - archive_file = cached_path.cached_path( + + rarchive_file = cached_path.cached_path( url_or_filename, cache_dir=cache_dir or CACHE_DIRECTORY, - extract_archive=True ) - model = Model.load(archive_file, cuda_device=cuda_device) - - config_path = os.path.join(archive_file, 'config.json') - if not os.path.exists(config_path): - config_path = os.path.join(archive_file, 'model/config.json') - if not os.path.exists(config_path): - raise ConfigurationError("config.json is not stored in " + str(archive_file) + " or " + str(archive_file) + "/model") - with open(config_path, 'r') as f: - config = json.load(f) - - data_loader, validation_data_loader, dataset_reader = None, None, None - pass_down_parameters = {} - if config.get("model_name"): - pass_down_parameters = {"model_name": config.get("model_name")} - - if 'data_loader' in config: - data_loader = resolve(config['data_loader'], pass_down_parameters=pass_down_parameters) - if 'validation_data_loader' in config: - validation_data_loader = resolve(config['validation_data_loader'], pass_down_parameters=pass_down_parameters) - if 'dataset_reader' in config: - dataset_reader = resolve(config['dataset_reader'], pass_down_parameters=pass_down_parameters) - + + with extracted_archive(rarchive_file) as archive_file: + model = Model.load(archive_file, cuda_device=cuda_device) + + config_path = os.path.join(archive_file, 'config.json') + with open(config_path, 'r') as f: + config = json.load(f) + + data_loader, validation_data_loader, dataset_reader = None, None, None + pass_down_parameters = {} + if config.get("model_name"): + pass_down_parameters = {"model_name": config.get("model_name")} + + + if 'data_loader' in config: + try: + data_loader = resolve(config['data_loader'], + pass_down_parameters=pass_down_parameters) + except Exception as e: + logger.warning(f'Error while loading Training Data Loader: {str(e)}. Setting Data Loader to None', + prefix=PREFIX) + if 'validation_data_loader' in config: + try: + validation_data_loader = resolve(config['validation_data_loader'], + pass_down_parameters=pass_down_parameters) + except Exception as e: + logger.warning(f'Error while loading Validation Data Loader: {str(e)}. Setting Data Loader to None', + prefix=PREFIX) + if 'dataset_reader' in config: + try: + dataset_reader = resolve(config['dataset_reader'], + pass_down_parameters=pass_down_parameters) + except Exception as e: + logger.warning(f'Error while loading Dataset Reader: {str(e)}. Setting Dataset Reader to None', + prefix=PREFIX) + return Archive(model=model, config=config, data_loader=data_loader, diff --git a/combo/modules/model.py b/combo/modules/model.py index 5f2beb699179cdc7a0ef4f4b1b03757e6c62e264..223bacc2ce07ad77e57876a8be367d46b92c9b58 100644 --- a/combo/modules/model.py +++ b/combo/modules/model.py @@ -349,10 +349,6 @@ class Model(Module, pl.LightningModule, FromParameters): # Load vocabulary from file vocab_dir = os.path.join(serialization_dir, "vocabulary") - if not os.path.exists(vocab_dir): - vocab_dir =os.path.join(serialization_dir, "model/vocabulary") - if not os.path.exists(vocab_dir): - raise ConfigurationError("Vocabulary not saved in " + serialization_dir + " or " + serialization_dir + "/model") # If the config specifies a vocabulary subclass, we need to use it. vocab_params = config.get("vocabulary") if vocab_params['type'] == 'from_files_vocabulary': diff --git a/combo/utils/logging.py b/combo/utils/logging.py index cd04b6a3ebff83682b92fffac9a02add90d2debc..2d8d99a39236629b7d7844eed10c42acf463312c 100644 --- a/combo/utils/logging.py +++ b/combo/utils/logging.py @@ -27,7 +27,7 @@ class ComboLogger(logging.Logger): self.log(level=logging.INFO, msg=msg, prefix=prefix) @overrides(check_signature=False) - def warn(self, msg: str, prefix: str = None): + def warning(self, msg: str, prefix: str = None): self.log(level=logging.WARN, msg=msg, prefix=prefix) @overrides(check_signature=False)