diff --git a/combo/modules/archival.py b/combo/modules/archival.py index 27901f83f1c6f37bd79d77bf7f8ecc1da0056d40..ecf22f78517f465b056e44904d7122de331bbce1 100644 --- a/combo/modules/archival.py +++ b/combo/modules/archival.py @@ -10,9 +10,9 @@ import tarfile from io import BytesIO from tempfile import TemporaryDirectory -from config import resolve -from data.dataset_loaders import DataLoader -from modules.model import Model +from combo.config import resolve +from combo.data.dataset_loaders import DataLoader +from combo.modules.model import Model CACHE_ROOT = Path(os.getenv("COMBO_CACHE_ROOT", Path.home() / ".combo")) diff --git a/combo/modules/model.py b/combo/modules/model.py index 97ccc88aa44586dc63a41244ca0ed5e99997eeb7..c87cc0ad71423d45e1b631d58c54d3df56ea235e 100644 --- a/combo/modules/model.py +++ b/combo/modules/model.py @@ -28,8 +28,8 @@ from combo.utils import ConfigurationError import pytorch_lightning as pl -from training import Scheduler -from training.optimizer import Adam +from combo.training import Scheduler +from combo.training.optimizer import Adam logger = logging.getLogger(__name__) diff --git a/combo/modules/text_field_embedders/basic_text_field_embedder.py b/combo/modules/text_field_embedders/basic_text_field_embedder.py index b6d174a6167e3c44809cbb6f7e889c38722697da..dea5f8accf965f9ce53137eb1e28a0b24338fb57 100644 --- a/combo/modules/text_field_embedders/basic_text_field_embedder.py +++ b/combo/modules/text_field_embedders/basic_text_field_embedder.py @@ -15,7 +15,7 @@ from combo.modules.text_field_embedders.text_field_embedder import TextFieldEmbe from combo.modules.token_embedders import EmptyEmbedder from combo.modules.token_embedders.token_embedder import TokenEmbedder from combo.utils import ConfigurationError -from models.base import TimeDistributed +from combo.models.base import TimeDistributed @Registry.register("base_text_field_embedder") diff --git a/combo/modules/token_embedders/token_embedder.py b/combo/modules/token_embedders/token_embedder.py index ec253ded5bef9d9ee643f0390f5b0a34499bf1ec..0417713e3348ab157443bc4429dff62c35f1a953 100644 --- a/combo/modules/token_embedders/token_embedder.py +++ b/combo/modules/token_embedders/token_embedder.py @@ -12,7 +12,7 @@ from combo.data import Vocabulary from combo.nn.utils import tiny_value_of_dtype, uncombine_initial_dims, combine_initial_dims from combo.modules.module import Module from combo.utils import ConfigurationError -from models.base import TimeDistributed +from combo.models.base import TimeDistributed class TokenEmbedder(Module, FromParameters): diff --git a/combo/nn/regularizers/regularizer.py b/combo/nn/regularizers/regularizer.py index a3955bdbb1abbea4b0ae780002c9c8ba0295437f..d28853a870dc7a0ec7045bc83a97aee771fa349a 100644 --- a/combo/nn/regularizers/regularizer.py +++ b/combo/nn/regularizers/regularizer.py @@ -6,10 +6,7 @@ import torch from combo.config import FromParameters, Registry from combo.config.from_parameters import register_arguments, resolve from combo.nn.regularizers import Regularizer - -from overrides import overrides - -from utils import ConfigurationError +from combo.utils.checks import ConfigurationError @Registry.register('base_regularizer') diff --git a/combo/predict.py b/combo/predict.py index 1140750ed07b62a8a690292077d38b0daeadbab1..c97b26a9e78bb31e4fa6d16b1de358815e7344da 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -18,8 +18,8 @@ from combo.data.dataset_readers.dataset_reader import DatasetReader from combo.data.instance import JsonDict from combo.predictors import PredictorModule from combo.utils import download, graph -from modules.archival import load_archive -from modules.model import Model +from combo.modules.archival import load_archive +from combo.modules.model import Model logger = logging.getLogger(__name__) diff --git a/combo/utils/download.py b/combo/utils/download.py index 5c7ce6f951147e48f7ac793ce7b4e59817da9c89..ff5ed9b5e8ef9424a2dbfb64c0ac90985e4a3d98 100644 --- a/combo/utils/download.py +++ b/combo/utils/download.py @@ -9,21 +9,15 @@ from requests import adapters, exceptions logger = logging.getLogger(__name__) -DATA_TO_PATH = { - "enhanced" : "iwpt_2020", - "iwpt2021" : "iwpt_2021", - "ud25" : "ud_25", - "ud27" : "ud_27", - "ud29" : "ud_29"} -_URL = "http://s3.clarin-pl.eu/dspace/combo/{data}/{model}.tar.gz" +_URL = "http://s3.clarin-pl.eu/dspace/combo/prototype/{model}.tar.gz" _HOME_DIR = os.getenv("HOME", os.curdir) _CACHE_DIR = os.getenv("COMBO_DIR", os.path.join(_HOME_DIR, ".combo")) def download_file(model_name, force=False): _make_cache_dir() - data = model_name.split("-")[-1] - url = _URL.format(model=model_name, data=DATA_TO_PATH[data]) + url = _URL.format(model=model_name) + print('URL', url) local_filename = url.split("/")[-1] location = os.path.join(_CACHE_DIR, local_filename) if os.path.exists(location) and not force: