diff --git a/combo/utils/download.py b/combo/utils/download.py index b2c6b2e32579dc91e2adb50298e2b71bf3936536..003b64c1101aa6f82375c1b600851d1642e7a616 100644 --- a/combo/utils/download.py +++ b/combo/utils/download.py @@ -1,6 +1,5 @@ import errno import logging -import math import os import requests @@ -23,18 +22,18 @@ def download_file(model_name, force=False): if os.path.exists(location) and not force: logger.debug("Using cached model.") return location - chunk_size = 1024 * 10 + chunk_size = 1024 logger.info(url) try: with _requests_retry_session(retries=2).get(url, stream=True) as r: - total_length = math.ceil(int(r.headers.get("content-length")) / chunk_size) + pbar = tqdm.tqdm(unit="B", total=int(r.headers.get("content-length")), + unit_divisor=chunk_size, unit_scale=True) with open(location, "wb") as f: - with tqdm.tqdm(total=total_length) as pbar: - for chunk in r.raw.stream(chunk_size, decode_content=False): + with pbar: + for chunk in r.iter_content(chunk_size): if chunk: f.write(chunk) - f.flush() - pbar.update(1) + pbar.update(len(chunk)) except exceptions.RetryError: raise ConnectionError(f"Couldn't find or download model {model_name}.tar.gz. " "Check if model name is correct or try again later!")