EngineABC

class EngineABC(model, batch_size=8, num_workers=0, weights=None, *, device='cpu', verbose=False)[source]

Abstract base class for TIAToolbox deep learning engines to run CNN models.

This class provides a unified interface for running inference on image patches or whole slide images (WSIs), handling preprocessing, batching, postprocessing, and saving predictions.

Parameters:
  • model (str | ModelABC) – Model name from TIAToolbox or a PyTorch model instance. The user can request pretrained models from the toolbox model zoo using the list of pretrained models available at this link By default, the corresponding pretrained weights will also be downloaded. However, you can override with your own set of weights using the weights parameter.

  • batch_size (int) – Number of patches per forward pass. Default is 8.

  • num_workers (int) – Number of workers for data loading. Default is 0.

  • weights (str | Path | None) –

    Path to model weights. If None, default weights are used.

    >>> engine = EngineABC(
    ...    model="pretrained-model",
    ...    weights="/path/to/pretrained-local-weights.pth"
    ... )
    

  • device (str) – Device to run the model on (e.g., “cpu”, “cuda”). Please see https://pytorch.org/docs/stable/tensor_attributes.html#torch.device for more details on input parameters for device. Default is “cpu”.

  • verbose (bool) – Enable verbose logging. Default is False.

images

Input images or patches. A list of image patches in NHWC format as a numpy array or a list of str/paths to WSIs.

Type:

list[str | Path] | np.ndarray

masks

Optional masks for WSIs. A list of tissue masks or binary masks corresponding to processing area of input images. These can be a list of numpy arrays or paths to the saved image masks. These are only utilized when patch_mode is False. Patches are only generated within a masked area. If not provided, then a tissue mask will be automatically generated for whole slide images, if auto_get_mask is True.

Type:

list[str | Path] | np.ndarray

patch_mode

Whether input is treated as patches. TIAToolbox defines an image as a patch if HWC of the input image matches with the HWC expected by the model. If HWC of the input image does not match with the HWC expected by the model, then the patch_mode must be set to False which will allow the engine to extract patches from the input image. In this case, when the patch_mode is False the input images are treated as WSIs. Default value is True.

Type:

bool

model

Loaded PyTorch model. For a full list of pretrained models, refer to the docs By default, the corresponding pretrained weights will also be downloaded. However, you can override with your own set of weights via the weights argument. Argument is case-insensitive.

Type:

ModelABC

ioconfig

IO configuration (ModelIOConfigABC) for model input/output.

Type:

ModelIOConfigABC

dataloader

Torch DataLoader for inference.

Type:

DataLoader

return_labels

Whether to return labels with probabilities and predictions.

Type:

bool

input_resolutions

Resolution settings for input heads. Supported units are level, power and mpp. Keys should be “units” and “resolution” e.g., [{“units”: “mpp”, “resolution”: 0.25}]. Please see WSIReader for details.

Type:

list[dict[Units, Resolution]]

patch_input_shape

Shape of input patches. Patches are at requested read resolution, not with respect to level 0, and must be positive.

Type:

IntPair

stride_shape

Stride used during WSI processing. Stride is at requested read resolution, not with respect to level 0, and must be positive. If not provided, stride_shape=patch_input_shape.

Type:

IntPair

batch_size

Number of patches per forward pass.

Type:

int

labels

Optional labels for input images. Only a single label per image is supported.

Type:

list | None

num_workers

Number of workers for data loading.

Type:

int

patch_input_shape

Shape of input patches.

Type:

IntPair | None

input_resolutions

Resolution settings for input heads.

Type:

list[dict[Units, Resolution]] | None

return_labels

Whether to return labels with predictions.

Type:

bool

stride_shape

Stride used during WSI processing.

Type:

IntPair | None

verbose

Whether to enable verbose logging.

Type:

bool

dataloader

Torch DataLoader for inference.

Type:

DataLoader | None

drop_keys

Keys to exclude from model output.

Type:

list

output_type

Format of output (“dict”, “zarr”, “qupath”, “AnnotationStore”).

Type:

Any

verbose

Whether to enable verbose logging.

Type:

bool

Example

>>> # Inherit from EngineABC
>>> class MyEngine(EngineABC):
>>>     def __init__(self, model, weights, verbose):
>>>         super().__init__(model=model, weights=weights, verbose=verbose)
>>> engine = MyEngine(model="resnet18-kather100k")
>>> output = engine.run(images, patch_mode=True)
>>> # Define all the abstract classes
>>> data = np.array([np.ndarray, np.ndarray])
>>> engine = TestEngineABC(model="resnet18-kather100k")
>>> output = engine.run(data, patch_mode=True)
>>> # array of list of 2 image patches as input
>>> data = np.array([np.ndarray, np.ndarray])
>>> engine = TestEngineABC(model="resnet18-kather100k")
>>> output = engine.run(data, patch_mode=True)
>>> # list of 2 image files as input
>>> image = ['path/image1.png', 'path/image2.png']
>>> engine = TestEngineABC(model="resnet18-kather100k")
>>> output = engine.run(image, patch_mode=False)
>>> # list of 2 wsi files as input
>>> wsi_file = ['path/wsi1.svs', 'path/wsi2.svs']
>>> engine = TestEngineABC(model="resnet18-kather100k")
>>> output = engine.run(wsi_file, patch_mode=False)

Initialize the EngineABC instance.

Parameters:
  • model (str | ModelABC) – Model name from TIAToolbox or a PyTorch model instance.

  • batch_size (int) – Number of patches per forward pass. Default is 8.

  • num_workers (int) – Number of workers for data loading. Default is 0.

  • weights (str | Path | None) – Path to model weights. If None, default weights are used.

  • device (str) – Device to run the model on (e.g., “cpu”, “cuda”). Default is “cpu”.

  • verbose (bool) – Enable verbose logging. Default is False.

Methods

get_dataloader

Pre-process images and masks and return a DataLoader for inference.

infer_patches

Run model inference on image patches and return predictions.

infer_wsi

Run model inference on a whole slide image (WSI).

post_process_patches

Post-process raw patch predictions from inference.

post_process_wsi

Post-process predictions from whole slide image (WSI) inference.

run

Run the engine on input images.

save_predictions

Save model predictions to disk or return them in memory.

save_predictions_as_zarr

Save model predictions as a zarr file.

get_dataloader(images, masks=None, labels=None, ioconfig=None, *, patch_mode=True, auto_get_mask=True)[source]

Pre-process images and masks and return a DataLoader for inference.

Parameters:
  • images (list[str | Path] | np.ndarray) – A list of image patches in NHWC format as a numpy array, or a list of file paths to WSIs. When patch_mode is False, expects file paths to WSIs.

  • masks (Path | None) – Optional list of masks used when patch_mode is False. Patches are generated only within masked areas. If not provided, tissue masks are automatically generated.

  • labels (list | None) – Optional list of labels. Only one label per image is supported.

  • ioconfig (ModelIOConfigABC | None) – IO configuration object specifying patch size, stride, and resolution.

  • patch_mode (bool) – Whether to treat input as patches (True) or WSIs (False).

  • auto_get_mask (bool) – Whether to automatically generate a tissue mask using wsireader.tissue_mask() when patch_mode is False. If True, only tissue regions are processed. If False, all patches are processed. Default is True.

  • self (EngineABC)

Returns:

A PyTorch DataLoader configured for inference.

Return type:

torch.utils.data.DataLoader

infer_patches(dataloader, *, return_coordinates=False)[source]

Run model inference on image patches and return predictions.

This method performs batched inference using a PyTorch DataLoader, and accumulates predictions in Dask arrays. It supports optional inclusion of coordinates and labels in the output.

Parameters:
  • dataloader (DataLoader) – PyTorch DataLoader containing image patches for inference.

  • return_coordinates (bool) – Whether to include coordinates in the output. Required when called by infer_wsi and patch_mode is False.

  • self (EngineABC)

Returns:

Dictionary containing prediction results as Dask arrays. Keys include:

  • ”probabilities”: Model output probabilities.

  • ”labels”: Ground truth labels (if return_labels is True).

  • ”coordinates”: Patch coordinates (if return_coordinates is True).

Return type:

dict[str, dask.array.Array]

infer_wsi(dataloader, save_path, **kwargs)[source]

Run model inference on a whole slide image (WSI).

This method performs inference on a WSI using the provided DataLoader, and accumulates predictions in Dask arrays. Optionally includes coordinates and labels in the output.

Parameters:
  • dataloader (DataLoader) – PyTorch DataLoader configured for WSI processing.

  • save_path (Path) – Path to save the intermediate output. The intermediate output is saved in a zarr file.

  • **kwargs (EngineABCRunParams) –

    Additional runtime parameters to update engine attributes.

    Optional Keys:
    auto_get_mask (bool):

    Automatically generate segmentation masks using wsireader.tissue_mask() during processing.

    batch_size (int):

    Number of image patches per forward pass.

    class_dict (dict):

    Mapping of classification outputs to class names.

    device (str):

    Device to run the model on (e.g., “cpu”, “cuda”). See https://pytorch.org/docs/stable/tensor_attributes.html#torch.device for more details.

    labels (list):

    Optional labels for input images. Only a single label per image is supported.

    memory_threshold (int):

    Memory usage threshold (percentage) to trigger caching behavior.

    num_workers (int):

    Number of workers for DataLoader and post-processing.

    output_file (str):

    Filename for saving output (e.g., “zarr” or “annotationstore”).

    return_labels (bool):

    Whether to return labels with predictions.

    scale_factor (tuple[float, float]):

    Scale factor for annotations (model_mpp / slide_mpp). Used to convert coordinates from non-baseline to baseline resolution.

    stride_shape (IntPair):

    Stride used during WSI processing, at requested read resolution. Must be positive. Defaults to patch_input_shape if not provided.

    verbose (bool):

    Whether to enable verbose logging.

  • self (EngineABC)

Returns:

Dictionary containing prediction results as Dask arrays.

Return type:

dict

post_process_patches(raw_predictions, **kwargs)[source]

Post-process raw patch predictions from inference.

This method applies a post-processing function (e.g., smoothing, filtering) to the raw model predictions. It supports delayed execution using Dask and returns a Dask array for efficient computation.

Parameters:
  • raw_predictions (dict[str, da.Array]) – Dictionary containing raw model predictions as Dask arrays.

  • **kwargs (EngineABCRunParams) –

    Additional runtime parameters to update engine attributes.

    Optional Keys:
    auto_get_mask (bool):

    Automatically generate segmentation masks using wsireader.tissue_mask() during processing.

    batch_size (int):

    Number of image patches per forward pass.

    class_dict (dict):

    Mapping of classification outputs to class names.

    device (str):

    Device to run the model on (e.g., “cpu”, “cuda”). See https://pytorch.org/docs/stable/tensor_attributes.html#torch.device for more details.

    labels (list):

    Optional labels for input images. Only a single label per image is supported.

    memory_threshold (int):

    Memory usage threshold (percentage) to trigger caching behavior.

    num_workers (int):

    Number of workers for DataLoader and post-processing.

    output_file (str):

    Filename for saving output (e.g., “zarr” or “annotationstore”).

    return_labels (bool):

    Whether to return labels with predictions.

    scale_factor (tuple[float, float]):

    Scale factor for annotations (model_mpp / slide_mpp). Used to convert coordinates from non-baseline to baseline resolution.

    stride_shape (IntPair):

    Stride used during WSI processing, at requested read resolution. Must be positive. Defaults to patch_input_shape if not provided.

    verbose (bool):

    Whether to enable verbose logging.

  • self (EngineABC)

Returns:

Post-processed predictions as a dictionary of Dask arrays.

Return type:

dict[str, da.Array]

post_process_wsi(raw_predictions, save_path, **kwargs)[source]

Post-process predictions from whole slide image (WSI) inference.

This method applies a post-processing function (e.g., smoothing, filtering) to the raw model predictions. It supports delayed execution using Dask and returns a Dask array for efficient computation.

Parameters:
  • raw_predictions (dict[str, da.Array]) – Dictionary containing raw model predictions as Dask arrays.

  • save_path (Path) – Path to save the intermediate output. The intermediate output is saved in a zarr file.

  • **kwargs (EngineABCRunParams) –

    Additional runtime parameters to update engine attributes.

    Optional Keys:
    auto_get_mask (bool):

    Automatically generate segmentation masks using wsireader.tissue_mask() during processing.

    batch_size (int):

    Number of image patches per forward pass.

    class_dict (dict):

    Mapping of classification outputs to class names.

    device (str):

    Device to run the model on (e.g., “cpu”, “cuda”). See https://pytorch.org/docs/stable/tensor_attributes.html#torch.device for more details.

    labels (list):

    Optional labels for input images. Only a single label per image is supported.

    memory_threshold (int):

    Memory usage threshold (percentage) to trigger caching behavior.

    num_workers (int):

    Number of workers for DataLoader and post-processing.

    output_file (str):

    Filename for saving output (e.g., “zarr” or “annotationstore”).

    return_labels (bool):

    Whether to return labels with predictions.

    scale_factor (tuple[float, float]):

    Scale factor for annotations (model_mpp / slide_mpp). Used to convert coordinates from non-baseline to baseline resolution.

    stride_shape (IntPair):

    Stride used during WSI processing, at requested read resolution. Must be positive. Defaults to patch_input_shape if not provided.

    verbose (bool):

    Whether to enable verbose logging.

  • self (EngineABC)

Returns:

Post-processed predictions as a dictionary of Dask arrays.

Return type:

dict[str, da.Array]

run(images, *, masks=None, input_resolutions=None, patch_input_shape=None, ioconfig=None, patch_mode=True, save_dir=None, overwrite=False, output_type='dict', **kwargs)[source]

Run the engine on input images.

This method orchestrates the full inference pipeline, including preprocessing, model inference, post-processing, and saving results. It supports both patch and WSI modes.

Parameters:
  • images (list[PathLike | Path | WSIReader] | np.ndarray) – List of input images or a NumPy array of patches.

  • masks (list[PathLike | Path] | np.ndarray | None) – Optional list of masks for WSI processing. Only utilised when patch_mode is False. Patches are only generated within a masked area. If not provided, then a tissue mask will be automatically generated for whole slide images.

  • input_resolutions (list[dict[Units, Resolution]] | None) – Resolution settings for input heads. Supported units are level, power, and mpp. Keys should be “units” and “resolution”, e.g., [{“units”: “mpp”, “resolution”: 0.25}]. See WSIReader for details.

  • patch_input_shape (IntPair | None) – Shape of input patches (height, width), requested at read resolution. Must be positive.

  • ioconfig (ModelIOConfigABC | None) – IO configuration for patch extraction and resolution settings.

  • patch_mode (bool) – Whether to treat input as patches (True) or WSIs (False). Default is True.

  • save_dir (PathLike | Path | None) – Directory to save output files. Required for WSI mode.

  • overwrite (bool) – Whether to overwrite existing output files. Default is False.

  • output_type (str) – Desired output format: “dict”, “zarr”, “qupath”, or “annotationstore”.

  • **kwargs (EngineABCRunParams) –

    Additional runtime parameters to update engine attributes.

    Optional Keys:
    auto_get_mask (bool):

    Automatically generate segmentation masks using wsireader.tissue_mask() during processing.

    batch_size (int):

    Number of image patches per forward pass.

    class_dict (dict):

    Mapping of classification outputs to class names.

    device (str):

    Device to run the model on (e.g., “cpu”, “cuda”). See https://pytorch.org/docs/stable/tensor_attributes.html#torch.device for more details.

    labels (list):

    Optional labels for input images. Only a single label per image is supported.

    memory_threshold (int):

    Memory usage threshold (percentage) to trigger caching behavior.

    num_workers (int):

    Number of workers for DataLoader and post-processing.

    output_file (str):

    Filename for saving output (e.g., “zarr” or “annotationstore”).

    return_labels (bool):

    Whether to return labels with predictions.

    scale_factor (tuple[float, float]):

    Scale factor for annotations (model_mpp / slide_mpp). Used to convert coordinates from non-baseline to baseline resolution.

    stride_shape (IntPair):

    Stride used during WSI processing, at requested read resolution. Must be positive. Defaults to patch_input_shape if not provided.

    verbose (bool):

    Whether to enable verbose logging.

  • self (EngineABC)

Returns:

  • If patch_mode is True: returns predictions or path to saved output.

  • If patch_mode is False: returns a dictionary mapping each WSI to its output path.

Return type:

AnnotationStore | Path | str | dict

Examples

>>> wsis = ['wsi1.svs', 'wsi2.svs']
>>> class PatchPredictor(EngineABC):
>>> # Define all Abstract methods.
>>>     ...
>>> predictor = PatchPredictor(model="resnet18-kather100k")
>>> output = predictor.run(image_patches, patch_mode=True)
>>> output
... "/path/to/Output.db"
>>> output = predictor.run(
>>>     image_patches,
>>>     patch_mode=True,
>>>     output_type="zarr")
>>> output
... "/path/to/Output.zarr"
>>> output = predictor.run(wsis, patch_mode=False)
>>> output.keys()
... ['wsi1.svs', 'wsi2.svs']
>>> output['wsi1.svs']
... {'/path/to/wsi1.db'}
save_predictions(processed_predictions, output_type, save_path=None, **kwargs)[source]

Save model predictions to disk or return them in memory.

Depending on the output type, this method saves predictions as a zarr group, an AnnotationStore (SQLite database), or returns them as a dictionary.

Parameters:
  • processed_predictions (dict) – Dictionary containing processed model predictions.

  • output_type (str) – Desired output format. Supported values are “dict”, “zarr”, “qupath” and “annotationstore”.

  • save_path (Path | None) – Path to save the output file. Required for “zarr” and “annotationstore” formats.

  • **kwargs (EngineABCRunParams) –

    Additional runtime parameters to update engine attributes.

    Optional Keys:
    auto_get_mask (bool):

    Automatically generate segmentation masks using wsireader.tissue_mask() during processing.

    batch_size (int):

    Number of image patches per forward pass.

    class_dict (dict):

    Mapping of classification outputs to class names.

    device (str):

    Device to run the model on (e.g., “cpu”, “cuda”). See https://pytorch.org/docs/stable/tensor_attributes.html#torch.device for more details.

    labels (list):

    Optional labels for input images. Only a single label per image is supported.

    memory_threshold (int):

    Memory usage threshold (percentage) to trigger caching behavior.

    num_workers (int):

    Number of workers for DataLoader and post-processing.

    output_file (str):

    Filename for saving output (e.g., “zarr” or “annotationstore”).

    return_labels (bool):

    Whether to return labels with predictions.

    scale_factor (tuple[float, float]):

    Scale factor for annotations (model_mpp / slide_mpp). Used to convert coordinates from non-baseline to baseline resolution.

    stride_shape (IntPair):

    Stride used during WSI processing, at requested read resolution. Must be positive. Defaults to patch_input_shape if not provided.

    verbose (bool):

    Whether to enable verbose logging.

  • self (EngineABC)

Returns:

  • If output_type is “dict”: returns predictions as a dictionary.

  • If output_type is “zarr”: returns path to saved zarr file.

  • If output_type is “qupath”: returns a QuPath JSON or path to .json file.

  • If output_type is “annotationstore”: returns an AnnotationStore or path to .db file.

Return type:

dict | AnnotationStore | Path

Raises:

TypeError – If an unsupported output_type is provided.

save_predictions_as_zarr(processed_predictions, save_path, keys_to_compute, task_name=None)[source]

Save model predictions as a zarr file.

This method saves the processed predictions to a zarr file at the specified path.

Parameters:
  • processed_predictions (dict) – Dictionary containing processed model predictions.

  • save_path (Path) – Path to save the zarr file.

  • keys_to_compute (list) – List of keys in processed_predictions to save.

  • task_name (str) – Task Name for Multitask outputs.

  • self (EngineABC)

Returns:

Path to the saved zarr file.

Return type:

save_path (Path)