Skip to content
Snippets Groups Projects
Commit 5fae577a authored by Martyna Wiącek's avatar Martyna Wiącek
Browse files

fixed name of gelu function

parent 25ccba60
Branches
1 merge request!47Fixed multiword prediction + bug that made the code write empty predictions
...@@ -19,7 +19,7 @@ from combo.modules.morpho import MorphologicalFeatures ...@@ -19,7 +19,7 @@ from combo.modules.morpho import MorphologicalFeatures
from combo.modules.parser import DependencyRelationModel, HeadPredictionModel from combo.modules.parser import DependencyRelationModel, HeadPredictionModel
from combo.modules.text_field_embedders import BasicTextFieldEmbedder from combo.modules.text_field_embedders import BasicTextFieldEmbedder
from combo.modules.token_embedders import CharacterBasedWordEmbedder, TransformersWordEmbedder from combo.modules.token_embedders import CharacterBasedWordEmbedder, TransformersWordEmbedder
from combo.nn.activations import ReLUActivation, TanhActivation, LinearActivation from combo.nn.activations import ReLUActivation, TanhActivation, LinearActivation, GELUActivation
from combo.modules import FeedForwardPredictor from combo.modules import FeedForwardPredictor
from combo.nn.base import Linear from combo.nn.base import Linear
from combo.nn.regularizers.regularizers import L2Regularizer from combo.nn.regularizers.regularizers import L2Regularizer
...@@ -128,10 +128,10 @@ def default_model(pretrained_transformer_name: str, vocabulary: Vocabulary) -> C ...@@ -128,10 +128,10 @@ def default_model(pretrained_transformer_name: str, vocabulary: Vocabulary) -> C
), ),
lemmatizer=LemmatizerModel( lemmatizer=LemmatizerModel(
vocabulary=vocabulary, vocabulary=vocabulary,
activations=[ReLUActivation(), ReLUActivation(), ReLUActivation(), LinearActivation()], activations=[GELUActivation(), GELUActivation(), GELUActivation(), LinearActivation()],
char_vocab_namespace="token_characters", char_vocab_namespace="token_characters",
dilation=[1, 2, 4, 1], dilation=[1, 2, 4, 1],
embedding_dim=256, embedding_dim=300,
filters=[256, 256, 256], filters=[256, 256, 256],
input_projection_layer=Linear( input_projection_layer=Linear(
activation=TanhActivation(), activation=TanhActivation(),
...@@ -183,7 +183,7 @@ def default_model(pretrained_transformer_name: str, vocabulary: Vocabulary) -> C ...@@ -183,7 +183,7 @@ def default_model(pretrained_transformer_name: str, vocabulary: Vocabulary) -> C
"char": CharacterBasedWordEmbedder( "char": CharacterBasedWordEmbedder(
vocabulary=vocabulary, vocabulary=vocabulary,
dilated_cnn_encoder=DilatedCnnEncoder( dilated_cnn_encoder=DilatedCnnEncoder(
activations=[ReLUActivation(), ReLUActivation(), LinearActivation()], activations=[GELUActivation(), GELUActivation(), LinearActivation()],
dilation=[1, 2, 4], dilation=[1, 2, 4],
filters=[512, 256, 64], filters=[512, 256, 64],
input_dim=64, input_dim=64,
......
...@@ -39,7 +39,7 @@ class ReLUActivation(Activation): ...@@ -39,7 +39,7 @@ class ReLUActivation(Activation):
@Registry.register('gelu') @Registry.register('gelu')
class ReLUActivation(Activation): class GELUActivation(Activation):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.__torch_activation = torch.nn.GELU() self.__torch_activation = torch.nn.GELU()
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment