diff --git a/combo/predict.py b/combo/predict.py index 72e97f671ef2a999fee08a438722db954fc4bf0d..42c55a220eb7b5e7eac44b8a462c4b764dd04406 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -7,7 +7,7 @@ import numpy as np import torch from overrides import overrides -from combo import data, models +from combo import data from combo.common import util from combo.config import Registry from combo.config.from_parameters import register_arguments @@ -16,9 +16,10 @@ from combo.data.dataset_loaders.dataset_loader import TensorDict from combo.data.dataset_readers.dataset_reader import DatasetReader from combo.data.instance import JsonDict from combo.default_model import default_ud_dataset_reader +from combo.modules.archival import load_archive from combo.predictors import PredictorModule from combo.utils import download, graph -from modules.model import Model +from combo.modules.model import Model logger = logging.getLogger(__name__) @@ -244,7 +245,7 @@ class COMBO(PredictorModule): return tree, predictions["sentence_embedding"], embeddings @classmethod - def with_spacy_tokenizer(cls, model: models.Model, + def with_spacy_tokenizer(cls, model: Model, dataset_reader: DatasetReader): return cls(model, dataset_reader, tokenizers.SpacyTokenizer()) @@ -262,7 +263,7 @@ class COMBO(PredictorModule): logger.error(e) raise e - archive = models.load_archive(model_path, cuda_device=cuda_device) + archive = load_archive(model_path, cuda_device=cuda_device) model = archive.model dataset_reader = default_ud_dataset_reader() return cls(model, dataset_reader, tokenizer, batch_size) diff --git a/combo/utils/download.py b/combo/utils/download.py index 53fd7413dba8d91c86bb0fa5094c90015120c18b..7a928d2f81387d91945f6a82fa2d058ec195154a 100644 --- a/combo/utils/download.py +++ b/combo/utils/download.py @@ -10,12 +10,9 @@ from requests import adapters, exceptions logger = logging.getLogger(__name__) DATA_TO_PATH = { - "enhanced" : "iwpt_2020", - "iwpt2021" : "iwpt_2021", - "ud25" : "ud_25", - "ud27" : "ud_27", - "ud29" : "ud_29"} -_PROTOTYPE_URL = "http://s3.clarin-pl.eu/dspace/combo/prototype/{model}.tar.gz" + "prototype": "prototype", + "ud213": "ud_213" +} _URL = "http://s3.clarin-pl.eu/dspace/combo/{data}/{model}.tar.gz" _HOME_DIR = os.getenv("HOME", os.curdir) _CACHE_DIR = os.getenv("COMBO_DIR", os.path.join(_HOME_DIR, ".combo")) @@ -59,10 +56,10 @@ def _make_cache_dir(): def _requests_retry_session( - retries=3, - backoff_factor=0.3, - status_forcelist=(404, 500, 502, 504), - session=None, + retries=3, + backoff_factor=0.3, + status_forcelist=(404, 500, 502, 504), + session=None, ): """Source: https://www.peterbe.com/plog/best-practice-with-retries-with-requests""" session = session or requests.Session()