Skip to content
Snippets Groups Projects

Refactor

Merged Mateusz Klimaszewski requested to merge refactor into develop
Compare and
5 files
+ 67
75
Compare changes
  • Side-by-side
  • Inline
Files
5
+ 5
21
@@ -107,18 +107,16 @@ class TransformersWordEmbedder(token_embedders.PretrainedTransformerMismatchedEm
def __init__(self,
model_name: str,
projection_dim: int,
projection_dim: int = 0,
projection_activation: Optional[allen_nn.Activation] = lambda x: x,
projection_dropout_rate: Optional[float] = 0.0,
freeze_transformer: bool = True,
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
transformer_kwargs: Optional[Dict[str, Any]] = None):
super().__init__(model_name, tokenizer_kwargs=tokenizer_kwargs, transformer_kwargs=transformer_kwargs)
self.freeze_transformer = freeze_transformer
if self.freeze_transformer:
self._matched_embedder.eval()
for param in self._matched_embedder.parameters():
param.requires_grad = False
super().__init__(model_name,
train_parameters=not freeze_transformer,
tokenizer_kwargs=tokenizer_kwargs,
transformer_kwargs=transformer_kwargs)
if projection_dim:
self.projection_layer = base.Linear(in_features=super().get_output_dim(),
out_features=projection_dim,
@@ -148,20 +146,6 @@ class TransformersWordEmbedder(token_embedders.PretrainedTransformerMismatchedEm
def get_output_dim(self):
return self.output_dim
@overrides
def train(self, mode: bool):
if self.freeze_transformer:
self.projection_layer.train(mode)
else:
super().train(mode)
@overrides
def eval(self):
if self.freeze_transformer:
self.projection_layer.eval()
else:
super().eval()
@token_embedders.TokenEmbedder.register("feats_embedding")
class FeatsTokenEmbedder(token_embedders.Embedding):