From a28a2670fa0213ba0a4a6c369ebdfe5a64b346b1 Mon Sep 17 00:00:00 2001
From: Maja Jablonska <majajjablonska@gmail.com>
Date: Tue, 10 Oct 2023 19:43:38 +1100
Subject: [PATCH] Remove weights path from parameters

---
 combo/main.py          |  2 +-
 combo/modules/model.py | 20 --------------------
 2 files changed, 1 insertion(+), 21 deletions(-)

diff --git a/combo/main.py b/combo/main.py
index 6a1ac4b..36c9f68 100755
--- a/combo/main.py
+++ b/combo/main.py
@@ -12,7 +12,7 @@ from combo.predictors import Predictor
 from combo.training.trainable_combo import TrainableCombo
 from combo.utils import checks
 
-from config import Registry, resolve
+from config import resolve
 from models import ComboModel
 from predict import COMBO
 
diff --git a/combo/modules/model.py b/combo/modules/model.py
index e62bead..694c86d 100644
--- a/combo/modules/model.py
+++ b/combo/modules/model.py
@@ -478,26 +478,6 @@ class Model(Module, FromParameters):
             model.extend_embedder_vocab()
         return model
 
-    @classmethod
-    @overrides
-    def from_parameters(cls,
-                        parameters: Dict[str, Any] = None,
-                        constructor_method_name: str = None,
-                        pass_down_parameters: Dict[str, Any] = None):
-        constructed_model = super().from_parameters(parameters)
-        constructed_model.load_state_dict(torch.load(parameters['weights']))
-        return constructed_model
-
-    @overrides
-    def serialize(self, pass_down_parameter_names: List[str] = None) -> Dict[str, Any]:
-        vocabulary_dir = os.path.join(self.serialization_dir, 'vocabulary')
-        self.vocab._serialization_dir = vocabulary_dir
-        weights_path = os.path.join(self.serialization_dir, 'weights.pth')
-        torch.save(self.state_dict(), weights_path)
-        serialized_model = super().serialize()
-        serialized_model['parameters']['weights'] = weights_path
-        return serialized_model
-
 
 def remove_weights_related_keys_from_params(
     params: Params, keys: List[str] = ["pretrained_file", "initializer"]
-- 
GitLab