From ad2a1908992bf8663d2641c418d09ee861e9a948 Mon Sep 17 00:00:00 2001 From: Maja Jablonska <majajjablonska@gmail.com> Date: Wed, 8 Nov 2023 22:35:59 +1100 Subject: [PATCH] Fixes to from_pretrained --- combo/modules/archival.py | 8 ++++---- combo/nn/regularizers/regularizer.py | 4 +--- combo/predict.py | 10 +++++----- combo/utils/download.py | 12 +++--------- 4 files changed, 13 insertions(+), 21 deletions(-) diff --git a/combo/modules/archival.py b/combo/modules/archival.py index 29a6d70..1476939 100644 --- a/combo/modules/archival.py +++ b/combo/modules/archival.py @@ -10,9 +10,9 @@ import tarfile from io import BytesIO from tempfile import TemporaryDirectory -from config import resolve -from data.dataset_loaders import DataLoader -from modules.model import Model +from combo.config import resolve +from combo.data.dataset_loaders import DataLoader +from combo.modules.model import Model CACHE_ROOT = Path(os.getenv("COMBO_CACHE_ROOT", Path.home() / ".combo")) @@ -75,7 +75,7 @@ def load_archive(url_or_filename: Union[PathLike, str], ) model = Model.load(archive_file, cuda_device=cuda_device) - with open(os.path.join(archive_file, 'model/config.json'), 'r') as f: + with open(os.path.join(archive_file, 'config.json'), 'r') as f: config = json.load(f) data_loader, validation_data_loader = None, None diff --git a/combo/nn/regularizers/regularizer.py b/combo/nn/regularizers/regularizer.py index 5e62437..b27dabf 100644 --- a/combo/nn/regularizers/regularizer.py +++ b/combo/nn/regularizers/regularizer.py @@ -7,9 +7,7 @@ from combo.config import FromParameters, Registry from combo.config.from_parameters import register_arguments, resolve from combo.nn.regularizers import Regularizer -from overrides import overrides - -from utils import ConfigurationError +from combo.utils import ConfigurationError @Registry.register('base_regularizer') diff --git a/combo/predict.py b/combo/predict.py index 9a3a42f..9a0818d 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -18,7 +18,9 @@ from combo.data.dataset_readers.dataset_reader import DatasetReader from combo.data.instance import JsonDict from combo.predictors import PredictorModule from combo.utils import download, graph -from modules.model import Model +from combo.modules.model import Model +from combo.modules.archival import load_archive +from combo.default_model import default_ud_dataset_reader logger = logging.getLogger(__name__) @@ -262,9 +264,7 @@ class COMBO(PredictorModule): logger.error(e) raise e - archive = models.load_archive(model_path, cuda_device=cuda_device) + archive = load_archive(model_path, cuda_device=cuda_device) model = archive.model - dataset_reader = resolve( - archive.config["dataset_reader"] - ) + dataset_reader = default_ud_dataset_reader() return cls(model, dataset_reader, tokenizer, batch_size) diff --git a/combo/utils/download.py b/combo/utils/download.py index 5c7ce6f..ff5ed9b 100644 --- a/combo/utils/download.py +++ b/combo/utils/download.py @@ -9,21 +9,15 @@ from requests import adapters, exceptions logger = logging.getLogger(__name__) -DATA_TO_PATH = { - "enhanced" : "iwpt_2020", - "iwpt2021" : "iwpt_2021", - "ud25" : "ud_25", - "ud27" : "ud_27", - "ud29" : "ud_29"} -_URL = "http://s3.clarin-pl.eu/dspace/combo/{data}/{model}.tar.gz" +_URL = "http://s3.clarin-pl.eu/dspace/combo/prototype/{model}.tar.gz" _HOME_DIR = os.getenv("HOME", os.curdir) _CACHE_DIR = os.getenv("COMBO_DIR", os.path.join(_HOME_DIR, ".combo")) def download_file(model_name, force=False): _make_cache_dir() - data = model_name.split("-")[-1] - url = _URL.format(model=model_name, data=DATA_TO_PATH[data]) + url = _URL.format(model=model_name) + print('URL', url) local_filename = url.split("/")[-1] location = os.path.join(_CACHE_DIR, local_filename) if os.path.exists(location) and not force: -- GitLab