Skip to content
Snippets Groups Projects
Commit d6a8e3eb authored by Maja Jablonska's avatar Maja Jablonska
Browse files

Fixed paths in archival.py

parent 82ba2f40
1 merge request!46Merge COMBO 3.0 into master
import os import os
import shutil
import tempfile
from os import PathLike from os import PathLike
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Union, NamedTuple, Optional from typing import Any, Dict, Union, NamedTuple, Optional
...@@ -14,11 +16,19 @@ from combo.config import resolve ...@@ -14,11 +16,19 @@ from combo.config import resolve
from combo.data.dataset_loaders import DataLoader from combo.data.dataset_loaders import DataLoader
from combo.data.dataset_readers import DatasetReader from combo.data.dataset_readers import DatasetReader
from combo.modules.model import Model from combo.modules.model import Model
from combo.utils import ConfigurationError from contextlib import contextmanager
import logging
from combo.utils import ComboLogger
logging.setLoggerClass(ComboLogger)
logger = logging.getLogger(__name__)
CACHE_ROOT = Path(os.getenv("COMBO_CACHE_ROOT", Path.home() / ".combo")) CACHE_ROOT = Path(os.getenv("COMBO_CACHE_ROOT", Path.home() / ".combo"))
CACHE_DIRECTORY = str(CACHE_ROOT / "cache") CACHE_DIRECTORY = str(CACHE_ROOT / "cache")
PREFIX = 'Loading archive'
class Archive(NamedTuple): class Archive(NamedTuple):
model: Model model: Model
...@@ -82,36 +92,73 @@ def archive(model: Model, ...@@ -82,36 +92,73 @@ def archive(model: Model,
return serialization_dir return serialization_dir
@contextmanager
def extracted_archive(resolved_archive_file, cleanup=True):
tempdir = None
try:
tempdir = tempfile.mkdtemp(dir=CACHE_DIRECTORY)
with tarfile.open(resolved_archive_file) as archive:
subdir_and_files = [
tarinfo for tarinfo in archive.getmembers()
if (any([tarinfo.name.endswith(f) for f in ['config.json', 'weights.th']])
or 'vocabulary' in tarinfo.name)
]
for f in subdir_and_files:
if 'vocabulary' in f.name and not f.name.endswith('vocabulary'):
f.name = os.path.join('vocabulary', os.path.basename(f.name))
else:
f.name = os.path.basename(f.name)
archive.extractall(path=tempdir, members=subdir_and_files)
yield tempdir
finally:
if tempdir is not None and cleanup:
shutil.rmtree(tempdir, ignore_errors=True)
def load_archive(url_or_filename: Union[PathLike, str], def load_archive(url_or_filename: Union[PathLike, str],
cache_dir: Union[PathLike, str] = None, cache_dir: Union[PathLike, str] = None,
cuda_device: int = -1) -> Archive: cuda_device: int = -1) -> Archive:
archive_file = cached_path.cached_path(
rarchive_file = cached_path.cached_path(
url_or_filename, url_or_filename,
cache_dir=cache_dir or CACHE_DIRECTORY, cache_dir=cache_dir or CACHE_DIRECTORY,
extract_archive=True
) )
model = Model.load(archive_file, cuda_device=cuda_device)
with extracted_archive(rarchive_file) as archive_file:
config_path = os.path.join(archive_file, 'config.json') model = Model.load(archive_file, cuda_device=cuda_device)
if not os.path.exists(config_path):
config_path = os.path.join(archive_file, 'model/config.json') config_path = os.path.join(archive_file, 'config.json')
if not os.path.exists(config_path): with open(config_path, 'r') as f:
raise ConfigurationError("config.json is not stored in " + str(archive_file) + " or " + str(archive_file) + "/model") config = json.load(f)
with open(config_path, 'r') as f:
config = json.load(f) data_loader, validation_data_loader, dataset_reader = None, None, None
pass_down_parameters = {}
data_loader, validation_data_loader, dataset_reader = None, None, None if config.get("model_name"):
pass_down_parameters = {} pass_down_parameters = {"model_name": config.get("model_name")}
if config.get("model_name"):
pass_down_parameters = {"model_name": config.get("model_name")}
if 'data_loader' in config:
if 'data_loader' in config: try:
data_loader = resolve(config['data_loader'], pass_down_parameters=pass_down_parameters) data_loader = resolve(config['data_loader'],
if 'validation_data_loader' in config: pass_down_parameters=pass_down_parameters)
validation_data_loader = resolve(config['validation_data_loader'], pass_down_parameters=pass_down_parameters) except Exception as e:
if 'dataset_reader' in config: logger.warning(f'Error while loading Training Data Loader: {str(e)}. Setting Data Loader to None',
dataset_reader = resolve(config['dataset_reader'], pass_down_parameters=pass_down_parameters) prefix=PREFIX)
if 'validation_data_loader' in config:
try:
validation_data_loader = resolve(config['validation_data_loader'],
pass_down_parameters=pass_down_parameters)
except Exception as e:
logger.warning(f'Error while loading Validation Data Loader: {str(e)}. Setting Data Loader to None',
prefix=PREFIX)
if 'dataset_reader' in config:
try:
dataset_reader = resolve(config['dataset_reader'],
pass_down_parameters=pass_down_parameters)
except Exception as e:
logger.warning(f'Error while loading Dataset Reader: {str(e)}. Setting Dataset Reader to None',
prefix=PREFIX)
return Archive(model=model, return Archive(model=model,
config=config, config=config,
data_loader=data_loader, data_loader=data_loader,
......
...@@ -349,10 +349,6 @@ class Model(Module, pl.LightningModule, FromParameters): ...@@ -349,10 +349,6 @@ class Model(Module, pl.LightningModule, FromParameters):
# Load vocabulary from file # Load vocabulary from file
vocab_dir = os.path.join(serialization_dir, "vocabulary") vocab_dir = os.path.join(serialization_dir, "vocabulary")
if not os.path.exists(vocab_dir):
vocab_dir =os.path.join(serialization_dir, "model/vocabulary")
if not os.path.exists(vocab_dir):
raise ConfigurationError("Vocabulary not saved in " + serialization_dir + " or " + serialization_dir + "/model")
# If the config specifies a vocabulary subclass, we need to use it. # If the config specifies a vocabulary subclass, we need to use it.
vocab_params = config.get("vocabulary") vocab_params = config.get("vocabulary")
if vocab_params['type'] == 'from_files_vocabulary': if vocab_params['type'] == 'from_files_vocabulary':
......
...@@ -27,7 +27,7 @@ class ComboLogger(logging.Logger): ...@@ -27,7 +27,7 @@ class ComboLogger(logging.Logger):
self.log(level=logging.INFO, msg=msg, prefix=prefix) self.log(level=logging.INFO, msg=msg, prefix=prefix)
@overrides(check_signature=False) @overrides(check_signature=False)
def warn(self, msg: str, prefix: str = None): def warning(self, msg: str, prefix: str = None):
self.log(level=logging.WARN, msg=msg, prefix=prefix) self.log(level=logging.WARN, msg=msg, prefix=prefix)
@overrides(check_signature=False) @overrides(check_signature=False)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment