Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
poldeepner2
Manage
Activity
Members
Labels
Plan
Issues
29
Issue boards
Milestones
Wiki
Redmine
Code
Merge requests
0
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container Registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Information extraction
poldeepner2
Commits
5118ae5b
Commit
5118ae5b
authored
2 years ago
by
Michał Marcińczuk
Browse files
Options
Downloads
Patches
Plain Diff
Write training arguments to the file.
parent
b8d24205
Branches
Branches containing commit
1 merge request
!41
Dev v07
Pipeline
#6242
failed with stage
in 2 minutes and 25 seconds
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
train.py
+7
-6
7 additions, 6 deletions
train.py
with
7 additions
and
6 deletions
train.py
+
7
−
6
View file @
5118ae5b
...
...
@@ -4,6 +4,7 @@ from __future__ import absolute_import, division, print_function
import
argparse
import
glob
import
json
import
logging
import
os
import
sys
...
...
@@ -82,6 +83,7 @@ def train_model(args: Namespace):
"
empty.
"
%
args
.
output_dir
)
Path
(
args
.
output_dir
).
mkdir
(
parents
=
True
,
exist_ok
=
True
)
json
.
dump
(
args
.
__dict__
,
open
(
str
(
Path
(
args
.
output_dir
)
/
"
train_args.json
"
),
"
w
"
,
encoding
=
"
utf-8
"
),
indent
=
4
)
logger
=
logging
.
getLogger
(
__name__
)
for
item
in
sorted
(
config
.
items
()):
...
...
@@ -169,7 +171,7 @@ def train_model(args: Namespace):
if
args
.
freeze_model
:
logger
.
info
(
"
Freezing XLM-R model...
"
)
for
n
,
p
in
model
.
named_parameters
():
if
'
xlm
r
'
in
n
and
p
.
requires_grad
:
if
'
encode
r
'
in
n
and
p
.
requires_grad
:
logging
.
info
(
"
Parameter %s - freezed
"
%
n
)
p
.
requires_grad
=
False
else
:
...
...
@@ -253,7 +255,7 @@ def train_model(args: Namespace):
epoch_stats
[
"
epoch_training_time
"
]
=
time
.
time
()
-
time_start
if
args
.
data_tune
:
logger
.
info
(
"
\n
Testing on validation set...
"
)
logger
.
info
(
"
Testing on validation set...
"
)
time_start
=
time
.
time
()
f1
,
precision
,
recall
,
report
=
evaluate_model
(
model
,
val_data
,
label_list
,
args
.
eval_batch_size
,
device
)
time_end
=
time
.
time
()
...
...
@@ -264,8 +266,8 @@ def train_model(args: Namespace):
if
f1
>
best_val_f1
:
best_val_f1
=
f1
logger
.
info
(
"
\n
Found better f1=%.4f on validation set.
"
"
Saving model
\n
"
%
f1
)
logger
.
info
(
"
Found better f1=%.4f on validation set.
"
"
Saving model
"
%
f1
)
logger
.
info
(
"
%s
\n
"
%
report
)
model
.
save
(
args
.
output_dir
)
else
:
...
...
@@ -283,8 +285,7 @@ def train_model(args: Namespace):
logger
.
info
(
"
%s
\n
"
%
report
)
if
args
.
epoch_save_model
:
epoch_output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"
e%03d
"
%
epoch_no
)
epoch_output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"
e%03d
"
%
epoch_no
)
os
.
makedirs
(
epoch_output_dir
)
model
.
save
(
epoch_output_dir
)
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment