From c9c0e3c7d757615f849f0ffc4f69655a11d79a5b Mon Sep 17 00:00:00 2001 From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com> Date: Mon, 4 Jan 2021 10:54:41 +0100 Subject: [PATCH] Simplified transformer word embedder and fix validation loss. --- combo/models/embeddings.py | 24 ++++-------------------- combo/models/parser.py | 2 +- combo/training/trainer.py | 13 +++++++++---- 3 files changed, 14 insertions(+), 25 deletions(-) diff --git a/combo/models/embeddings.py b/combo/models/embeddings.py index 6ad2559..49f1ab9 100644 --- a/combo/models/embeddings.py +++ b/combo/models/embeddings.py @@ -113,12 +113,10 @@ class TransformersWordEmbedder(token_embedders.PretrainedTransformerMismatchedEm freeze_transformer: bool = True, tokenizer_kwargs: Optional[Dict[str, Any]] = None, transformer_kwargs: Optional[Dict[str, Any]] = None): - super().__init__(model_name, tokenizer_kwargs=tokenizer_kwargs, transformer_kwargs=transformer_kwargs) - self.freeze_transformer = freeze_transformer - if self.freeze_transformer: - self._matched_embedder.eval() - for param in self._matched_embedder.parameters(): - param.requires_grad = False + super().__init__(model_name, + train_parameters=not freeze_transformer, + tokenizer_kwargs=tokenizer_kwargs, + transformer_kwargs=transformer_kwargs) if projection_dim: self.projection_layer = base.Linear(in_features=super().get_output_dim(), out_features=projection_dim, @@ -148,20 +146,6 @@ class TransformersWordEmbedder(token_embedders.PretrainedTransformerMismatchedEm def get_output_dim(self): return self.output_dim - @overrides - def train(self, mode: bool): - if self.freeze_transformer: - self.projection_layer.train(mode) - else: - super().train(mode) - - @overrides - def eval(self): - if self.freeze_transformer: - self.projection_layer.eval() - else: - super().eval() - @token_embedders.TokenEmbedder.register("feats_embedding") class FeatsTokenEmbedder(token_embedders.Embedding): diff --git a/combo/models/parser.py b/combo/models/parser.py index dfb53ab..511edff 100644 --- a/combo/models/parser.py +++ b/combo/models/parser.py @@ -158,7 +158,7 @@ class DependencyRelationModel(base.Predictor): output["prediction"] = (relation_prediction.argmax(-1)[:, 1:], head_output["prediction"]) else: # Mask root label whenever head is not 0. - relation_prediction_output = relation_prediction[:, 1:] + relation_prediction_output = relation_prediction[:, 1:].clone() mask = (head_output["prediction"] == 0) vocab_size = relation_prediction_output.size(-1) root_idx = torch.tensor([self.root_idx], device=device) diff --git a/combo/training/trainer.py b/combo/training/trainer.py index aeb9f09..26bd75f 100644 --- a/combo/training/trainer.py +++ b/combo/training/trainer.py @@ -230,22 +230,24 @@ class GradientDescentTrainer(training.GradientDescentTrainer): patience: int = None, validation_metric: str = "-loss", num_epochs: int = 20, - cuda_device: int = -1, + cuda_device: Optional[Union[int, torch.device]] = -1, grad_norm: float = None, grad_clipping: float = None, distributed: bool = None, world_size: int = 1, num_gradient_accumulation_steps: int = 1, - opt_level: Optional[str] = None, use_amp: bool = False, - optimizer: common.Lazy[optimizers.Optimizer] = None, + no_grad: List[str] = None, + optimizer: common.Lazy[optimizers.Optimizer] = common.Lazy(optimizers.Optimizer.default), learning_rate_scheduler: common.Lazy[learning_rate_schedulers.LearningRateScheduler] = None, momentum_scheduler: common.Lazy[momentum_schedulers.MomentumScheduler] = None, tensorboard_writer: common.Lazy[allen_tensorboard_writer.TensorboardWriter] = None, moving_average: common.Lazy[moving_average.MovingAverage] = None, - checkpointer: common.Lazy[training.Checkpointer] = None, + checkpointer: common.Lazy[training.Checkpointer] = common.Lazy(training.Checkpointer), batch_callbacks: List[training.BatchCallback] = None, epoch_callbacks: List[training.EpochCallback] = None, + end_callbacks: List[training.EpochCallback] = None, + trainer_callbacks: List[training.TrainerCallback] = None, ) -> "training.Trainer": if tensorboard_writer is None: tensorboard_writer = common.Lazy(combo_tensorboard_writer.NullTensorboardWriter) @@ -265,6 +267,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer): world_size=world_size, num_gradient_accumulation_steps=num_gradient_accumulation_steps, use_amp=use_amp, + no_grad=no_grad, optimizer=optimizer, learning_rate_scheduler=learning_rate_scheduler, momentum_scheduler=momentum_scheduler, @@ -273,4 +276,6 @@ class GradientDescentTrainer(training.GradientDescentTrainer): checkpointer=checkpointer, batch_callbacks=batch_callbacks, epoch_callbacks=epoch_callbacks, + end_callbacks=end_callbacks, + trainer_callbacks=trainer_callbacks, ) -- GitLab