From 7ef46307145785a93433acf2127ed766ea508838 Mon Sep 17 00:00:00 2001 From: Maja Jablonska <majajjablonska@gmail.com> Date: Thu, 18 Jan 2024 21:50:01 +0100 Subject: [PATCH] Fix classmethod registered arguments --- combo/data/vocabulary.py | 24 ++++++++++++++++++------ combo/modules/feedforward_predictor.py | 15 +++++++++++---- setup.py | 2 +- tests/config/test_archive.py | 2 +- 4 files changed, 31 insertions(+), 12 deletions(-) diff --git a/combo/data/vocabulary.py b/combo/data/vocabulary.py index 4f3565c..153ed73 100644 --- a/combo/data/vocabulary.py +++ b/combo/data/vocabulary.py @@ -118,7 +118,6 @@ class _NamespaceDependentDefaultDict(defaultdict[str, NamespaceVocabulary]): @Registry.register("from_files_vocabulary", "from_files") @Registry.register("from_pretrained_transformer_vocabulary", "from_pretrained_transformer") @Registry.register("from_data_loader_vocabulary", "from_data_loader") -@Registry.register("from_pretrained_transformer_and_instances_vocabulary", "from_pretrained_transformer_and_instances") @Registry.register("from_data_loader_extended_vocabulary", "from_data_loader_extended") class Vocabulary(FromParameters): @register_arguments @@ -239,7 +238,6 @@ class Vocabulary(FromParameters): self._vocab[namespace].append_tokens(tokens) @classmethod - @register_arguments def from_files(cls, directory: Union[str, os.PathLike], padding_token: Optional[str] = DEFAULT_PADDING_TOKEN, @@ -308,10 +306,14 @@ class Vocabulary(FromParameters): vocab.slices = json.load(slices_file) vocab.constructed_from = 'from_files' + vocab.constructed_args = { + "directory": directory.split("/")[-1], + "padding_token": padding_token, + "oov_token": oov_token + } return vocab @classmethod - @register_arguments def from_data_loader( cls, data_loader: "DataLoader", @@ -339,7 +341,19 @@ class Vocabulary(FromParameters): oov_token=oov_token, serialization_dir=serialization_dir ) - vocab.constructed_from = 'from_dataset_loader' + vocab.constructed_args = { + "data_loader": data_loader, + "min_count": min_count, + "max_vocab_size": max_vocab_size, + "non_padded_namespaces": non_padded_namespaces, + "pretrained_files": pretrained_files, + "only_include_pretrained_words": only_include_pretrained_words, + "tokens_to_add": tokens_to_add, + "min_pretrained_embeddings": min_pretrained_embeddings, + "padding_token": padding_token, + "oov_token": oov_token, + "serialization_dir": serialization_dir + } return vocab @classmethod @@ -389,7 +403,6 @@ class Vocabulary(FromParameters): return vocab @classmethod - @register_arguments def from_files_and_instances( cls, instances: Iterable["Instance"], @@ -431,7 +444,6 @@ class Vocabulary(FromParameters): return vocab @classmethod - @register_arguments def from_pretrained_transformer_and_instances( cls, instances: Iterable["Instance"], diff --git a/combo/modules/feedforward_predictor.py b/combo/modules/feedforward_predictor.py index e585737..0e90000 100644 --- a/combo/modules/feedforward_predictor.py +++ b/combo/modules/feedforward_predictor.py @@ -57,7 +57,6 @@ class FeedForwardPredictor(Predictor): return loss.sum() / valid_positions @classmethod - @register_arguments def from_vocab(cls, vocabulary: Vocabulary, vocab_namespace: str, @@ -74,14 +73,22 @@ class FeedForwardPredictor(Predictor): assert vocab_namespace in vocabulary.get_namespaces(), \ f"There is not {vocab_namespace} in created vocabs, check if this field has any values to predict!" - hidden_dims = hidden_dims + [vocabulary.get_vocab_size(vocab_namespace)] + hidden_dims_w_vocab_added = hidden_dims + [vocabulary.get_vocab_size(vocab_namespace)] ff_p = cls(FeedForward( input_dim=input_dim, num_layers=num_layers, - hidden_dims=hidden_dims, + hidden_dims=hidden_dims_w_vocab_added, activations=activations, dropout=dropout)) ff_p.constructed_from = "from_vocab" - return ff_p \ No newline at end of file + ff_p.constructed_args = { + "vocab_namespace": vocab_namespace, + "input_dim": input_dim, + "num_layers": num_layers, + "hidden_dims": hidden_dims, + "activations": activations, + "dropout": dropout + } + return ff_p diff --git a/setup.py b/setup.py index 50f4c06..723b582 100644 --- a/setup.py +++ b/setup.py @@ -27,7 +27,7 @@ REQUIREMENTS = [ ] setup( - name="combo", + name="combo-nlp", version="3.0.3", author="Maja Jablonska", author_email="maja.jablonska@ipipan.waw.pl", diff --git a/tests/config/test_archive.py b/tests/config/test_archive.py index 36aebab..fa17b0d 100644 --- a/tests/config/test_archive.py +++ b/tests/config/test_archive.py @@ -23,4 +23,4 @@ class ArchivalTest(unittest.TestCase): with TemporaryDirectory(TEMP_FILE_PATH) as t: archive(model, t) loaded_model = ComboModel.from_archive(os.path.join(t, 'model.tar.gz')) - self.assertDictEqual(loaded_model.serialize(), model.serialize()) + self.assertDictEqual(loaded_model.serialize(['vocabulary']), model.serialize(['vocabulary'])) -- GitLab