From 3c8773237881c4645848e37a606a5a6d1ff2ae06 Mon Sep 17 00:00:00 2001 From: Maja Jablonska <majajjablonska@gmail.com> Date: Tue, 16 Jan 2024 13:19:30 +0100 Subject: [PATCH] Add feedforward_from_vocab --- combo/modules/feedforward_predictor.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/combo/modules/feedforward_predictor.py b/combo/modules/feedforward_predictor.py index 74c457b..e585737 100644 --- a/combo/modules/feedforward_predictor.py +++ b/combo/modules/feedforward_predictor.py @@ -11,6 +11,7 @@ from combo.utils import ConfigurationError @Registry.register("feedforward_predictor") +@Registry.register("feedforward_predictor_from_vocab", constructor_method="from_vocab") class FeedForwardPredictor(Predictor): """Feedforward predictor. Should be used on top of Seq2Seq encoder.""" @@ -56,6 +57,7 @@ class FeedForwardPredictor(Predictor): return loss.sum() / valid_positions @classmethod + @register_arguments def from_vocab(cls, vocabulary: Vocabulary, vocab_namespace: str, @@ -74,9 +76,12 @@ class FeedForwardPredictor(Predictor): f"There is not {vocab_namespace} in created vocabs, check if this field has any values to predict!" hidden_dims = hidden_dims + [vocabulary.get_vocab_size(vocab_namespace)] - return cls(FeedForward( + ff_p = cls(FeedForward( input_dim=input_dim, num_layers=num_layers, hidden_dims=hidden_dims, activations=activations, dropout=dropout)) + + ff_p.constructed_from = "from_vocab" + return ff_p \ No newline at end of file -- GitLab