"""Nucleus Detection Engine for Digital Pathology (WSIs and patches).
This module implements the `NucleusDetector` class which extends
`SemanticSegmentor` to perform instance-level nucleus detection on
histology images. It supports patch-mode and whole slide image (WSI)
workflows using TIAToolbox or custom PyTorch models, and provides
utilities for parallel post-processing (centroid extraction, thresholding),
merging detections across patches, and exporting results in multiple
formats (in-memory dict, Zarr, AnnotationStore).
Classes
-------
NucleusDetectorRunParams
TypedDict specifying runtime configuration keys for detection.
NucleusDetector
Core engine for nucleus detection on image patches or WSIs.
Examples:
--------
>>> from tiatoolbox.models.engine.nucleus_detector import NucleusDetector
>>> detector = NucleusDetector(model="mapde-conic")
>>> # WSI workflow: save to AnnotationStore (.db)
>>> out = detector.run(
... images=[pathlib.Path("example_wsi.tiff")],
... patch_mode=False,
... device="cuda",
... save_dir=pathlib.Path("output_directory/"),
... overwrite=True,
... output_type="annotationstore",
... class_dict={0: "nucleus"},
... auto_get_mask=True,
... memory_threshold=80,
... )
>>> # Patch workflow: return in-memory detections
>>> patches = [np.ndarray, np.ndarray] # NHWC
>>> out = detector.run(patches, patch_mode=True, output_type="dict")
Notes:
-----
- Outputs can be returned as Python dictionaries, saved as Zarr groups,
or converted to AnnotationStore (.db).
- Post-processing uses tile rechunking and halo padding to facilitate
centroid extraction near chunk boundaries.
"""
from __future__ import annotations
import shutil
from pathlib import Path
from typing import TYPE_CHECKING
import dask.array as da
import numpy as np
import zarr
from matplotlib import pyplot as plt
from shapely.geometry import Point
from tqdm.auto import tqdm
from tiatoolbox import logger
from tiatoolbox.annotation.storage import Annotation, SQLiteStore
from tiatoolbox.models.engine.semantic_segmentor import (
SemanticSegmentor,
SemanticSegmentorRunParams,
)
from tiatoolbox.utils.misc import (
save_annotations,
save_qupath_json,
tqdm_dask_progress_bar,
)
if TYPE_CHECKING: # pragma: no cover
import os
from typing import Unpack
from tiatoolbox.annotation import AnnotationStore
from tiatoolbox.models.models_abc import ModelABC
from tiatoolbox.type_hints import IntPair, Resolution, Units
from tiatoolbox.wsicore import WSIReader
from .io_config import IOSegmentorConfig
[docs]
class NucleusDetectorRunParams(SemanticSegmentorRunParams, total=False):
"""Runtime parameters for configuring the `NucleusDetector.run()` method.
This class extends `SemanticSegmentorRunParams` (and transitively
`PredictorRunParams` → `EngineABCRunParams`) with additional options
specific to nucleus detection workflows.
Attributes:
auto_get_mask (bool):
Whether to automatically generate segmentation masks using
`wsireader.tissue_mask()` during WSI processing.
batch_size (int):
Number of image patches to feed to the model in a forward pass.
class_dict (dict):
Optional dictionary mapping numeric class IDs to class names.
device (str):
Device to run the model on (e.g., "cpu", "cuda").
labels (list):
Optional labels for input images. Only a single label per image
is supported.
memory_threshold (int):
Memory usage threshold (in percentage) to trigger caching behavior.
num_workers (int):
Number of workers used in the DataLoader.
output_file (str):
Output file name for saving results (e.g., ".zarr" or ".db").
output_resolutions (Resolution):
Resolution used for writing output predictions/coordinates.
patch_output_shape (tuple[int, int]):
Shape of output patches (height, width).
min_distance (int):
Minimum separation between nuclei (in pixels) used during
centroid extraction/post-processing.
threshold_abs (float):
Absolute detection threshold applied to model outputs.
threshold_rel (float):
Relative detection threshold (e.g., with respect to local maxima).
tile_shape (tuple[int, int]):
Tile shape (height, width) used during post-processing
(in pixels) to control rechunking behavior.
return_labels (bool):
Whether to return labels with predictions.
return_probabilities (bool):
Whether to include per-class probabilities in the output.
scale_factor (tuple[float, float]):
Scale factor for converting coordinates to baseline resolution.
Typically, `model_mpp / slide_mpp`.
stride_shape (tuple[int, int]):
Stride used during WSI processing. Defaults to `patch_input_shape`.
verbose (bool):
Whether to enable verbose logging.
"""
min_distance: int
threshold_abs: float
threshold_rel: float
tile_shape: IntPair
[docs]
class NucleusDetector(SemanticSegmentor):
r"""Nucleus detection engine for digital histology images.
This class extends :class:`SemanticSegmentor` to support instance-level
nucleus detection using pretrained or custom models from TIAToolbox.
It operates in both patch-level and whole slide image (WSI) modes and
provides utilities for post-processing (e.g., centroid extraction,
thresholding, tile-overlap handling), merging predictions, and saving
results in multiple output formats. Supported TIAToolbox models include
nucleus-detection architectures such as ``mapde-conic`` and
``mapde-crchisto``. For the full list of pretrained models, refer to the
model zoo documentation:
https://tia-toolbox.readthedocs.io/en/latest/pretrained.html
The class integrates seamlessly with the TIAToolbox engine interface,
inheriting the data loading, inference orchestration, memory-aware
chunking, and output-saving conventions of :class:`SemanticSegmentor`,
while overriding only the nucleus-specific post-processing and export
routines.
Args:
model (str or nn.Module):
Defined PyTorch model or name of the existing models support by
tiatoolbox for processing the data e.g., mapde-conic, mapde-crchisto.
For a full list of pretrained models, please refer to the `docs
<https://tia-toolbox.readthedocs.io/en/latest/pretrained.html>`.
By default, the corresponding pretrained weights will also
be downloaded. However, you can override with your own set
of weights via the `weights` argument. Argument is case insensitive.
batch_size (int):
Number of image patches processed per forward pass.
Default is ``8``.
num_workers (int):
Number of workers for ``torch.utils.data.DataLoader``.
Default is ``0``.
weights (str or pathlib.Path or None):
Optional path to pretrained weights. If ``None`` and ``model`` is
a string, default pretrained weights for that model will be used.
If ``model`` is an ``nn.Module``, weights are loaded only if
provided.
device (str):
Device on which the model will run (e.g., ``"cpu"``, ``"cuda"``).
Default is ``"cpu"``.
verbose (bool):
Whether to output logging information. Default is ``True``.
Attributes:
images (list[str or Path] or np.ndarray):
Input images supplied to the engine, either as WSI paths or
NHWC-formatted patches.
masks (list[str or Path] or np.ndarray):
Optional tissue masks for WSI processing. Only used when
``patch_mode=False``.
patch_mode (bool):
Whether input is treated as image patches (``True``) or as WSIs
(``False``).
model (ModelABC):
Loaded PyTorch model. Can be a pretrained TIAToolbox model or a
custom user-provided model.
ioconfig (ModelIOConfigABC):
IO configuration specifying patch extraction shape, stride, and
resolution settings for inference.
return_labels (bool):
Whether to include labels in the output, if provided.
input_resolutions (list[dict]):
Resolution settings for model input heads. Supported units are
``"level"``, ``"power"``, and ``"mpp"``.
patch_input_shape (tuple[int, int]):
Height and width of input patches read from slides, expressed in
read resolution space.
stride_shape (tuple[int, int]):
Stride used during patch extraction. Defaults to
``patch_input_shape``.
drop_keys (list):
Keys to exclude from model output when saving results.
output_type (str):
Output format (``"dict"``, ``"zarr"``, ``"qupath"``,
or ``"annotationstore"``).
Examples:
>>> from tiatoolbox.models.engine.nucleus_detector import NucleusDetector
>>> model_name = "mapde-conic"
>>> detector = NucleusDetector(model=model_name, batch_size=16, num_workers=8)
>>> detector.run(
... images=[pathlib.Path("example_wsi.tiff")],
... patch_mode=False,
... device="cuda",
... save_dir=pathlib.Path("output_directory/"),
... overwrite=True,
... output_type="annotationstore",
... class_dict={0: "nucleus"},
... auto_get_mask=True,
... memory_threshold=80,
... )
"""
def __init__(
self: NucleusDetector,
model: str | ModelABC,
batch_size: int = 8,
num_workers: int = 0,
weights: str | Path | None = None,
*,
device: str = "cpu",
verbose: bool = True,
) -> None:
"""Initialize :class:`NucleusDetector`.
This constructor follows the standard TIAToolbox engine initialization
workflow. A model may be provided either as a string referring to a
pretrained TIAToolbox architecture or as a custom ``torch.nn.Module``.
When ``model`` is a string, the corresponding pretrained weights are
automatically downloaded unless explicitly overridden via ``weights``.
Args:
model (str or ModelABC):
A PyTorch model instance or the name of a pretrained TIAToolbox
model. If a string is provided, default pretrained weights are
loaded unless ``weights`` is supplied to override them.
batch_size (int):
Number of image patches processed per forward pass.
Default is ``8``.
num_workers (int):
Number of workers used for ``torch.utils.data.DataLoader``.
Default is ``0``.
weights (str or Path or None):
Path to model weights. If ``None`` and ``model`` is a string,
the default pretrained weights for that model will be used.
If ``model`` is a ``nn.Module``, weights are loaded only when
specified here.
device (str):
Device on which the model will run (e.g., ``"cpu"``, ``"cuda"``).
Default is ``"cpu"``.
verbose (bool):
Whether to enable verbose logging during initialization and
inference. Default is ``True``.
"""
super().__init__(
model=model,
batch_size=batch_size,
num_workers=num_workers,
weights=weights,
device=device,
verbose=verbose,
)
[docs]
def post_process_patches(
self: NucleusDetector,
raw_predictions: dict[str, da.Array],
**kwargs: Unpack[NucleusDetectorRunParams],
) -> dict[str, list[da.Array]]:
"""Post-process patch-level detection outputs.
Applies the model's post-processing function (e.g., centroid extraction and
thresholding) to each patch's probability map, yielding per-patch detection
arrays suitable for saving or further merging.
Args:
raw_predictions (dict[str, da.Array]):
Dictionary containing raw model predictions as Dask arrays.
**kwargs (NucleusDetectorRunParams):
Additional runtime parameters to configure segmentation.
Optional Keys:
min_distance (int):
Minimum separation between nuclei (in pixels) used during
centroid extraction/post-processing.
threshold_abs (float):
Absolute detection threshold applied to model outputs.
threshold_rel (float):
Relative detection threshold
(e.g., with respect to local maxima).
Returns:
dict[str, list[da.Array]]:
A dictionary of lists (one list per patch), with keys:
- ``"x"`` (list[dask array]):
1-D object dask arrays of x coordinates
- ``"y"`` (list[dask array]):
1-D object dask arrays of y coordinates
- ``"classes"`` (list[dask array]):
1-D object dask arrays of class IDs
- ``"probabilities"`` (list[dask array]):
1-D object dask arrays of detection probabilities
Notes:
- If thresholds are not provided via ``kwargs``, model defaults are used.
"""
logger.info("Post processing patch predictions in NucleusDetector.")
# If these are not provided, defaults from model will be used in postproc
min_distance = kwargs.get("min_distance")
threshold_abs = kwargs.get("threshold_abs")
threshold_rel = kwargs.get("threshold_rel")
# Lists to hold per-patch detection arrays
xs = []
ys = []
classes = []
probs = []
# Process each patch's predictions
for i in range(raw_predictions["probabilities"].shape[0]):
probs_prediction_patch = raw_predictions["probabilities"][i].compute()
postproc_func = self._get_model_attr("postproc_func")
centroids_map_patch = postproc_func(
probs_prediction_patch,
min_distance=min_distance,
threshold_abs=threshold_abs,
threshold_rel=threshold_rel,
)
centroids_map_patch = da.from_array(centroids_map_patch, chunks="auto")
xs_patch, ys_patch, classes_patch, probs_patch = (
self._centroid_maps_to_detection_arrays(centroids_map_patch).values()
)
xs.append(xs_patch)
ys.append(ys_patch)
classes.append(classes_patch)
probs.append(probs_patch)
return {"x": xs, "y": ys, "classes": classes, "probabilities": probs}
[docs]
def post_process_wsi(
self: NucleusDetector,
raw_predictions: dict[str, da.Array],
save_path: Path,
**kwargs: Unpack[NucleusDetectorRunParams],
) -> dict[str, da.Array]:
"""Post-process WSI-level nucleus detection outputs.
Processes the full-slide prediction map using Dask's block-wise operations
to extract nuclei centroids across the entire WSI. The prediction map is
first re-chunked to the model's preferred post-processing tile shape, and
`dask.map_overlap` with halo padding is used to facilitate centroid
extraction on large prediction maps. The resulting centroid maps are
computed and saved to Zarr storage for memory-efficient processing, then
converted into detection arrays (x, y, classes, probabilities) through
sequential block processing.
Args:
raw_predictions (dict[str, da.Array]):
Dictionary containing raw model predictions as Dask arrays.
save_path (Path):
Path to save the intermediate output. The intermediate output is saved
in a zarr file.
**kwargs (NucleusDetectorRunParams):
Additional runtime parameters to configure segmentation.
Optional Keys:
min_distance (int):
Minimum distance separating two nuclei (in pixels).
threshold_abs (float):
Absolute detection threshold applied to model outputs.
threshold_rel (float):
Relative detection threshold
(e.g., with respect to local maxima).
postproc_tile_shape (tuple[int, int]):
Tile shape (height, width) for post-processing rechunking.
Returns:
dict[str, da.Array]:
A dictionary mapping detection fields to 1-D Dask arrays:
- ``"x"``: x coordinates of detected nuclei.
- ``"y"``: y coordinates of detected nuclei.
- ``"classes"``: class IDs.
- ``"probabilities"``: detection probabilities.
Notes:
- Halo padding ensures that nuclei crossing tile/chunk boundaries
are not fragmented or duplicated.
- If thresholds are not explicitly provided, model defaults are used.
- Centroid maps are computed and saved to Zarr storage to avoid
out-of-memory errors on large WSIs.
- The Zarr-backed centroid maps are then processed block-by-block
to extract detections incrementally.
"""
logger.info("Post processing WSI predictions in NucleusDetector")
# If these are not provided, defaults from model will be used in postproc
threshold_abs = kwargs.get("threshold_abs")
threshold_rel = kwargs.get("threshold_rel")
# min_distance and postproc_tile_shape cannot be None here
min_distance = kwargs.get("min_distance")
if min_distance is None:
min_distance = self.model.min_distance
tile_shape = kwargs.get("tile_shape")
if tile_shape is None:
tile_shape = self.model.tile_shape
# Add halo (overlap) around each block for post-processing
depth_h = min_distance
depth_w = min_distance
depth = {0: depth_h, 1: depth_w, 2: 0}
# Re-chunk to post-processing tile shape for more efficient processing
rechunked_probability_map = raw_predictions["probabilities"].rechunk(
(tile_shape[0], tile_shape[1], -1)
)
postproc_func = self._get_model_attr("postproc_func")
centroid_maps = da.map_overlap(
postproc_func,
rechunked_probability_map,
min_distance=min_distance,
threshold_abs=threshold_abs,
threshold_rel=threshold_rel,
depth=depth,
boundary=0,
dtype=raw_predictions["probabilities"].dtype,
block_info=True,
depth_h=depth_h,
depth_w=depth_w,
)
# Compute and save centroid maps to zarr to avoid memory issues
zarr_file = save_path.with_suffix(".zarr")
logger.info(
"Computing and caching centroid maps to zarr file at: %s",
zarr_file,
)
task = centroid_maps.to_zarr(
url=zarr_file,
component="centroid_maps",
compute=False,
object_codec=None,
)
_ = tqdm_dask_progress_bar(
desc="Computing Centroids",
write_tasks=[task],
num_workers=self.num_workers,
scheduler="threads",
leave=False,
verbose=self.verbose,
)
self.drop_keys.append("centroid_maps")
zarr_group = zarr.open(zarr_file, mode="r+")
centroid_maps = da.from_zarr(zarr_group["centroid_maps"])
return self._centroid_maps_to_detection_arrays(
centroid_maps, verbose=self.verbose
)
[docs]
def save_predictions(
self: NucleusDetector,
processed_predictions: dict,
output_type: str,
save_path: Path | None = None,
**kwargs: Unpack[NucleusDetectorRunParams],
) -> dict | AnnotationStore | Path | list[Path]:
"""Save nucleus detections to disk or return them in memory.
Saves post-processed detection outputs in one of the supported formats.
If ``patch_mode=True``, predictions are saved per image. If
``patch_mode=False``, detections are merged and saved as a single output.
Args:
processed_predictions (dict):
Dictionary containing processed detection results. Expected to include
a ``"predictions"`` key with detection arrays. The internal structure
follows TIAToolbox conventions and may differ slightly between patch
and WSI modes:
- Patch mode (patch_mode=True):
- ``"x"`` (list[da.Array]):
per-patch x coordinates.
- ``"y"`` (list[da.Array]):
per-patch y coordinates.
- ``"classes"`` (list[da.Array]):
per-patch class IDs.
- ``"probabilities"`` (list[da.Array]):
per-patch detection probabilities.
- WSI mode (patch_mode=False):
- ``"x"`` (da.Array):
x coordinates.
- ``"y"`` (da.Array):
y coordinates.
- ``"classes"`` (da.Array):
class IDs.
- ``"probabilities"`` (da.Array):
detection probabilities.
output_type (str):
Desired output format: ``"dict"``, ``"zarr"``, ``"qupath"``
or ``"annotationstore"``.
save_path (Path | None):
Path at which to save the output file(s). Required for file outputs
(e.g., Zarr or SQLite .db). If ``None`` and ``output_type="dict"``,
results are returned in memory.
**kwargs (NucleusDetectorRunParams):
Additional runtime parameters to configure segmentation.
Optional Keys:
auto_get_mask (bool):
Whether to automatically generate segmentation masks using
`wsireader.tissue_mask()` during processing.
batch_size (int):
Number of image patches to feed to the model in a forward pass.
class_dict (dict):
Optional dictionary mapping classification outputs to
class names.
device (str):
Device to run the model on (e.g., "cpu", "cuda").
labels (list):
Optional labels for input images. Only a single label per image
is supported.
memory_threshold (int):
Memory usage threshold (in percentage) to
trigger caching behavior.
num_workers (int):
Number of workers used in DataLoader.
output_file (str):
Output file name for saving results (e.g., .zarr or .db).
output_resolutions (Resolution):
Resolution used for writing output predictions.
patch_output_shape (tuple[int, int]):
Shape of output patches (height, width).
min_distance (int):
Minimum distance separating two nuclei (in pixels).
postproc_tile_shape (tuple[int, int]):
Tile shape (height, width) for post-processing (in pixels).
return_labels (bool):
Whether to return labels with predictions.
return_probabilities (bool):
Whether to return per-class probabilities.
scale_factor (tuple[float, float]):
Scale factor for converting annotations to baseline resolution.
Typically model_mpp / slide_mpp.
stride_shape (tuple[int, int]):
Stride used during WSI processing.
Defaults to patch_input_shape.
verbose (bool):
Whether to output logging information.
Returns:
dict | AnnotationStore | Path | list[Path]:
- If ``output_type="dict"``:
returns a Python dictionary of predictions.
- If ``output_type="zarr"``:
returns the path to the saved ``.zarr`` group.
- If ``output_type="qupath"``:
returns QuPath JSON or the path(s) to saved
``.json`` file(s). In patch mode, a list of per-image paths
may be returned.
- If ``output_type="annotationstore"``:
returns an AnnotationStore handle or the path(s) to saved
``.db`` file(s). In patch mode, a list of per-image paths
may be returned.
Notes:
- For non-AnnotationStore outputs, this method delegates to the
base engine's saving function to preserve consistency across
TIAToolbox engines.
"""
if output_type.lower() not in ["qupath", "annotationstore"]:
out = super().save_predictions(
processed_predictions,
output_type,
save_path=save_path,
**kwargs,
)
else:
# scale_factor set from kwargs
scale_factor = kwargs.get("scale_factor", (1.0, 1.0))
# class_dict set from kwargs
class_dict = kwargs.get("class_dict")
if class_dict is None:
class_dict = self._get_model_attr("class_dict")
out = self._save_predictions_qupath_json_annotations_db(
processed_predictions,
save_path=save_path,
scale_factor=scale_factor,
class_dict=class_dict,
output_type=output_type,
)
# Remove cached centroid maps if wsi mode
if not self.patch_mode:
shutil.rmtree(save_path.with_suffix(".zarr"))
logger.info(
"Removed cached centroid maps at: %s",
save_path.with_suffix(".zarr"),
)
return out
def _save_predictions_qupath_json_annotations_db(
self: NucleusDetector,
processed_predictions: dict,
save_path: Path | None = None,
scale_factor: tuple[float, float] = (1.0, 1.0),
class_dict: dict | None = None,
output_type: str = "annotationstore",
) -> AnnotationStore | Path | list[Path]:
"""Save nucleus detections to an AnnotationStore (.db).
Converts the processed detection arrays into per-instance `Annotation`
records, applies coordinate scaling and optional class-ID remapping,
and writes the results into an SQLite-backed AnnotationStore. In patch
mode, detections are written to separate `.db` files per input image;
in WSI mode, all detections are merged and written to a single store.
Args:
processed_predictions (dict):
Dictionary containing the computed detection outputs.
Expected keys:
- For wsi mode:
- ``"x"`` (da.Array):
dask array of x coordinates
- ``"y"`` (da.Array):
dask array of y coordinates
- ``"classes"`` (da.Array):
dask array of class IDs
- ``"probabilities"`` (da.Array):
dask array of detection probabilities
- For patch mode:
- ``"x"`` (list[da.Array]):
list of per-patch dask arrays of x coordinates
- ``"y"`` (list[da.Array]):
list of per-patch dask arrays of y coordinates
- ``"classes"`` (list[da.Array]):
list of per-patch dask arrays of class IDs
- ``"probabilities"`` (list[da.Array]):
list of per-patch dask arrays of detection probabilities
save_path (Path or None):
Output path for saving the AnnotationStore. If ``None``, an in-memory
store is returned. When patch mode is active, this path serves as the
directory for producing one `.db` file per patch input.
scale_factor (tuple[float, float], optional):
Scaling factors applied to x and y coordinates prior to writing.
Typically corresponds to ``model_mpp / slide_mpp``.
Defaults to ``(1.0, 1.0)``.
output_type (str):
Desired output format: ``"qupath"`` or ``"annotationstore"``.
class_dict (dict or None):
Optional mapping from original class IDs to class names or remapped IDs.
If ``None``, an identity mapping based on present classes is used.
Returns:
AnnotationStore or Path or list[Path]:
- For WSI mode: a single AnnotationStore handle or the path to the saved
`.db` file.
- For patch mode: a list of paths, one per saved patch-level
AnnotationStore.
Notes:
- This method centralizes the translation of detection arrays into
`Annotation` objects and abstracts batching logic via
``_write_detection_arrays_to_store``.
"""
logger.info("Saving predictions as AnnotationStore.")
if self.patch_mode:
save_paths = []
num_patches = len(processed_predictions["x"])
suffix = ".json" if output_type == "qupath" else ".db"
for i in range(num_patches):
if isinstance(self.images[i], Path):
output_path = save_path.parent / (self.images[i].stem + suffix)
else:
output_path = save_path.parent / (str(i) + suffix)
detection_arrays = {
"x": processed_predictions["x"][i],
"y": processed_predictions["y"][i],
"classes": processed_predictions["classes"][i],
"probabilities": processed_predictions["probabilities"][i],
}
out_file = (
save_detection_arrays_to_qupath_json(
detection_arrays=detection_arrays,
scale_factor=scale_factor,
class_dict=class_dict,
save_path=output_path,
)
if output_type == "qupath"
else save_detection_arrays_to_store(
detection_arrays=detection_arrays,
scale_factor=scale_factor,
class_dict=class_dict,
save_path=output_path,
)
)
save_paths.append(out_file)
return save_paths
if output_type == "qupath":
return save_detection_arrays_to_qupath_json(
detection_arrays=processed_predictions,
scale_factor=scale_factor,
save_path=save_path,
class_dict=class_dict,
)
return save_detection_arrays_to_store(
detection_arrays=processed_predictions,
scale_factor=scale_factor,
save_path=save_path,
class_dict=class_dict,
)
@staticmethod
def _centroid_maps_to_detection_arrays(
detection_maps: da.Array,
*,
verbose: bool = True,
) -> dict[str, da.Array]:
"""Convert centroid maps into 1-D detection arrays.
This helper function extracts non-zero centroid predictions from a
already computed Dask array of centroid maps and flattens them into
coordinate, class, and probability arrays suitable for saving or
further processing. The function processes the centroid maps block
by block to minimize memory usage, reading each block from disk
and extracting detections incrementally.
Args:
detection_maps (da.Array):
A Dask array of shape ``(H, W, C)`` representing centroid
probability maps, where non-zero values correspond to nucleus
detections. Each non-zero entry encodes both the class channel
and its associated probability. This array is expected to be
already computed.
verbose (bool):
Whether to display logs and progress bar.
Returns:
dict[str, da.Array]:
A dictionary containing four 1-D Dask arrays:
- ``"x"``:
x coordinates of detected nuclei (``np.uint32``).
- ``"y"``:
y coordinates of detected nuclei (``np.uint32``).
- ``"classes"``:
class IDs for each detection (``np.uint32``).
- ``"probabilities"``:
detection probabilities (``np.float32``).
Notes:
- The centroid maps are expected to be pre-computed.
- Blocks are processed sequentially to avoid loading the entire
centroid map into memory at once.
- Global coordinates are computed by adding block offsets to local
coordinates within each block.
- This method is used by both patch-level and WSI-level
post-processing routines to unify detection formatting.
"""
logger.info("Extracting detections from centroid maps block by block...")
# Get chunk information
num_blocks_h = detection_maps.numblocks[0]
num_blocks_w = detection_maps.numblocks[1]
# Lists to collect detections from each block
ys_list = []
xs_list = []
classes_list = []
probs_list = []
tqdm_loop = tqdm(
range(num_blocks_h),
leave=False,
desc="Processing detection blocks",
disable=not verbose,
)
for i in tqdm_loop:
for j in range(num_blocks_w):
# Get block offsets
y_offset = sum(detection_maps.chunks[0][:i]) if i > 0 else 0
x_offset = sum(detection_maps.chunks[1][:j]) if j > 0 else 0
# Read this block from Zarr (already computed, so this is just I/O)
block = np.array(detection_maps.blocks[i, j])
# Extract nonzero detections
ys, xs, classes = np.nonzero(block)
probs = block[ys, xs, classes]
# Adjust to global coordinates
ys = ys + y_offset
xs = xs + x_offset
# Append to lists if we have detections
if len(ys) > 0:
ys_list.append(ys.astype(np.uint32))
xs_list.append(xs.astype(np.uint32))
classes_list.append(classes.astype(np.uint32))
probs_list.append(probs.astype(np.float32))
# Concatenate all block results
if ys_list:
ys = np.concatenate(ys_list)
xs = np.concatenate(xs_list)
classes = np.concatenate(classes_list)
probs = np.concatenate(probs_list)
else:
ys = np.array([], dtype=np.uint32)
xs = np.array([], dtype=np.uint32)
classes = np.array([], dtype=np.uint32)
probs = np.array([], dtype=np.float32)
return {
"y": da.from_array(ys, chunks="auto"),
"x": da.from_array(xs, chunks="auto"),
"classes": da.from_array(classes, chunks="auto"),
"probabilities": da.from_array(probs, chunks="auto"),
}
[docs]
def run(
self: NucleusDetector,
images: list[os.PathLike | Path | WSIReader | np.ndarray] | np.ndarray,
*,
masks: list[os.PathLike | Path] | np.ndarray | None = None,
input_resolutions: list[dict[Units, Resolution]] | None = None,
patch_input_shape: IntPair | None = None,
ioconfig: IOSegmentorConfig | None = None,
patch_mode: bool = True,
save_dir: os.PathLike | Path | None = None,
overwrite: bool = False,
output_type: str = "dict",
**kwargs: Unpack[NucleusDetectorRunParams],
) -> AnnotationStore | Path | str | dict | list[Path]:
"""Run the nucleus detection engine on input images.
This method orchestrates the full inference pipeline, including preprocessing,
model inference, post-processing, and saving results. It supports both
patch-level and whole slide image (WSI) modes.
Args:
images (list[PathLike | WSIReader] | np.ndarray):
Input images or patches. Can be a list of file paths, WSIReader objects,
or a NumPy array of image patches.
masks (list[PathLike] | np.ndarray | None):
Optional masks for WSI processing. Only used when `patch_mode` is False.
input_resolutions (list[dict[Units, Resolution]] | None):
Resolution settings for input heads. Supported units are `level`,
`power`, and `mpp`. Keys should be "units" and "resolution", e.g.,
[{"units": "mpp", "resolution": 0.25}]. See :class:`WSIReader` for
details.
patch_input_shape (IntPair | None):
Shape of input patches (height, width), requested at read
resolution. Must be positive.
ioconfig (IOSegmentorConfig | None):
IO configuration for patch extraction and resolution.
patch_mode (bool):
Whether to treat input as patches (`True`) or WSIs (`False`). Default
is True.
save_dir (PathLike | None):
Directory to save output files. Required for WSI mode.
overwrite (bool):
Whether to overwrite existing output files. Default is False.
output_type (str):
Desired output format: "dict", "zarr", or "annotationstore". Default
is "dict".
**kwargs (NucleusDetectorRunParams):
Additional runtime parameters to configure segmentation.
Optional Keys:
auto_get_mask (bool):
Whether to automatically generate tissue masks using
`wsireader.tissue_mask()` during processing.
batch_size (int):
Number of image patches to feed to the model in a forward pass.
class_dict (dict):
Optional dictionary mapping classification outputs to
class names.
device (str):
Device to run the model on (e.g., "cpu", "cuda").
labels (list):
Optional labels for input images. Only a single label per image
is supported.
memory_threshold (int):
Memory usage threshold (in percentage) to
trigger caching behavior.
num_workers (int):
Number of workers used in DataLoader.
output_file (str):
Output file name for saving results (e.g., .zarr or .db).
output_resolutions (Resolution):
Resolution used for writing output predictions.
patch_output_shape (tuple[int, int]):
Shape of output patches (height, width).
min_distance (int):
Minimum distance separating two nuclei (in pixels).
threshold_abs (float):
Absolute detection threshold applied to model outputs.
threshold_rel (float):
Relative detection threshold
(e.g., with respect to local maxima).
postproc_tile_shape (tuple[int, int]):
Tile shape (height, width) for post-processing (in pixels).
return_labels (bool):
Whether to return labels with predictions.
return_probabilities (bool):
Whether to return per-class probabilities.
scale_factor (tuple[float, float]):
Scale factor for converting annotations to baseline resolution.
Typically model_mpp / slide_mpp.
stride_shape (tuple[int, int]):
Stride used during WSI processing.
Defaults to patch_input_shape.
verbose (bool):
Whether to output logging information.
Returns:
AnnotationStore | Path | str | dict | list[Path]:
- If `patch_mode` is True: returns predictions or path to saved output.
- If `patch_mode` is False: returns a dictionary mapping each WSI
to its output path.
Examples:
>>> from tiatoolbox.models.engine.nucleus_detector import NucleusDetector
>>> detector = NucleusDetector(model="mapde-conic")
>>> # WSI workflow: save to AnnotationStore (.db)
>>> out = detector.run(
... images=[pathlib.Path("example_wsi.tiff")],
... patch_mode=False,
... device="cuda",
... save_dir=pathlib.Path("output_directory/"),
... overwrite=True,
... output_type="annotationstore",
... class_dict={0: "nucleus"},
... auto_get_mask=True,
... memory_threshold=80,
... )
>>> # Patch workflow: return in-memory detections
>>> patches = [np.ndarray, np.ndarray] # NHWC
>>> out = detector.run(patches, patch_mode=True, output_type="dict")
"""
return super().run(
images=images,
masks=masks,
input_resolutions=input_resolutions,
patch_input_shape=patch_input_shape,
ioconfig=ioconfig,
patch_mode=patch_mode,
save_dir=save_dir,
overwrite=overwrite,
output_type=output_type,
**kwargs,
)
[docs]
def save_detection_arrays_to_qupath_json(
detection_arrays: dict[str, da.Array],
scale_factor: tuple[float, float] = (1.0, 1.0),
class_dict: dict | None = None,
save_path: Path | None = None,
) -> dict | Path:
"""Write nucleus detection arrays to QuPath JSON.
Produces a FeatureCollection where each detection is represented as a
Point geometry with classification metadata and probability score.
Args:
detection_arrays (dict[str, da.Array]):
A dictionary containing the detection fields:
- ``"x"``: dask array of x coordinates (``np.uint32``).
- ``"y"``: dask array of y coordinates (``np.uint32``).
- ``"classes"``: dask array of class IDs (``np.uint32``).
- ``"probabilities"``: dask array of detection scores (``np.float32``).
scale_factor (tuple[float, float], optional):
Multiplicative factors applied to the x and y coordinates before
saving. The scaled coordinates are rounded to integer pixel
locations. Defaults to ``(1.0, 1.0)``.
class_dict (dict or None):
Optional mapping of class IDs to class names or remapped IDs.
If ``None``, an identity mapping is used based on the detected
class IDs.
save_path (Path or None):
Destination path for saving the QuPath-compatible ``.json`` file.
If ``None``, an in-memory JSON-compatible representation of all
detections is returned instead of writing to disk.
Returns:
Path or QuPath:
- If ``save_path`` is provided: the path to the saved ``.json`` file.
- If ``save_path`` is ``None``: an in-memory dict representing
QuPath JSON containing all detections.
"""
xs, ys, classes, probs = _validate_detections_for_saving_to_json(
detection_arrays=detection_arrays,
)
# Determine class dictionary
unique_classes = np.unique(classes).tolist()
if class_dict is None:
class_dict = {int(i): int(i) for i in unique_classes}
# Color map for classes
num_classes = len(class_dict)
cmap = plt.cm.get_cmap("tab20", num_classes)
class_colors = {
class_idx: [
int(cmap(class_idx)[0] * 255),
int(cmap(class_idx)[1] * 255),
int(cmap(class_idx)[2] * 255),
]
for class_idx in class_dict
}
features: list[dict] = []
for i, _ in enumerate(xs):
# Scale coordinates
x = float(xs[i]) * scale_factor[0]
y = float(ys[i]) * scale_factor[1]
class_id = int(classes[i])
class_label = class_dict.get(class_id, class_id)
prob = float(probs[i])
# QuPath point geometry
point_geo = {
"type": "Point",
"coordinates": [x, y],
}
feature = {
"type": "Feature",
"id": f"detection_{i}",
"geometry": point_geo,
"properties": {
"classification": {
"name": class_label,
"color": class_colors[class_id],
},
"probability": prob,
},
"objectType": "detection",
"name": class_label,
"class_value": class_id,
}
features.append(feature)
qupath_json = {"type": "FeatureCollection", "features": features}
if save_path:
return save_qupath_json(save_path=save_path, qupath_json=qupath_json)
return qupath_json
[docs]
def save_detection_arrays_to_store(
detection_arrays: dict[str, da.Array],
scale_factor: tuple[float, float] = (1.0, 1.0),
class_dict: dict | None = None,
save_path: Path | None = None,
batch_size: int = 5000,
) -> Path | SQLiteStore:
"""Write nucleus detection arrays to an SQLite-backed AnnotationStore.
Converts the detection arrays into NumPy form, applies coordinate scaling
and optional class-ID remapping, and writes the results into an in-memory
SQLiteStore. If `save_path` is provided, the store is committed and saved
to disk as a `.db` file. This method provides a unified interface for
converting Dask-based detection outputs into persistent annotation storage.
Args:
detection_arrays (dict[str, da.Array]):
A dictionary containing the detection fields:
- ``"x"``: dask array of x coordinates (``np.uint32``).
- ``"y"``: dask array of y coordinates (``np.uint32``).
- ``"classes"``: dask array of class IDs (``np.uint32``).
- ``"probabilities"``: dask array of detection scores (``np.float32``).
scale_factor (tuple[float, float], optional):
Multiplicative factors applied to the x and y coordinates before
saving. The scaled coordinates are rounded to integer pixel
locations. Defaults to ``(1.0, 1.0)``.
class_dict (dict or None):
Optional mapping of class IDs to class names or remapped IDs.
If ``None``, an identity mapping is used based on the detected
class IDs.
save_path (Path or None):
Destination path for saving the `.db` file. If ``None``, the
resulting SQLiteStore is returned in memory. If provided, the
parent directory is created if needed, and the final store is
written as ``save_path.with_suffix(".db")``.
batch_size (int):
Number of detection records to write per batch. Defaults to ``5000``.
Returns:
Path or SQLiteStore:
- If `save_path` is provided: the path to the saved `.db` file.
- If `save_path` is ``None``: an in-memory `SQLiteStore` containing
all detections.
Notes:
- The heavy lifting is delegated to
:meth:`_write_detection_arrays_to_store`,
which performs coordinate scaling, class mapping, and batch writing.
"""
xs, ys, classes, probs = _validate_detections_for_saving_to_json(
detection_arrays=detection_arrays,
)
store = SQLiteStore()
total_written = _write_detection_arrays_to_store(
detection_arrays=(xs, ys, classes, probs),
store=store,
scale_factor=scale_factor,
class_dict=class_dict,
batch_size=batch_size,
)
logger.info("Total detections written to store: %s", total_written)
if save_path:
return save_annotations(
save_path=save_path,
store=store,
)
return store
def _validate_detections_for_saving_to_json(
detection_arrays: dict[str, da.Array],
) -> tuple:
"""Validates x, y, classes and probs for writing to QuPath or AnnotationStore."""
xs = np.atleast_1d(np.asarray(detection_arrays["x"]))
ys = np.atleast_1d(np.asarray(detection_arrays["y"]))
classes = np.atleast_1d(np.asarray(detection_arrays["classes"]))
probs = np.atleast_1d(np.asarray(detection_arrays["probabilities"]))
if not len(xs) == len(ys) == len(classes) == len(probs):
msg = "Detection record lengths are misaligned."
raise ValueError(msg)
return xs, ys, classes, probs
def _write_detection_arrays_to_store(
detection_arrays: tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
store: SQLiteStore,
scale_factor: tuple[float, float],
class_dict: dict[int, str | int] | None,
batch_size: int = 5000,
*,
verbose: bool = True,
) -> int:
"""Write detection arrays to an AnnotationStore in batches.
Converts coordinate, class, and probability arrays into `Annotation`
objects and appends them to an SQLite-backed store in configurable
batch sizes. Coordinates are scaled to baseline slide resolution using
the provided `scale_factor`, and optional class-ID remapping is applied
via `class_dict`.
Args:
detection_arrays (tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]):
Tuple of arrays in the order:
`(x_coords, y_coords, class_ids, probabilities)`.
Each element must be a 1-D NumPy array of equal length.
store (SQLiteStore):
Target `AnnotationStore` instance to receive the detections.
scale_factor (tuple[float, float]):
Factors applied to `(x, y)` coordinates prior to writing,
typically `(model_mpp / slide_mpp)`. The scaled coordinates are
rounded to `np.uint32`.
class_dict (dict[int, str | int] | None):
Optional mapping from original class IDs to names or remapped IDs.
If `None`, an identity mapping is used for the set of present classes.
batch_size (int):
Number of records to write per batch. Default is `5000`.
verbose (bool):
Whether to display logs and progress bar.
Returns:
int:
Total number of detection records written to the store.
Notes:
- Coordinates are scaled and rounded to integers to ensure consistent
geometry creation for `Annotation` points.
- Class mapping is applied per-record; unmapped IDs fall back to their
original values.
- Writing in batches reduces memory pressure and improves throughput
on large number of detections.
"""
xs, ys, classes, probs = detection_arrays
n = len(xs)
if n == 0:
return 0 # nothing to write
# scale coordinates
xs = np.rint(xs * scale_factor[0]).astype(np.uint32, copy=False)
ys = np.rint(ys * scale_factor[1]).astype(np.uint32, copy=False)
# class mapping
if class_dict is None:
# identity over actually-present types
uniq = np.unique(classes)
class_dict = {int(k): int(k) for k in uniq}
labels = np.array([class_dict.get(int(k), int(k)) for k in classes], dtype=object)
def make_points(xs_batch: np.ndarray, ys_batch: np.ndarray) -> list[Point]:
"""Create Shapely Point geometries from coordinate arrays in batches."""
return [
Point(int(xx), int(yy)) for xx, yy in zip(xs_batch, ys_batch, strict=True)
]
tqdm_loop = tqdm(
range(0, n, batch_size),
leave=False,
desc="Writing detections to store",
disable=not verbose,
)
written = 0
for i in tqdm_loop:
j = min(i + batch_size, n)
pts = make_points(xs[i:j], ys[i:j])
anns = [
Annotation(geometry=pt, properties={"type": lbl, "probability": float(pp)})
for pt, lbl, pp in zip(pts, labels[i:j], probs[i:j], strict=True)
]
store.append_many(anns)
written += j - i
return written