Skip to content
Snippets Groups Projects
Commit ad2a1908 authored by Maja Jablonska's avatar Maja Jablonska
Browse files

Fixes to from_pretrained

parent 4a2f945f
No related branches found
No related tags found
1 merge request!46Merge COMBO 3.0 into master
...@@ -10,9 +10,9 @@ import tarfile ...@@ -10,9 +10,9 @@ import tarfile
from io import BytesIO from io import BytesIO
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from config import resolve from combo.config import resolve
from data.dataset_loaders import DataLoader from combo.data.dataset_loaders import DataLoader
from modules.model import Model from combo.modules.model import Model
CACHE_ROOT = Path(os.getenv("COMBO_CACHE_ROOT", Path.home() / ".combo")) CACHE_ROOT = Path(os.getenv("COMBO_CACHE_ROOT", Path.home() / ".combo"))
...@@ -75,7 +75,7 @@ def load_archive(url_or_filename: Union[PathLike, str], ...@@ -75,7 +75,7 @@ def load_archive(url_or_filename: Union[PathLike, str],
) )
model = Model.load(archive_file, cuda_device=cuda_device) model = Model.load(archive_file, cuda_device=cuda_device)
with open(os.path.join(archive_file, 'model/config.json'), 'r') as f: with open(os.path.join(archive_file, 'config.json'), 'r') as f:
config = json.load(f) config = json.load(f)
data_loader, validation_data_loader = None, None data_loader, validation_data_loader = None, None
......
...@@ -7,9 +7,7 @@ from combo.config import FromParameters, Registry ...@@ -7,9 +7,7 @@ from combo.config import FromParameters, Registry
from combo.config.from_parameters import register_arguments, resolve from combo.config.from_parameters import register_arguments, resolve
from combo.nn.regularizers import Regularizer from combo.nn.regularizers import Regularizer
from overrides import overrides from combo.utils import ConfigurationError
from utils import ConfigurationError
@Registry.register('base_regularizer') @Registry.register('base_regularizer')
......
...@@ -18,7 +18,9 @@ from combo.data.dataset_readers.dataset_reader import DatasetReader ...@@ -18,7 +18,9 @@ from combo.data.dataset_readers.dataset_reader import DatasetReader
from combo.data.instance import JsonDict from combo.data.instance import JsonDict
from combo.predictors import PredictorModule from combo.predictors import PredictorModule
from combo.utils import download, graph from combo.utils import download, graph
from modules.model import Model from combo.modules.model import Model
from combo.modules.archival import load_archive
from combo.default_model import default_ud_dataset_reader
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -262,9 +264,7 @@ class COMBO(PredictorModule): ...@@ -262,9 +264,7 @@ class COMBO(PredictorModule):
logger.error(e) logger.error(e)
raise e raise e
archive = models.load_archive(model_path, cuda_device=cuda_device) archive = load_archive(model_path, cuda_device=cuda_device)
model = archive.model model = archive.model
dataset_reader = resolve( dataset_reader = default_ud_dataset_reader()
archive.config["dataset_reader"]
)
return cls(model, dataset_reader, tokenizer, batch_size) return cls(model, dataset_reader, tokenizer, batch_size)
...@@ -9,21 +9,15 @@ from requests import adapters, exceptions ...@@ -9,21 +9,15 @@ from requests import adapters, exceptions
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DATA_TO_PATH = { _URL = "http://s3.clarin-pl.eu/dspace/combo/prototype/{model}.tar.gz"
"enhanced" : "iwpt_2020",
"iwpt2021" : "iwpt_2021",
"ud25" : "ud_25",
"ud27" : "ud_27",
"ud29" : "ud_29"}
_URL = "http://s3.clarin-pl.eu/dspace/combo/{data}/{model}.tar.gz"
_HOME_DIR = os.getenv("HOME", os.curdir) _HOME_DIR = os.getenv("HOME", os.curdir)
_CACHE_DIR = os.getenv("COMBO_DIR", os.path.join(_HOME_DIR, ".combo")) _CACHE_DIR = os.getenv("COMBO_DIR", os.path.join(_HOME_DIR, ".combo"))
def download_file(model_name, force=False): def download_file(model_name, force=False):
_make_cache_dir() _make_cache_dir()
data = model_name.split("-")[-1] url = _URL.format(model=model_name)
url = _URL.format(model=model_name, data=DATA_TO_PATH[data]) print('URL', url)
local_filename = url.split("/")[-1] local_filename = url.split("/")[-1]
location = os.path.join(_CACHE_DIR, local_filename) location = os.path.join(_CACHE_DIR, local_filename)
if os.path.exists(location) and not force: if os.path.exists(location) and not force:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment