Skip to content
Snippets Groups Projects
Commit 3c877323 authored by Maja Jablonska's avatar Maja Jablonska
Browse files

Add feedforward_from_vocab

parent d616cd18
Branches
Tags
1 merge request!46Merge COMBO 3.0 into master
......@@ -11,6 +11,7 @@ from combo.utils import ConfigurationError
@Registry.register("feedforward_predictor")
@Registry.register("feedforward_predictor_from_vocab", constructor_method="from_vocab")
class FeedForwardPredictor(Predictor):
"""Feedforward predictor. Should be used on top of Seq2Seq encoder."""
......@@ -56,6 +57,7 @@ class FeedForwardPredictor(Predictor):
return loss.sum() / valid_positions
@classmethod
@register_arguments
def from_vocab(cls,
vocabulary: Vocabulary,
vocab_namespace: str,
......@@ -74,9 +76,12 @@ class FeedForwardPredictor(Predictor):
f"There is not {vocab_namespace} in created vocabs, check if this field has any values to predict!"
hidden_dims = hidden_dims + [vocabulary.get_vocab_size(vocab_namespace)]
return cls(FeedForward(
ff_p = cls(FeedForward(
input_dim=input_dim,
num_layers=num_layers,
hidden_dims=hidden_dims,
activations=activations,
dropout=dropout))
ff_p.constructed_from = "from_vocab"
return ff_p
\ No newline at end of file
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment