From 3c8773237881c4645848e37a606a5a6d1ff2ae06 Mon Sep 17 00:00:00 2001
From: Maja Jablonska <majajjablonska@gmail.com>
Date: Tue, 16 Jan 2024 13:19:30 +0100
Subject: [PATCH] Add feedforward_from_vocab

---
 combo/modules/feedforward_predictor.py | 7 ++++++-
 1 file changed, 6 insertions(+), 1 deletion(-)

diff --git a/combo/modules/feedforward_predictor.py b/combo/modules/feedforward_predictor.py
index 74c457b..e585737 100644
--- a/combo/modules/feedforward_predictor.py
+++ b/combo/modules/feedforward_predictor.py
@@ -11,6 +11,7 @@ from combo.utils import ConfigurationError
 
 
 @Registry.register("feedforward_predictor")
+@Registry.register("feedforward_predictor_from_vocab", constructor_method="from_vocab")
 class FeedForwardPredictor(Predictor):
     """Feedforward predictor. Should be used on top of Seq2Seq encoder."""
 
@@ -56,6 +57,7 @@ class FeedForwardPredictor(Predictor):
         return loss.sum() / valid_positions
 
     @classmethod
+    @register_arguments
     def from_vocab(cls,
                    vocabulary: Vocabulary,
                    vocab_namespace: str,
@@ -74,9 +76,12 @@ class FeedForwardPredictor(Predictor):
             f"There is not {vocab_namespace} in created vocabs, check if this field has any values to predict!"
         hidden_dims = hidden_dims + [vocabulary.get_vocab_size(vocab_namespace)]
 
-        return cls(FeedForward(
+        ff_p = cls(FeedForward(
             input_dim=input_dim,
             num_layers=num_layers,
             hidden_dims=hidden_dims,
             activations=activations,
             dropout=dropout))
+
+        ff_p.constructed_from = "from_vocab"
+        return ff_p
\ No newline at end of file
-- 
GitLab