Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
C
combo
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Redmine
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container registry
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
GitLab community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Syntactic Tools
combo
Commits
cbe6d0a4
Commit
cbe6d0a4
authored
Apr 6, 2023
by
Maja Jabłońska
Browse files
Options
Downloads
Patches
Plain Diff
Add archival.py
parent
229d48d2
Branches
Branches containing commit
Tags
Tags containing commit
1 merge request
!46
Merge COMBO 3.0 into master
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
combo/models/archival.py
+261
-0
261 additions, 0 deletions
combo/models/archival.py
with
261 additions
and
0 deletions
combo/models/archival.py
0 → 100644
+
261
−
0
View file @
cbe6d0a4
"""
Helper functions for archiving models and restoring archived models.
"""
from
os
import
PathLike
from
pathlib
import
Path
from
typing
import
NamedTuple
,
Union
,
Dict
,
Any
,
List
,
Optional
import
logging
import
os
import
tempfile
import
tarfile
import
shutil
from
contextlib
import
contextmanager
import
glob
from
torch.nn
import
Module
from
combo.common.params
import
Params
from
combo.data.dataset
import
DatasetReader
from
combo.models.model
import
Model
from
combo.utils
import
ConfigurationError
from
combo.utils.file_utils
import
cached_path
logger
=
logging
.
getLogger
(
__name__
)
class
Archive
(
NamedTuple
):
"""
An archive comprises a Model and its experimental config
"""
model
:
Model
config
:
Params
dataset_reader
:
DatasetReader
validation_dataset_reader
:
DatasetReader
def
extract_module
(
self
,
path
:
str
,
freeze
:
bool
=
True
)
->
Module
:
"""
This method can be used to load a module from the pretrained model archive.
It is also used implicitly in FromParams based construction. So instead of using standard
params to construct a module, you can instead load a pretrained module from the model
archive directly. For eg, instead of using params like {
"
type
"
:
"
module_type
"
, ...}, you
can use the following template::
{
"
_pretrained
"
: {
"
archive_file
"
:
"
../path/to/model.tar.gz
"
,
"
path
"
:
"
path.to.module.in.model
"
,
"
freeze
"
: False
}
}
If you use this feature with FromParams, take care of the following caveat: Call to
initializer(self) at end of model initializer can potentially wipe the transferred parameters
by reinitializing them. This can happen if you have setup initializer regex that also
matches parameters of the transferred module. To safe-guard against this, you can either
update your initializer regex to prevent conflicting match or add extra initializer::
[
[
"
.*transferred_module_name.*
"
,
"
prevent
"
]]
]
# Parameters
path : `str`, required
Path of target module to be loaded from the model.
Eg.
"
_textfield_embedder.token_embedder_tokens
"
freeze : `bool`, optional (default=`True`)
Whether to freeze the module parameters or not.
"""
modules_dict
=
{
path
:
module
for
path
,
module
in
self
.
model
.
named_modules
()}
module
=
modules_dict
.
get
(
path
)
if
not
module
:
raise
ConfigurationError
(
f
"
You asked to transfer module at path
{
path
}
from
"
f
"
the model
{
type
(
self
.
model
)
}
. But it
'
s not present.
"
)
if
not
isinstance
(
module
,
Module
):
raise
ConfigurationError
(
f
"
The transferred object from model
{
type
(
self
.
model
)
}
at path
"
f
"
{
path
}
is not a PyTorch Module.
"
)
for
parameter
in
module
.
parameters
():
# type: ignore
parameter
.
requires_grad_
(
not
freeze
)
return
module
# We archive a model by creating a tar.gz file with its weights, config, and vocabulary.
#
# These constants are the *known names* under which we archive them.
CONFIG_NAME
=
"
config.json
"
_DEFAULT_WEIGHTS
=
"
best.th
"
_WEIGHTS_NAME
=
"
weights.th
"
def
archive_model
(
serialization_dir
:
Union
[
str
,
PathLike
],
weights
:
str
=
_DEFAULT_WEIGHTS
,
archive_path
:
Union
[
str
,
PathLike
]
=
None
,
include_in_archive
:
Optional
[
List
[
str
]]
=
None
,
)
->
str
:
"""
Archive the model weights, its training configuration, and its vocabulary to `model.tar.gz`.
# Parameters
serialization_dir : `str`
The directory where the weights and vocabulary are written out.
weights : `str`, optional (default=`_DEFAULT_WEIGHTS`)
Which weights file to include in the archive. The default is `best.th`.
archive_path : `str`, optional, (default = `None`)
A full path to serialize the model to. The default is
"
model.tar.gz
"
inside the
serialization_dir. If you pass a directory here, we
'
ll serialize the model
to
"
model.tar.gz
"
inside the directory.
include_in_archive : `List[str]`, optional, (default = `None`)
Paths relative to `serialization_dir` that should be archived in addition to the default ones.
# Returns
The final archive path.
"""
extra_copy_of_weights_just_for_mypy
=
Path
(
weights
)
if
extra_copy_of_weights_just_for_mypy
.
is_absolute
():
weights_file
=
extra_copy_of_weights_just_for_mypy
else
:
weights_file
=
Path
(
serialization_dir
)
/
extra_copy_of_weights_just_for_mypy
if
not
os
.
path
.
exists
(
weights_file
):
err_msg
=
f
"
weights file
'
{
weights_file
}
'
does not exist, unable to archive model
"
logger
.
error
(
err_msg
)
raise
RuntimeError
(
err_msg
)
config_file
=
os
.
path
.
join
(
serialization_dir
,
CONFIG_NAME
)
if
not
os
.
path
.
exists
(
config_file
):
err_msg
=
f
"
config file
'
{
config_file
}
'
does not exist, unable to archive model
"
logger
.
error
(
err_msg
)
raise
RuntimeError
(
err_msg
)
if
archive_path
is
not
None
:
archive_file
=
archive_path
if
os
.
path
.
isdir
(
archive_file
):
archive_file
=
os
.
path
.
join
(
archive_file
,
"
model.tar.gz
"
)
else
:
archive_file
=
os
.
path
.
join
(
serialization_dir
,
"
model.tar.gz
"
)
logger
.
info
(
"
archiving weights and vocabulary to %s
"
,
archive_file
)
with
tarfile
.
open
(
archive_file
,
"
w:gz
"
)
as
archive
:
archive
.
add
(
config_file
,
arcname
=
CONFIG_NAME
)
archive
.
add
(
weights_file
,
arcname
=
_WEIGHTS_NAME
)
archive
.
add
(
os
.
path
.
join
(
serialization_dir
,
"
vocabulary
"
),
arcname
=
"
vocabulary
"
)
if
include_in_archive
is
not
None
:
for
archival_target
in
include_in_archive
:
archival_target_path
=
os
.
path
.
join
(
serialization_dir
,
archival_target
)
for
path
in
glob
.
glob
(
archival_target_path
):
if
os
.
path
.
exists
(
path
):
arcname
=
path
[
len
(
os
.
path
.
join
(
serialization_dir
,
""
))
:]
archive
.
add
(
path
,
arcname
=
arcname
)
return
str
(
archive_file
)
def
load_archive
(
archive_file
:
Union
[
str
,
PathLike
],
cuda_device
:
int
=
-
1
,
overrides
:
Union
[
str
,
Dict
[
str
,
Any
]]
=
""
,
weights_file
:
str
=
None
,
)
->
Archive
:
"""
Instantiates an Archive from an archived `tar.gz` file.
# Parameters
archive_file : `Union[str, PathLike]`
The archive file to load the model from.
cuda_device : `int`, optional (default = `-1`)
If `cuda_device` is >= 0, the model will be loaded onto the
corresponding GPU. Otherwise it will be loaded onto the CPU.
overrides : `Union[str, Dict[str, Any]]`, optional (default = `
""
`)
JSON overrides to apply to the unarchived `Params` object.
weights_file : `str`, optional (default = `None`)
The weights file to use. If unspecified, weights.th in the archive_file will be used.
"""
# redirect to the cache, if necessary
resolved_archive_file
=
cached_path
(
archive_file
)
if
resolved_archive_file
==
archive_file
:
logger
.
info
(
f
"
loading archive file
{
archive_file
}
"
)
else
:
logger
.
info
(
f
"
loading archive file
{
archive_file
}
from cache at
{
resolved_archive_file
}
"
)
tempdir
=
None
try
:
if
os
.
path
.
isdir
(
resolved_archive_file
):
serialization_dir
=
resolved_archive_file
else
:
with
extracted_archive
(
resolved_archive_file
,
cleanup
=
False
)
as
tempdir
:
serialization_dir
=
tempdir
if
weights_file
:
weights_path
=
weights_file
else
:
weights_path
=
get_weights_path
(
serialization_dir
)
# Load config
config
=
Params
.
from_file
(
os
.
path
.
join
(
serialization_dir
,
CONFIG_NAME
),
overrides
)
# Instantiate model and dataset readers. Use a duplicate of the config, as it will get consumed.
dataset_reader
,
validation_dataset_reader
=
_load_dataset_readers
(
config
.
duplicate
(),
serialization_dir
)
model
=
_load_model
(
config
.
duplicate
(),
weights_path
,
serialization_dir
,
cuda_device
)
finally
:
if
tempdir
is
not
None
:
logger
.
info
(
f
"
removing temporary unarchived model dir at
{
tempdir
}
"
)
shutil
.
rmtree
(
tempdir
,
ignore_errors
=
True
)
return
Archive
(
model
=
model
,
config
=
config
,
dataset_reader
=
dataset_reader
,
validation_dataset_reader
=
validation_dataset_reader
,
)
def
_load_dataset_readers
(
config
,
serialization_dir
):
dataset_reader_params
=
config
.
get
(
"
dataset_reader
"
)
# Try to use the validation dataset reader if there is one - otherwise fall back
# to the default dataset_reader used for both training and validation.
validation_dataset_reader_params
=
config
.
get
(
"
validation_dataset_reader
"
,
dataset_reader_params
.
duplicate
()
)
dataset_reader
=
DatasetReader
.
from_params
(
dataset_reader_params
,
serialization_dir
=
serialization_dir
)
validation_dataset_reader
=
DatasetReader
.
from_params
(
validation_dataset_reader_params
,
serialization_dir
=
serialization_dir
)
return
dataset_reader
,
validation_dataset_reader
def
_load_model
(
config
,
weights_path
,
serialization_dir
,
cuda_device
):
return
Model
.
load
(
config
,
weights_file
=
weights_path
,
serialization_dir
=
serialization_dir
,
cuda_device
=
cuda_device
,
)
def
get_weights_path
(
serialization_dir
):
weights_path
=
os
.
path
.
join
(
serialization_dir
,
_WEIGHTS_NAME
)
# Fallback for serialization directories.
if
not
os
.
path
.
exists
(
weights_path
):
weights_path
=
os
.
path
.
join
(
serialization_dir
,
_DEFAULT_WEIGHTS
)
return
weights_path
@contextmanager
def
extracted_archive
(
resolved_archive_file
,
cleanup
=
True
):
tempdir
=
None
try
:
tempdir
=
tempfile
.
mkdtemp
()
logger
.
info
(
f
"
extracting archive file
{
resolved_archive_file
}
to temp dir
{
tempdir
}
"
)
with
tarfile
.
open
(
resolved_archive_file
,
"
r:gz
"
)
as
archive
:
archive
.
extractall
(
tempdir
)
yield
tempdir
finally
:
if
tempdir
is
not
None
and
cleanup
:
logger
.
info
(
f
"
removing temporary unarchived model dir at
{
tempdir
}
"
)
shutil
.
rmtree
(
tempdir
,
ignore_errors
=
True
)
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment