Skip to content
Snippets Groups Projects
Commit c9c0e3c7 authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski
Browse files

Simplified transformer word embedder and fix validation loss.

parent 847f3a90
No related branches found
No related tags found
2 merge requests!13Refactor merge develop to master,!12Refactor
...@@ -113,12 +113,10 @@ class TransformersWordEmbedder(token_embedders.PretrainedTransformerMismatchedEm ...@@ -113,12 +113,10 @@ class TransformersWordEmbedder(token_embedders.PretrainedTransformerMismatchedEm
freeze_transformer: bool = True, freeze_transformer: bool = True,
tokenizer_kwargs: Optional[Dict[str, Any]] = None, tokenizer_kwargs: Optional[Dict[str, Any]] = None,
transformer_kwargs: Optional[Dict[str, Any]] = None): transformer_kwargs: Optional[Dict[str, Any]] = None):
super().__init__(model_name, tokenizer_kwargs=tokenizer_kwargs, transformer_kwargs=transformer_kwargs) super().__init__(model_name,
self.freeze_transformer = freeze_transformer train_parameters=not freeze_transformer,
if self.freeze_transformer: tokenizer_kwargs=tokenizer_kwargs,
self._matched_embedder.eval() transformer_kwargs=transformer_kwargs)
for param in self._matched_embedder.parameters():
param.requires_grad = False
if projection_dim: if projection_dim:
self.projection_layer = base.Linear(in_features=super().get_output_dim(), self.projection_layer = base.Linear(in_features=super().get_output_dim(),
out_features=projection_dim, out_features=projection_dim,
...@@ -148,20 +146,6 @@ class TransformersWordEmbedder(token_embedders.PretrainedTransformerMismatchedEm ...@@ -148,20 +146,6 @@ class TransformersWordEmbedder(token_embedders.PretrainedTransformerMismatchedEm
def get_output_dim(self): def get_output_dim(self):
return self.output_dim return self.output_dim
@overrides
def train(self, mode: bool):
if self.freeze_transformer:
self.projection_layer.train(mode)
else:
super().train(mode)
@overrides
def eval(self):
if self.freeze_transformer:
self.projection_layer.eval()
else:
super().eval()
@token_embedders.TokenEmbedder.register("feats_embedding") @token_embedders.TokenEmbedder.register("feats_embedding")
class FeatsTokenEmbedder(token_embedders.Embedding): class FeatsTokenEmbedder(token_embedders.Embedding):
......
...@@ -158,7 +158,7 @@ class DependencyRelationModel(base.Predictor): ...@@ -158,7 +158,7 @@ class DependencyRelationModel(base.Predictor):
output["prediction"] = (relation_prediction.argmax(-1)[:, 1:], head_output["prediction"]) output["prediction"] = (relation_prediction.argmax(-1)[:, 1:], head_output["prediction"])
else: else:
# Mask root label whenever head is not 0. # Mask root label whenever head is not 0.
relation_prediction_output = relation_prediction[:, 1:] relation_prediction_output = relation_prediction[:, 1:].clone()
mask = (head_output["prediction"] == 0) mask = (head_output["prediction"] == 0)
vocab_size = relation_prediction_output.size(-1) vocab_size = relation_prediction_output.size(-1)
root_idx = torch.tensor([self.root_idx], device=device) root_idx = torch.tensor([self.root_idx], device=device)
......
...@@ -230,22 +230,24 @@ class GradientDescentTrainer(training.GradientDescentTrainer): ...@@ -230,22 +230,24 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
patience: int = None, patience: int = None,
validation_metric: str = "-loss", validation_metric: str = "-loss",
num_epochs: int = 20, num_epochs: int = 20,
cuda_device: int = -1, cuda_device: Optional[Union[int, torch.device]] = -1,
grad_norm: float = None, grad_norm: float = None,
grad_clipping: float = None, grad_clipping: float = None,
distributed: bool = None, distributed: bool = None,
world_size: int = 1, world_size: int = 1,
num_gradient_accumulation_steps: int = 1, num_gradient_accumulation_steps: int = 1,
opt_level: Optional[str] = None,
use_amp: bool = False, use_amp: bool = False,
optimizer: common.Lazy[optimizers.Optimizer] = None, no_grad: List[str] = None,
optimizer: common.Lazy[optimizers.Optimizer] = common.Lazy(optimizers.Optimizer.default),
learning_rate_scheduler: common.Lazy[learning_rate_schedulers.LearningRateScheduler] = None, learning_rate_scheduler: common.Lazy[learning_rate_schedulers.LearningRateScheduler] = None,
momentum_scheduler: common.Lazy[momentum_schedulers.MomentumScheduler] = None, momentum_scheduler: common.Lazy[momentum_schedulers.MomentumScheduler] = None,
tensorboard_writer: common.Lazy[allen_tensorboard_writer.TensorboardWriter] = None, tensorboard_writer: common.Lazy[allen_tensorboard_writer.TensorboardWriter] = None,
moving_average: common.Lazy[moving_average.MovingAverage] = None, moving_average: common.Lazy[moving_average.MovingAverage] = None,
checkpointer: common.Lazy[training.Checkpointer] = None, checkpointer: common.Lazy[training.Checkpointer] = common.Lazy(training.Checkpointer),
batch_callbacks: List[training.BatchCallback] = None, batch_callbacks: List[training.BatchCallback] = None,
epoch_callbacks: List[training.EpochCallback] = None, epoch_callbacks: List[training.EpochCallback] = None,
end_callbacks: List[training.EpochCallback] = None,
trainer_callbacks: List[training.TrainerCallback] = None,
) -> "training.Trainer": ) -> "training.Trainer":
if tensorboard_writer is None: if tensorboard_writer is None:
tensorboard_writer = common.Lazy(combo_tensorboard_writer.NullTensorboardWriter) tensorboard_writer = common.Lazy(combo_tensorboard_writer.NullTensorboardWriter)
...@@ -265,6 +267,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer): ...@@ -265,6 +267,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
world_size=world_size, world_size=world_size,
num_gradient_accumulation_steps=num_gradient_accumulation_steps, num_gradient_accumulation_steps=num_gradient_accumulation_steps,
use_amp=use_amp, use_amp=use_amp,
no_grad=no_grad,
optimizer=optimizer, optimizer=optimizer,
learning_rate_scheduler=learning_rate_scheduler, learning_rate_scheduler=learning_rate_scheduler,
momentum_scheduler=momentum_scheduler, momentum_scheduler=momentum_scheduler,
...@@ -273,4 +276,6 @@ class GradientDescentTrainer(training.GradientDescentTrainer): ...@@ -273,4 +276,6 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
checkpointer=checkpointer, checkpointer=checkpointer,
batch_callbacks=batch_callbacks, batch_callbacks=batch_callbacks,
epoch_callbacks=epoch_callbacks, epoch_callbacks=epoch_callbacks,
end_callbacks=end_callbacks,
trainer_callbacks=trainer_callbacks,
) )
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment