Skip to content
Snippets Groups Projects
Commit a28a2670 authored by Maja Jablonska's avatar Maja Jablonska
Browse files

Remove weights path from parameters

parent 53dd3bd7
Branches
Tags
1 merge request!46Merge COMBO 3.0 into master
......@@ -12,7 +12,7 @@ from combo.predictors import Predictor
from combo.training.trainable_combo import TrainableCombo
from combo.utils import checks
from config import Registry, resolve
from config import resolve
from models import ComboModel
from predict import COMBO
......
......@@ -478,26 +478,6 @@ class Model(Module, FromParameters):
model.extend_embedder_vocab()
return model
@classmethod
@overrides
def from_parameters(cls,
parameters: Dict[str, Any] = None,
constructor_method_name: str = None,
pass_down_parameters: Dict[str, Any] = None):
constructed_model = super().from_parameters(parameters)
constructed_model.load_state_dict(torch.load(parameters['weights']))
return constructed_model
@overrides
def serialize(self, pass_down_parameter_names: List[str] = None) -> Dict[str, Any]:
vocabulary_dir = os.path.join(self.serialization_dir, 'vocabulary')
self.vocab._serialization_dir = vocabulary_dir
weights_path = os.path.join(self.serialization_dir, 'weights.pth')
torch.save(self.state_dict(), weights_path)
serialized_model = super().serialize()
serialized_model['parameters']['weights'] = weights_path
return serialized_model
def remove_weights_related_keys_from_params(
params: Params, keys: List[str] = ["pretrained_file", "initializer"]
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment