Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
C
combo
Manage
Activity
Members
Labels
Plan
Issues
20
Issue boards
Milestones
Wiki
Redmine
Code
Merge requests
2
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container Registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Syntactic Tools
combo
Commits
512bd5e2
Commit
512bd5e2
authored
3 years ago
by
Łukasz Pszenny
Browse files
Options
Downloads
Patches
Plain Diff
learning rate test fix
parent
1a0f387d
No related merge requests found
Pipeline
#4257
passed with stage
in 4 minutes and 55 seconds
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
combo/training/trainer.py
+22
-22
22 additions, 22 deletions
combo/training/trainer.py
with
22 additions
and
22 deletions
combo/training/trainer.py
+
22
−
22
View file @
512bd5e2
...
@@ -33,20 +33,26 @@ class TransferPatienceEpochCallback(training.EpochCallback):
...
@@ -33,20 +33,26 @@ class TransferPatienceEpochCallback(training.EpochCallback):
def
__call__
(
self
,
trainer
:
"
training.GradientDescentTrainer
"
,
metrics
:
Dict
[
str
,
Any
],
epoch
:
int
,
def
__call__
(
self
,
trainer
:
"
training.GradientDescentTrainer
"
,
metrics
:
Dict
[
str
,
Any
],
epoch
:
int
,
is_master
:
bool
)
->
None
:
is_master
:
bool
)
->
None
:
#LR range test variables
path_to_result_file
=
"
/tmp/lustre_shared/lukasz/tmp/LR_range_3.txt
"
path_to_result_file
=
"
/tmp/lustre_shared/lukasz/tmp/LR_range_3.txt
"
end_lr
=
0.001
start_lr
=
0.0000001
num_lr_in_test
=
25
lr_update_factor
=
(
end_lr
/
start_lr
)
**
(
1.0
/
num_lr_in_test
)
# # # # # # #
param_group
=
trainer
.
optimizer
.
param_groups
param_group
=
trainer
.
optimizer
.
param_groups
if
epoch
==
0
:
if
epoch
==
0
:
with
open
(
path_to_result_file
,
"
a
"
)
as
file
:
with
open
(
path_to_result_file
,
"
a
"
)
as
file
:
file
.
write
(
"
\n
"
+
str
(
param_group
[
0
][
"
lr
"
])
+
"
;
"
+
str
(
param_group
[
1
][
"
lr
"
])
+
"
;
"
)
file
.
write
(
"
\n
"
+
str
(
param_group
[
0
][
"
lr
"
])
+
"
;
"
+
str
(
param_group
[
1
][
"
lr
"
])
+
"
;
"
)
else
:
else
:
# LR range test variables
path_to_result_file
=
"
/tmp/lustre_shared/lukasz/tmp/LR_range_3.txt
"
end_lr
=
0.001
start_lr
=
0.0000001
num_lr_in_test
=
25
epochs_to_change_lr
=
trainer
.
epochs_to_change_lr
lr_update_factor
=
(
end_lr
/
start_lr
)
**
(
1.0
/
(
num_lr_in_test
-
1
))
test_number
=
int
(
epoch
/
epochs_to_change_lr
)
encoder_exponent
=
test_number
%
num_lr_in_test
rest_exponent
=
int
(
test_number
/
num_lr_in_test
)
# # # # # # #
with
open
(
path_to_result_file
,
"
a
"
)
as
file
:
with
open
(
path_to_result_file
,
"
a
"
)
as
file
:
file
.
write
(
str
(
metrics
[
"
training_loss
"
])
+
"
;
"
+
file
.
write
(
str
(
metrics
[
"
training_loss
"
])
+
"
;
"
+
str
(
metrics
[
"
validation_loss
"
])
+
"
;
"
+
str
(
metrics
[
"
validation_loss
"
])
+
"
;
"
+
...
@@ -84,13 +90,11 @@ class TransferPatienceEpochCallback(training.EpochCallback):
...
@@ -84,13 +90,11 @@ class TransferPatienceEpochCallback(training.EpochCallback):
)
)
# END CONDITIONS
# END CONDITIONS
if
param_group
[
1
][
"
lr
"
]
>
=
end_lr
and
param_group
[
0
][
"
lr
"
]
>
=
end_lr
:
if
param_group
[
1
][
"
lr
"
]
>
end_lr
and
param_group
[
0
][
"
lr
"
]
>
end_lr
:
raise
Exception
(
'
End of LR test
'
)
raise
Exception
(
'
End of LR test
'
)
param_group
[
0
][
"
lr
"
]
=
param_group
[
0
][
"
lr
"
]
*
lr_update_factor
param_group
[
0
][
"
lr
"
]
=
start_lr
*
(
lr_update_factor
**
encoder_exponent
)
if
param_group
[
0
][
"
lr
"
]
>=
end_lr
:
param_group
[
1
][
"
lr
"
]
=
start_lr
*
(
lr_update_factor
**
rest_exponent
)
param_group
[
0
][
"
lr
"
]
=
start_lr
param_group
[
1
][
"
lr
"
]
=
param_group
[
1
][
"
lr
"
]
*
lr_update_factor
file
.
write
(
"
\n
"
+
str
(
param_group
[
0
][
"
lr
"
])
+
"
;
"
+
str
(
param_group
[
1
][
"
lr
"
])
+
"
;
"
)
file
.
write
(
"
\n
"
+
str
(
param_group
[
0
][
"
lr
"
])
+
"
;
"
+
str
(
param_group
[
1
][
"
lr
"
])
+
"
;
"
)
...
@@ -124,6 +128,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
...
@@ -124,6 +128,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
num_gradient_accumulation_steps
,
use_amp
)
num_gradient_accumulation_steps
,
use_amp
)
# TODO extract param to constructor (+ constructor method?)
# TODO extract param to constructor (+ constructor method?)
self
.
validate_every_n
=
1
self
.
validate_every_n
=
1
self
.
epochs_to_change_lr
=
15
@overrides
@overrides
def
_try_train
(
self
)
->
Dict
[
str
,
Any
]:
def
_try_train
(
self
)
->
Dict
[
str
,
Any
]:
...
@@ -153,12 +158,6 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
...
@@ -153,12 +158,6 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
for
epoch
in
range
(
epoch_counter
,
self
.
_num_epochs
):
for
epoch
in
range
(
epoch_counter
,
self
.
_num_epochs
):
epochs_to_change_lr
=
15
# every epochs_to_change_lr epoch loads weights after 1 epoch
if
(
epoch
-
1
)
%
epochs_to_change_lr
==
0
and
epoch
>
1
:
self
.
model
.
load_state_dict
(
torch
.
load
(
os
.
path
.
join
(
self
.
_serialization_dir
,
"
initial.th
"
)))
epoch_start_time
=
time
.
time
()
epoch_start_time
=
time
.
time
()
train_metrics
=
self
.
_train_epoch
(
epoch
)
train_metrics
=
self
.
_train_epoch
(
epoch
)
...
@@ -195,7 +194,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
...
@@ -195,7 +194,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
for
key
,
value
in
val_metrics
.
items
():
for
key
,
value
in
val_metrics
.
items
():
metrics
[
"
validation_
"
+
key
]
=
value
metrics
[
"
validation_
"
+
key
]
=
value
if
self
.
_metric_tracker
.
is_best_so_far
():
if
self
.
_metric_tracker
.
is_best_so_far
()
or
(
epoch
%
self
.
epochs_to_change_lr
==
1
and
epoch
>
1
)
:
# Update all the best_ metrics.
# Update all the best_ metrics.
# (Otherwise they just stay the same as they were.)
# (Otherwise they just stay the same as they were.)
metrics
[
"
best_epoch
"
]
=
epoch
metrics
[
"
best_epoch
"
]
=
epoch
...
@@ -205,7 +204,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
...
@@ -205,7 +204,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
self
.
_metric_tracker
.
best_epoch_metrics
=
val_metrics
self
.
_metric_tracker
.
best_epoch_metrics
=
val_metrics
for
callback
in
self
.
_epoch_callbacks
:
for
callback
in
self
.
_epoch_callbacks
:
if
((
epoch
-
1
)
%
epochs_to_change_lr
==
0
and
epoch
>
1
)
or
epoch
==
0
:
if
epoch
%
self
.
epochs_to_change_lr
==
0
:
callback
(
self
,
metrics
=
metrics
,
epoch
=
epoch
,
is_master
=
self
.
_master
)
callback
(
self
,
metrics
=
metrics
,
epoch
=
epoch
,
is_master
=
self
.
_master
)
# The Scheduler API is agnostic to whether your schedule requires a validation metric -
# The Scheduler API is agnostic to whether your schedule requires a validation metric -
...
@@ -231,8 +230,9 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
...
@@ -231,8 +230,9 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
if
epoch
==
0
:
if
epoch
==
0
:
torch
.
save
(
self
.
model
.
state_dict
(),
os
.
path
.
join
(
self
.
_serialization_dir
,
"
initial.th
"
))
torch
.
save
(
self
.
model
.
state_dict
(),
os
.
path
.
join
(
self
.
_serialization_dir
,
"
initial.th
"
))
if
(
epoch
-
1
)
%
epochs_to_change_lr
==
0
and
epoch
>
1
:
# every epochs_to_change_lr epoch loads weights after 1 epoch
self
.
_metric_tracker
.
best_epoch_metrics
=
val_metrics
if
epoch
%
self
.
epochs_to_change_lr
==
0
and
epoch
>
1
:
self
.
model
.
load_state_dict
(
torch
.
load
(
os
.
path
.
join
(
self
.
_serialization_dir
,
"
initial.th
"
)))
for
callback
in
self
.
_end_callbacks
:
for
callback
in
self
.
_end_callbacks
:
callback
(
self
,
metrics
=
metrics
,
epoch
=
epoch
,
is_master
=
self
.
_master
)
callback
(
self
,
metrics
=
metrics
,
epoch
=
epoch
,
is_master
=
self
.
_master
)
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment