Skip to content
Snippets Groups Projects
Commit d6a8e3eb authored by Maja Jablonska's avatar Maja Jablonska
Browse files

Fixed paths in archival.py

parent 82ba2f40
No related branches found
No related tags found
1 merge request!46Merge COMBO 3.0 into master
import os
import shutil
import tempfile
from os import PathLike
from pathlib import Path
from typing import Any, Dict, Union, NamedTuple, Optional
......@@ -14,11 +16,19 @@ from combo.config import resolve
from combo.data.dataset_loaders import DataLoader
from combo.data.dataset_readers import DatasetReader
from combo.modules.model import Model
from combo.utils import ConfigurationError
from contextlib import contextmanager
import logging
from combo.utils import ComboLogger
logging.setLoggerClass(ComboLogger)
logger = logging.getLogger(__name__)
CACHE_ROOT = Path(os.getenv("COMBO_CACHE_ROOT", Path.home() / ".combo"))
CACHE_DIRECTORY = str(CACHE_ROOT / "cache")
PREFIX = 'Loading archive'
class Archive(NamedTuple):
model: Model
......@@ -82,21 +92,42 @@ def archive(model: Model,
return serialization_dir
@contextmanager
def extracted_archive(resolved_archive_file, cleanup=True):
tempdir = None
try:
tempdir = tempfile.mkdtemp(dir=CACHE_DIRECTORY)
with tarfile.open(resolved_archive_file) as archive:
subdir_and_files = [
tarinfo for tarinfo in archive.getmembers()
if (any([tarinfo.name.endswith(f) for f in ['config.json', 'weights.th']])
or 'vocabulary' in tarinfo.name)
]
for f in subdir_and_files:
if 'vocabulary' in f.name and not f.name.endswith('vocabulary'):
f.name = os.path.join('vocabulary', os.path.basename(f.name))
else:
f.name = os.path.basename(f.name)
archive.extractall(path=tempdir, members=subdir_and_files)
yield tempdir
finally:
if tempdir is not None and cleanup:
shutil.rmtree(tempdir, ignore_errors=True)
def load_archive(url_or_filename: Union[PathLike, str],
cache_dir: Union[PathLike, str] = None,
cuda_device: int = -1) -> Archive:
archive_file = cached_path.cached_path(
rarchive_file = cached_path.cached_path(
url_or_filename,
cache_dir=cache_dir or CACHE_DIRECTORY,
extract_archive=True
)
with extracted_archive(rarchive_file) as archive_file:
model = Model.load(archive_file, cuda_device=cuda_device)
config_path = os.path.join(archive_file, 'config.json')
if not os.path.exists(config_path):
config_path = os.path.join(archive_file, 'model/config.json')
if not os.path.exists(config_path):
raise ConfigurationError("config.json is not stored in " + str(archive_file) + " or " + str(archive_file) + "/model")
with open(config_path, 'r') as f:
config = json.load(f)
......@@ -105,12 +136,28 @@ def load_archive(url_or_filename: Union[PathLike, str],
if config.get("model_name"):
pass_down_parameters = {"model_name": config.get("model_name")}
if 'data_loader' in config:
data_loader = resolve(config['data_loader'], pass_down_parameters=pass_down_parameters)
try:
data_loader = resolve(config['data_loader'],
pass_down_parameters=pass_down_parameters)
except Exception as e:
logger.warning(f'Error while loading Training Data Loader: {str(e)}. Setting Data Loader to None',
prefix=PREFIX)
if 'validation_data_loader' in config:
validation_data_loader = resolve(config['validation_data_loader'], pass_down_parameters=pass_down_parameters)
try:
validation_data_loader = resolve(config['validation_data_loader'],
pass_down_parameters=pass_down_parameters)
except Exception as e:
logger.warning(f'Error while loading Validation Data Loader: {str(e)}. Setting Data Loader to None',
prefix=PREFIX)
if 'dataset_reader' in config:
dataset_reader = resolve(config['dataset_reader'], pass_down_parameters=pass_down_parameters)
try:
dataset_reader = resolve(config['dataset_reader'],
pass_down_parameters=pass_down_parameters)
except Exception as e:
logger.warning(f'Error while loading Dataset Reader: {str(e)}. Setting Dataset Reader to None',
prefix=PREFIX)
return Archive(model=model,
config=config,
......
......@@ -349,10 +349,6 @@ class Model(Module, pl.LightningModule, FromParameters):
# Load vocabulary from file
vocab_dir = os.path.join(serialization_dir, "vocabulary")
if not os.path.exists(vocab_dir):
vocab_dir =os.path.join(serialization_dir, "model/vocabulary")
if not os.path.exists(vocab_dir):
raise ConfigurationError("Vocabulary not saved in " + serialization_dir + " or " + serialization_dir + "/model")
# If the config specifies a vocabulary subclass, we need to use it.
vocab_params = config.get("vocabulary")
if vocab_params['type'] == 'from_files_vocabulary':
......
......@@ -27,7 +27,7 @@ class ComboLogger(logging.Logger):
self.log(level=logging.INFO, msg=msg, prefix=prefix)
@overrides(check_signature=False)
def warn(self, msg: str, prefix: str = None):
def warning(self, msg: str, prefix: str = None):
self.log(level=logging.WARN, msg=msg, prefix=prefix)
@overrides(check_signature=False)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment