diff --git a/combo/config/from_parameters.py b/combo/config/from_parameters.py index 4fb1b6b0ddfea399304f0507d4b0fb558c8f7788..e601fc53671cca61f62353b965f7d516af832941 100644 --- a/combo/config/from_parameters.py +++ b/combo/config/from_parameters.py @@ -162,10 +162,6 @@ class FromParameters: return {'type': Registry.get_class_name(type(self), constructor_method), 'parameters': self._to_params(pass_down_parameter_names)} - def save(self, path: str, pass_down_parameters: Dict[str, Any] = None): - with open(path, 'wb') as f: - json.dump(self._to_params(pass_down_parameters), f) - def resolve(parameters: Dict[str, Any], pass_down_parameters: Dict[str, Any] = None) -> Any: pass_down_parameters = pass_down_parameters or {} diff --git a/combo/modules/model.py b/combo/modules/model.py index 7a4d9e437d8be69c03bcfaa6e74d35c57bca6753..c3da4bdc158905d9c61742cb198bb4aa9033b33d 100644 --- a/combo/modules/model.py +++ b/combo/modules/model.py @@ -2,7 +2,7 @@ Adapted from AllenNLP https://github.com/allenai/allennlp/blob/main/allennlp/models/model.py """ - +import json import logging import os import re @@ -299,6 +299,28 @@ class Model(Module, FromParameters): # so we set this to false so we don't warn again. self._warn_for_unseparable_batches.add(output_key) + def save(self, + serialization_dir: Optional[Union[str, PathLike]] = None, + weights_file: Optional[Union[str, PathLike]] = None): + if serialization_dir is None: + serialization_dir = self.serialization_dir + weights_file = weights_file or os.path.join(serialization_dir, _DEFAULT_WEIGHTS) + vocab_serialization_dir = os.path.join(serialization_dir, "vocabulary") + self.vocab.save_to_files(vocab_serialization_dir) + serialized = {} + serialized['vocabulary'] = { + 'type': 'from_files_vocabulary', + 'parameters': { + 'directory': vocab_serialization_dir, + 'padding_token': self.vocab._padding_token, + 'oov_token': self.vocab._oov_token + } + } + serialized['model'] = self.serialize() + torch.save(self.state_dict(), weights_file) + with open(os.path.join(serialization_dir, 'params.json'), 'w') as f: + json.dump(serialized, f) + @classmethod def _load( cls,