Source code for tiatoolbox.models.engine.deep_feature_extractor
"""Deep Feature Extraction Engine for Digital Pathology.
This module defines the `DeepFeatureExtractor` class, which extends
`SemanticSegmentor` to extract intermediate feature representations
from whole slide images (WSIs) or image patches. Unlike segmentation
or classification engines, this extractor focuses on generating feature
embeddings for downstream tasks such as clustering, visualization, or
training other machine learning models.
Key Components:
---------------
Functions:
- save_to_cache:
Utility to spill intermediate feature and coordinate arrays to
disk using Zarr for memory-efficient processing.
Classes:
- DeepFeatureExtractor:
Core engine for extracting deep features from WSIs or patches.
Supports memory-aware caching and outputs in Zarr format.
Features:
---------
- Handles large-scale WSIs with memory-aware caching.
- Outputs feature maps and patch coordinates for downstream analysis.
- Compatible with TIAToolbox pretrained models and custom PyTorch models.
- Supports both patch-based and WSI-based workflows.
Example:
--------
>>> from tiatoolbox.models.engine.deep_feature_extractor import DeepFeatureExtractor
>>> extractor = DeepFeatureExtractor(model="resnet18")
>>> wsis = ["slide1.svs", "slide2.svs"]
>>> output = extractor.run(wsis, patch_mode=False, output_type="zarr")
>>> print(output)
... '/path/to/output.zarr'
"""
from __future__ import annotations
import gc
from typing import TYPE_CHECKING
import dask.array as da
import psutil
import zarr
from dask import compute
from tqdm.auto import tqdm
from typing_extensions import Unpack
from tiatoolbox.utils.misc import update_tqdm_desc
from .patch_predictor import PatchPredictor, PredictorRunParams
if TYPE_CHECKING: # pragma: no cover
import os
from pathlib import Path
import numpy as np
from torch.utils.data import DataLoader
from tiatoolbox.annotation import AnnotationStore
from tiatoolbox.models.engine.io_config import IOPatchPredictorConfig
from tiatoolbox.models.models_abc import ModelABC
from tiatoolbox.type_hints import IntPair, Resolution, Units
from tiatoolbox.wsicore import WSIReader
[docs]
class DeepFeatureExtractor(PatchPredictor):
r"""Generic deep feature extractor for digital pathology images.
This class extends :class:`PatchPredictor` to extract deep features from
whole slide images (WSIs) or image patches using a deep learning model.
It is designed for use cases where the goal is to obtain intermediate
feature representations (e.g., embeddings) rather than final classification
or segmentation outputs.
The extracted features are returned or saved in Zarr format for downstream
analysis, such as clustering, visualization, or training other machine learning
models.
Args:
model (str | ModelABC):
A PyTorch model instance or the name of a pretrained model from TIAToolbox.
batch_size (int):
Number of image patches processed per forward pass. Default is 8.
num_workers (int):
Number of workers for data loading. Default is 0.
weights (str | Path | None):
Path to model weights. If None, default weights are used.
device (str):
Device to run the model on (e.g., "cpu", "cuda"). Default is "cpu".
verbose (bool):
Whether to enable verbose logging. Default is True.
Attributes:
process_prediction_per_batch (bool):
Flag to control whether predictions are processed per batch.
Default is False.
"""
def __init__(
self: DeepFeatureExtractor,
model: str | ModelABC,
batch_size: int = 8,
num_workers: int = 0,
weights: str | Path | None = None,
*,
device: str = "cpu",
verbose: bool = True,
) -> None:
"""Initialize :class:`DeepFeatureExtractor`.
Args:
model (str | ModelABC):
A PyTorch model instance or the name of a pretrained model from
TIAToolbox. If a string is provided, the corresponding pretrained
weights will be downloaded unless overridden via `weights`.
batch_size (int):
Number of image patches processed per forward pass. Default is 8.
num_workers (int):
Number of workers for data loading. Default is 0.
weights (str | Path | None):
Path to model weights. If None, default weights are used.
device (str):
Device to run the model on (e.g., "cpu", "cuda"). Default is "cpu".
verbose (bool):
Whether to enable verbose logging. Default is True.
"""
super().__init__(
model=model,
batch_size=batch_size,
num_workers=num_workers,
weights=weights,
device=device,
verbose=verbose,
)
self.process_prediction_per_batch = False
[docs]
def infer_wsi(
self: DeepFeatureExtractor,
dataloader: DataLoader,
save_path: Path,
**kwargs: Unpack[PredictorRunParams],
) -> dict[str, da.Array]:
"""Perform model inference on a whole slide image (WSI).
This method processes a WSI using the provided DataLoader and extracts
deep features from each patch using the model. It supports memory-aware
caching by spilling intermediate results to disk when memory usage exceeds
a specified threshold. The final output includes feature maps and their
corresponding spatial coordinates.
Args:
dataloader (DataLoader):
PyTorch DataLoader configured for WSI processing.
save_path (Path):
Path to save intermediate Zarr output. Used for caching.
**kwargs (PredictorRunParams):
Additional runtime parameters to configure prediction.
Optional Keys:
auto_get_mask (bool):
Automatically generate segmentation masks using
`wsireader.tissue_mask()` during processing.
batch_size (int):
Number of image patches per forward pass.
class_dict (dict):
Mapping of classification outputs to class names.
device (str):
Device to run the model on (e.g., "cpu", "cuda").
labels (list):
Optional labels for input images. Only a single label per image
is supported.
memory_threshold (int):
Memory usage threshold (percentage) to trigger caching behavior.
num_workers (int):
Number of workers for DataLoader and post-processing.
output_file (str):
Filename for saving output (e.g., ".zarr" or ".db").
return_labels (bool):
Whether to return labels with predictions.
return_probabilities (bool):
Whether to return per-class probabilities in the output.
If False, only predicted labels are returned.
scale_factor (tuple[float, float]):
Scale factor for annotations (model_mpp / slide_mpp).
Used to convert coordinates to baseline resolution.
stride_shape (tuple[int, int]):
Stride used during WSI processing.
Defaults to `patch_input_shape` if not provided.
verbose (bool):
Whether to enable verbose logging.
Returns:
dict[str, dask.array.Array]:
Dictionary containing:
- "probabilities": Extracted feature maps from the model.
- "coordinates": Patch coordinates corresponding to the features.
"""
# Default Memory threshold percentage is 80.
memory_threshold = kwargs.get("memory_threshold", 80)
vm = psutil.virtual_memory()
keys = ["probabilities", "coordinates"]
probabilities, coordinates = [], []
# Main output dictionary
raw_predictions = dict(
zip(keys, [da.empty(shape=(0, 0))] * len(keys), strict=False)
)
# Inference loop
tqdm_loop = tqdm(
dataloader,
leave=False,
desc="Inferring Patches",
disable=not self.verbose,
)
probabilities_zarr, coordinates_zarr = None, None
probabilities_used_percent = 0
infer_batch = self._get_model_attr("infer_batch")
for batch_data in tqdm_loop:
batch_output = infer_batch(
self.model,
batch_data["image"],
device=self.device,
)
probabilities.append(da.from_array(batch_output[0]))
coordinates.append(
da.from_array(
self._get_coordinates(batch_data),
)
)
used_percent = vm.percent
probabilities_used_percent = (
probabilities_used_percent + (probabilities[-1].nbytes / vm.free) * 100
)
if (
used_percent > memory_threshold
or probabilities_used_percent > memory_threshold
):
used_percent = (
probabilities_used_percent
if (probabilities_used_percent > memory_threshold)
else used_percent
)
msg = (
f"Current Memory usage: {used_percent} % "
f"exceeds specified threshold: {memory_threshold}. "
f"Saving intermediate results to disk."
)
update_tqdm_desc(tqdm_loop=tqdm_loop, desc=msg)
# Flush data in Memory and clear dask graph
probabilities_zarr, coordinates_zarr = save_to_cache(
probabilities,
coordinates,
probabilities_zarr,
coordinates_zarr,
save_path=save_path,
)
probabilities, coordinates = [], []
probabilities_used_percent = 0
gc.collect()
update_tqdm_desc(tqdm_loop=tqdm_loop, desc="Inferring patches")
if probabilities_zarr is not None:
probabilities_zarr, coordinates_zarr = save_to_cache(
probabilities,
coordinates,
probabilities_zarr,
coordinates_zarr,
save_path=save_path,
)
# Wrap zarr in dask array
raw_predictions["probabilities"] = da.from_zarr(
probabilities_zarr, chunks=probabilities_zarr.chunks
)
raw_predictions["coordinates"] = da.from_zarr(
coordinates_zarr, chunks=coordinates_zarr.chunks
)
else:
raw_predictions["coordinates"] = da.concatenate(coordinates, axis=0)
raw_predictions["probabilities"] = da.concatenate(probabilities, axis=0)
return raw_predictions
[docs]
def post_process_patches(
self: DeepFeatureExtractor,
raw_predictions: dict[str, da.Array],
**kwargs: Unpack[PredictorRunParams],
) -> dict[str, da.Array]:
"""Post-process raw patch predictions from model inference.
This method overrides the base implementation to return raw feature maps
without applying any additional processing. It is intended for use cases
where intermediate features are required as output.
Args:
raw_predictions (dask.array.Array):
Raw model predictions as a Dask array.
**kwargs (PredictorRunParams):
Additional runtime parameters to configure prediction.
Optional Keys:
auto_get_mask (bool):
Automatically generate segmentation masks using
`wsireader.tissue_mask()` during processing.
batch_size (int):
Number of image patches per forward pass.
class_dict (dict):
Mapping of classification outputs to class names.
device (str):
Device to run the model on (e.g., "cpu", "cuda").
labels (list):
Optional labels for input images. Only a single label per image
is supported.
memory_threshold (int):
Memory usage threshold (percentage) to trigger caching behavior.
num_workers (int):
Number of workers for DataLoader and post-processing.
output_file (str):
Filename for saving output (e.g., ".zarr" or ".db").
return_labels (bool):
Whether to return labels with predictions.
return_probabilities (bool):
Whether to return per-class probabilities in the output.
If False, only predicted labels are returned.
scale_factor (tuple[float, float]):
Scale factor for annotations (model_mpp / slide_mpp).
Used to convert coordinates to baseline resolution.
stride_shape (tuple[int, int]):
Stride used during WSI processing.
Defaults to `patch_input_shape` if not provided.
verbose (bool):
Whether to enable verbose logging.
Returns:
dask.array.Array:
Unmodified raw predictions.
"""
_ = kwargs.get("return_probabilities")
return raw_predictions
[docs]
def save_predictions(
self: DeepFeatureExtractor,
processed_predictions: dict,
output_type: str,
save_path: Path | None = None,
**kwargs: Unpack[PredictorRunParams],
) -> dict | Path:
"""Save patch-level feature predictions to disk or return them in memory.
This method saves the extracted deep features in the specified output format.
Only the "zarr" format is supported for this engine. The method disables
saving the "predictions" key, as it is not relevant for feature extraction.
Args:
processed_predictions (dict):
Dictionary containing processed model outputs.
output_type (str):
Desired output format. Must be "zarr".
save_path (Path | None):
Path to save the output file. Required for "zarr" format.
**kwargs (PredictorRunParams):
Additional runtime parameters to configure prediction.
Optional Keys:
auto_get_mask (bool):
Automatically generate segmentation masks using
`wsireader.tissue_mask()` during processing.
batch_size (int):
Number of image patches per forward pass.
class_dict (dict):
Mapping of classification outputs to class names.
device (str):
Device to run the model on (e.g., "cpu", "cuda").
labels (list):
Optional labels for input images. Only a single label per image
is supported.
memory_threshold (int):
Memory usage threshold (percentage) to trigger caching behavior.
num_workers (int):
Number of workers for DataLoader and post-processing.
output_file (str):
Filename for saving output (e.g., ".zarr" or ".db").
return_labels (bool):
Whether to return labels with predictions.
return_probabilities (bool):
Whether to return per-class probabilities in the output.
If False, only predicted labels are returned.
scale_factor (tuple[float, float]):
Scale factor for annotations (model_mpp / slide_mpp).
Used to convert coordinates to baseline resolution.
stride_shape (tuple[int, int]):
Stride used during WSI processing.
Defaults to `patch_input_shape` if not provided.
verbose (bool):
Whether to enable verbose logging.
Returns:
dict | Path:
- If `output_type` is "zarr": returns the path to the saved Zarr file.
- If `output_type` is "dict": returns predictions as a dictionary.
Raises:
ValueError:
If an unsupported output format is provided.
"""
# no need to compute predictions
self.drop_keys.append("predictions")
processed_predictions["features"] = processed_predictions.pop("probabilities")
return super().save_predictions(
processed_predictions, output_type, save_path=save_path, **kwargs
)
def _update_run_params(
self: DeepFeatureExtractor,
images: list[os.PathLike | Path | WSIReader] | np.ndarray,
masks: list[os.PathLike | Path] | np.ndarray | None = None,
input_resolutions: list[dict[Units, Resolution]] | None = None,
patch_input_shape: IntPair | None = None,
save_dir: os.PathLike | Path | None = None,
ioconfig: IOPatchPredictorConfig | None = None,
output_type: str = "dict",
*,
overwrite: bool = False,
patch_mode: bool,
**kwargs: Unpack[PredictorRunParams],
) -> Path | None:
"""Update runtime parameters for the DeepFeatureExtractor engine.
This method sets internal attributes such as caching, batch size,
IO configuration, and output format based on user input and keyword arguments.
It also validates that the output format is supported.
Args:
images (list[PathLike | WSIReader] | np.ndarray):
Input images or patches.
masks (list[PathLike] | np.ndarray | None):
Optional masks for WSI processing.
input_resolutions (list[dict[Units, Resolution]] | None):
Resolution settings for input heads. Supported units are `level`,
`power`, and `mpp`. Keys should be "units" and "resolution", e.g.,
[{"units": "mpp", "resolution": 0.25}]. See :class:`WSIReader` for
details.
patch_input_shape (IntPair | None):
Shape of input patches (height, width), requested at read
resolution. Must be positive.
save_dir (PathLike | None):
Directory to save output files. Required for WSI mode.
ioconfig (IOPatchPredictorConfig | None):
IO configuration for patch extraction and resolution.
output_type (str):
Desired output format. Must be "zarr" or "dict".
overwrite (bool):
Whether to overwrite existing output files. Default is False.
patch_mode (bool):
Whether to treat input as patches (`True`) or WSIs (`False`).
**kwargs (PredictorRunParams):
Additional runtime parameters to configure prediction.
Optional Keys:
auto_get_mask (bool):
Automatically generate segmentation masks using
`wsireader.tissue_mask()` during processing.
batch_size (int):
Number of image patches per forward pass.
class_dict (dict):
Mapping of classification outputs to class names.
device (str):
Device to run the model on (e.g., "cpu", "cuda").
labels (list):
Optional labels for input images. Only a single label per image
is supported.
memory_threshold (int):
Memory usage threshold (percentage) to trigger caching behavior.
num_workers (int):
Number of workers for DataLoader and post-processing.
output_file (str):
Filename for saving output (e.g., ".zarr" or ".db").
return_labels (bool):
Whether to return labels with predictions.
return_probabilities (bool):
Whether to return per-class probabilities in the output.
If False, only predicted labels are returned.
scale_factor (tuple[float, float]):
Scale factor for annotations (model_mpp / slide_mpp).
Used to convert coordinates to baseline resolution.
stride_shape (tuple[int, int]):
Stride used during WSI processing.
Defaults to `patch_input_shape` if not provided.
verbose (bool):
Whether to enable verbose logging.
Returns:
Path | None:
Path to the save directory if applicable, otherwise None.
Raises:
ValueError:
If `output_type` is not "zarr" or "dict", which are the
only supported formats.
"""
if output_type not in ["zarr", "dict"]:
msg = (
f"output_type: `{output_type}` is not supported for "
f"`DeepFeatureExtractor` engine."
)
raise ValueError(msg)
return super()._update_run_params(
images=images,
masks=masks,
input_resolutions=input_resolutions,
patch_input_shape=patch_input_shape,
save_dir=save_dir,
ioconfig=ioconfig,
overwrite=overwrite,
patch_mode=patch_mode,
output_type=output_type,
**kwargs,
)
[docs]
def run(
self: DeepFeatureExtractor,
images: list[os.PathLike | Path | WSIReader] | np.ndarray,
*,
masks: list[os.PathLike | Path] | np.ndarray | None = None,
input_resolutions: list[dict[Units, Resolution]] | None = None,
patch_input_shape: IntPair | None = None,
ioconfig: IOPatchPredictorConfig | None = None,
patch_mode: bool = True,
save_dir: os.PathLike | Path | None = None,
overwrite: bool = False,
output_type: str = "dict",
**kwargs: Unpack[PredictorRunParams],
) -> AnnotationStore | Path | str | dict | list[Path]:
"""Run the DeepFeatureExtractor engine on input images.
This method orchestrates the full inference pipeline, including preprocessing,
model inference, and saving of extracted deep features. It supports both
patch-level and whole slide image (WSI) modes. The output is returned or saved
in Zarr format.
Note:
The `return_probabilities` flag is always set to True for this engine,
as it is designed to extract intermediate feature maps.
Args:
images (list[PathLike | WSIReader] | np.ndarray):
Input images or patches. Can be a list of file paths, WSIReader objects,
or a NumPy array of image patches.
masks (list[PathLike] | np.ndarray | None):
Optional masks for WSI processing. Only used when `patch_mode` is False.
input_resolutions (list[dict[Units, Resolution]] | None):
Resolution settings for input heads. Supported units are `level`,
`power`, and `mpp`. Keys should be "units" and "resolution", e.g.,
[{"units": "mpp", "resolution": 0.25}]. See :class:`WSIReader` for
details.
patch_input_shape (IntPair | None):
Shape of input patches (height, width), requested at read
resolution. Must be positive.
ioconfig (IOPatchPredictorConfig | None):
IO configuration for patch extraction and resolution.
patch_mode (bool):
Whether to treat input as patches (`True`) or WSIs (`False`).
Default is True.
save_dir (PathLike | None):
Directory to save output files. Required for WSI mode.
overwrite (bool):
Whether to overwrite existing output files. Default is False.
output_type (str):
Desired output format. Must be "zarr" or "dict".
**kwargs (PredictorRunParams):
Additional runtime parameters to configure prediction.
Optional Keys:
auto_get_mask (bool):
Automatically generate segmentation masks using
`wsireader.tissue_mask()` during processing.
batch_size (int):
Number of image patches per forward pass.
class_dict (dict):
Mapping of classification outputs to class names.
device (str):
Device to run the model on (e.g., "cpu", "cuda").
labels (list):
Optional labels for input images. Only a single label per image
is supported.
memory_threshold (int):
Memory usage threshold (percentage) to trigger caching behavior.
num_workers (int):
Number of workers for DataLoader and post-processing.
output_file (str):
Filename for saving output (e.g., ".zarr" or ".db").
return_labels (bool):
Whether to return labels with predictions.
return_probabilities (bool):
Whether to return per-class probabilities in the output.
If False, only predicted labels are returned.
scale_factor (tuple[float, float]):
Scale factor for annotations (model_mpp / slide_mpp).
Used to convert coordinates to baseline resolution.
stride_shape (tuple[int, int]):
Stride used during WSI processing.
Defaults to `patch_input_shape` if not provided.
verbose (bool):
Whether to enable verbose logging.
Returns:
AnnotationStore | Path | str | dict | list[Path]:
- If `patch_mode` is True: returns predictions or path to saved output.
- If `patch_mode` is False: returns a dictionary mapping each WSI
to its output path.
Raises:
ValueError:
If `output_type` is not "zarr" or "dict".
"""
# return_probabilities is always True for FeatureExtractor.
kwargs["return_probabilities"] = True
return super().run(
images=images,
masks=masks,
input_resolutions=input_resolutions,
patch_input_shape=patch_input_shape,
ioconfig=ioconfig,
patch_mode=patch_mode,
save_dir=save_dir,
overwrite=overwrite,
output_type=output_type,
**kwargs,
)
[docs]
def save_to_cache(
probabilities: list[da.Array],
coordinates: list[da.Array],
probabilities_zarr: zarr.Array | None,
coordinates_zarr: zarr.Array | None,
save_path: str | Path = "temp.zarr",
) -> tuple[zarr.Array, zarr.Array]:
"""Save computed feature and coordinate arrays to Zarr cache.
This function computes the given Dask arrays (`probabilities` and `coordinates`),
resizes the corresponding Zarr datasets to accommodate the new data, and appends
the results. If the Zarr datasets do not exist, it initializes them within the
specified Zarr group.
Args:
probabilities (list[dask.array.Array]):
List of Dask arrays representing extracted feature maps.
coordinates (list[dask.array.Array]):
List of Dask arrays representing patch coordinates.
probabilities_zarr (zarr.Array | None):
Existing Zarr dataset for feature maps. If None, a new one is created.
coordinates_zarr (zarr.Array | None):
Existing Zarr dataset for coordinates. If None, a new one is created.
save_path (str | Path):
Path to the Zarr group for saving datasets. Defaults to "temp.zarr".
Returns:
tuple[zarr.Array, zarr.Array]:
Updated Zarr datasets for feature maps and coordinates.
"""
if len(probabilities) == 0:
return probabilities_zarr, coordinates_zarr
coordinates = da.concatenate(coordinates, axis=0)
probabilities = da.concatenate(probabilities, axis=0)
computed_values = compute(*[probabilities, coordinates])
probabilities_computed, coordinates_computed = computed_values
chunk_shape = tuple(chunk[0] for chunk in probabilities.chunks)
if probabilities_zarr is None:
zarr_group = zarr.open(str(save_path), mode="w")
probabilities_zarr = zarr_group.create_dataset(
name="canvas",
shape=(0, *probabilities_computed.shape[1:]),
chunks=(chunk_shape[0], *probabilities_computed.shape[1:]),
dtype=probabilities_computed.dtype,
overwrite=True,
)
coordinates_zarr = zarr_group.create_dataset(
name="count",
shape=(0, *coordinates_computed.shape[1:]),
dtype=coordinates_computed.dtype,
chunks=(chunk_shape[0], *coordinates_computed.shape[1:]),
overwrite=True,
)
probabilities_zarr.resize(
(
probabilities_zarr.shape[0] + probabilities_computed.shape[0],
*probabilities_zarr.shape[1:],
)
)
probabilities_zarr[-probabilities_computed.shape[0] :] = probabilities_computed
coordinates_zarr.resize(
(
coordinates_zarr.shape[0] + coordinates_computed.shape[0],
*coordinates_zarr.shape[1:],
)
)
coordinates_zarr[-coordinates_computed.shape[0] :] = coordinates_computed
return probabilities_zarr, coordinates_zarr