Skip to content
Snippets Groups Projects
Commit a36bed8b authored by Maja Jablonska's avatar Maja Jablonska
Browse files

Fix archival.py serialization paths

parent 5c1574ce
Branches
Tags
1 merge request!46Merge COMBO 3.0 into master
......@@ -75,7 +75,7 @@ def load_archive(url_or_filename: Union[PathLike, str],
)
model = Model.load(archive_file, cuda_device=cuda_device)
with open(os.path.join(archive_file, 'model/config.json'), 'r') as f:
with open(os.path.join(archive_file, 'config.json'), 'r') as f:
config = json.load(f)
data_loader, validation_data_loader = None, None
......
......@@ -305,7 +305,7 @@ class Model(Module, FromParameters):
if serialization_dir is None:
serialization_dir = self.serialization_dir
weights_file = weights_file or os.path.join(serialization_dir, _DEFAULT_WEIGHTS)
vocab_serialization_dir = os.path.join(serialization_dir, "model/vocabulary")
vocab_serialization_dir = os.path.join(serialization_dir, "vocabulary")
self.vocab.save_to_files(vocab_serialization_dir)
serialized = {}
serialized['vocabulary'] = {
......@@ -318,7 +318,7 @@ class Model(Module, FromParameters):
}
serialized['model'] = self.serialize()
torch.save(self.state_dict(), weights_file)
with open(os.path.join(serialization_dir, 'model/params.json'), 'w') as f:
with open(os.path.join(serialization_dir, 'params.json'), 'w') as f:
json.dump(serialized, f)
@classmethod
......@@ -336,7 +336,7 @@ class Model(Module, FromParameters):
weights_file = weights_file or os.path.join(serialization_dir, _DEFAULT_WEIGHTS)
# Load vocabulary from file
vocab_dir = os.path.join(serialization_dir, "model/vocabulary")
vocab_dir = os.path.join(serialization_dir, "vocabulary")
# If the config specifies a vocabulary subclass, we need to use it.
vocab_params = config.get("vocabulary")
if vocab_params['type'] == 'from_files_vocabulary':
......@@ -436,13 +436,13 @@ class Model(Module, FromParameters):
vocabulary and the trained weights.
"""
if config is None:
with open(os.path.join(serialization_dir, 'model/config.json'), 'r') as f:
with open(os.path.join(serialization_dir, 'config.json'), 'r') as f:
config = json.load(f)
elif isinstance(config, str) or isinstance(config, PathLike):
with open(config, 'r') as f:
config = json.load(f)
weights_file = weights_file or os.path.join(serialization_dir, 'model/weights.th')
weights_file = weights_file or os.path.join(serialization_dir, 'weights.th')
# Peak at the class of the model.
model_type = (
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment