Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
C
combo
Manage
Activity
Members
Labels
Plan
Issues
20
Issue boards
Milestones
Wiki
Redmine
Code
Merge requests
2
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container Registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Syntactic Tools
combo
Commits
b03c33b2
There was an error fetching the commit references. Please try again later.
Commit
b03c33b2
authored
1 year ago
by
Maja Jablonska
Browse files
Options
Downloads
Patches
Plain Diff
Add a tensorboard logger
parent
ad097cf0
1 merge request
!46
Merge COMBO 3.0 into master
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
combo/main.py
+8
-1
8 additions, 1 deletion
combo/main.py
with
8 additions
and
1 deletion
combo/main.py
+
8
−
1
View file @
b03c33b2
...
@@ -8,6 +8,7 @@ from typing import Dict
...
@@ -8,6 +8,7 @@ from typing import Dict
import
torch
import
torch
from
absl
import
app
,
flags
from
absl
import
app
,
flags
import
pytorch_lightning
as
pl
import
pytorch_lightning
as
pl
from
pytorch_lightning.loggers
import
TensorBoardLogger
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
combo.training.trainable_combo
import
TrainableCombo
from
combo.training.trainable_combo
import
TrainableCombo
...
@@ -67,6 +68,8 @@ flags.DEFINE_string(name="serialization_dir", default=None,
...
@@ -67,6 +68,8 @@ flags.DEFINE_string(name="serialization_dir", default=None,
help
=
"
Model serialization directory (default - system temp dir).
"
)
help
=
"
Model serialization directory (default - system temp dir).
"
)
flags
.
DEFINE_boolean
(
name
=
"
tensorboard
"
,
default
=
False
,
flags
.
DEFINE_boolean
(
name
=
"
tensorboard
"
,
default
=
False
,
help
=
"
When provided model will log tensorboard metrics.
"
)
help
=
"
When provided model will log tensorboard metrics.
"
)
flags
.
DEFINE_string
(
name
=
"
tensorboard_name
"
,
default
=
"
combo
"
,
help
=
"
Name of the model in TensorBoard logs.
"
)
flags
.
DEFINE_string
(
name
=
"
config_path
"
,
default
=
str
(
pathlib
.
Path
(
__file__
).
parent
/
"
config.json
"
),
flags
.
DEFINE_string
(
name
=
"
config_path
"
,
default
=
str
(
pathlib
.
Path
(
__file__
).
parent
/
"
config.json
"
),
help
=
"
Config file path.
"
)
help
=
"
Config file path.
"
)
...
@@ -208,10 +211,14 @@ def run(_):
...
@@ -208,10 +211,14 @@ def run(_):
n_cuda_devices
=
"
auto
"
if
FLAGS
.
n_cuda_devices
==
-
1
else
FLAGS
.
n_cuda_devices
n_cuda_devices
=
"
auto
"
if
FLAGS
.
n_cuda_devices
==
-
1
else
FLAGS
.
n_cuda_devices
tensorboard_logger
=
TensorBoardLogger
(
os
.
path
.
join
(
serialization_dir
,
'
tensorboard_logs
'
),
name
=
FLAGS
.
tensorboard_name
)
if
FLAGS
.
tensorboard
else
None
trainer
=
pl
.
Trainer
(
max_epochs
=
FLAGS
.
num_epochs
,
trainer
=
pl
.
Trainer
(
max_epochs
=
FLAGS
.
num_epochs
,
default_root_dir
=
serialization_dir
,
default_root_dir
=
serialization_dir
,
gradient_clip_val
=
5
,
gradient_clip_val
=
5
,
devices
=
n_cuda_devices
)
devices
=
n_cuda_devices
,
logger
=
tensorboard_logger
)
try
:
try
:
trainer
.
fit
(
model
=
nlp
,
train_dataloaders
=
train_data_loader
,
val_dataloaders
=
validation_data_loader
)
trainer
.
fit
(
model
=
nlp
,
train_dataloaders
=
train_data_loader
,
val_dataloaders
=
validation_data_loader
)
except
Exception
as
e
:
except
Exception
as
e
:
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment