From ad2a1908992bf8663d2641c418d09ee861e9a948 Mon Sep 17 00:00:00 2001
From: Maja Jablonska <majajjablonska@gmail.com>
Date: Wed, 8 Nov 2023 22:35:59 +1100
Subject: [PATCH] Fixes to from_pretrained

---
 combo/modules/archival.py            |  8 ++++----
 combo/nn/regularizers/regularizer.py |  4 +---
 combo/predict.py                     | 10 +++++-----
 combo/utils/download.py              | 12 +++---------
 4 files changed, 13 insertions(+), 21 deletions(-)

diff --git a/combo/modules/archival.py b/combo/modules/archival.py
index 29a6d70..1476939 100644
--- a/combo/modules/archival.py
+++ b/combo/modules/archival.py
@@ -10,9 +10,9 @@ import tarfile
 from io import BytesIO
 from tempfile import TemporaryDirectory
 
-from config import resolve
-from data.dataset_loaders import DataLoader
-from modules.model import Model
+from combo.config import resolve
+from combo.data.dataset_loaders import DataLoader
+from combo.modules.model import Model
 
 
 CACHE_ROOT = Path(os.getenv("COMBO_CACHE_ROOT", Path.home() / ".combo"))
@@ -75,7 +75,7 @@ def load_archive(url_or_filename: Union[PathLike, str],
         )
     model = Model.load(archive_file, cuda_device=cuda_device)
 
-    with open(os.path.join(archive_file, 'model/config.json'), 'r') as f:
+    with open(os.path.join(archive_file, 'config.json'), 'r') as f:
         config = json.load(f)
 
     data_loader, validation_data_loader = None, None
diff --git a/combo/nn/regularizers/regularizer.py b/combo/nn/regularizers/regularizer.py
index 5e62437..b27dabf 100644
--- a/combo/nn/regularizers/regularizer.py
+++ b/combo/nn/regularizers/regularizer.py
@@ -7,9 +7,7 @@ from combo.config import FromParameters, Registry
 from combo.config.from_parameters import register_arguments, resolve
 from combo.nn.regularizers import Regularizer
 
-from overrides import overrides
-
-from utils import ConfigurationError
+from combo.utils import ConfigurationError
 
 
 @Registry.register('base_regularizer')
diff --git a/combo/predict.py b/combo/predict.py
index 9a3a42f..9a0818d 100644
--- a/combo/predict.py
+++ b/combo/predict.py
@@ -18,7 +18,9 @@ from combo.data.dataset_readers.dataset_reader import DatasetReader
 from combo.data.instance import JsonDict
 from combo.predictors import PredictorModule
 from combo.utils import download, graph
-from modules.model import Model
+from combo.modules.model import Model
+from combo.modules.archival import load_archive
+from combo.default_model import default_ud_dataset_reader
 
 logger = logging.getLogger(__name__)
 
@@ -262,9 +264,7 @@ class COMBO(PredictorModule):
                 logger.error(e)
                 raise e
 
-        archive = models.load_archive(model_path, cuda_device=cuda_device)
+        archive = load_archive(model_path, cuda_device=cuda_device)
         model = archive.model
-        dataset_reader = resolve(
-            archive.config["dataset_reader"]
-        )
+        dataset_reader = default_ud_dataset_reader()
         return cls(model, dataset_reader, tokenizer, batch_size)
diff --git a/combo/utils/download.py b/combo/utils/download.py
index 5c7ce6f..ff5ed9b 100644
--- a/combo/utils/download.py
+++ b/combo/utils/download.py
@@ -9,21 +9,15 @@ from requests import adapters, exceptions
 
 logger = logging.getLogger(__name__)
 
-DATA_TO_PATH = {
-    "enhanced" : "iwpt_2020",
-    "iwpt2021" : "iwpt_2021",
-    "ud25" : "ud_25",
-    "ud27" : "ud_27",
-    "ud29" : "ud_29"}
-_URL = "http://s3.clarin-pl.eu/dspace/combo/{data}/{model}.tar.gz"
+_URL = "http://s3.clarin-pl.eu/dspace/combo/prototype/{model}.tar.gz"
 _HOME_DIR = os.getenv("HOME", os.curdir)
 _CACHE_DIR = os.getenv("COMBO_DIR", os.path.join(_HOME_DIR, ".combo"))
 
 
 def download_file(model_name, force=False):
     _make_cache_dir()
-    data = model_name.split("-")[-1]
-    url = _URL.format(model=model_name, data=DATA_TO_PATH[data])
+    url = _URL.format(model=model_name)
+    print('URL', url)
     local_filename = url.split("/")[-1]
     location = os.path.join(_CACHE_DIR, local_filename)
     if os.path.exists(location) and not force:
-- 
GitLab