PatchPredictor

class PatchPredictor(model, batch_size=8, num_workers=0, weights=None, *, device='cpu', verbose=True)[source]

Patch-level prediction engine for digital histology images.

This class extends EngineABC to support patch-based inference using pretrained or custom models from TIAToolbox. It supports both patch and whole slide image (WSI) modes, and provides utilities for post-processing and saving predictions.

Supported Models:
PatchPredictor performance on the Kather100K dataset [1].

Model name

F1score

alexnet-kather100k

0.965

resnet18-kather100k

0.990

resnet34-kather100k

0.991

resnet50-kather100k

0.989

resnet101-kather100k

0.989

resnext50_32x4d-kather100k

0.992

resnext101_32x8d-kather100k

0.991

wide_resnet50_2-kather100k

0.989

wide_resnet101_2-kather100k

0.990

densenet121-kather100k

0.993

densenet161-kather100k

0.992

densenet169-kather100k

0.992

densenet201-kather100k

0.991

mobilenet_v2-kather100k

0.990

mobilenet_v3_large-kather100k

0.991

mobilenet_v3_small-kather100k

0.992

googlenet-kather100k

0.992

PatchPredictor performance on the PCam dataset [2]

Model name

F1score

alexnet-pcam

0.840

resnet18-pcam

0.888

resnet34-pcam

0.889

resnet50-pcam

0.892

resnet101-pcam

0.888

resnext50_32x4d-pcam

0.900

resnext101_32x8d-pcam

0.892

wide_resnet50_2-pcam

0.901

wide_resnet101_2-pcam

0.898

densenet121-pcam

0.897

densenet161-pcam

0.893

densenet169-pcam

0.895

densenet201-pcam

0.891

mobilenet_v2-pcam

0.899

mobilenet_v3_large-pcam

0.895

mobilenet_v3_small-pcam

0.890

googlenet-pcam

0.867

Parameters:
  • model (str | ModelABC) – A PyTorch model instance or name of a pretrained model from TIAToolbox. If a string is provided, pretrained weights will be downloaded unless overridden via weights. The user can request pretrained models from the toolbox model zoo using the list of pretrained models available at this link By default, the corresponding pretrained weights will also be downloaded.

  • batch_size (int) – Number of image patches processed per forward pass. Default is 8.

  • num_workers (int) – Number of workers for data loading. Default is 0.

  • weights (str | Path | None) –

    Path to model weights. If None, default weights are used.

    >>> engine = PatchPredictor(
    ...    model="pretrained-model",
    ...    weights="/path/to/pretrained-local-weights.pth"
    ... )
    

  • device (str) – Device to run the model on (e.g., “cpu”, “cuda”). Default is “cpu”.

  • verbose (bool) – Whether to enable verbose logging. Default is True.

images

Input image patches or WSI paths.

Type:

list[str | Path] | np.ndarray

masks

Optional tissue masks for WSI processing. These are only utilized when patch_mode is False. If not provided, then a tissue mask will be automatically generated for whole slide images.

Type:

list[str | Path] | np.ndarray

patch_mode

Whether input is treated as patches (True) or WSIs (False).

Type:

bool

model

Loaded PyTorch model.

Type:

ModelABC

ioconfig

IO configuration for patch extraction and resolution.

Type:

IOPatchPredictorConfig

return_labels

Whether to include labels in the output.

Type:

bool

input_resolutions

Resolution settings for model input. Supported units are level, power and mpp. Keys should be “units” and “resolution” e.g., [{“units”: “mpp”, “resolution”: 0.25}]. Please see WSIReader for details.

Type:

list[dict]

patch_input_shape

Shape of input patches (height, width). Patches are at requested read resolution, not with respect to level 0, and must be positive.

Type:

tuple[int, int]

stride_shape

Stride used during patch extraction. Stride is at requested read resolution, not with respect to level 0, and must be positive. If not provided, stride_shape=patch_input_shape.

Type:

tuple[int, int]

labels

Optional labels for input images. Only a single label per image is supported.

Type:

list | None

drop_keys

Keys to exclude from model output.

Type:

list

output_type

Format of output (“dict”, “zarr”, “qupath”, “annotationstore”).

Type:

str

Example

>>> # list of 2 image patches as input
>>> data = ['path/img.svs', 'path/img.svs']
>>> predictor = PatchPredictor(model="resnet18-kather100k")
>>> output = predictor.run(data, patch_mode=False)
>>> # array of list of 2 image patches as input
>>> data = np.array([img1, img2])
>>> predictor = PatchPredictor(model="resnet18-kather100k")
>>> output = predictor.run(data, patch_mode=True)
>>> # list of 2 image patch files as input
>>> data = ['path/img.png', 'path/img.png']
>>> predictor = PatchPredictor(model="resnet18-kather100k")
>>> output = predictor.run(data, patch_mode=True)
>>> # list of 2 image tile files as input
>>> tile_file = ['path/tile1.png', 'path/tile2.png']
>>> predictor = PatchPredictor(model="resnet18-kather100k")
>>> output = predictor.run(tile_file, patch_mode=False)
>>> # list of 2 wsi files as input
>>> wsi_file = ['path/wsi1.svs', 'path/wsi2.svs']
>>> predictor = PatchPredictor(model="resnet18-kather100k")
>>> output = predictor.run(wsi_file, patch_mode=False)

References

[1] Kather, Jakob Nikolas, et al. “Predicting survival from colorectal cancer histology slides using deep learning: A retrospective multicenter study.” PLoS medicine 16.1 (2019): e1002730.

[2] Veeling, Bastiaan S., et al. “Rotation equivariant CNNs for digital pathology.” International Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2018.

Initialize the PatchPredictor engine.

Parameters:
  • model (str | ModelABC) – A PyTorch model instance or name of a pretrained model from TIAToolbox. If a string is provided, the corresponding pretrained weights will be downloaded unless overridden via weights.

  • batch_size (int) – Number of image patches processed per forward pass. Default is 8.

  • num_workers (int) – Number of workers for data loading. Default is 0.

  • weights (str | Path | None) – Path to model weights. If None, default weights are used.

  • device (str) – device to run the model on (e.g., “cpu”, “cuda”). Default is “cpu”.

  • verbose (bool) – Whether to enable verbose logging. Default is True.

Methods

post_process_patches

Post-process raw patch predictions from model inference.

post_process_wsi

Post-process predictions from whole slide image (WSI) inference.

run

Run the PatchPredictor engine on input images.

post_process_patches(raw_predictions, **kwargs)[source]

Post-process raw patch predictions from model inference.

This method applies the model’s post-processing function to the raw predictions obtained from infer_patches(). The output is wrapped in a Dask array for efficient computation and memory handling.

Parameters:
  • raw_predictions (dict[str, da.Array]) – Dictionary containing raw model predictions as Dask arrays.

  • **kwargs (PredictorRunParams) –

    Additional runtime parameters to configure prediction.

    Optional Keys:
    auto_get_mask (bool):

    Automatically generate segmentation masks using wsireader.tissue_mask() during processing.

    batch_size (int):

    Number of image patches per forward pass.

    class_dict (dict):

    Mapping of classification outputs to class names.

    device (str):

    Device to run the model on (e.g., “cpu”, “cuda”).

    labels (list):

    Optional labels for input images. Only a single label per image is supported.

    memory_threshold (int):

    Memory usage threshold (percentage) to trigger caching behavior.

    num_workers (int):

    Number of workers for DataLoader and post-processing.

    output_file (str):

    Filename for saving output (e.g., “.zarr” or “.db”).

    return_labels (bool):

    Whether to return labels with predictions.

    return_probabilities (bool):

    Whether to return per-class probabilities in the output. If False, only predicted labels are returned.

    scale_factor (tuple[float, float]):

    Scale factor for annotations (model_mpp / slide_mpp). Used to convert coordinates to baseline resolution.

    stride_shape (tuple[int, int]):

    Stride used during WSI processing. Defaults to patch_input_shape if not provided.

    verbose (bool):

    Whether to enable verbose logging.

  • self (PatchPredictor)

Returns:

Post-processed predictions as a Dask array.

Return type:

dict[str, da.Array]

post_process_wsi(raw_predictions, save_path, **kwargs)[source]

Post-process predictions from whole slide image (WSI) inference.

This method refines the raw patch-level predictions obtained from WSI inference. It typically applies spatial smoothing or other contextual operations using neighboring patch information. Internally, it delegates to post_process_patches().

Parameters:
  • raw_predictions (dict[str, da.Array]) – Dictionary containing raw model predictions as Dask arrays.

  • save_path (Path) – Path to save the intermediate output. The intermediate output is saved in a zarr file.

  • **kwargs (PredictorRunParams) –

    Additional runtime parameters to configure prediction.

    Optional Keys:
    auto_get_mask (bool):

    Automatically generate segmentation masks using wsireader.tissue_mask() during processing.

    batch_size (int):

    Number of image patches per forward pass.

    class_dict (dict):

    Mapping of classification outputs to class names.

    device (str):

    Device to run the model on (e.g., “cpu”, “cuda”).

    labels (list):

    Optional labels for input images. Only a single label per image is supported.

    memory_threshold (int):

    Memory usage threshold (percentage) to trigger caching behavior.

    num_workers (int):

    Number of workers for DataLoader and post-processing.

    output_file (str):

    Filename for saving output (e.g., “.zarr” or “.db”).

    return_labels (bool):

    Whether to return labels with predictions.

    return_probabilities (bool):

    Whether to return per-class probabilities in the output. If False, only predicted labels are returned.

    scale_factor (tuple[float, float]):

    Scale factor for annotations (model_mpp / slide_mpp). Used to convert coordinates to baseline resolution.

    stride_shape (tuple[int, int]):

    Stride used during WSI processing. Defaults to patch_input_shape if not provided.

    verbose (bool):

    Whether to enable verbose logging.

  • self (PatchPredictor)

Returns:

Post-processed predictions as a Dask array.

Return type:

dask.array.Array

run(images, *, masks=None, input_resolutions=None, patch_input_shape=None, ioconfig=None, patch_mode=True, save_dir=None, overwrite=False, output_type='dict', **kwargs)[source]

Run the PatchPredictor engine on input images.

This method orchestrates the full inference pipeline, including preprocessing, model inference, post-processing, and saving results. It supports both patch and whole slide image (WSI) modes.

Parameters:
  • images (list[PathLike | WSIReader] | np.ndarray) – Input images or patches. When using patch mode, the input must be either a list of images, a list of image file paths or a numpy array of an image list.

  • masks (list[PathLike] | np.ndarray | None) – Optional masks for WSI processing. Only utilised when patch_mode is False. Patches are only generated within a masked area. If not provided, then a tissue mask will be automatically generated for whole slide images.

  • input_resolutions (list[dict[Units, Resolution]] | None) – Resolution settings for input heads. Supported units are level, power, and mpp. Keys should be “units” and “resolution”, e.g., [{“units”: “mpp”, “resolution”: 0.25}]. See WSIReader for details.

  • patch_input_shape (IntPair | None) – Shape of input patches (height, width), requested at read resolution. Must be positive.

  • ioconfig (IOPatchPredictorConfig | None) – IO configuration for patch extraction and resolution.

  • patch_mode (bool) – Whether to treat input as patches (True) or WSIs (False).

  • save_dir (PathLike | None) – Directory to save output files. Required for WSI mode.

  • overwrite (bool) – Whether to overwrite existing output files. Default is False.

  • output_type (str) – Desired output format: “dict”, “zarr”, “qupath” or “annotationstore”. Default value is “zarr”.

  • **kwargs (PredictorRunParams) –

    Additional runtime parameters to configure prediction.

    Optional Keys:
    auto_get_mask (bool):

    Automatically generate segmentation masks using wsireader.tissue_mask() during processing.

    batch_size (int):

    Number of image patches per forward pass.

    class_dict (dict):

    Mapping of classification outputs to class names.

    device (str):

    Device to run the model on (e.g., “cpu”, “cuda”).

    labels (list):

    Optional labels for input images. Only a single label per image is supported.

    memory_threshold (int):

    Memory usage threshold (percentage) to trigger caching behavior.

    num_workers (int):

    Number of workers for DataLoader and post-processing.

    output_file (str):

    Filename for saving output (e.g., “.zarr” or “.db”).

    return_labels (bool):

    Whether to return labels with predictions.

    return_probabilities (bool):

    Whether to return per-class probabilities in the output. If False, only predicted labels are returned.

    scale_factor (tuple[float, float]):

    Scale factor for annotations (model_mpp / slide_mpp). Used to convert coordinates to baseline resolution.

    stride_shape (tuple[int, int]):

    Stride used during WSI processing. Defaults to patch_input_shape if not provided.

    verbose (bool):

    Whether to enable verbose logging.

  • self (PatchPredictor)

Returns:

  • If patch_mode is True: returns predictions or path to saved output.

  • If patch_mode is False: returns a dictionary mapping each WSI to its output path.

Return type:

AnnotationStore | Path | str | dict

Examples

>>> wsis = ['wsi1.svs', 'wsi2.svs']
>>> image_patches = [np.ndarray, np.ndarray]
>>> class PatchPredictor(EngineABC):
>>> # Define all Abstract methods.
>>>     ...
>>> predictor = PatchPredictor(model="resnet18-kather100k")
>>> output = predictor.run(image_patches, patch_mode=True)
>>> output
... "/path/to/Output.db"
>>> output = predictor.run(
>>>     image_patches,
>>>     patch_mode=True,
>>>     output_type="zarr")
>>> output
... "/path/to/Output.zarr"
>>> output = predictor.run(wsis, patch_mode=False)
>>> output.keys()
... ['wsi1.svs', 'wsi2.svs']
>>> output['wsi1.svs']
... {'/path/to/wsi1.db'}