From a28a2670fa0213ba0a4a6c369ebdfe5a64b346b1 Mon Sep 17 00:00:00 2001 From: Maja Jablonska <majajjablonska@gmail.com> Date: Tue, 10 Oct 2023 19:43:38 +1100 Subject: [PATCH] Remove weights path from parameters --- combo/main.py | 2 +- combo/modules/model.py | 20 -------------------- 2 files changed, 1 insertion(+), 21 deletions(-) diff --git a/combo/main.py b/combo/main.py index 6a1ac4b..36c9f68 100755 --- a/combo/main.py +++ b/combo/main.py @@ -12,7 +12,7 @@ from combo.predictors import Predictor from combo.training.trainable_combo import TrainableCombo from combo.utils import checks -from config import Registry, resolve +from config import resolve from models import ComboModel from predict import COMBO diff --git a/combo/modules/model.py b/combo/modules/model.py index e62bead..694c86d 100644 --- a/combo/modules/model.py +++ b/combo/modules/model.py @@ -478,26 +478,6 @@ class Model(Module, FromParameters): model.extend_embedder_vocab() return model - @classmethod - @overrides - def from_parameters(cls, - parameters: Dict[str, Any] = None, - constructor_method_name: str = None, - pass_down_parameters: Dict[str, Any] = None): - constructed_model = super().from_parameters(parameters) - constructed_model.load_state_dict(torch.load(parameters['weights'])) - return constructed_model - - @overrides - def serialize(self, pass_down_parameter_names: List[str] = None) -> Dict[str, Any]: - vocabulary_dir = os.path.join(self.serialization_dir, 'vocabulary') - self.vocab._serialization_dir = vocabulary_dir - weights_path = os.path.join(self.serialization_dir, 'weights.pth') - torch.save(self.state_dict(), weights_path) - serialized_model = super().serialize() - serialized_model['parameters']['weights'] = weights_path - return serialized_model - def remove_weights_related_keys_from_params( params: Params, keys: List[str] = ["pretrained_file", "initializer"] -- GitLab