diff --git a/combo/modules/feedforward_predictor.py b/combo/modules/feedforward_predictor.py index 74c457bd6fb4218ab39a4172caffb506101d07c5..e5857373c1879e0f856121daae42c6fd77a0f541 100644 --- a/combo/modules/feedforward_predictor.py +++ b/combo/modules/feedforward_predictor.py @@ -11,6 +11,7 @@ from combo.utils import ConfigurationError @Registry.register("feedforward_predictor") +@Registry.register("feedforward_predictor_from_vocab", constructor_method="from_vocab") class FeedForwardPredictor(Predictor): """Feedforward predictor. Should be used on top of Seq2Seq encoder.""" @@ -56,6 +57,7 @@ class FeedForwardPredictor(Predictor): return loss.sum() / valid_positions @classmethod + @register_arguments def from_vocab(cls, vocabulary: Vocabulary, vocab_namespace: str, @@ -74,9 +76,12 @@ class FeedForwardPredictor(Predictor): f"There is not {vocab_namespace} in created vocabs, check if this field has any values to predict!" hidden_dims = hidden_dims + [vocabulary.get_vocab_size(vocab_namespace)] - return cls(FeedForward( + ff_p = cls(FeedForward( input_dim=input_dim, num_layers=num_layers, hidden_dims=hidden_dims, activations=activations, dropout=dropout)) + + ff_p.constructed_from = "from_vocab" + return ff_p \ No newline at end of file