Skip to content
Snippets Groups Projects
Commit 451eaad5 authored by Maja Jablonska's avatar Maja Jablonska
Browse files

Workaround for different paths of archival models

parent 3c47dbf1
Branches
No related tags found
1 merge request!46Merge COMBO 3.0 into master
......@@ -14,7 +14,7 @@ from combo.config import resolve
from combo.data.dataset_loaders import DataLoader
from combo.data.dataset_readers import DatasetReader
from combo.modules.model import Model
from combo.utils import ConfigurationError
CACHE_ROOT = Path(os.getenv("COMBO_CACHE_ROOT", Path.home() / ".combo"))
CACHE_DIRECTORY = str(CACHE_ROOT / "cache")
......@@ -92,7 +92,12 @@ def load_archive(url_or_filename: Union[PathLike, str],
)
model = Model.load(archive_file, cuda_device=cuda_device)
with open(os.path.join(archive_file, 'config.json'), 'r') as f:
config_path = os.path.join(archive_file, 'config.json')
if not os.path.exists(config_path):
config_path = os.path.join(archive_file, 'model/config.json')
if not os.path.exists(config_path):
raise ConfigurationError("config.json is not stored in " + str(archive_file) + " or " + str(archive_file) + "/model")
with open(config_path, 'r') as f:
config = json.load(f)
data_loader, validation_data_loader, dataset_reader = None, None, None
......
......@@ -349,6 +349,10 @@ class Model(Module, pl.LightningModule, FromParameters):
# Load vocabulary from file
vocab_dir = os.path.join(serialization_dir, "vocabulary")
if not os.path.exists(vocab_dir):
vocab_dir =os.path.join(serialization_dir, "model/vocabulary")
if not os.path.exists(vocab_dir):
raise ConfigurationError("Vocabulary not saved in " + serialization_dir + " or " + serialization_dir + "/model")
# If the config specifies a vocabulary subclass, we need to use it.
vocab_params = config.get("vocabulary")
if vocab_params['type'] == 'from_files_vocabulary':
......@@ -403,12 +407,12 @@ class Model(Module, pl.LightningModule, FromParameters):
filter_out_authorized_missing_keys(model)
# if unexpected_keys or missing_keys:
# raise RuntimeError(
# f"Error loading state dict for {model.__class__.__name__}\n\t"
# f"Missing keys: {missing_keys}\n\t"
# f"Unexpected keys: {unexpected_keys}"
# )
if unexpected_keys or missing_keys:
raise RuntimeError(
f"Error loading state dict for {model.__class__.__name__}\n\t"
f"Missing keys: {missing_keys}\n\t"
f"Unexpected keys: {unexpected_keys}"
)
return model
......@@ -447,13 +451,21 @@ class Model(Module, pl.LightningModule, FromParameters):
vocabulary and the trained weights.
"""
if config is None:
try:
with open(os.path.join(serialization_dir, 'config.json'), 'r') as f:
config = json.load(f)
except FileNotFoundError:
with open(os.path.join(serialization_dir, 'model/config.json'), 'r') as f:
config = json.load(f)
elif isinstance(config, str) or isinstance(config, PathLike):
with open(config, 'r') as f:
config = json.load(f)
weights_file = weights_file or os.path.join(serialization_dir, 'weights.th')
if not os.path.exists(weights_file):
weights_file = os.path.join(serialization_dir, 'model/weights.th')
if not os.path.exists(weights_file):
raise ConfigurationError("weights.th not in " + serialization_dir + " or " + serialization_dir + "/model")
# Peak at the class of the model.
model_type = (
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment