Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
C
combo
Manage
Activity
Members
Labels
Plan
Issues
20
Issue boards
Milestones
Wiki
Redmine
Code
Merge requests
2
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container Registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Syntactic Tools
combo
Commits
1a0f387d
There was an error fetching the commit references. Please try again later.
Commit
1a0f387d
authored
3 years ago
by
Łukasz Pszenny
Browse files
Options
Downloads
Patches
Plain Diff
learning rate test implementation
parent
d7849a0a
No related merge requests found
Pipeline
#4251
passed with stage
in 8 minutes and 37 seconds
Changes
3
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
combo/config.graph.template.jsonnet
+13
-13
13 additions, 13 deletions
combo/config.graph.template.jsonnet
combo/training/scheduler.py
+1
-1
1 addition, 1 deletion
combo/training/scheduler.py
combo/training/trainer.py
+78
-62
78 additions, 62 deletions
combo/training/trainer.py
with
92 additions
and
76 deletions
combo/config.graph.template.jsonnet
+
13
−
13
View file @
1a0f387d
...
@@ -274,25 +274,25 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
...
@@ -274,25 +274,25 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
local
projection_dim
=
512
,
local
projection_dim
=
512
,
cycle_loss_n
:
cycle_loss_n
,
cycle_loss_n
:
cycle_loss_n
,
head_projection_layer
:
{
head_projection_layer
:
{
in_features
:
hidden_size
*
2
,
in_features
:
char_dim
+
768
,
#
hidden_size * 2,
out_features
:
projection_dim
,
out_features
:
projection_dim
,
activation
:
"tanh"
,
activation
:
"tanh"
,
},
},
dependency_projection_layer
:
{
dependency_projection_layer
:
{
in_features
:
hidden_size
*
2
,
in_features
:
char_dim
+
768
,
#
hidden_size * 2,
out_features
:
projection_dim
,
out_features
:
projection_dim
,
activation
:
"tanh"
,
activation
:
"tanh"
,
},
},
},
},
local
projection_dim
=
128
,
local
projection_dim
=
128
,
head_projection_layer
:
{
head_projection_layer
:
{
in_features
:
hidden_size
*
2
,
in_features
:
char_dim
+
768
,
#
hidden_size * 2,
out_features
:
projection_dim
,
out_features
:
projection_dim
,
dropout_rate
:
predictors_dropout
,
dropout_rate
:
predictors_dropout
,
activation
:
"tanh"
activation
:
"tanh"
},
},
dependency_projection_layer
:
{
dependency_projection_layer
:
{
in_features
:
hidden_size
*
2
,
in_features
:
char_dim
+
768
,
#
hidden_size * 2,
out_features
:
projection_dim
,
out_features
:
projection_dim
,
dropout_rate
:
predictors_dropout
,
dropout_rate
:
predictors_dropout
,
activation
:
"tanh"
activation
:
"tanh"
...
@@ -305,25 +305,25 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
...
@@ -305,25 +305,25 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
local
projection_dim
=
512
,
local
projection_dim
=
512
,
cycle_loss_n
:
cycle_loss_n
,
cycle_loss_n
:
cycle_loss_n
,
head_projection_layer
:
{
head_projection_layer
:
{
in_features
:
hidden_size
*
2
,
in_features
:
char_dim
+
768
,
#
hidden_size * 2,
out_features
:
projection_dim
,
out_features
:
projection_dim
,
activation
:
"tanh"
,
activation
:
"tanh"
,
},
},
dependency_projection_layer
:
{
dependency_projection_layer
:
{
in_features
:
hidden_size
*
2
,
in_features
:
char_dim
+
768
,
#
hidden_size * 2,
out_features
:
projection_dim
,
out_features
:
projection_dim
,
activation
:
"tanh"
,
activation
:
"tanh"
,
},
},
},
},
local
projection_dim
=
128
,
local
projection_dim
=
128
,
head_projection_layer
:
{
head_projection_layer
:
{
in_features
:
hidden_size
*
2
,
in_features
:
char_dim
+
768
,
#
hidden_size * 2,
out_features
:
projection_dim
,
out_features
:
projection_dim
,
dropout_rate
:
predictors_dropout
,
dropout_rate
:
predictors_dropout
,
activation
:
"tanh"
activation
:
"tanh"
},
},
dependency_projection_layer
:
{
dependency_projection_layer
:
{
in_features
:
hidden_size
*
2
,
in_features
:
char_dim
+
768
,
#
hidden_size * 2,
out_features
:
projection_dim
,
out_features
:
projection_dim
,
dropout_rate
:
predictors_dropout
,
dropout_rate
:
predictors_dropout
,
activation
:
"tanh"
activation
:
"tanh"
...
@@ -332,7 +332,7 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
...
@@ -332,7 +332,7 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
morphological_feat
:
if
in_targets
(
"feats"
)
then
{
morphological_feat
:
if
in_targets
(
"feats"
)
then
{
type
:
"combo_morpho_from_vocab"
,
type
:
"combo_morpho_from_vocab"
,
vocab_namespace
:
"feats_labels"
,
vocab_namespace
:
"feats_labels"
,
input_dim
:
hidden_size
*
2
,
input_dim
:
char_dim
+
768
,
#
hidden_size * 2,
hidden_dims
:
[
128
],
hidden_dims
:
[
128
],
activations
:
[
"tanh"
,
"linear"
],
activations
:
[
"tanh"
,
"linear"
],
dropout
:
[
predictors_dropout
,
0.0
],
dropout
:
[
predictors_dropout
,
0.0
],
...
@@ -344,7 +344,7 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
...
@@ -344,7 +344,7 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
lemma_vocab_namespace
:
"lemma_characters"
,
lemma_vocab_namespace
:
"lemma_characters"
,
embedding_dim
:
256
,
embedding_dim
:
256
,
input_projection_layer
:
{
input_projection_layer
:
{
in_features
:
hidden_size
*
2
,
in_features
:
char_dim
+
768
,
#
hidden_size * 2,
out_features
:
32
,
out_features
:
32
,
dropout_rate
:
predictors_dropout
,
dropout_rate
:
predictors_dropout
,
activation
:
"tanh"
activation
:
"tanh"
...
@@ -357,7 +357,7 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
...
@@ -357,7 +357,7 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
activations
:
[
"relu"
,
"relu"
,
"relu"
,
"linear"
],
activations
:
[
"relu"
,
"relu"
,
"relu"
,
"linear"
],
},
},
upos_tagger
:
if
in_targets
(
"upostag"
)
then
{
upos_tagger
:
if
in_targets
(
"upostag"
)
then
{
input_dim
:
hidden_size
*
2
,
input_dim
:
char_dim
+
768
,
#
hidden_size * 2,
hidden_dims
:
[
64
],
hidden_dims
:
[
64
],
activations
:
[
"tanh"
,
"linear"
],
activations
:
[
"tanh"
,
"linear"
],
dropout
:
[
predictors_dropout
,
0.0
],
dropout
:
[
predictors_dropout
,
0.0
],
...
@@ -365,7 +365,7 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
...
@@ -365,7 +365,7 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
vocab_namespace
:
"upostag_labels"
vocab_namespace
:
"upostag_labels"
},
},
xpos_tagger
:
if
in_targets
(
"xpostag"
)
then
{
xpos_tagger
:
if
in_targets
(
"xpostag"
)
then
{
input_dim
:
hidden_size
*
2
,
input_dim
:
char_dim
+
768
,
#
hidden_size * 2,
hidden_dims
:
[
128
],
hidden_dims
:
[
128
],
activations
:
[
"tanh"
,
"linear"
],
activations
:
[
"tanh"
,
"linear"
],
dropout
:
[
predictors_dropout
,
0.0
],
dropout
:
[
predictors_dropout
,
0.0
],
...
@@ -373,7 +373,7 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
...
@@ -373,7 +373,7 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
vocab_namespace
:
"xpostag_labels"
vocab_namespace
:
"xpostag_labels"
},
},
semantic_relation
:
if
in_targets
(
"semrel"
)
then
{
semantic_relation
:
if
in_targets
(
"semrel"
)
then
{
input_dim
:
hidden_size
*
2
,
input_dim
:
char_dim
+
768
,
#
hidden_size * 2,
hidden_dims
:
[
64
],
hidden_dims
:
[
64
],
activations
:
[
"tanh"
,
"linear"
],
activations
:
[
"tanh"
,
"linear"
],
dropout
:
[
predictors_dropout
,
0.0
],
dropout
:
[
predictors_dropout
,
0.0
],
...
...
This diff is collapsed.
Click to expand it.
combo/training/scheduler.py
+
1
−
1
View file @
1a0f387d
...
@@ -7,7 +7,7 @@ from overrides import overrides
...
@@ -7,7 +7,7 @@ from overrides import overrides
class
Scheduler
(
learning_rate_scheduler
.
_PyTorchLearningRateSchedulerWrapper
):
class
Scheduler
(
learning_rate_scheduler
.
_PyTorchLearningRateSchedulerWrapper
):
def
__init__
(
self
,
optimizer
,
patience
:
int
=
6
,
decreases
:
int
=
2
,
threshold
:
float
=
1e-3
):
def
__init__
(
self
,
optimizer
,
patience
:
int
=
6
,
decreases
:
int
=
2
,
threshold
:
float
=
1e-3
):
super
().
__init__
(
lr_scheduler
.
LambdaLR
(
optimizer
,
lr_lambda
=
[
self
.
_lr_lambda
]))
super
().
__init__
(
lr_scheduler
.
LambdaLR
(
optimizer
,
lr_lambda
=
[
self
.
_lr_lambda
,
self
.
_lr_lambda
]))
self
.
threshold
=
threshold
self
.
threshold
=
threshold
self
.
decreases
=
decreases
self
.
decreases
=
decreases
self
.
patience
=
patience
self
.
patience
=
patience
...
...
This diff is collapsed.
Click to expand it.
combo/training/trainer.py
+
78
−
62
View file @
1a0f387d
...
@@ -32,11 +32,69 @@ class TransferPatienceEpochCallback(training.EpochCallback):
...
@@ -32,11 +32,69 @@ class TransferPatienceEpochCallback(training.EpochCallback):
def
__call__
(
self
,
trainer
:
"
training.GradientDescentTrainer
"
,
metrics
:
Dict
[
str
,
Any
],
epoch
:
int
,
def
__call__
(
self
,
trainer
:
"
training.GradientDescentTrainer
"
,
metrics
:
Dict
[
str
,
Any
],
epoch
:
int
,
is_master
:
bool
)
->
None
:
is_master
:
bool
)
->
None
:
if
trainer
.
_learning_rate_scheduler
and
trainer
.
_learning_rate_scheduler
.
patience
is
not
None
:
trainer
.
_metric_tracker
.
_patience
=
trainer
.
_learning_rate_scheduler
.
patience
#LR range test variables
trainer
.
_metric_tracker
.
_epochs_with_no_improvement
=
0
path_to_result_file
=
"
/tmp/lustre_shared/lukasz/tmp/LR_range_3.txt
"
end_lr
=
0.001
start_lr
=
0.0000001
num_lr_in_test
=
25
lr_update_factor
=
(
end_lr
/
start_lr
)
**
(
1.0
/
num_lr_in_test
)
# # # # # # #
param_group
=
trainer
.
optimizer
.
param_groups
if
epoch
==
0
:
with
open
(
path_to_result_file
,
"
a
"
)
as
file
:
file
.
write
(
"
\n
"
+
str
(
param_group
[
0
][
"
lr
"
])
+
"
;
"
+
str
(
param_group
[
1
][
"
lr
"
])
+
"
;
"
)
else
:
else
:
raise
checks
.
ConfigurationError
(
"
Learning rate scheduler isn
'
t properly setup!
"
)
with
open
(
path_to_result_file
,
"
a
"
)
as
file
:
file
.
write
(
str
(
metrics
[
"
training_loss
"
])
+
"
;
"
+
str
(
metrics
[
"
validation_loss
"
])
+
"
;
"
+
str
(
metrics
[
"
best_epoch
"
])
+
"
;
"
+
#training losses
str
(
metrics
[
"
training_partial_loss/upostag_loss
"
])
+
"
;
"
+
str
(
metrics
[
"
training_partial_loss/xpostag_loss
"
])
+
"
;
"
+
str
(
metrics
[
"
training_partial_loss/feats_loss
"
])
+
"
;
"
+
str
(
metrics
[
"
training_partial_loss/lemma_loss
"
])
+
"
;
"
+
str
(
metrics
[
"
training_partial_loss/head_loss
"
])
+
"
;
"
+
str
(
metrics
[
"
training_partial_loss/deprel_loss
"
])
+
"
;
"
+
#training acc
str
(
metrics
[
"
training_UPOS_ACC
"
])
+
"
;
"
+
str
(
metrics
[
"
training_XPOS_ACC
"
])
+
"
;
"
+
str
(
metrics
[
"
training_SEMREL_ACC
"
])
+
"
;
"
+
str
(
metrics
[
"
training_LEMMA_ACC
"
])
+
"
;
"
+
str
(
metrics
[
"
training_FEATS_ACC
"
])
+
"
;
"
+
str
(
metrics
[
"
training_LAS
"
])
+
"
;
"
+
#validation losses
str
(
metrics
[
"
validation_partial_loss/upostag_loss
"
])
+
"
;
"
+
str
(
metrics
[
"
validation_partial_loss/xpostag_loss
"
])
+
"
;
"
+
str
(
metrics
[
"
validation_partial_loss/feats_loss
"
])
+
"
;
"
+
str
(
metrics
[
"
validation_partial_loss/lemma_loss
"
])
+
"
;
"
+
str
(
metrics
[
"
validation_partial_loss/head_loss
"
])
+
"
;
"
+
str
(
metrics
[
"
validation_partial_loss/deprel_loss
"
])
+
"
;
"
+
# validation acc
str
(
metrics
[
"
validation_UPOS_ACC
"
])
+
"
;
"
+
str
(
metrics
[
"
validation_XPOS_ACC
"
])
+
"
;
"
+
str
(
metrics
[
"
validation_SEMREL_ACC
"
])
+
"
;
"
+
str
(
metrics
[
"
validation_LEMMA_ACC
"
])
+
"
;
"
+
str
(
metrics
[
"
validation_FEATS_ACC
"
])
+
"
;
"
+
str
(
metrics
[
"
validation_LAS
"
])
+
"
;
"
+
# best loss
str
(
metrics
[
"
best_validation_loss
"
])
+
"
;
"
)
# END CONDITIONS
if
param_group
[
1
][
"
lr
"
]
>=
end_lr
and
param_group
[
0
][
"
lr
"
]
>=
end_lr
:
raise
Exception
(
'
End of LR test
'
)
param_group
[
0
][
"
lr
"
]
=
param_group
[
0
][
"
lr
"
]
*
lr_update_factor
if
param_group
[
0
][
"
lr
"
]
>=
end_lr
:
param_group
[
0
][
"
lr
"
]
=
start_lr
param_group
[
1
][
"
lr
"
]
=
param_group
[
1
][
"
lr
"
]
*
lr_update_factor
file
.
write
(
"
\n
"
+
str
(
param_group
[
0
][
"
lr
"
])
+
"
;
"
+
str
(
param_group
[
1
][
"
lr
"
])
+
"
;
"
)
trainer
.
optimizer
.
param_groups
=
param_group
@training.Trainer.register
(
"
gradient_descent_validate_n
"
,
constructor
=
"
from_partial_objects
"
)
@training.Trainer.register
(
"
gradient_descent_validate_n
"
,
constructor
=
"
from_partial_objects
"
)
...
@@ -65,7 +123,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
...
@@ -65,7 +123,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
batch_callbacks
,
epoch_callbacks
,
end_callbacks
,
trainer_callbacks
,
distributed
,
local_rank
,
world_size
,
batch_callbacks
,
epoch_callbacks
,
end_callbacks
,
trainer_callbacks
,
distributed
,
local_rank
,
world_size
,
num_gradient_accumulation_steps
,
use_amp
)
num_gradient_accumulation_steps
,
use_amp
)
# TODO extract param to constructor (+ constructor method?)
# TODO extract param to constructor (+ constructor method?)
self
.
validate_every_n
=
5
self
.
validate_every_n
=
1
@overrides
@overrides
def
_try_train
(
self
)
->
Dict
[
str
,
Any
]:
def
_try_train
(
self
)
->
Dict
[
str
,
Any
]:
...
@@ -93,26 +151,16 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
...
@@ -93,26 +151,16 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
for
key
,
value
in
self
.
_metric_tracker
.
best_epoch_metrics
.
items
():
for
key
,
value
in
self
.
_metric_tracker
.
best_epoch_metrics
.
items
():
metrics
[
"
best_validation_
"
+
key
]
=
value
metrics
[
"
best_validation_
"
+
key
]
=
value
for
callback
in
self
.
_epoch_callbacks
:
callback
(
self
,
metrics
=
{},
epoch
=-
1
,
is_master
=
self
.
_master
)
for
epoch
in
range
(
epoch_counter
,
self
.
_num_epochs
):
for
epoch
in
range
(
epoch_counter
,
self
.
_num_epochs
):
epoch_start_time
=
time
.
time
()
train_metrics
=
self
.
_train_epoch
(
epoch
)
if
self
.
_master
and
self
.
_checkpointer
is
not
None
:
epochs_to_change_lr
=
15
self
.
_checkpointer
.
save_checkpoint
(
epoch
,
self
,
save_model_only
=
True
)
#
Wait for the master to finish saving the model checkpoint
#
every epochs_to_change_lr epoch loads weights after 1 epoch
if
self
.
_distributed
:
if
(
epoch
-
1
)
%
epochs_to_change_lr
==
0
and
epoch
>
1
:
dist
.
barrier
(
)
self
.
model
.
load_state_dict
(
torch
.
load
(
os
.
path
.
join
(
self
.
_serialization_dir
,
"
initial.th
"
))
)
# get peak of memory usage
epoch_start_time
=
time
.
time
()
for
key
,
value
in
train_metrics
.
items
():
train_metrics
=
self
.
_train_epoch
(
epoch
)
if
key
.
startswith
(
"
gpu_
"
)
and
key
.
endswith
(
"
_memory_MB
"
):
metrics
[
"
peak_
"
+
key
]
=
max
(
metrics
.
get
(
"
peak_
"
+
key
,
0
),
value
)
elif
key
.
startswith
(
"
worker_
"
)
and
key
.
endswith
(
"
_memory_MB
"
):
metrics
[
"
peak_
"
+
key
]
=
max
(
metrics
.
get
(
"
peak_
"
+
key
,
0
),
value
)
if
self
.
_validation_data_loader
is
not
None
:
if
self
.
_validation_data_loader
is
not
None
:
val_metrics
=
{}
val_metrics
=
{}
...
@@ -141,24 +189,6 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
...
@@ -141,24 +189,6 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
# Check validation metric for early stopping
# Check validation metric for early stopping
this_epoch_val_metric
=
val_metrics
[
self
.
_validation_metric
]
this_epoch_val_metric
=
val_metrics
[
self
.
_validation_metric
]
# self._metric_tracker.add_metric(this_epoch_val_metric)
train_metrics
[
"
patience
"
]
=
self
.
_metric_tracker
.
_patience
if
self
.
_metric_tracker
.
should_stop_early
():
logger
.
info
(
"
Ran out of patience. Stopping training.
"
)
break
if
self
.
_master
:
self
.
_tensorboard
.
log_metrics
(
train_metrics
,
val_metrics
=
val_metrics
,
log_to_console
=
True
,
epoch
=
epoch
+
1
)
# +1 because tensorboard doesn't like 0
# Create overall metrics dict
training_elapsed_time
=
time
.
time
()
-
training_start_time
metrics
[
"
training_duration
"
]
=
str
(
datetime
.
timedelta
(
seconds
=
training_elapsed_time
))
metrics
[
"
training_start_epoch
"
]
=
epoch_counter
metrics
[
"
training_epochs
"
]
=
epochs_trained
metrics
[
"
epoch
"
]
=
epoch
for
key
,
value
in
train_metrics
.
items
():
for
key
,
value
in
train_metrics
.
items
():
metrics
[
"
training_
"
+
key
]
=
value
metrics
[
"
training_
"
+
key
]
=
value
...
@@ -174,10 +204,9 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
...
@@ -174,10 +204,9 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
self
.
_metric_tracker
.
best_epoch_metrics
=
val_metrics
self
.
_metric_tracker
.
best_epoch_metrics
=
val_metrics
if
self
.
_serialization_dir
and
self
.
_master
:
for
callback
in
self
.
_epoch_callbacks
:
common_util
.
dump_metrics
(
if
((
epoch
-
1
)
%
epochs_to_change_lr
==
0
and
epoch
>
1
)
or
epoch
==
0
:
os
.
path
.
join
(
self
.
_serialization_dir
,
f
"
metrics_epoch_
{
epoch
}
.json
"
),
metrics
callback
(
self
,
metrics
=
metrics
,
epoch
=
epoch
,
is_master
=
self
.
_master
)
)
# The Scheduler API is agnostic to whether your schedule requires a validation metric -
# The Scheduler API is agnostic to whether your schedule requires a validation metric -
# if it doesn't, the validation metric passed here is ignored.
# if it doesn't, the validation metric passed here is ignored.
...
@@ -186,18 +215,6 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
...
@@ -186,18 +215,6 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
if
self
.
_momentum_scheduler
:
if
self
.
_momentum_scheduler
:
self
.
_momentum_scheduler
.
step
(
this_epoch_val_metric
)
self
.
_momentum_scheduler
.
step
(
this_epoch_val_metric
)
if
self
.
_master
and
self
.
_checkpointer
is
not
None
:
self
.
_checkpointer
.
save_checkpoint
(
epoch
,
self
,
is_best_so_far
=
self
.
_metric_tracker
.
is_best_so_far
()
)
# Wait for the master to finish saving the checkpoint
if
self
.
_distributed
:
dist
.
barrier
()
for
callback
in
self
.
_epoch_callbacks
:
callback
(
self
,
metrics
=
metrics
,
epoch
=
epoch
,
is_master
=
self
.
_master
)
epoch_elapsed_time
=
time
.
time
()
-
epoch_start_time
epoch_elapsed_time
=
time
.
time
()
-
epoch_start_time
logger
.
info
(
"
Epoch duration: %s
"
,
datetime
.
timedelta
(
seconds
=
epoch_elapsed_time
))
logger
.
info
(
"
Epoch duration: %s
"
,
datetime
.
timedelta
(
seconds
=
epoch_elapsed_time
))
...
@@ -211,16 +228,15 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
...
@@ -211,16 +228,15 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
epochs_trained
+=
1
epochs_trained
+=
1
if
epoch
==
0
:
torch
.
save
(
self
.
model
.
state_dict
(),
os
.
path
.
join
(
self
.
_serialization_dir
,
"
initial.th
"
))
if
(
epoch
-
1
)
%
epochs_to_change_lr
==
0
and
epoch
>
1
:
self
.
_metric_tracker
.
best_epoch_metrics
=
val_metrics
for
callback
in
self
.
_end_callbacks
:
for
callback
in
self
.
_end_callbacks
:
callback
(
self
,
metrics
=
metrics
,
epoch
=
epoch
,
is_master
=
self
.
_master
)
callback
(
self
,
metrics
=
metrics
,
epoch
=
epoch
,
is_master
=
self
.
_master
)
# Load the best model state before returning
best_model_state
=
(
None
if
self
.
_checkpointer
is
None
else
self
.
_checkpointer
.
best_model_state
()
)
if
best_model_state
:
self
.
model
.
load_state_dict
(
best_model_state
)
return
metrics
return
metrics
@classmethod
@classmethod
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment