diff --git a/combo/modules/archival.py b/combo/modules/archival.py index 3b39f3d7e3d31b91b76a6a3a3d0501a471e133bf..56a293e7cd54f944d25c5396e77da1bb6cd38d6c 100644 --- a/combo/modules/archival.py +++ b/combo/modules/archival.py @@ -14,7 +14,7 @@ from combo.config import resolve from combo.data.dataset_loaders import DataLoader from combo.data.dataset_readers import DatasetReader from combo.modules.model import Model - +from combo.utils import ConfigurationError CACHE_ROOT = Path(os.getenv("COMBO_CACHE_ROOT", Path.home() / ".combo")) CACHE_DIRECTORY = str(CACHE_ROOT / "cache") @@ -92,7 +92,12 @@ def load_archive(url_or_filename: Union[PathLike, str], ) model = Model.load(archive_file, cuda_device=cuda_device) - with open(os.path.join(archive_file, 'config.json'), 'r') as f: + config_path = os.path.join(archive_file, 'config.json') + if not os.path.exists(config_path): + config_path = os.path.join(archive_file, 'model/config.json') + if not os.path.exists(config_path): + raise ConfigurationError("config.json is not stored in " + str(archive_file) + " or " + str(archive_file) + "/model") + with open(config_path, 'r') as f: config = json.load(f) data_loader, validation_data_loader, dataset_reader = None, None, None diff --git a/combo/modules/model.py b/combo/modules/model.py index 64bb5667992ea4e6f99f2911c9b685f201aad693..5f2beb699179cdc7a0ef4f4b1b03757e6c62e264 100644 --- a/combo/modules/model.py +++ b/combo/modules/model.py @@ -349,6 +349,10 @@ class Model(Module, pl.LightningModule, FromParameters): # Load vocabulary from file vocab_dir = os.path.join(serialization_dir, "vocabulary") + if not os.path.exists(vocab_dir): + vocab_dir =os.path.join(serialization_dir, "model/vocabulary") + if not os.path.exists(vocab_dir): + raise ConfigurationError("Vocabulary not saved in " + serialization_dir + " or " + serialization_dir + "/model") # If the config specifies a vocabulary subclass, we need to use it. vocab_params = config.get("vocabulary") if vocab_params['type'] == 'from_files_vocabulary': @@ -403,12 +407,12 @@ class Model(Module, pl.LightningModule, FromParameters): filter_out_authorized_missing_keys(model) - # if unexpected_keys or missing_keys: - # raise RuntimeError( - # f"Error loading state dict for {model.__class__.__name__}\n\t" - # f"Missing keys: {missing_keys}\n\t" - # f"Unexpected keys: {unexpected_keys}" - # ) + if unexpected_keys or missing_keys: + raise RuntimeError( + f"Error loading state dict for {model.__class__.__name__}\n\t" + f"Missing keys: {missing_keys}\n\t" + f"Unexpected keys: {unexpected_keys}" + ) return model @@ -447,13 +451,21 @@ class Model(Module, pl.LightningModule, FromParameters): vocabulary and the trained weights. """ if config is None: - with open(os.path.join(serialization_dir, 'config.json'), 'r') as f: - config = json.load(f) + try: + with open(os.path.join(serialization_dir, 'config.json'), 'r') as f: + config = json.load(f) + except FileNotFoundError: + with open(os.path.join(serialization_dir, 'model/config.json'), 'r') as f: + config = json.load(f) elif isinstance(config, str) or isinstance(config, PathLike): with open(config, 'r') as f: config = json.load(f) weights_file = weights_file or os.path.join(serialization_dir, 'weights.th') + if not os.path.exists(weights_file): + weights_file = os.path.join(serialization_dir, 'model/weights.th') + if not os.path.exists(weights_file): + raise ConfigurationError("weights.th not in " + serialization_dir + " or " + serialization_dir + "/model") # Peak at the class of the model. model_type = (