Skip to content
Snippets Groups Projects
Commit df3ec793 authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski
Browse files

Add module for downloading pretrained models from mozart.

parent 3e11893a
No related merge requests found
......@@ -5,6 +5,11 @@ Clone this repository and run:
python setup.py develop
```
### Problems & solutions
* **jsonnet** installation error
use `conda install -c conda-forge jsonnet=0.15.0`
## Training
Command:
......
import collections
import errno
import logging
import os
import time
from typing import List
import conllu
import requests
import tqdm
from allennlp import data as allen_data, common, models
from allennlp.common import util
from allennlp.data import tokenizers
......@@ -11,6 +15,7 @@ from allennlp.predictors import predictor
from overrides import overrides
from combo import data
from combo.utils import download
logger = logging.getLogger(__name__)
......@@ -154,7 +159,17 @@ class SemanticMultitaskPredictor(predictor.Predictor):
util.import_module_and_submodules("combo.commands")
util.import_module_and_submodules("combo.models")
util.import_module_and_submodules("combo.training")
model = models.Model.from_archive(path)
if os.path.exists(path):
model_path = path
else:
try:
model_path = download.download_file(path)
except Exception as e:
logger.error(e)
raise e
model = models.Model.from_archive(model_path)
dataset_reader = allen_data.DatasetReader.from_params(
models.load_archive(path).config["dataset_reader"])
models.load_archive(model_path).config["dataset_reader"])
return cls(model, dataset_reader, tokenizer)
import errno
import logging
import math
import os
import requests
import tqdm
import urllib3
from requests import adapters, exceptions
logger = logging.getLogger(__name__)
_URL = "http://mozart.ipipan.waw.pl/~mklimaszewski/models/{name}.tar.gz"
_HOME_DIR = os.getenv("HOME", os.curdir)
_CACHE_DIR = os.getenv("COMBO_DIR", os.path.join(_HOME_DIR, ".combo"))
def download_file(model_name, force=False):
_make_cache_dir()
url = _URL.format(name=model_name)
local_filename = url.split("/")[-1]
location = os.path.join(_CACHE_DIR, local_filename)
if os.path.exists(location) and not force:
logger.debug("Using cached model.")
return location
chunk_size = 1024 * 10
logger.info(url)
try:
with _requests_retry_session(retries=2).get(url, stream=True) as r:
total_length = math.ceil(int(r.headers.get("content-length")) / chunk_size)
with open(location, "wb") as f:
with tqdm.tqdm(total=total_length) as pbar:
for chunk in r.raw.stream(chunk_size, decode_content=False):
if chunk:
f.write(chunk)
f.flush()
pbar.update(1)
except exceptions.RetryError:
raise ConnectionError(f"Couldn't find or download model {model_name}.tar.gz. "
"Check if model name is correct or try again later!")
return location
def _make_cache_dir():
try:
os.makedirs(_CACHE_DIR)
logger.info(f"Making cache dir {_CACHE_DIR}")
except OSError as e:
if e.errno != errno.EEXIST:
raise
def _requests_retry_session(
retries=3,
backoff_factor=0.3,
status_forcelist=(404, 500, 502, 504),
session=None,
):
session = session or requests.Session()
retry = urllib3.Retry(
total=retries,
read=retries,
connect=retries,
backoff_factor=backoff_factor,
status_forcelist=status_forcelist,
)
adapter = adapters.HTTPAdapter(max_retries=retry)
session.mount('http://', adapter)
session.mount('https://', adapter)
return session
......@@ -7,10 +7,12 @@ REQUIREMENTS = [
'conllu==2.3.2',
'joblib==0.14.1',
'jsonnet==0.15.0',
'requests==2.23.0',
'overrides==3.0.0',
'tensorboard==2.1.0',
'torch==1.5.0',
'torchvision==0.6.0',
'tqdm==4.43.0'
'transformers==2.9.1',
]
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment