Source code for deepSTRF.utils.data_download

"""Auto-download utilities for deepSTRF datasets.

Public surface:
    - ``default_cache_dir(dataset_name) -> Path``       — platformdirs-based default
    - ``stream_download(url, dest_path)``               — resumable streaming download
    - ``unzip(zip_path, dest_dir)``                     — flat unzip with overwrite
    - ``untar(tar_path, dest_dir, strip_components=)``  — tar.gz / .tar / .tar.bz2 unpack
    - ``osf_download(guid, dest)``                      — public OSF storage files
    - ``github_raw_download(repo, path, dest, ref=)``   — public GitHub raw files
    - ``zenodo_download(record_id, filename, dest)``    — public Zenodo records
    - ``figshare_download(article_id, dest_dir, filename=)`` — public figshare articles
    - ``crcns_download(file_path, dest, username=, password=)`` — CRCNS (free account)
"""

from __future__ import annotations

import os
import shutil
import tarfile
import zipfile
from pathlib import Path
from typing import Optional, Union

# requests + platformdirs are runtime deps (cf. pyproject.toml). Imported here
# so this module can be used standalone outside dataset constructors too.
import requests
from platformdirs import user_cache_dir


[docs] def default_cache_dir(dataset_name: str) -> Path: """Return ``$DEEPSTRF_DATA_DIR/<dataset>`` if the env var is set, otherwise ``platformdirs.user_cache_dir('deepSTRF') / <dataset>``. The env-var override is the standard escape hatch for users on shared storage / scratch filesystems, and matches the convention in torchvision / huggingface_hub. """ base = os.environ.get("DEEPSTRF_DATA_DIR") if base: return Path(base).expanduser() / dataset_name return Path(user_cache_dir("deepSTRF")) / dataset_name
[docs] def stream_download( url: str, dest_path: Union[str, Path], *, chunk_size: int = 1 << 20, progress: bool = True, ) -> Path: """Stream-download ``url`` to ``dest_path``. Atomic via a ``.part`` swap. Already-existing destination paths are returned unchanged (no-op) — the caller is responsible for cache invalidation. Parameters ---------- url : str dest_path : str | Path chunk_size : int, default 1 MiB progress : bool, default True Show a tqdm progress bar if available; falls back silently otherwise. Returns ------- Path The destination path (resolved). """ dest = Path(dest_path).expanduser().resolve() if dest.exists(): return dest dest.parent.mkdir(parents=True, exist_ok=True) tmp = dest.with_suffix(dest.suffix + ".part") # (connect, read): 30s to establish, then 10 min between recvs. The read # timeout has to be generous because slow upstream mirrors (CRCNS / NERSC, # OSF on a busy day) can pause for tens of seconds between chunks while # the file is fetched from cold storage. with requests.get(url, stream=True, timeout=(30, 600), allow_redirects=True) as resp: resp.raise_for_status() total = int(resp.headers.get("Content-Length") or 0) bar = None if progress: try: from tqdm.auto import tqdm bar = tqdm(total=total or None, unit="B", unit_scale=True, desc=f"download {dest.name}") except ImportError: bar = None try: with open(tmp, "wb") as f: for chunk in resp.iter_content(chunk_size=chunk_size): if not chunk: continue f.write(chunk) if bar is not None: bar.update(len(chunk)) finally: if bar is not None: bar.close() tmp.replace(dest) return dest
[docs] def osf_download(file_guid: str, dest_path: Union[str, Path], **kwargs) -> Path: """Download a single file from OSF by its short GUID. Resolves to ``https://osf.io/download/<guid>/`` — works for any public OSF storage file (the OSF API exposes this URL as the ``download`` link in each file's metadata). Example ------- >>> osf_download("gdwyd", "MetadataSHEnCneurons.mat") """ return stream_download(f"https://osf.io/download/{file_guid}/", dest_path, **kwargs)
[docs] def crcns_download( file_path: str, dest_path: Union[str, Path], *, username: Optional[str] = None, password: Optional[str] = None, chunk_size: int = 1 << 20, progress: bool = True, ) -> Path: """Download a single file from the CRCNS NERSC mirror with form auth. The CRCNS download portal at ``https://portal.nersc.gov/project/crcns/ download/<file_path>`` serves an HTML login form to anonymous GETs. To actually fetch the file, the form must be POSTed to the same URL with ``username`` / ``password`` / ``fn`` / ``submit`` fields. There is no persistent session cookie — auth is per-request, so the same pattern works equally well whether you fetch one file or many. Parameters ---------- file_path : str Path under ``/download/``, e.g. ``"aa-1/crcns-aa1.zip"`` or ``"aa-4/BlaBro09xxF.tar.gz"``. dest_path : path-like username, password : str, optional Default to ``$CRCNS_USERNAME`` / ``$CRCNS_PASSWORD``. Account is free at https://crcns.org/register. chunk_size, progress As ``stream_download``. Returns ------- Path Resolved destination. Raises ------ RuntimeError If credentials are missing, or if the response body still looks like the login form (auth failed silently — the portal returns 200 OK with the login HTML rather than 401 on bad credentials). Notes ----- Status: experimental. The auth + URL conventions were reverse-engineered from probing the public NERSC mirror; we do not have a contract from CRCNS that they'll stay stable. If the portal layout changes, this helper breaks. Verified 2026-04-25 against the AA1 archive (``aa-1/crcns-aa1.zip``). Example ------- >>> import os >>> os.environ["CRCNS_USERNAME"] = "..." >>> os.environ["CRCNS_PASSWORD"] = "..." >>> crcns_download("aa-1/crcns-aa1.zip", "/tmp/aa1.zip") """ import os as _os username = username or _os.environ.get("CRCNS_USERNAME") password = password or _os.environ.get("CRCNS_PASSWORD") if not username or not password: raise RuntimeError( "CRCNS credentials missing. Pass username/password explicitly, or set " "the CRCNS_USERNAME / CRCNS_PASSWORD env vars. Free account at " "https://crcns.org/register." ) dest = Path(dest_path).expanduser().resolve() if dest.exists(): return dest dest.parent.mkdir(parents=True, exist_ok=True) tmp = dest.with_suffix(dest.suffix + ".part") url = f"https://portal.nersc.gov/project/crcns/download/{file_path.lstrip('/')}" form = { "fn": file_path.lstrip("/"), "username": username, "password": password, "submit": "Login", } # (connect, read): see comment in stream_download — NERSC is regularly # slow to start streaming a CRCNS archive (cold-storage fetch). with requests.post(url, data=form, stream=True, timeout=(30, 600), allow_redirects=True) as resp: resp.raise_for_status() # NERSC returns 200 + the login form HTML on auth failure (no 401). # Sniff the first chunk: the real file is binary; the form is small HTML. first = next(resp.iter_content(chunk_size=chunk_size), b"") if b"<form" in first[:4096] and b"password" in first[:4096]: raise RuntimeError( f"CRCNS auth failed for {file_path!r} (server returned the login form). " f"Check $CRCNS_USERNAME / $CRCNS_PASSWORD." ) total = int(resp.headers.get("Content-Length") or 0) bar = None if progress: try: from tqdm.auto import tqdm bar = tqdm(total=total or None, unit="B", unit_scale=True, desc=f"download {dest.name}") except ImportError: bar = None try: with open(tmp, "wb") as f: if first: f.write(first) if bar is not None: bar.update(len(first)) for chunk in resp.iter_content(chunk_size=chunk_size): if not chunk: continue f.write(chunk) if bar is not None: bar.update(len(chunk)) finally: if bar is not None: bar.close() tmp.replace(dest) return dest
[docs] def zenodo_download( record_id: Union[int, str], filename: str, dest_path: Union[str, Path], **kwargs, ) -> Path: """Download a single file from a public Zenodo record. Resolves to ``https://zenodo.org/api/records/<record_id>/files/<filename>/content`` — the canonical URL for fetching a file from a Zenodo record. Public records are accessible without auth. Example ------- >>> zenodo_download(8044773, "A1_NAT4_ozgf.fs100.ch18.tgz", "/tmp/X.tgz") """ url = f"https://zenodo.org/api/records/{record_id}/files/{filename}/content" return stream_download(url, dest_path, **kwargs)
[docs] def figshare_download( article_id: Union[int, str], dest_dir: Union[str, Path], *, filename: Optional[str] = None, **kwargs, ) -> Path: """Download one file from a public figshare article. Resolves the article's file list via the public REST API (``https://api.figshare.com/v2/articles/<id>``) and streams the matching file into ``dest_dir``. With ``filename=None``, the article must contain exactly one file; pass an explicit name to disambiguate when there are several. Parameters ---------- article_id : int | str Numeric figshare article id (the trailing component of the DOI ``10.6084/m9.figshare.<id>``, e.g. ``29203457``). dest_dir : path-like Directory the file is downloaded into. Created if missing. filename : str, optional Name of the file to fetch. Required when the article carries more than one file. Matched case-sensitively against the file's ``name`` field returned by the API. Returns ------- Path Path to the downloaded file under ``dest_dir``. Example ------- >>> figshare_download(29203457, "/tmp/le2025") PosixPath('/tmp/le2025/zebf-auditory-restoration-1.zip') """ api_url = f"https://api.figshare.com/v2/articles/{article_id}" resp = requests.get(api_url, timeout=(30, 60)) resp.raise_for_status() files = resp.json().get("files", []) if not files: raise RuntimeError(f"figshare article {article_id} lists no files") if filename is None: if len(files) > 1: names = ", ".join(f["name"] for f in files) raise ValueError( f"figshare article {article_id} has {len(files)} files; " f"pass `filename=` to disambiguate (candidates: {names})" ) chosen = files[0] else: chosen = next((f for f in files if f.get("name") == filename), None) if chosen is None: names = ", ".join(f["name"] for f in files) raise FileNotFoundError( f"file {filename!r} not in figshare article {article_id} (have: {names})" ) dest_dir = Path(dest_dir).expanduser().resolve() dest_dir.mkdir(parents=True, exist_ok=True) dest = dest_dir / chosen["name"] return stream_download(chosen["download_url"], dest, **kwargs)
[docs] def github_raw_download( repo: str, path_in_repo: str, dest_path: Union[str, Path], *, ref: str = "HEAD", **kwargs, ) -> Path: """Download a file from a GitHub repo's raw content. Useful for paper-companion repos that publish small datasets / model artefacts alongside the code (e.g. DNet hosts ``test_data_5ms.mat`` for the Rahman et al. 2018 NS1 reanalysis at https://github.com/monzilur/DNet). Parameters ---------- repo : str ``"<owner>/<name>"``, e.g. ``"monzilur/DNet"``. path_in_repo : str Path of the file within the repo, e.g. ``"test_data_5ms.mat"``. dest_path : path-like ref : str, default "HEAD" Branch / tag / commit. ``"HEAD"`` resolves the default branch. Notes ----- Uses the ``raw.githubusercontent.com`` CDN, which has no rate limit for anonymous reads (unlike the GitHub REST API). """ url = f"https://raw.githubusercontent.com/{repo}/{ref}/{path_in_repo.lstrip('/')}" return stream_download(url, dest_path, **kwargs)
[docs] def unzip(zip_path: Union[str, Path], dest_dir: Union[str, Path], *, strip_root: bool = False) -> Path: """Unzip ``zip_path`` into ``dest_dir``. Idempotent (overwrites existing files). Parameters ---------- zip_path, dest_dir : path-like strip_root : bool, default False If True and the archive contains a single top-level directory, strip it from the extracted layout (so ``foo/a/b`` -> ``a/b``). Mirrors the common ``--strip-components=1`` tar idiom. Returns ------- Path The destination directory. """ zip_path = Path(zip_path) dest_dir = Path(dest_dir) dest_dir.mkdir(parents=True, exist_ok=True) with zipfile.ZipFile(zip_path) as zf: members = zf.namelist() prefix = "" if strip_root and members: roots = {m.split("/", 1)[0] for m in members if m} if len(roots) == 1: prefix = next(iter(roots)) + "/" for m in members: if prefix and not m.startswith(prefix): continue target = m[len(prefix):] if prefix else m if not target: continue out = dest_dir / target if m.endswith("/"): out.mkdir(parents=True, exist_ok=True) continue out.parent.mkdir(parents=True, exist_ok=True) with zf.open(m) as src, open(out, "wb") as dst: shutil.copyfileobj(src, dst) return dest_dir
[docs] def untar(tar_path: Union[str, Path], dest_dir: Union[str, Path], *, strip_components: int = 0) -> Path: """Extract a tar / tar.gz / tar.bz2 archive into ``dest_dir``. Mirrors GNU ``tar --strip-components=N``: drops the first ``N`` path components from every member. Useful when an archive wraps everything in nested directories that aren't part of the dataset's own layout — e.g. CRCNS-AA2 archives all wrap content in ``crcns/aa2/`` (strip 2). Parameters ---------- tar_path, dest_dir : path-like strip_components : int, default 0 How many leading path components to drop. Members that have fewer components than this are silently skipped. Returns ------- Path The destination directory. """ tar_path = Path(tar_path) dest_dir = Path(dest_dir) dest_dir.mkdir(parents=True, exist_ok=True) with tarfile.open(tar_path, "r:*") as tf: for member in tf.getmembers(): parts = member.name.split("/") # GNU tar's --strip-components silently skips members with too # few components (e.g. the root dir entry itself). if strip_components and len(parts) <= strip_components: continue stripped = "/".join(parts[strip_components:]) if not stripped: continue out = dest_dir / stripped if member.isdir(): out.mkdir(parents=True, exist_ok=True) continue if not (member.isfile() or member.islnk() or member.issym()): continue out.parent.mkdir(parents=True, exist_ok=True) extracted = tf.extractfile(member) if extracted is None: continue with extracted as src, open(out, "wb") as dst: shutil.copyfileobj(src, dst) return dest_dir