Skip to content
Snippets Groups Projects
Commit c9c0e3c7 authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski
Browse files

Simplified transformer word embedder and fix validation loss.

parent 847f3a90
Branches
Tags
2 merge requests!13Refactor merge develop to master,!12Refactor
......@@ -113,12 +113,10 @@ class TransformersWordEmbedder(token_embedders.PretrainedTransformerMismatchedEm
freeze_transformer: bool = True,
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
transformer_kwargs: Optional[Dict[str, Any]] = None):
super().__init__(model_name, tokenizer_kwargs=tokenizer_kwargs, transformer_kwargs=transformer_kwargs)
self.freeze_transformer = freeze_transformer
if self.freeze_transformer:
self._matched_embedder.eval()
for param in self._matched_embedder.parameters():
param.requires_grad = False
super().__init__(model_name,
train_parameters=not freeze_transformer,
tokenizer_kwargs=tokenizer_kwargs,
transformer_kwargs=transformer_kwargs)
if projection_dim:
self.projection_layer = base.Linear(in_features=super().get_output_dim(),
out_features=projection_dim,
......@@ -148,20 +146,6 @@ class TransformersWordEmbedder(token_embedders.PretrainedTransformerMismatchedEm
def get_output_dim(self):
return self.output_dim
@overrides
def train(self, mode: bool):
if self.freeze_transformer:
self.projection_layer.train(mode)
else:
super().train(mode)
@overrides
def eval(self):
if self.freeze_transformer:
self.projection_layer.eval()
else:
super().eval()
@token_embedders.TokenEmbedder.register("feats_embedding")
class FeatsTokenEmbedder(token_embedders.Embedding):
......
......@@ -158,7 +158,7 @@ class DependencyRelationModel(base.Predictor):
output["prediction"] = (relation_prediction.argmax(-1)[:, 1:], head_output["prediction"])
else:
# Mask root label whenever head is not 0.
relation_prediction_output = relation_prediction[:, 1:]
relation_prediction_output = relation_prediction[:, 1:].clone()
mask = (head_output["prediction"] == 0)
vocab_size = relation_prediction_output.size(-1)
root_idx = torch.tensor([self.root_idx], device=device)
......
......@@ -230,22 +230,24 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
patience: int = None,
validation_metric: str = "-loss",
num_epochs: int = 20,
cuda_device: int = -1,
cuda_device: Optional[Union[int, torch.device]] = -1,
grad_norm: float = None,
grad_clipping: float = None,
distributed: bool = None,
world_size: int = 1,
num_gradient_accumulation_steps: int = 1,
opt_level: Optional[str] = None,
use_amp: bool = False,
optimizer: common.Lazy[optimizers.Optimizer] = None,
no_grad: List[str] = None,
optimizer: common.Lazy[optimizers.Optimizer] = common.Lazy(optimizers.Optimizer.default),
learning_rate_scheduler: common.Lazy[learning_rate_schedulers.LearningRateScheduler] = None,
momentum_scheduler: common.Lazy[momentum_schedulers.MomentumScheduler] = None,
tensorboard_writer: common.Lazy[allen_tensorboard_writer.TensorboardWriter] = None,
moving_average: common.Lazy[moving_average.MovingAverage] = None,
checkpointer: common.Lazy[training.Checkpointer] = None,
checkpointer: common.Lazy[training.Checkpointer] = common.Lazy(training.Checkpointer),
batch_callbacks: List[training.BatchCallback] = None,
epoch_callbacks: List[training.EpochCallback] = None,
end_callbacks: List[training.EpochCallback] = None,
trainer_callbacks: List[training.TrainerCallback] = None,
) -> "training.Trainer":
if tensorboard_writer is None:
tensorboard_writer = common.Lazy(combo_tensorboard_writer.NullTensorboardWriter)
......@@ -265,6 +267,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
world_size=world_size,
num_gradient_accumulation_steps=num_gradient_accumulation_steps,
use_amp=use_amp,
no_grad=no_grad,
optimizer=optimizer,
learning_rate_scheduler=learning_rate_scheduler,
momentum_scheduler=momentum_scheduler,
......@@ -273,4 +276,6 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
checkpointer=checkpointer,
batch_callbacks=batch_callbacks,
epoch_callbacks=epoch_callbacks,
end_callbacks=end_callbacks,
trainer_callbacks=trainer_callbacks,
)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment