From 7ef46307145785a93433acf2127ed766ea508838 Mon Sep 17 00:00:00 2001
From: Maja Jablonska <majajjablonska@gmail.com>
Date: Thu, 18 Jan 2024 21:50:01 +0100
Subject: [PATCH] Fix classmethod registered arguments

---
 combo/data/vocabulary.py               | 24 ++++++++++++++++++------
 combo/modules/feedforward_predictor.py | 15 +++++++++++----
 setup.py                               |  2 +-
 tests/config/test_archive.py           |  2 +-
 4 files changed, 31 insertions(+), 12 deletions(-)

diff --git a/combo/data/vocabulary.py b/combo/data/vocabulary.py
index 4f3565c..153ed73 100644
--- a/combo/data/vocabulary.py
+++ b/combo/data/vocabulary.py
@@ -118,7 +118,6 @@ class _NamespaceDependentDefaultDict(defaultdict[str, NamespaceVocabulary]):
 @Registry.register("from_files_vocabulary", "from_files")
 @Registry.register("from_pretrained_transformer_vocabulary", "from_pretrained_transformer")
 @Registry.register("from_data_loader_vocabulary", "from_data_loader")
-@Registry.register("from_pretrained_transformer_and_instances_vocabulary", "from_pretrained_transformer_and_instances")
 @Registry.register("from_data_loader_extended_vocabulary", "from_data_loader_extended")
 class Vocabulary(FromParameters):
     @register_arguments
@@ -239,7 +238,6 @@ class Vocabulary(FromParameters):
             self._vocab[namespace].append_tokens(tokens)
 
     @classmethod
-    @register_arguments
     def from_files(cls,
                    directory: Union[str, os.PathLike],
                    padding_token: Optional[str] = DEFAULT_PADDING_TOKEN,
@@ -308,10 +306,14 @@ class Vocabulary(FromParameters):
                 vocab.slices = json.load(slices_file)
 
         vocab.constructed_from = 'from_files'
+        vocab.constructed_args = {
+            "directory": directory.split("/")[-1],
+            "padding_token": padding_token,
+            "oov_token": oov_token
+        }
         return vocab
 
     @classmethod
-    @register_arguments
     def from_data_loader(
             cls,
             data_loader: "DataLoader",
@@ -339,7 +341,19 @@ class Vocabulary(FromParameters):
             oov_token=oov_token,
             serialization_dir=serialization_dir
         )
-        vocab.constructed_from = 'from_dataset_loader'
+        vocab.constructed_args = {
+            "data_loader": data_loader,
+            "min_count": min_count,
+            "max_vocab_size": max_vocab_size,
+            "non_padded_namespaces": non_padded_namespaces,
+            "pretrained_files": pretrained_files,
+            "only_include_pretrained_words": only_include_pretrained_words,
+            "tokens_to_add": tokens_to_add,
+            "min_pretrained_embeddings": min_pretrained_embeddings,
+            "padding_token": padding_token,
+            "oov_token": oov_token,
+            "serialization_dir": serialization_dir
+        }
         return vocab
 
     @classmethod
@@ -389,7 +403,6 @@ class Vocabulary(FromParameters):
         return vocab
 
     @classmethod
-    @register_arguments
     def from_files_and_instances(
             cls,
             instances: Iterable["Instance"],
@@ -431,7 +444,6 @@ class Vocabulary(FromParameters):
         return vocab
 
     @classmethod
-    @register_arguments
     def from_pretrained_transformer_and_instances(
             cls,
             instances: Iterable["Instance"],
diff --git a/combo/modules/feedforward_predictor.py b/combo/modules/feedforward_predictor.py
index e585737..0e90000 100644
--- a/combo/modules/feedforward_predictor.py
+++ b/combo/modules/feedforward_predictor.py
@@ -57,7 +57,6 @@ class FeedForwardPredictor(Predictor):
         return loss.sum() / valid_positions
 
     @classmethod
-    @register_arguments
     def from_vocab(cls,
                    vocabulary: Vocabulary,
                    vocab_namespace: str,
@@ -74,14 +73,22 @@ class FeedForwardPredictor(Predictor):
 
         assert vocab_namespace in vocabulary.get_namespaces(), \
             f"There is not {vocab_namespace} in created vocabs, check if this field has any values to predict!"
-        hidden_dims = hidden_dims + [vocabulary.get_vocab_size(vocab_namespace)]
+        hidden_dims_w_vocab_added = hidden_dims + [vocabulary.get_vocab_size(vocab_namespace)]
 
         ff_p = cls(FeedForward(
             input_dim=input_dim,
             num_layers=num_layers,
-            hidden_dims=hidden_dims,
+            hidden_dims=hidden_dims_w_vocab_added,
             activations=activations,
             dropout=dropout))
 
         ff_p.constructed_from = "from_vocab"
-        return ff_p
\ No newline at end of file
+        ff_p.constructed_args = {
+            "vocab_namespace": vocab_namespace,
+            "input_dim": input_dim,
+            "num_layers": num_layers,
+            "hidden_dims": hidden_dims,
+            "activations": activations,
+            "dropout": dropout
+        }
+        return ff_p
diff --git a/setup.py b/setup.py
index 50f4c06..723b582 100644
--- a/setup.py
+++ b/setup.py
@@ -27,7 +27,7 @@ REQUIREMENTS = [
 ]
 
 setup(
-    name="combo",
+    name="combo-nlp",
     version="3.0.3",
     author="Maja Jablonska",
     author_email="maja.jablonska@ipipan.waw.pl",
diff --git a/tests/config/test_archive.py b/tests/config/test_archive.py
index 36aebab..fa17b0d 100644
--- a/tests/config/test_archive.py
+++ b/tests/config/test_archive.py
@@ -23,4 +23,4 @@ class ArchivalTest(unittest.TestCase):
         with TemporaryDirectory(TEMP_FILE_PATH) as t:
             archive(model, t)
             loaded_model = ComboModel.from_archive(os.path.join(t, 'model.tar.gz'))
-        self.assertDictEqual(loaded_model.serialize(), model.serialize())
+        self.assertDictEqual(loaded_model.serialize(['vocabulary']), model.serialize(['vocabulary']))
-- 
GitLab