From d6a8e3eb2d14e7706b96f388ae5b359e4ae53efa Mon Sep 17 00:00:00 2001
From: Maja Jablonska <majajjablonska@gmail.com>
Date: Tue, 21 Nov 2023 23:13:40 +1100
Subject: [PATCH] Fixed paths in archival.py

---
 combo/modules/archival.py | 97 +++++++++++++++++++++++++++++----------
 combo/modules/model.py    |  4 --
 combo/utils/logging.py    |  2 +-
 3 files changed, 73 insertions(+), 30 deletions(-)

diff --git a/combo/modules/archival.py b/combo/modules/archival.py
index 56a293e..496753f 100644
--- a/combo/modules/archival.py
+++ b/combo/modules/archival.py
@@ -1,4 +1,6 @@
 import os
+import shutil
+import tempfile
 from os import PathLike
 from pathlib import Path
 from typing import Any, Dict, Union, NamedTuple, Optional
@@ -14,11 +16,19 @@ from combo.config import resolve
 from combo.data.dataset_loaders import DataLoader
 from combo.data.dataset_readers import DatasetReader
 from combo.modules.model import Model
-from combo.utils import ConfigurationError
+from contextlib import contextmanager
+
+import logging
+from combo.utils import ComboLogger
+
+logging.setLoggerClass(ComboLogger)
+logger = logging.getLogger(__name__)
+
 
 CACHE_ROOT = Path(os.getenv("COMBO_CACHE_ROOT", Path.home() / ".combo"))
 CACHE_DIRECTORY = str(CACHE_ROOT / "cache")
 
+PREFIX = 'Loading archive'
 
 class Archive(NamedTuple):
     model: Model
@@ -82,36 +92,73 @@ def archive(model: Model,
     return serialization_dir
 
 
+@contextmanager
+def extracted_archive(resolved_archive_file, cleanup=True):
+    tempdir = None
+    try:
+        tempdir = tempfile.mkdtemp(dir=CACHE_DIRECTORY)
+        with tarfile.open(resolved_archive_file) as archive:
+            subdir_and_files = [
+                tarinfo for tarinfo in archive.getmembers()
+                if (any([tarinfo.name.endswith(f) for f in ['config.json', 'weights.th']])
+                    or 'vocabulary' in tarinfo.name)
+            ]
+            for f in subdir_and_files:
+                if 'vocabulary' in f.name and not f.name.endswith('vocabulary'):
+                    f.name = os.path.join('vocabulary', os.path.basename(f.name))
+                else:
+                    f.name = os.path.basename(f.name)
+            archive.extractall(path=tempdir, members=subdir_and_files)
+            yield tempdir
+    finally:
+        if tempdir is not None and cleanup:
+            shutil.rmtree(tempdir, ignore_errors=True)
+
+
 def load_archive(url_or_filename: Union[PathLike, str],
                  cache_dir: Union[PathLike, str] = None,
                  cuda_device: int = -1) -> Archive:
-    archive_file = cached_path.cached_path(
+
+    rarchive_file = cached_path.cached_path(
             url_or_filename,
             cache_dir=cache_dir or CACHE_DIRECTORY,
-            extract_archive=True
         )
-    model = Model.load(archive_file, cuda_device=cuda_device)
-
-    config_path = os.path.join(archive_file, 'config.json')
-    if not os.path.exists(config_path):
-        config_path = os.path.join(archive_file, 'model/config.json')
-    if not os.path.exists(config_path):
-        raise ConfigurationError("config.json is not stored in " + str(archive_file) + " or " + str(archive_file) + "/model")
-    with open(config_path, 'r') as f:
-        config = json.load(f)
-
-    data_loader, validation_data_loader, dataset_reader = None, None, None
-    pass_down_parameters = {}
-    if config.get("model_name"):
-        pass_down_parameters = {"model_name": config.get("model_name")}
-
-    if 'data_loader' in config:
-        data_loader = resolve(config['data_loader'], pass_down_parameters=pass_down_parameters)
-    if 'validation_data_loader' in config:
-        validation_data_loader = resolve(config['validation_data_loader'], pass_down_parameters=pass_down_parameters)
-    if 'dataset_reader' in config:
-        dataset_reader = resolve(config['dataset_reader'], pass_down_parameters=pass_down_parameters)
-    
+
+    with extracted_archive(rarchive_file) as archive_file:
+        model = Model.load(archive_file, cuda_device=cuda_device)
+
+        config_path = os.path.join(archive_file, 'config.json')
+        with open(config_path, 'r') as f:
+            config = json.load(f)
+
+        data_loader, validation_data_loader, dataset_reader = None, None, None
+        pass_down_parameters = {}
+        if config.get("model_name"):
+            pass_down_parameters = {"model_name": config.get("model_name")}
+
+
+        if 'data_loader' in config:
+            try:
+                data_loader = resolve(config['data_loader'],
+                                      pass_down_parameters=pass_down_parameters)
+            except Exception as e:
+                logger.warning(f'Error while loading Training Data Loader: {str(e)}. Setting Data Loader to None',
+                               prefix=PREFIX)
+        if 'validation_data_loader' in config:
+            try:
+                validation_data_loader = resolve(config['validation_data_loader'],
+                                                 pass_down_parameters=pass_down_parameters)
+            except Exception as e:
+                logger.warning(f'Error while loading Validation Data Loader: {str(e)}. Setting Data Loader to None',
+                               prefix=PREFIX)
+        if 'dataset_reader' in config:
+            try:
+                dataset_reader = resolve(config['dataset_reader'],
+                                         pass_down_parameters=pass_down_parameters)
+            except Exception as e:
+                logger.warning(f'Error while loading Dataset Reader: {str(e)}. Setting Dataset Reader to None',
+                               prefix=PREFIX)
+
     return Archive(model=model,
                    config=config,
                    data_loader=data_loader,
diff --git a/combo/modules/model.py b/combo/modules/model.py
index 5f2beb6..223bacc 100644
--- a/combo/modules/model.py
+++ b/combo/modules/model.py
@@ -349,10 +349,6 @@ class Model(Module, pl.LightningModule, FromParameters):
 
         # Load vocabulary from file
         vocab_dir = os.path.join(serialization_dir, "vocabulary")
-        if not os.path.exists(vocab_dir):
-            vocab_dir =os.path.join(serialization_dir, "model/vocabulary")
-        if not os.path.exists(vocab_dir):
-            raise ConfigurationError("Vocabulary not saved in " + serialization_dir + " or " + serialization_dir + "/model")
         # If the config specifies a vocabulary subclass, we need to use it.
         vocab_params = config.get("vocabulary")
         if vocab_params['type'] == 'from_files_vocabulary':
diff --git a/combo/utils/logging.py b/combo/utils/logging.py
index cd04b6a..2d8d99a 100644
--- a/combo/utils/logging.py
+++ b/combo/utils/logging.py
@@ -27,7 +27,7 @@ class ComboLogger(logging.Logger):
         self.log(level=logging.INFO, msg=msg, prefix=prefix)
 
     @overrides(check_signature=False)
-    def warn(self, msg: str, prefix: str = None):
+    def warning(self, msg: str, prefix: str = None):
         self.log(level=logging.WARN, msg=msg, prefix=prefix)
 
     @overrides(check_signature=False)
-- 
GitLab