Skip to content
Snippets Groups Projects
Select Git revision
  • 0521c672de29d2c29fb4e2b3bfa03ce6a0b02931
  • main default protected
  • ud_training_script
  • fix_seed
  • merged-with-ner
  • multiword_fix_transformer
  • transformer_encoder
  • combo3
  • save_deprel_matrix_to_npz
  • master protected
  • combo-lambo
  • lambo-sent-attributes
  • adding_lambo
  • develop
  • update_allenlp2
  • develop_tmp
  • tokens_truncation
  • LR_test
  • eud_iwpt
  • iob
  • eud_iwpt_shared_task_bert_finetuning
  • 3.3.1
  • list
  • 3.2.1
  • 3.0.3
  • 3.0.1
  • 3.0.0
  • v1.0.6
  • v1.0.5
  • v1.0.4
  • v1.0.3
  • v1.0.2
  • v1.0.1
  • v1.0.0
34 results

download.py

Blame
  • download.py 2.38 KiB
    import errno
    import logging
    import os
    
    import requests
    import tqdm
    import urllib3
    from requests import adapters, exceptions
    
    logger = logging.getLogger(__name__)
    
    DATA_TO_PATH = {
        "enhanced" : "iwpt_2020",
        "iwpt2021" : "iwpt_2021",
        "ud25" : "ud_25",
        "ud27" : "ud_27",
        "ud29" : "ud_29"}
    _URL = "http://s3.clarin-pl.eu/dspace/combo/{data}/{model}.tar.gz"
    _HOME_DIR = os.getenv("HOME", os.curdir)
    _CACHE_DIR = os.getenv("COMBO_DIR", os.path.join(_HOME_DIR, ".combo"))
    
    
    def download_file(model_name, force=False):
        _make_cache_dir()
        data = model_name.split("-")[-1]
        url = _URL.format(model=model_name, data=DATA_TO_PATH[data])
        local_filename = url.split("/")[-1]
        location = os.path.join(_CACHE_DIR, local_filename)
        if os.path.exists(location) and not force:
            logger.debug("Using cached model.")
            return location
        chunk_size = 1024
        logger.info(url)
        try:
            with _requests_retry_session(retries=2).get(url, stream=True) as r:
                pbar = tqdm.tqdm(unit="B", total=int(r.headers.get("content-length")),
                                 unit_divisor=chunk_size, unit_scale=True)
                with open(location, "wb") as f:
                    with pbar:
                        for chunk in r.iter_content(chunk_size):
                            if chunk:
                                f.write(chunk)
                                pbar.update(len(chunk))
        except exceptions.RetryError:
            raise ConnectionError(f"Couldn't find or download model {model_name}.tar.gz. "
                                  "Check if model name is correct or try again later!")
    
        return location
    
    
    def _make_cache_dir():
        try:
            os.makedirs(_CACHE_DIR)
            logger.info(f"Making cache dir {_CACHE_DIR}")
        except OSError as e:
            if e.errno != errno.EEXIST:
                raise
    
    
    def _requests_retry_session(
        retries=3,
        backoff_factor=0.3,
        status_forcelist=(404, 500, 502, 504),
        session=None,
    ):
        """Source: https://www.peterbe.com/plog/best-practice-with-retries-with-requests"""
        session = session or requests.Session()
        retry = urllib3.Retry(
            total=retries,
            read=retries,
            connect=retries,
            backoff_factor=backoff_factor,
            status_forcelist=status_forcelist,
        )
        adapter = adapters.HTTPAdapter(max_retries=retry)
        session.mount("http://", adapter)
        session.mount("https://", adapter)
        return session