PatchPredictor¶
- class PatchPredictor(model, batch_size=8, num_workers=0, weights=None, *, device='cpu', verbose=True)[source]¶
Patch-level prediction engine for digital histology images.
This class extends EngineABC to support patch-based inference using pretrained or custom models from TIAToolbox. It supports both patch and whole slide image (WSI) modes, and provides utilities for post-processing and saving predictions.
- Supported Models:
PatchPredictor performance on the Kather100K dataset [1].¶ Model name
F1score
alexnet-kather100k
0.965
resnet18-kather100k
0.990
resnet34-kather100k
0.991
resnet50-kather100k
0.989
resnet101-kather100k
0.989
resnext50_32x4d-kather100k
0.992
resnext101_32x8d-kather100k
0.991
wide_resnet50_2-kather100k
0.989
wide_resnet101_2-kather100k
0.990
densenet121-kather100k
0.993
densenet161-kather100k
0.992
densenet169-kather100k
0.992
densenet201-kather100k
0.991
mobilenet_v2-kather100k
0.990
mobilenet_v3_large-kather100k
0.991
mobilenet_v3_small-kather100k
0.992
googlenet-kather100k
0.992
PatchPredictor performance on the PCam dataset [2]¶ Model name
F1score
alexnet-pcam
0.840
resnet18-pcam
0.888
resnet34-pcam
0.889
resnet50-pcam
0.892
resnet101-pcam
0.888
resnext50_32x4d-pcam
0.900
resnext101_32x8d-pcam
0.892
wide_resnet50_2-pcam
0.901
wide_resnet101_2-pcam
0.898
densenet121-pcam
0.897
densenet161-pcam
0.893
densenet169-pcam
0.895
densenet201-pcam
0.891
mobilenet_v2-pcam
0.899
mobilenet_v3_large-pcam
0.895
mobilenet_v3_small-pcam
0.890
googlenet-pcam
0.867
- Parameters:
model (str | ModelABC) – A PyTorch model instance or name of a pretrained model from TIAToolbox. If a string is provided, pretrained weights will be downloaded unless overridden via weights. The user can request pretrained models from the toolbox model zoo using the list of pretrained models available at this link By default, the corresponding pretrained weights will also be downloaded.
batch_size (int) – Number of image patches processed per forward pass. Default is 8.
num_workers (int) – Number of workers for data loading. Default is 0.
weights (str | Path | None) –
Path to model weights. If None, default weights are used.
>>> engine = PatchPredictor( ... model="pretrained-model", ... weights="/path/to/pretrained-local-weights.pth" ... )
device (str) – Device to run the model on (e.g., “cpu”, “cuda”). Default is “cpu”.
verbose (bool) – Whether to enable verbose logging. Default is True.
- masks¶
Optional tissue masks for WSI processing. These are only utilized when patch_mode is False. If not provided, then a tissue mask will be automatically generated for whole slide images.
- ioconfig¶
IO configuration for patch extraction and resolution.
- Type:
- input_resolutions¶
Resolution settings for model input. Supported units are level, power and mpp. Keys should be “units” and “resolution” e.g., [{“units”: “mpp”, “resolution”: 0.25}]. Please see
WSIReaderfor details.
- patch_input_shape¶
Shape of input patches (height, width). Patches are at requested read resolution, not with respect to level 0, and must be positive.
- stride_shape¶
Stride used during patch extraction. Stride is at requested read resolution, not with respect to level 0, and must be positive. If not provided, stride_shape=patch_input_shape.
- labels¶
Optional labels for input images. Only a single label per image is supported.
- Type:
list | None
Example
>>> # list of 2 image patches as input >>> data = ['path/img.svs', 'path/img.svs'] >>> predictor = PatchPredictor(model="resnet18-kather100k") >>> output = predictor.run(data, patch_mode=False)
>>> # array of list of 2 image patches as input >>> data = np.array([img1, img2]) >>> predictor = PatchPredictor(model="resnet18-kather100k") >>> output = predictor.run(data, patch_mode=True)
>>> # list of 2 image patch files as input >>> data = ['path/img.png', 'path/img.png'] >>> predictor = PatchPredictor(model="resnet18-kather100k") >>> output = predictor.run(data, patch_mode=True)
>>> # list of 2 image tile files as input >>> tile_file = ['path/tile1.png', 'path/tile2.png'] >>> predictor = PatchPredictor(model="resnet18-kather100k") >>> output = predictor.run(tile_file, patch_mode=False)
>>> # list of 2 wsi files as input >>> wsi_file = ['path/wsi1.svs', 'path/wsi2.svs'] >>> predictor = PatchPredictor(model="resnet18-kather100k") >>> output = predictor.run(wsi_file, patch_mode=False)
References
[1] Kather, Jakob Nikolas, et al. “Predicting survival from colorectal cancer histology slides using deep learning: A retrospective multicenter study.” PLoS medicine 16.1 (2019): e1002730.
[2] Veeling, Bastiaan S., et al. “Rotation equivariant CNNs for digital pathology.” International Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2018.
Initialize the PatchPredictor engine.
- Parameters:
model (str | ModelABC) – A PyTorch model instance or name of a pretrained model from TIAToolbox. If a string is provided, the corresponding pretrained weights will be downloaded unless overridden via weights.
batch_size (int) – Number of image patches processed per forward pass. Default is 8.
num_workers (int) – Number of workers for data loading. Default is 0.
weights (str | Path | None) – Path to model weights. If None, default weights are used.
device (str) – device to run the model on (e.g., “cpu”, “cuda”). Default is “cpu”.
verbose (bool) – Whether to enable verbose logging. Default is True.
Methods
Post-process raw patch predictions from model inference.
Post-process predictions from whole slide image (WSI) inference.
Run the PatchPredictor engine on input images.
- post_process_patches(raw_predictions, **kwargs)[source]¶
Post-process raw patch predictions from model inference.
This method applies the model’s post-processing function to the raw predictions obtained from infer_patches(). The output is wrapped in a Dask array for efficient computation and memory handling.
- Parameters:
raw_predictions (dict[str, da.Array]) – Dictionary containing raw model predictions as Dask arrays.
**kwargs (PredictorRunParams) –
Additional runtime parameters to configure prediction.
- Optional Keys:
- auto_get_mask (bool):
Automatically generate segmentation masks using wsireader.tissue_mask() during processing.
- batch_size (int):
Number of image patches per forward pass.
- class_dict (dict):
Mapping of classification outputs to class names.
- device (str):
Device to run the model on (e.g., “cpu”, “cuda”).
- labels (list):
Optional labels for input images. Only a single label per image is supported.
- memory_threshold (int):
Memory usage threshold (percentage) to trigger caching behavior.
- num_workers (int):
Number of workers for DataLoader and post-processing.
- output_file (str):
Filename for saving output (e.g., “.zarr” or “.db”).
- return_labels (bool):
Whether to return labels with predictions.
- return_probabilities (bool):
Whether to return per-class probabilities in the output. If False, only predicted labels are returned.
- scale_factor (tuple[float, float]):
Scale factor for annotations (model_mpp / slide_mpp). Used to convert coordinates to baseline resolution.
- stride_shape (tuple[int, int]):
Stride used during WSI processing. Defaults to patch_input_shape if not provided.
- verbose (bool):
Whether to enable verbose logging.
self (PatchPredictor)
- Returns:
Post-processed predictions as a Dask array.
- Return type:
- post_process_wsi(raw_predictions, save_path, **kwargs)[source]¶
Post-process predictions from whole slide image (WSI) inference.
This method refines the raw patch-level predictions obtained from WSI inference. It typically applies spatial smoothing or other contextual operations using neighboring patch information. Internally, it delegates to post_process_patches().
- Parameters:
raw_predictions (dict[str, da.Array]) – Dictionary containing raw model predictions as Dask arrays.
save_path (Path) – Path to save the intermediate output. The intermediate output is saved in a zarr file.
**kwargs (PredictorRunParams) –
Additional runtime parameters to configure prediction.
- Optional Keys:
- auto_get_mask (bool):
Automatically generate segmentation masks using wsireader.tissue_mask() during processing.
- batch_size (int):
Number of image patches per forward pass.
- class_dict (dict):
Mapping of classification outputs to class names.
- device (str):
Device to run the model on (e.g., “cpu”, “cuda”).
- labels (list):
Optional labels for input images. Only a single label per image is supported.
- memory_threshold (int):
Memory usage threshold (percentage) to trigger caching behavior.
- num_workers (int):
Number of workers for DataLoader and post-processing.
- output_file (str):
Filename for saving output (e.g., “.zarr” or “.db”).
- return_labels (bool):
Whether to return labels with predictions.
- return_probabilities (bool):
Whether to return per-class probabilities in the output. If False, only predicted labels are returned.
- scale_factor (tuple[float, float]):
Scale factor for annotations (model_mpp / slide_mpp). Used to convert coordinates to baseline resolution.
- stride_shape (tuple[int, int]):
Stride used during WSI processing. Defaults to patch_input_shape if not provided.
- verbose (bool):
Whether to enable verbose logging.
self (PatchPredictor)
- Returns:
Post-processed predictions as a Dask array.
- Return type:
dask.array.Array
- run(images, *, masks=None, input_resolutions=None, patch_input_shape=None, ioconfig=None, patch_mode=True, save_dir=None, overwrite=False, output_type='dict', **kwargs)[source]¶
Run the PatchPredictor engine on input images.
This method orchestrates the full inference pipeline, including preprocessing, model inference, post-processing, and saving results. It supports both patch and whole slide image (WSI) modes.
- Parameters:
images (list[PathLike | WSIReader] | np.ndarray) – Input images or patches. When using patch mode, the input must be either a list of images, a list of image file paths or a numpy array of an image list.
masks (list[PathLike] | np.ndarray | None) – Optional masks for WSI processing. Only utilised when patch_mode is False. Patches are only generated within a masked area. If not provided, then a tissue mask will be automatically generated for whole slide images.
input_resolutions (list[dict[Units, Resolution]] | None) – Resolution settings for input heads. Supported units are level, power, and mpp. Keys should be “units” and “resolution”, e.g., [{“units”: “mpp”, “resolution”: 0.25}]. See
WSIReaderfor details.patch_input_shape (IntPair | None) – Shape of input patches (height, width), requested at read resolution. Must be positive.
ioconfig (IOPatchPredictorConfig | None) – IO configuration for patch extraction and resolution.
patch_mode (bool) – Whether to treat input as patches (True) or WSIs (False).
save_dir (PathLike | None) – Directory to save output files. Required for WSI mode.
overwrite (bool) – Whether to overwrite existing output files. Default is False.
output_type (str) – Desired output format: “dict”, “zarr”, “qupath” or “annotationstore”. Default value is “zarr”.
**kwargs (PredictorRunParams) –
Additional runtime parameters to configure prediction.
- Optional Keys:
- auto_get_mask (bool):
Automatically generate segmentation masks using wsireader.tissue_mask() during processing.
- batch_size (int):
Number of image patches per forward pass.
- class_dict (dict):
Mapping of classification outputs to class names.
- device (str):
Device to run the model on (e.g., “cpu”, “cuda”).
- labels (list):
Optional labels for input images. Only a single label per image is supported.
- memory_threshold (int):
Memory usage threshold (percentage) to trigger caching behavior.
- num_workers (int):
Number of workers for DataLoader and post-processing.
- output_file (str):
Filename for saving output (e.g., “.zarr” or “.db”).
- return_labels (bool):
Whether to return labels with predictions.
- return_probabilities (bool):
Whether to return per-class probabilities in the output. If False, only predicted labels are returned.
- scale_factor (tuple[float, float]):
Scale factor for annotations (model_mpp / slide_mpp). Used to convert coordinates to baseline resolution.
- stride_shape (tuple[int, int]):
Stride used during WSI processing. Defaults to patch_input_shape if not provided.
- verbose (bool):
Whether to enable verbose logging.
self (PatchPredictor)
- Returns:
If patch_mode is True: returns predictions or path to saved output.
If patch_mode is False: returns a dictionary mapping each WSI to its output path.
- Return type:
AnnotationStore | Path | str | dict
Examples
>>> wsis = ['wsi1.svs', 'wsi2.svs'] >>> image_patches = [np.ndarray, np.ndarray] >>> class PatchPredictor(EngineABC): >>> # Define all Abstract methods. >>> ... >>> predictor = PatchPredictor(model="resnet18-kather100k") >>> output = predictor.run(image_patches, patch_mode=True) >>> output ... "/path/to/Output.db" >>> output = predictor.run( >>> image_patches, >>> patch_mode=True, >>> output_type="zarr") >>> output ... "/path/to/Output.zarr" >>> output = predictor.run(wsis, patch_mode=False) >>> output.keys() ... ['wsi1.svs', 'wsi2.svs'] >>> output['wsi1.svs'] ... {'/path/to/wsi1.db'}