diff --git a/README.md b/README.md index d67918d9c7f8e156835f9b61873599f03f733a0d..0ef57069216a10ae8e4392c4f201ac2cc7883910 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,11 @@ Clone this repository and run: python setup.py develop ``` +### Problems & solutions +* **jsonnet** installation error + +use `conda install -c conda-forge jsonnet=0.15.0` + ## Training Command: diff --git a/combo/predict.py b/combo/predict.py index c25bb755409165b62dcd1866c558efd33e9dad23..edcb81787084733f2a769b6dc33a2224ff85af5b 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -1,9 +1,13 @@ import collections +import errno import logging +import os import time from typing import List import conllu +import requests +import tqdm from allennlp import data as allen_data, common, models from allennlp.common import util from allennlp.data import tokenizers @@ -11,6 +15,7 @@ from allennlp.predictors import predictor from overrides import overrides from combo import data +from combo.utils import download logger = logging.getLogger(__name__) @@ -154,7 +159,17 @@ class SemanticMultitaskPredictor(predictor.Predictor): util.import_module_and_submodules("combo.commands") util.import_module_and_submodules("combo.models") util.import_module_and_submodules("combo.training") - model = models.Model.from_archive(path) + + if os.path.exists(path): + model_path = path + else: + try: + model_path = download.download_file(path) + except Exception as e: + logger.error(e) + raise e + + model = models.Model.from_archive(model_path) dataset_reader = allen_data.DatasetReader.from_params( - models.load_archive(path).config["dataset_reader"]) + models.load_archive(model_path).config["dataset_reader"]) return cls(model, dataset_reader, tokenizer) diff --git a/combo/utils/download.py b/combo/utils/download.py new file mode 100644 index 0000000000000000000000000000000000000000..d464dbd62722fd505bee9d9b643b65368981b5a4 --- /dev/null +++ b/combo/utils/download.py @@ -0,0 +1,71 @@ +import errno +import logging +import math +import os + +import requests +import tqdm +import urllib3 +from requests import adapters, exceptions + +logger = logging.getLogger(__name__) + +_URL = "http://mozart.ipipan.waw.pl/~mklimaszewski/models/{name}.tar.gz" +_HOME_DIR = os.getenv("HOME", os.curdir) +_CACHE_DIR = os.getenv("COMBO_DIR", os.path.join(_HOME_DIR, ".combo")) + + +def download_file(model_name, force=False): + _make_cache_dir() + url = _URL.format(name=model_name) + local_filename = url.split("/")[-1] + location = os.path.join(_CACHE_DIR, local_filename) + if os.path.exists(location) and not force: + logger.debug("Using cached model.") + return location + chunk_size = 1024 * 10 + logger.info(url) + try: + with _requests_retry_session(retries=2).get(url, stream=True) as r: + total_length = math.ceil(int(r.headers.get("content-length")) / chunk_size) + with open(location, "wb") as f: + with tqdm.tqdm(total=total_length) as pbar: + for chunk in r.raw.stream(chunk_size, decode_content=False): + if chunk: + f.write(chunk) + f.flush() + pbar.update(1) + except exceptions.RetryError: + raise ConnectionError(f"Couldn't find or download model {model_name}.tar.gz. " + "Check if model name is correct or try again later!") + + return location + + +def _make_cache_dir(): + try: + os.makedirs(_CACHE_DIR) + logger.info(f"Making cache dir {_CACHE_DIR}") + except OSError as e: + if e.errno != errno.EEXIST: + raise + + +def _requests_retry_session( + retries=3, + backoff_factor=0.3, + status_forcelist=(404, 500, 502, 504), + session=None, +): + session = session or requests.Session() + retry = urllib3.Retry( + total=retries, + read=retries, + connect=retries, + backoff_factor=backoff_factor, + status_forcelist=status_forcelist, + ) + adapter = adapters.HTTPAdapter(max_retries=retry) + session.mount('http://', adapter) + session.mount('https://', adapter) + return session diff --git a/setup.py b/setup.py index 44a26199d0f4a53c7008f14cc00e8e07aeb140a9..695e9cf0f70db498480ff8956d97c9a567eba9e7 100644 --- a/setup.py +++ b/setup.py @@ -7,10 +7,12 @@ REQUIREMENTS = [ 'conllu==2.3.2', 'joblib==0.14.1', 'jsonnet==0.15.0', + 'requests==2.23.0', 'overrides==3.0.0', 'tensorboard==2.1.0', 'torch==1.5.0', 'torchvision==0.6.0', + 'tqdm==4.43.0' 'transformers==2.9.1', ]