Skip to content
Snippets Groups Projects
Commit a28a2670 authored by Maja Jablonska's avatar Maja Jablonska
Browse files

Remove weights path from parameters

parent 53dd3bd7
No related branches found
No related tags found
1 merge request!46Merge COMBO 3.0 into master
......@@ -12,7 +12,7 @@ from combo.predictors import Predictor
from combo.training.trainable_combo import TrainableCombo
from combo.utils import checks
from config import Registry, resolve
from config import resolve
from models import ComboModel
from predict import COMBO
......
......@@ -478,26 +478,6 @@ class Model(Module, FromParameters):
model.extend_embedder_vocab()
return model
@classmethod
@overrides
def from_parameters(cls,
parameters: Dict[str, Any] = None,
constructor_method_name: str = None,
pass_down_parameters: Dict[str, Any] = None):
constructed_model = super().from_parameters(parameters)
constructed_model.load_state_dict(torch.load(parameters['weights']))
return constructed_model
@overrides
def serialize(self, pass_down_parameter_names: List[str] = None) -> Dict[str, Any]:
vocabulary_dir = os.path.join(self.serialization_dir, 'vocabulary')
self.vocab._serialization_dir = vocabulary_dir
weights_path = os.path.join(self.serialization_dir, 'weights.pth')
torch.save(self.state_dict(), weights_path)
serialized_model = super().serialize()
serialized_model['parameters']['weights'] = weights_path
return serialized_model
def remove_weights_related_keys_from_params(
params: Params, keys: List[str] = ["pretrained_file", "initializer"]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment