Skip to content
Snippets Groups Projects
Commit a7532106 authored by Maja Jablonska's avatar Maja Jablonska
Browse files

Fix from_pretrained path

parent d7e1399f
Branches
Tags
1 merge request!46Merge COMBO 3.0 into master
......@@ -10,9 +10,9 @@ import tarfile
from io import BytesIO
from tempfile import TemporaryDirectory
from config import resolve
from data.dataset_loaders import DataLoader
from modules.model import Model
from combo.config import resolve
from combo.data.dataset_loaders import DataLoader
from combo.modules.model import Model
CACHE_ROOT = Path(os.getenv("COMBO_CACHE_ROOT", Path.home() / ".combo"))
......
......@@ -28,8 +28,8 @@ from combo.utils import ConfigurationError
import pytorch_lightning as pl
from training import Scheduler
from training.optimizer import Adam
from combo.training import Scheduler
from combo.training.optimizer import Adam
logger = logging.getLogger(__name__)
......
......@@ -15,7 +15,7 @@ from combo.modules.text_field_embedders.text_field_embedder import TextFieldEmbe
from combo.modules.token_embedders import EmptyEmbedder
from combo.modules.token_embedders.token_embedder import TokenEmbedder
from combo.utils import ConfigurationError
from models.base import TimeDistributed
from combo.models.base import TimeDistributed
@Registry.register("base_text_field_embedder")
......
......@@ -12,7 +12,7 @@ from combo.data import Vocabulary
from combo.nn.utils import tiny_value_of_dtype, uncombine_initial_dims, combine_initial_dims
from combo.modules.module import Module
from combo.utils import ConfigurationError
from models.base import TimeDistributed
from combo.models.base import TimeDistributed
class TokenEmbedder(Module, FromParameters):
......
......@@ -6,10 +6,7 @@ import torch
from combo.config import FromParameters, Registry
from combo.config.from_parameters import register_arguments, resolve
from combo.nn.regularizers import Regularizer
from overrides import overrides
from utils import ConfigurationError
from combo.utils.checks import ConfigurationError
@Registry.register('base_regularizer')
......
......@@ -18,8 +18,8 @@ from combo.data.dataset_readers.dataset_reader import DatasetReader
from combo.data.instance import JsonDict
from combo.predictors import PredictorModule
from combo.utils import download, graph
from modules.archival import load_archive
from modules.model import Model
from combo.modules.archival import load_archive
from combo.modules.model import Model
logger = logging.getLogger(__name__)
......
......@@ -9,21 +9,15 @@ from requests import adapters, exceptions
logger = logging.getLogger(__name__)
DATA_TO_PATH = {
"enhanced" : "iwpt_2020",
"iwpt2021" : "iwpt_2021",
"ud25" : "ud_25",
"ud27" : "ud_27",
"ud29" : "ud_29"}
_URL = "http://s3.clarin-pl.eu/dspace/combo/{data}/{model}.tar.gz"
_URL = "http://s3.clarin-pl.eu/dspace/combo/prototype/{model}.tar.gz"
_HOME_DIR = os.getenv("HOME", os.curdir)
_CACHE_DIR = os.getenv("COMBO_DIR", os.path.join(_HOME_DIR, ".combo"))
def download_file(model_name, force=False):
_make_cache_dir()
data = model_name.split("-")[-1]
url = _URL.format(model=model_name, data=DATA_TO_PATH[data])
url = _URL.format(model=model_name)
print('URL', url)
local_filename = url.split("/")[-1]
location = os.path.join(_CACHE_DIR, local_filename)
if os.path.exists(location) and not force:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment