NucleusDetector

class NucleusDetector(model, batch_size=8, num_workers=0, weights=None, *, device='cpu', verbose=True)[source]

Nucleus detection engine for digital histology images.

This class extends SemanticSegmentor to support instance-level nucleus detection using pretrained or custom models from TIAToolbox. It operates in both patch-level and whole slide image (WSI) modes and provides utilities for post-processing (e.g., centroid extraction, thresholding, tile-overlap handling), merging predictions, and saving results in multiple output formats. Supported TIAToolbox models include nucleus-detection architectures such as mapde-conic and mapde-crchisto. For the full list of pretrained models, refer to the model zoo documentation: https://tia-toolbox.readthedocs.io/en/latest/pretrained.html

The class integrates seamlessly with the TIAToolbox engine interface, inheriting the data loading, inference orchestration, memory-aware chunking, and output-saving conventions of SemanticSegmentor, while overriding only the nucleus-specific post-processing and export routines.

Parameters:
  • model (str or nn.Module) – Defined PyTorch model or name of the existing models support by tiatoolbox for processing the data e.g., mapde-conic, mapde-crchisto. For a full list of pretrained models, please refer to the docs <https://tia-toolbox.readthedocs.io/en/latest/pretrained.html>. By default, the corresponding pretrained weights will also be downloaded. However, you can override with your own set of weights via the weights argument. Argument is case insensitive.

  • batch_size (int) – Number of image patches processed per forward pass. Default is 8.

  • num_workers (int) – Number of workers for torch.utils.data.DataLoader. Default is 0.

  • weights (str or pathlib.Path or None) – Optional path to pretrained weights. If None and model is a string, default pretrained weights for that model will be used. If model is an nn.Module, weights are loaded only if provided.

  • device (str) – Device on which the model will run (e.g., "cpu", "cuda"). Default is "cpu".

  • verbose (bool) – Whether to output logging information. Default is True.

images

Input images supplied to the engine, either as WSI paths or NHWC-formatted patches.

Type:

list[str or Path] or np.ndarray

masks

Optional tissue masks for WSI processing. Only used when patch_mode=False.

Type:

list[str or Path] or np.ndarray

patch_mode

Whether input is treated as image patches (True) or as WSIs (False).

Type:

bool

model

Loaded PyTorch model. Can be a pretrained TIAToolbox model or a custom user-provided model.

Type:

ModelABC

ioconfig

IO configuration specifying patch extraction shape, stride, and resolution settings for inference.

Type:

ModelIOConfigABC

return_labels

Whether to include labels in the output, if provided.

Type:

bool

input_resolutions

Resolution settings for model input heads. Supported units are "level", "power", and "mpp".

Type:

list[dict]

patch_input_shape

Height and width of input patches read from slides, expressed in read resolution space.

Type:

tuple[int, int]

stride_shape

Stride used during patch extraction. Defaults to patch_input_shape.

Type:

tuple[int, int]

drop_keys

Keys to exclude from model output when saving results.

Type:

list

output_type

Output format ("dict", "zarr", "qupath", or "annotationstore").

Type:

str

Examples

>>> from tiatoolbox.models.engine.nucleus_detector import NucleusDetector
>>> model_name = "mapde-conic"
>>> detector = NucleusDetector(model=model_name, batch_size=16, num_workers=8)
>>> detector.run(
...     images=[pathlib.Path("example_wsi.tiff")],
...     patch_mode=False,
...     device="cuda",
...     save_dir=pathlib.Path("output_directory/"),
...     overwrite=True,
...     output_type="annotationstore",
...     class_dict={0: "nucleus"},
...     auto_get_mask=True,
...     memory_threshold=80,
... )

Initialize NucleusDetector.

This constructor follows the standard TIAToolbox engine initialization workflow. A model may be provided either as a string referring to a pretrained TIAToolbox architecture or as a custom torch.nn.Module. When model is a string, the corresponding pretrained weights are automatically downloaded unless explicitly overridden via weights.

Parameters:
  • model (str or ModelABC) – A PyTorch model instance or the name of a pretrained TIAToolbox model. If a string is provided, default pretrained weights are loaded unless weights is supplied to override them.

  • batch_size (int) – Number of image patches processed per forward pass. Default is 8.

  • num_workers (int) – Number of workers used for torch.utils.data.DataLoader. Default is 0.

  • weights (str or Path or None) – Path to model weights. If None and model is a string, the default pretrained weights for that model will be used. If model is a nn.Module, weights are loaded only when specified here.

  • device (str) – Device on which the model will run (e.g., "cpu", "cuda"). Default is "cpu".

  • verbose (bool) – Whether to enable verbose logging during initialization and inference. Default is True.

Methods

post_process_patches

Post-process patch-level detection outputs.

post_process_wsi

Post-process WSI-level nucleus detection outputs.

run

Run the nucleus detection engine on input images.

save_predictions

Save nucleus detections to disk or return them in memory.

post_process_patches(raw_predictions, **kwargs)[source]

Post-process patch-level detection outputs.

Applies the model’s post-processing function (e.g., centroid extraction and thresholding) to each patch’s probability map, yielding per-patch detection arrays suitable for saving or further merging.

Parameters:
  • raw_predictions (dict[str, da.Array]) – Dictionary containing raw model predictions as Dask arrays.

  • **kwargs (NucleusDetectorRunParams) –

    Additional runtime parameters to configure segmentation.

    Optional Keys:
    min_distance (int):

    Minimum separation between nuclei (in pixels) used during centroid extraction/post-processing.

    threshold_abs (float):

    Absolute detection threshold applied to model outputs.

    threshold_rel (float):

    Relative detection threshold (e.g., with respect to local maxima).

  • self (NucleusDetector)

Returns:

A dictionary of lists (one list per patch), with keys:
  • "x" (list[dask array]):

    1-D object dask arrays of x coordinates

  • "y" (list[dask array]):

    1-D object dask arrays of y coordinates

  • "classes" (list[dask array]):

    1-D object dask arrays of class IDs

  • "probabilities" (list[dask array]):

    1-D object dask arrays of detection probabilities

Return type:

dict[str, list[da.Array]]

Notes

  • If thresholds are not provided via kwargs, model defaults are used.

post_process_wsi(raw_predictions, save_path, **kwargs)[source]

Post-process WSI-level nucleus detection outputs.

Processes the full-slide prediction map using Dask’s block-wise operations to extract nuclei centroids across the entire WSI. The prediction map is first re-chunked to the model’s preferred post-processing tile shape, and dask.map_overlap with halo padding is used to facilitate centroid extraction on large prediction maps. The resulting centroid maps are computed and saved to Zarr storage for memory-efficient processing, then converted into detection arrays (x, y, classes, probabilities) through sequential block processing.

Parameters:
  • raw_predictions (dict[str, da.Array]) – Dictionary containing raw model predictions as Dask arrays.

  • save_path (Path) – Path to save the intermediate output. The intermediate output is saved in a zarr file.

  • **kwargs (NucleusDetectorRunParams) –

    Additional runtime parameters to configure segmentation.

    Optional Keys:
    min_distance (int):

    Minimum distance separating two nuclei (in pixels).

    threshold_abs (float):

    Absolute detection threshold applied to model outputs.

    threshold_rel (float):

    Relative detection threshold (e.g., with respect to local maxima).

    postproc_tile_shape (tuple[int, int]):

    Tile shape (height, width) for post-processing rechunking.

  • self (NucleusDetector)

Returns:

A dictionary mapping detection fields to 1-D Dask arrays:
  • "x": x coordinates of detected nuclei.

  • "y": y coordinates of detected nuclei.

  • "classes": class IDs.

  • "probabilities": detection probabilities.

Return type:

dict[str, da.Array]

Notes

  • Halo padding ensures that nuclei crossing tile/chunk boundaries are not fragmented or duplicated.

  • If thresholds are not explicitly provided, model defaults are used.

  • Centroid maps are computed and saved to Zarr storage to avoid out-of-memory errors on large WSIs.

  • The Zarr-backed centroid maps are then processed block-by-block to extract detections incrementally.

run(images, *, masks=None, input_resolutions=None, patch_input_shape=None, ioconfig=None, patch_mode=True, save_dir=None, overwrite=False, output_type='dict', **kwargs)[source]

Run the nucleus detection engine on input images.

This method orchestrates the full inference pipeline, including preprocessing, model inference, post-processing, and saving results. It supports both patch-level and whole slide image (WSI) modes.

Parameters:
  • images (list[PathLike | WSIReader] | np.ndarray) – Input images or patches. Can be a list of file paths, WSIReader objects, or a NumPy array of image patches.

  • masks (list[PathLike] | np.ndarray | None) – Optional masks for WSI processing. Only used when patch_mode is False.

  • input_resolutions (list[dict[Units, Resolution]] | None) – Resolution settings for input heads. Supported units are level, power, and mpp. Keys should be “units” and “resolution”, e.g., [{“units”: “mpp”, “resolution”: 0.25}]. See WSIReader for details.

  • patch_input_shape (IntPair | None) – Shape of input patches (height, width), requested at read resolution. Must be positive.

  • ioconfig (IOSegmentorConfig | None) – IO configuration for patch extraction and resolution.

  • patch_mode (bool) – Whether to treat input as patches (True) or WSIs (False). Default is True.

  • save_dir (PathLike | None) – Directory to save output files. Required for WSI mode.

  • overwrite (bool) – Whether to overwrite existing output files. Default is False.

  • output_type (str) – Desired output format: “dict”, “zarr”, or “annotationstore”. Default is “dict”.

  • **kwargs (NucleusDetectorRunParams) –

    Additional runtime parameters to configure segmentation.

    Optional Keys:
    auto_get_mask (bool):

    Whether to automatically generate tissue masks using wsireader.tissue_mask() during processing.

    batch_size (int):

    Number of image patches to feed to the model in a forward pass.

    class_dict (dict):

    Optional dictionary mapping classification outputs to class names.

    device (str):

    Device to run the model on (e.g., “cpu”, “cuda”).

    labels (list):

    Optional labels for input images. Only a single label per image is supported.

    memory_threshold (int):

    Memory usage threshold (in percentage) to trigger caching behavior.

    num_workers (int):

    Number of workers used in DataLoader.

    output_file (str):

    Output file name for saving results (e.g., .zarr or .db).

    output_resolutions (Resolution):

    Resolution used for writing output predictions.

    patch_output_shape (tuple[int, int]):

    Shape of output patches (height, width).

    min_distance (int):

    Minimum distance separating two nuclei (in pixels).

    threshold_abs (float):

    Absolute detection threshold applied to model outputs.

    threshold_rel (float):

    Relative detection threshold (e.g., with respect to local maxima).

    postproc_tile_shape (tuple[int, int]):

    Tile shape (height, width) for post-processing (in pixels).

    return_labels (bool):

    Whether to return labels with predictions.

    return_probabilities (bool):

    Whether to return per-class probabilities.

    scale_factor (tuple[float, float]):

    Scale factor for converting annotations to baseline resolution. Typically model_mpp / slide_mpp.

    stride_shape (tuple[int, int]):

    Stride used during WSI processing. Defaults to patch_input_shape.

    verbose (bool):

    Whether to output logging information.

  • self (NucleusDetector)

Returns:

  • If patch_mode is True: returns predictions or path to saved output.

  • If patch_mode is False: returns a dictionary mapping each WSI to its output path.

Return type:

AnnotationStore | Path | str | dict | list[Path]

Examples

>>> from tiatoolbox.models.engine.nucleus_detector import NucleusDetector
>>> detector = NucleusDetector(model="mapde-conic")
>>> # WSI workflow: save to AnnotationStore (.db)
>>> out = detector.run(
...     images=[pathlib.Path("example_wsi.tiff")],
...     patch_mode=False,
...     device="cuda",
...     save_dir=pathlib.Path("output_directory/"),
...     overwrite=True,
...     output_type="annotationstore",
...     class_dict={0: "nucleus"},
...     auto_get_mask=True,
...     memory_threshold=80,
... )
>>> # Patch workflow: return in-memory detections
>>> patches = [np.ndarray, np.ndarray]  # NHWC
>>> out = detector.run(patches, patch_mode=True, output_type="dict")
save_predictions(processed_predictions, output_type, save_path=None, **kwargs)[source]

Save nucleus detections to disk or return them in memory.

Saves post-processed detection outputs in one of the supported formats. If patch_mode=True, predictions are saved per image. If patch_mode=False, detections are merged and saved as a single output.

Parameters:
  • processed_predictions (dict) –

    Dictionary containing processed detection results. Expected to include a "predictions" key with detection arrays. The internal structure follows TIAToolbox conventions and may differ slightly between patch and WSI modes:

    • Patch mode (patch_mode=True):
      • "x" (list[da.Array]):

        per-patch x coordinates.

      • "y" (list[da.Array]):

        per-patch y coordinates.

      • "classes" (list[da.Array]):

        per-patch class IDs.

      • "probabilities" (list[da.Array]):

        per-patch detection probabilities.

    • WSI mode (patch_mode=False):
      • "x" (da.Array):

        x coordinates.

      • "y" (da.Array):

        y coordinates.

      • "classes" (da.Array):

        class IDs.

      • "probabilities" (da.Array):

        detection probabilities.

  • output_type (str) – Desired output format: "dict", "zarr", "qupath" or "annotationstore".

  • save_path (Path | None) – Path at which to save the output file(s). Required for file outputs (e.g., Zarr or SQLite .db). If None and output_type="dict", results are returned in memory.

  • **kwargs (NucleusDetectorRunParams) –

    Additional runtime parameters to configure segmentation.

    Optional Keys:
    auto_get_mask (bool):

    Whether to automatically generate segmentation masks using wsireader.tissue_mask() during processing.

    batch_size (int):

    Number of image patches to feed to the model in a forward pass.

    class_dict (dict):

    Optional dictionary mapping classification outputs to class names.

    device (str):

    Device to run the model on (e.g., “cpu”, “cuda”).

    labels (list):

    Optional labels for input images. Only a single label per image is supported.

    memory_threshold (int):

    Memory usage threshold (in percentage) to trigger caching behavior.

    num_workers (int):

    Number of workers used in DataLoader.

    output_file (str):

    Output file name for saving results (e.g., .zarr or .db).

    output_resolutions (Resolution):

    Resolution used for writing output predictions.

    patch_output_shape (tuple[int, int]):

    Shape of output patches (height, width).

    min_distance (int):

    Minimum distance separating two nuclei (in pixels).

    postproc_tile_shape (tuple[int, int]):

    Tile shape (height, width) for post-processing (in pixels).

    return_labels (bool):

    Whether to return labels with predictions.

    return_probabilities (bool):

    Whether to return per-class probabilities.

    scale_factor (tuple[float, float]):

    Scale factor for converting annotations to baseline resolution. Typically model_mpp / slide_mpp.

    stride_shape (tuple[int, int]):

    Stride used during WSI processing. Defaults to patch_input_shape.

    verbose (bool):

    Whether to output logging information.

  • self (NucleusDetector)

Returns:

  • If output_type="dict":

    returns a Python dictionary of predictions.

  • If output_type="zarr":

    returns the path to the saved .zarr group.

  • If output_type="qupath":

    returns QuPath JSON or the path(s) to saved .json file(s). In patch mode, a list of per-image paths may be returned.

  • If output_type="annotationstore":

    returns an AnnotationStore handle or the path(s) to saved .db file(s). In patch mode, a list of per-image paths may be returned.

Return type:

dict | AnnotationStore | Path | list[Path]

Notes

  • For non-AnnotationStore outputs, this method delegates to the base engine’s saving function to preserve consistency across TIAToolbox engines.