Skip to content
Snippets Groups Projects
Commit cd660089 authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski
Browse files

Enable BERT fine-tuning.

parent e902b504
Branches
No related merge requests found
Pipeline #2898 passed with stage
in 3 minutes and 38 seconds
......@@ -202,6 +202,7 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
},
token: if use_transformer then {
type: "transformers_word_embeddings",
freeze_transformer: false,
model_name: pretrained_transformer_name,
projection_dim: projected_embedding_dim,
tokenizer_kwargs: if std.startsWith(pretrained_transformer_name, "allegro/herbert")
......@@ -401,9 +402,14 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
grad_clipping: 5.0,
num_epochs: num_epochs,
optimizer: {
type: "adam",
type: "adamw",
lr: learning_rate,
betas: [0.9, 0.9],
weight_decay: 0.0,
parameter_groups: [
[
['_embedder'], { lr: 5e-5, weight_decay: 0.01, finetune: true, },
],
],
},
patience: 1, # it will be overwriten by callback
epoch_callbacks: [
......
......@@ -7,7 +7,8 @@ from overrides import overrides
class Scheduler(learning_rate_scheduler._PyTorchLearningRateSchedulerWrapper):
def __init__(self, optimizer, patience: int = 6, decreases: int = 2, threshold: float = 1e-3):
super().__init__(lr_scheduler.LambdaLR(optimizer, lr_lambda=[self._lr_lambda]))
super().__init__(lr_scheduler.LambdaLR(optimizer,
lr_lambda=[self._lr_lambda] * len(optimizer.param_groups)))
self.threshold = threshold
self.decreases = decreases
self.patience = patience
......
......@@ -114,8 +114,8 @@ def run(_):
--pretrained_transformer_name {utils.LANG2TRANSFORMER[lang]}
--serialization_dir {serialization_dir}
--cuda_device {FLAGS.cuda_device}
--word_batch_size 2500
--config_path {pathlib.Path.cwd() / 'config.graph.template.jsonnet'}
--word_batch_size 1000
--config_path {pathlib.Path.cwd() / 'combo' / 'config.graph.template.jsonnet'}
--notensorboard
"""
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment