Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
C
combo
Manage
Activity
Members
Labels
Plan
Issues
20
Issue boards
Milestones
Wiki
Redmine
Code
Merge requests
2
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container Registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Syntactic Tools
combo
Commits
d6a8e3eb
Commit
d6a8e3eb
authored
1 year ago
by
Maja Jablonska
Browse files
Options
Downloads
Patches
Plain Diff
Fixed paths in archival.py
parent
82ba2f40
1 merge request
!46
Merge COMBO 3.0 into master
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
combo/modules/archival.py
+72
-25
72 additions, 25 deletions
combo/modules/archival.py
combo/modules/model.py
+0
-4
0 additions, 4 deletions
combo/modules/model.py
combo/utils/logging.py
+1
-1
1 addition, 1 deletion
combo/utils/logging.py
with
73 additions
and
30 deletions
combo/modules/archival.py
+
72
−
25
View file @
d6a8e3eb
import
os
import
os
import
shutil
import
tempfile
from
os
import
PathLike
from
os
import
PathLike
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Any
,
Dict
,
Union
,
NamedTuple
,
Optional
from
typing
import
Any
,
Dict
,
Union
,
NamedTuple
,
Optional
...
@@ -14,11 +16,19 @@ from combo.config import resolve
...
@@ -14,11 +16,19 @@ from combo.config import resolve
from
combo.data.dataset_loaders
import
DataLoader
from
combo.data.dataset_loaders
import
DataLoader
from
combo.data.dataset_readers
import
DatasetReader
from
combo.data.dataset_readers
import
DatasetReader
from
combo.modules.model
import
Model
from
combo.modules.model
import
Model
from
combo.utils
import
ConfigurationError
from
contextlib
import
contextmanager
import
logging
from
combo.utils
import
ComboLogger
logging
.
setLoggerClass
(
ComboLogger
)
logger
=
logging
.
getLogger
(
__name__
)
CACHE_ROOT
=
Path
(
os
.
getenv
(
"
COMBO_CACHE_ROOT
"
,
Path
.
home
()
/
"
.combo
"
))
CACHE_ROOT
=
Path
(
os
.
getenv
(
"
COMBO_CACHE_ROOT
"
,
Path
.
home
()
/
"
.combo
"
))
CACHE_DIRECTORY
=
str
(
CACHE_ROOT
/
"
cache
"
)
CACHE_DIRECTORY
=
str
(
CACHE_ROOT
/
"
cache
"
)
PREFIX
=
'
Loading archive
'
class
Archive
(
NamedTuple
):
class
Archive
(
NamedTuple
):
model
:
Model
model
:
Model
...
@@ -82,36 +92,73 @@ def archive(model: Model,
...
@@ -82,36 +92,73 @@ def archive(model: Model,
return
serialization_dir
return
serialization_dir
@contextmanager
def
extracted_archive
(
resolved_archive_file
,
cleanup
=
True
):
tempdir
=
None
try
:
tempdir
=
tempfile
.
mkdtemp
(
dir
=
CACHE_DIRECTORY
)
with
tarfile
.
open
(
resolved_archive_file
)
as
archive
:
subdir_and_files
=
[
tarinfo
for
tarinfo
in
archive
.
getmembers
()
if
(
any
([
tarinfo
.
name
.
endswith
(
f
)
for
f
in
[
'
config.json
'
,
'
weights.th
'
]])
or
'
vocabulary
'
in
tarinfo
.
name
)
]
for
f
in
subdir_and_files
:
if
'
vocabulary
'
in
f
.
name
and
not
f
.
name
.
endswith
(
'
vocabulary
'
):
f
.
name
=
os
.
path
.
join
(
'
vocabulary
'
,
os
.
path
.
basename
(
f
.
name
))
else
:
f
.
name
=
os
.
path
.
basename
(
f
.
name
)
archive
.
extractall
(
path
=
tempdir
,
members
=
subdir_and_files
)
yield
tempdir
finally
:
if
tempdir
is
not
None
and
cleanup
:
shutil
.
rmtree
(
tempdir
,
ignore_errors
=
True
)
def
load_archive
(
url_or_filename
:
Union
[
PathLike
,
str
],
def
load_archive
(
url_or_filename
:
Union
[
PathLike
,
str
],
cache_dir
:
Union
[
PathLike
,
str
]
=
None
,
cache_dir
:
Union
[
PathLike
,
str
]
=
None
,
cuda_device
:
int
=
-
1
)
->
Archive
:
cuda_device
:
int
=
-
1
)
->
Archive
:
archive_file
=
cached_path
.
cached_path
(
rarchive_file
=
cached_path
.
cached_path
(
url_or_filename
,
url_or_filename
,
cache_dir
=
cache_dir
or
CACHE_DIRECTORY
,
cache_dir
=
cache_dir
or
CACHE_DIRECTORY
,
extract_archive
=
True
)
)
model
=
Model
.
load
(
archive_file
,
cuda_device
=
cuda_device
)
with
extracted_archive
(
rarchive_file
)
as
archive_file
:
config_path
=
os
.
path
.
join
(
archive_file
,
'
config.json
'
)
model
=
Model
.
load
(
archive_file
,
cuda_device
=
cuda_device
)
if
not
os
.
path
.
exists
(
config_path
):
config_path
=
os
.
path
.
join
(
archive_file
,
'
model/config.json
'
)
config_path
=
os
.
path
.
join
(
archive_file
,
'
config.json
'
)
if
not
os
.
path
.
exists
(
config_path
):
with
open
(
config_path
,
'
r
'
)
as
f
:
raise
ConfigurationError
(
"
config.json is not stored in
"
+
str
(
archive_file
)
+
"
or
"
+
str
(
archive_file
)
+
"
/model
"
)
config
=
json
.
load
(
f
)
with
open
(
config_path
,
'
r
'
)
as
f
:
config
=
json
.
load
(
f
)
data_loader
,
validation_data_loader
,
dataset_reader
=
None
,
None
,
None
pass_down_parameters
=
{}
data_loader
,
validation_data_loader
,
dataset_reader
=
None
,
None
,
None
if
config
.
get
(
"
model_name
"
):
pass_down_parameters
=
{}
pass_down_parameters
=
{
"
model_name
"
:
config
.
get
(
"
model_name
"
)}
if
config
.
get
(
"
model_name
"
):
pass_down_parameters
=
{
"
model_name
"
:
config
.
get
(
"
model_name
"
)}
if
'
data_loader
'
in
config
:
if
'
data_loader
'
in
config
:
try
:
data_loader
=
resolve
(
config
[
'
data_loader
'
],
pass_down_parameters
=
pass_down_parameters
)
data_loader
=
resolve
(
config
[
'
data_loader
'
],
if
'
validation_data_loader
'
in
config
:
pass_down_parameters
=
pass_down_parameters
)
validation_data_loader
=
resolve
(
config
[
'
validation_data_loader
'
],
pass_down_parameters
=
pass_down_parameters
)
except
Exception
as
e
:
if
'
dataset_reader
'
in
config
:
logger
.
warning
(
f
'
Error while loading Training Data Loader:
{
str
(
e
)
}
. Setting Data Loader to None
'
,
dataset_reader
=
resolve
(
config
[
'
dataset_reader
'
],
pass_down_parameters
=
pass_down_parameters
)
prefix
=
PREFIX
)
if
'
validation_data_loader
'
in
config
:
try
:
validation_data_loader
=
resolve
(
config
[
'
validation_data_loader
'
],
pass_down_parameters
=
pass_down_parameters
)
except
Exception
as
e
:
logger
.
warning
(
f
'
Error while loading Validation Data Loader:
{
str
(
e
)
}
. Setting Data Loader to None
'
,
prefix
=
PREFIX
)
if
'
dataset_reader
'
in
config
:
try
:
dataset_reader
=
resolve
(
config
[
'
dataset_reader
'
],
pass_down_parameters
=
pass_down_parameters
)
except
Exception
as
e
:
logger
.
warning
(
f
'
Error while loading Dataset Reader:
{
str
(
e
)
}
. Setting Dataset Reader to None
'
,
prefix
=
PREFIX
)
return
Archive
(
model
=
model
,
return
Archive
(
model
=
model
,
config
=
config
,
config
=
config
,
data_loader
=
data_loader
,
data_loader
=
data_loader
,
...
...
This diff is collapsed.
Click to expand it.
combo/modules/model.py
+
0
−
4
View file @
d6a8e3eb
...
@@ -349,10 +349,6 @@ class Model(Module, pl.LightningModule, FromParameters):
...
@@ -349,10 +349,6 @@ class Model(Module, pl.LightningModule, FromParameters):
# Load vocabulary from file
# Load vocabulary from file
vocab_dir
=
os
.
path
.
join
(
serialization_dir
,
"
vocabulary
"
)
vocab_dir
=
os
.
path
.
join
(
serialization_dir
,
"
vocabulary
"
)
if
not
os
.
path
.
exists
(
vocab_dir
):
vocab_dir
=
os
.
path
.
join
(
serialization_dir
,
"
model/vocabulary
"
)
if
not
os
.
path
.
exists
(
vocab_dir
):
raise
ConfigurationError
(
"
Vocabulary not saved in
"
+
serialization_dir
+
"
or
"
+
serialization_dir
+
"
/model
"
)
# If the config specifies a vocabulary subclass, we need to use it.
# If the config specifies a vocabulary subclass, we need to use it.
vocab_params
=
config
.
get
(
"
vocabulary
"
)
vocab_params
=
config
.
get
(
"
vocabulary
"
)
if
vocab_params
[
'
type
'
]
==
'
from_files_vocabulary
'
:
if
vocab_params
[
'
type
'
]
==
'
from_files_vocabulary
'
:
...
...
This diff is collapsed.
Click to expand it.
combo/utils/logging.py
+
1
−
1
View file @
d6a8e3eb
...
@@ -27,7 +27,7 @@ class ComboLogger(logging.Logger):
...
@@ -27,7 +27,7 @@ class ComboLogger(logging.Logger):
self
.
log
(
level
=
logging
.
INFO
,
msg
=
msg
,
prefix
=
prefix
)
self
.
log
(
level
=
logging
.
INFO
,
msg
=
msg
,
prefix
=
prefix
)
@overrides
(
check_signature
=
False
)
@overrides
(
check_signature
=
False
)
def
warn
(
self
,
msg
:
str
,
prefix
:
str
=
None
):
def
warn
ing
(
self
,
msg
:
str
,
prefix
:
str
=
None
):
self
.
log
(
level
=
logging
.
WARN
,
msg
=
msg
,
prefix
=
prefix
)
self
.
log
(
level
=
logging
.
WARN
,
msg
=
msg
,
prefix
=
prefix
)
@overrides
(
check_signature
=
False
)
@overrides
(
check_signature
=
False
)
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment