Source code for tiatoolbox.data

# skipcq: PTC-W6004  # noqa: ERA001
"""Package to define datasets available to download via TIAToolbox."""

from __future__ import annotations

import importlib.resources as importlib_resources
import tempfile
from pathlib import Path
from typing import TYPE_CHECKING

from huggingface_hub import hf_hub_download

from tiatoolbox import read_registry_files
from tiatoolbox.utils import imread

if TYPE_CHECKING:  # pragma: no cover
    import numpy as np

from tiatoolbox.utils import unzip_data

# Load a dictionary of sample files data (names and urls)
SAMPLE_FILES = read_registry_files("data/remote_samples.yaml")["files"]

__all__ = ["stain_norm_target"]


def _fetch_remote_sample(
    key: str,
    tmp_path: str | Path | None = None,
) -> Path:
    """Get the path to a sample file, after downloading from remote if required.

    Loads remote resources by name from Hugging Face datasets. This is done
    by looking up files in `tiatoolbox/data/remote_samples.yaml`.

    Note: Downloaded files are stored in subdirectories matching the Hugging
    Face repository structure (e.g., `subfolder/filename`). This is the
    standard behavior of the Hugging Face Hub API and preserves the
    repository's organization. For example, a file in the "sample_wsis"
    subfolder will be downloaded to `tmp_path/sample_wsis/filename`.

    Args:
        key (str):
            The name of the resource to fetch (as defined in
            remote_samples.yaml).
        tmp_path (str | Path | None):
            The directory to use for local caching. Defaults to the OS
            tmp path, see `tempfile.gettempdir` for more information.
            During testing, `tmp_path` should be set to a temporary test
            location using `tmp_path_factory.mktemp()`. Note that files
            will be placed in subdirectories within this path according to
            their repository structure.

    Returns:
        Path:
            The local path to the cached sample file after downloading,
            including the subfolder structure (e.g.,
            `tmp_path/subfolder/filename`).

    Examples:
        >>> from pathlib import Path
        >>> from tiatoolbox.data import _fetch_remote_sample
        >>> # Download a sample whole slide image
        >>> sample_path = _fetch_remote_sample("svs-1-small")
        >>> # The file will be at: tmp_dir/sample_wsis/CMU-1-Small-Region.svs
        >>> # Download to a specific directory
        >>> target_dir = Path("/path/to/data")
        >>> sample_path = _fetch_remote_sample("svs-1-small", target_dir)
        >>> # File will be at: /path/to/data/sample_wsis/CMU-1-Small-Region.svs

    """
    tmp_path = Path(tmp_path) if tmp_path else Path(tempfile.gettempdir())
    if not tmp_path.is_dir():
        msg = "tmp_path must be a directory."
        raise ValueError(msg)

    file_path = hf_hub_download(
        repo_id=SAMPLE_FILES[key]["hf_repo_id"],
        filename=SAMPLE_FILES[key]["filename"],
        subfolder=SAMPLE_FILES[key]["subfolder"],
        local_dir=tmp_path,
        repo_type="dataset",
    )

    extract = SAMPLE_FILES[key].get("extract", False)
    if extract:
        unzip_path = Path(file_path).parent / Path(file_path).stem
        unzip_data(Path(file_path), unzip_path, del_zip=False)
        return unzip_path

    return Path(file_path)


def _local_sample_path(path: str | Path) -> Path:
    """Get the path to a data file bundled with the package.

    Args:
        path (str or Path):
            Relative path to the package data file.

    Returns:
        Path:
            Path within the package to the data file.


    Example:
        >>> # Get the path to a sample target image for performing
        >>> # stain normalization.
        >>> from tiatoolbox.data import stain_norm_target
        >>> img = stain_norm_target()

    """
    file_path = importlib_resources.files("tiatoolbox") / str(Path("data") / path)
    with importlib_resources.as_file(file_path) as file_path_:
        return file_path_


[docs] def stain_norm_target() -> np.ndarray: """Target image for stain normalization.""" return imread(_local_sample_path("target_image.png"))
[docs] def small_svs() -> Path: """Small SVS file for testing.""" return _fetch_remote_sample("svs-1-small")