Skip to content
Snippets Groups Projects
Commit 84cb75fb authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski
Browse files

Add freezing pretrained transformer.

parent 7eb1302a
Branches
Tags
No related merge requests found
......@@ -117,18 +117,21 @@ class TransformersWordEmbedder(token_embedders.PretrainedTransformerMismatchedEm
model_name: str,
projection_dim: int,
projection_activation: Optional[allen_nn.Activation] = lambda x: x,
projection_dropout_rate: Optional[float] = 0.0):
projection_dropout_rate: Optional[float] = 0.0,
freeze_transformer: bool = True):
super().__init__(model_name)
self.transformers_encoder = modeling_auto.AutoModel.from_pretrained(model_name)
self.output_dim = self.transformers_encoder.config.hidden_size
if freeze_transformer:
for param in self._matched_embedder.parameters():
param.requires_grad = False
if projection_dim:
self.projection_layer = base.Linear(in_features=self.output_dim,
self.projection_layer = base.Linear(in_features=super().get_output_dim(),
out_features=projection_dim,
dropout_rate=projection_dropout_rate,
activation=projection_activation)
self.output_dim = projection_dim
else:
self.projection_layer = None
self.output_dim = super().get_output_dim()
def forward(
self,
......@@ -148,3 +151,11 @@ class TransformersWordEmbedder(token_embedders.PretrainedTransformerMismatchedEm
@overrides
def get_output_dim(self):
return self.output_dim
@overrides
def train(self, mode: bool):
self.projection_layer.train(mode)
@overrides
def eval(self):
self.projection_layer.eval()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment