MultiTaskSegmentor

class MultiTaskSegmentor(model, batch_size=8, num_workers=0, weights=None, *, device='cpu', verbose=True)[source]

MultiTask segmentation engine to run models like hovernet and hovernetplus.

MultiTaskSegmentor performs segmentation across multiple model heads (e.g., semantic, instance, edge). It abstracts model invocation, preprocessing, and output postprocessing for multi-head segmentation.

Parameters:
  • model (str | ModelABC) – A PyTorch model instance or name of a pretrained model from TIAToolbox. The user can request pretrained models from the toolbox model zoo using the list of pretrained models available at this link By default, the corresponding pretrained weights will also be downloaded. However, you can override with your own set of weights using the weights parameter. Default is None.

  • batch_size (int) – Number of image patches processed per forward pass. Default is 8.

  • num_workers (int) – Number of workers for data loading. Default is 0.

  • weights (str | Path | None) –

    Path to model weights. If None, default weights are used.

    >>> engine = MultiTaskSegmentor(
    ...    model="pretrained-model",
    ...    weights="/path/to/pretrained-local-weights.pth"
    ... )
    

  • device (str) – Device to run the model on (e.g., “cpu”, “cuda”). Default is “cpu”.

  • verbose (bool) – Whether to enable verbose logging. Default is True.

images

Input image patches or WSI paths.

Type:

list[str | Path] | np.ndarray

masks

Optional tissue masks for WSI processing. These are only utilized when patch_mode is False. If not provided, then a tissue mask will be automatically generated for whole slide images.

Type:

list[str | Path] | np.ndarray

patch_mode

Whether input is treated as patches (True) or WSIs (False).

Type:

bool

model

Loaded PyTorch model.

Type:

ModelABC

ioconfig

IO configuration for patch extraction and resolution.

Type:

ModelIOConfigABC

return_labels

Whether to include labels in the output.

Type:

bool

return_predictions_dict

This dictionary helps keep track of which tasks require predictions in the output.

Type:

dict

input_resolutions

Resolution settings for model input. Supported units are level, power and mpp. Keys should be “units” and “resolution” e.g., [{“units”: “mpp”, “resolution”: 0.25}]. Please see WSIReader for details.

Type:

list[dict]

patch_input_shape

Shape of input patches (height, width). Patches are at requested read resolution, not with respect to level 0, and must be positive.

Type:

tuple[int, int]

stride_shape

Stride used during patch extraction. Stride is at requested read resolution, not with respect to level 0, and must be positive. If not provided, stride_shape=patch_input_shape.

Type:

tuple[int, int]

labels

Optional labels for input images. Only a single label per image is supported.

Type:

list | None

drop_keys

Keys to exclude from model output.

Type:

list

output_type

Format of output (“dict”, “zarr”, “annotationstore”).

Type:

str

output_locations

Coordinates of output patches used during WSI processing.

Type:

list | None

Examples: >>> # list of 2 image patches as input >>> wsis = [‘path/img.svs’, ‘path/img.svs’] >>> mtsegmentor = MultiTaskSegmentor(model=”hovernetplus-oed”) >>> output = mtsegmentor.run(wsis, patch_mode=False)

>>> # array of list of 2 image patches as input
>>> image_patches = [np.ndarray, np.ndarray]
>>> mtsegmentor = MultiTaskSegmentor(model="hovernetplus-oed")
>>> output = mtsegmentor.run(image_patches, patch_mode=True)
>>> # list of 2 image patch files as input
>>> data = ['path/img.png', 'path/img.png']
>>> mtsegmentor = MultiTaskSegmentor(model="hovernet_fast-pannuke")
>>> output = mtsegmentor.run(data, patch_mode=False)
>>> # list of 2 image tile files as input
>>> tile_file = ['path/tile1.png', 'path/tile2.png']
>>> mtsegmentor = MultiTaskSegmentor(model="hovernet_fast-pannuke")
>>> output = mtsegmentor.run(tile_file, patch_mode=False)
>>> # list of 2 wsi files as input
>>> wsis = ['path/wsi1.svs', 'path/wsi2.svs']
>>> mtsegmentor = MultiTaskSegmentor(model="hovernet_fast-pannuke")
>>> output = mtsegmentor.run(wsis, patch_mode=False)

Initialize MultiTaskSegmentor.

Parameters:
  • model (str | ModelABC) – A PyTorch model instance or name of a pretrained model from TIAToolbox. If a string is provided, the corresponding pretrained weights will be downloaded unless overridden via weights.

  • batch_size (int) – Number of image patches processed per forward pass. Default is 8.

  • num_workers (int) – Number of workers for data loading. Default is 0.

  • weights (str | Path | None) – Path to model weights. If None, default weights are used.

  • device (str) – Device to run the model on (e.g., “cpu”, “cuda”). Default is “cpu”.

  • verbose (bool) – Whether to enable verbose logging. Default is True.

Methods

build_post_process_raw_predictions

Merge per-image, per-task outputs into a task-organized prediction structure.

infer_patches

Run inference on a batch of image patches using the multitask model.

infer_wsi

Perform model inference on a whole slide image (WSI).

post_process_patches

Post-process raw patch-level predictions for multitask segmentation.

post_process_wsi

Post-process whole slide image (WSI) predictions for multitask segmentation.

run

Run the MultiTaskSegmentor engine on input images.

save_predictions

Save model predictions to disk or return them in memory.

build_post_process_raw_predictions(post_process_predictions, raw_predictions, return_predictions)[source]

Merge per-image, per-task outputs into a task-organized prediction structure.

This function takes a list of outputs where each element corresponds to one image and contains one or more task dictionaries returned by the model’s post-processing step (e.g., semantic, instance). Each task dictionary must include a "task_type" key along with any number of task-specific fields (for example, "predictions", "info_dict", or additional metadata). The function reorganizes this data into raw_predictions by grouping entries under their respective task types and aligning values across images.

The merging logic is as follows:
  1. For each task (identified by "task_type"), values for keys other than "task_type" are temporarily collected into lists, one entry per image.

  2. After all images are processed, list entries are normalized:

    • If all entries for a key are array-like (np.ndarray or dask.array.Array), they are stacked along a new leading dimension (image axis).

    • If all entries for a key are dictionaries, their subkeys are expanded into separate lists aligned across images (the original composite key is removed).

  3. Existing content in raw_predictions is preserved and extended as needed.

Parameters:
  • post_process_predictions (list[tuple]) –

    A list where each element represents a single image. Each element is an iterable of task dictionaries. Every task dictionary must contain:

    • "task_type" (str): Name/type of the task (e.g., "semantic", "instance", "edge").

    and may contain any number of additional fields, such as:
    • "predictions": array-like output for that task

    • "info_dict": dictionary of task-specific metadata

    • Any other task-dependent keys

  • raw_predictions (dict) – Dictionary that will be updated in-place. It may already contain task entries or unrelated keys (e.g., "probabilities", "coordinates"). New tasks and fields are added as they appear.

  • return_predictions (tuple[bool, ...]) – Whether to return array predictions for individual tasks.

  • self (MultiTaskSegmentor)

Returns:

The updated raw_predictions dictionary containing one entry per task type. Under each task name, keys hold per-image arrays (stacked as Dask/NumPy where applicable) or lists/dicts aligned across images. Example structure:

{
“semantic”: {

“predictions”: da.Array | np.ndarray, # stacked over images “info_dict”: [dict, dict, …] # or expanded subkeys

}, “instance”: {

”info_dict”: […], # per-image metadata “contours”: […], “classes”: […], # task-dependent keys

}, “coordinates”: da.Array, # if previously present

}

Return type:

dict

Notes

  • Array stacking occurs only when all per-image entries for a key are array-like; mixed types remain as lists.

  • Dictionary expansion occurs only when all per-image entries for a key are dictionaries; subkeys are promoted to top-level keys under the task and aligned across images.

  • The set self.tasks is updated to include all encountered task types.

infer_patches(dataloader, *, return_coordinates=False)[source]

Run inference on a batch of image patches using the multitask model.

This method processes patches provided by a PyTorch DataLoader and runs them through the model’s infer_batch method. Models with multiple heads (e.g., semantic, instance, edge) may return multiple outputs per patch. Outputs are collected as Dask arrays for efficient large-scale aggregation.

Parameters:
  • dataloader (DataLoader) – A PyTorch dataloader that yields dicts containing "image" tensors and optionally other metadata (e.g., coordinates).

  • return_coordinates (bool) – Whether to return the spatial coordinates associated with each patch (when available from the dataset). Default is False.

  • self (MultiTaskSegmentor)

Returns:

A dictionary containing the model outputs for all patches.

Keys:
probabilities (list[da.Array]):

A list of Dask arrays containing model outputs for each head. Each array has shape (N, C, H, W) depending on the model.

coordinates (da.Array):

Returned only when return_coordinates=True. A Dask array of shape (N, 2) or (N, 4) depending on how patch coordinates are stored in the dataset.

Return type:

dict[str, list[da.Array]]

Notes

  • The number of model outputs (heads) is inferred dynamically from the first forward pass.

  • Outputs are stacked via dask.array.concatenate for scalability.

  • This method does not perform postprocessing; raw logits/probabilities are returned exactly as produced by the model.

infer_wsi(dataloader, save_path, **kwargs)[source]

Perform model inference on a whole slide image (WSI).

This method iterates over WSI patches produced by a DataLoader, runs each patch through the model’s infer_batch callback, and incrementally assembles full-resolution model outputs for each model head (e.g., semantic, instance, edge). Patch-level outputs are merged row-by-row using horizontal stitching, optionally spilling intermediate results to disk when memory usage exceeds a threshold. After all rows are processed, vertical merging is performed to generate the final probability maps for each multitask head.

Raw probabilities and patch coordinates are returned as Dask arrays. This method does not perform any post-processing; downstream calls to post_process_wsi are required to convert model logits into task-specific outputs (e.g., instances, contours, or label maps).

Parameters:
  • dataloader (DataLoader) – A PyTorch dataloader yielding dictionaries with keys such as "image" and "output_locs" that correspond to extracted WSI patches and their placement metadata.

  • save_path (Path) – A filesystem path used to store temporary Zarr cache data when memory spilling is triggered. The directory is created if needed.

  • **kwargs (MultiTaskSegmentorRunParams) –

    Additional runtime parameters to configure segmentation.

    Optional Keys:
    auto_get_mask (bool):

    Automatically generate segmentation masks using wsireader.tissue_mask() during processing.

    batch_size (int):

    Number of image patches per forward pass.

    class_dict (dict):

    Mapping of classification outputs to class names.

    device (str):

    Device to run the model on (e.g., “cpu”, “cuda”).

    labels (list):

    Optional labels for input images. Only a single label per image is supported.

    memory_threshold (int):

    Memory usage threshold (percentage) to trigger caching behavior.

    num_workers (int):

    Number of workers for DataLoader and post-processing.

    output_file (str):

    Filename for saving output (e.g., “.zarr” or “.db”).

    output_resolutions (Resolution):

    Resolution used for writing output predictions.

    patch_output_shape (tuple[int, int]):

    Shape of output patches (height, width).

    return_labels (bool):

    Whether to return labels with predictions.

    return_predictions (tuple(bool, …):

    Whether to return array predictions for individual tasks.

    return_probabilities (bool):

    Whether to return per-class probabilities.

    scale_factor (tuple[float, float]):

    Scale factor for annotations (model_mpp / slide_mpp). Used to convert coordinates to baseline resolution.

    stride_shape (tuple[int, int]):

    Stride used during WSI processing. Defaults to patch_input_shape if not provided.

    verbose (bool):

    Whether to enable verbose logging.

  • self (SemanticSegmentor)

Returns:

A dictionary containing the raw multitask model outputs.

Keys:
probabilities (list[da.Array]):

One Dask array per model head, each representing the final WSI-sized probability map for that task. Each array has shape (H, W, C) depending on the head’s channel count.

coordinates (da.Array):

A Dask array of shape (N, 2) or (N, 4), containing accumulated patch coordinate metadata produced during the WSI dataloader iteration.

Return type:

dict[str, da.Array]

Notes

  • The number of model heads is inferred from the first infer_batch call.

  • Patch predictions are merged horizontally when the x-coordinate changes row, and vertically after all rows are processed.

  • Large WSIs may trigger spilling intermediate canvas data to disk when memory exceeds memory_threshold.

  • This function returns raw probabilities only. For task-specific segmentation or instance extraction, call post_process_wsi.

post_process_patches(raw_predictions, **kwargs)[source]

Post-process raw patch-level predictions for multitask segmentation.

This method applies the model’s postproc_func to per-patch probability maps produced by infer_patches. For multitask models (multiple heads), it zips the per-head probability arrays across patches and invokes postproc_func to obtain one or more task dictionaries per patch (e.g., semantic labels, instance info, edges). The per-patch outputs are then reorganized into a task-centric structure using build_post_process_raw_predictions for downstream saving.

Parameters:
  • raw_predictions (dict) –

    Dictionary containing raw model outputs from infer_patches. Expected keys:

    • "probabilities" (list[da.Array]): One Dask array per model head. Each array typically has shape (N, H, W, C) for N patches, with head-specific channels. These are raw logits/probabilities and are not normalized beyond what the model provides.

  • **kwargs (MultiTaskSegmentorRunParams) –

    Additional runtime parameters to configure segmentation.

    Optional Keys:
    auto_get_mask (bool):

    Automatically generate segmentation masks using wsireader.tissue_mask() during processing.

    batch_size (int):

    Number of image patches per forward pass.

    class_dict (dict):

    Mapping of classification outputs to class names.

    device (str):

    Device to run the model on (e.g., “cpu”, “cuda”).

    labels (list):

    Optional labels for input images. Only a single label per image is supported.

    memory_threshold (int):

    Memory usage threshold (percentage) to trigger caching behavior.

    num_workers (int):

    Number of workers for DataLoader and post-processing.

    output_file (str):

    Filename for saving output (e.g., “.zarr” or “.db”).

    output_resolutions (Resolution):

    Resolution used for writing output predictions.

    patch_output_shape (tuple[int, int]):

    Shape of output patches (height, width).

    return_labels (bool):

    Whether to return labels with predictions.

    return_predictions (tuple(bool, …):

    Whether to return array predictions for individual tasks.

    return_probabilities (bool):

    Whether to return per-class probabilities.

    scale_factor (tuple[float, float]):

    Scale factor for annotations (model_mpp / slide_mpp). Used to convert coordinates to baseline resolution.

    stride_shape (tuple[int, int]):

    Stride used during WSI processing. Defaults to patch_input_shape if not provided.

    verbose (bool):

    Whether to enable verbose logging.

  • self (MultiTaskSegmentor)

Returns:

A task-organized dictionary suitable for saving, where each entry corresponds to a task produced by postproc_func. For each task (e.g., "semantic", "instance"), keys and value types depend on the model’s post-processing output. Typical patterns include:

  • "predictions": list[da.Array] with per-patch outputs, if the model returns patch-level prediction arrays.

  • "info_dict": list[dict] with per-patch metadata dictionaries (e.g., instance tables, properties). Lists are aligned to the number of input patches.

Any pre-existing keys in raw_predictions (e.g., "coordinates") are preserved as returned by build_post_process_raw_predictions.

Return type:

dict

Notes

  • This method is patch-level post-processing only; it does not perform WSI-scale tiling or stitching. For WSI outputs, use post_process_wsi.

  • Inputs are typically Dask arrays; computation remains lazy until an explicit save step or dask.compute is invoked downstream.

  • The exact set of task keys and payload shapes are determined by the model’s postproc_func for each head.

post_process_wsi(raw_predictions, save_path, **kwargs)[source]

Post-process whole slide image (WSI) predictions for multitask segmentation.

This method converts raw WSI-scale probability maps (produced by infer_wsi) into task-specific outputs using the model’s postproc_func. If the probability maps are fully in memory, the method processes the entire WSI at once. If they are Zarr-backed (spilled during inference) or too large, it switches to tile mode: it iterates over WSI tiles, applies postproc_func per tile, merges instance predictions across tile boundaries, and optionally writes intermediate arrays to Zarr under save_path.with_suffix(".zarr") for memory efficiency.

The result is organized into a task-centric dictionary (e.g., semantic, instance) with arrays and/or metadata suitable for saving or further use.

Parameters:
  • raw_predictions (dict) –

    Dictionary containing WSI-scale model outputs from infer_wsi. Expected key:

    • "probabilities" (tuple[da.Array]): One Dask array per model head. Each array is either memory-backed (Dask→NumPy) or Zarr-backed depending on memory spilling during inference.

  • save_path (Path) – Base path for writing intermediate Zarr arrays in tile mode and for allocating per-task outputs when disk-backed arrays are needed.

  • **kwargs (MultiTaskSegmentorRunParams) –

    Additional runtime parameters to configure segmentation.

    Optional Keys:
    auto_get_mask (bool):

    Automatically generate segmentation masks using wsireader.tissue_mask() during processing.

    batch_size (int):

    Number of image patches per forward pass.

    class_dict (dict):

    Mapping of classification outputs to class names.

    device (str):

    Device to run the model on (e.g., “cpu”, “cuda”).

    labels (list):

    Optional labels for input images. Only a single label per image is supported.

    memory_threshold (int):

    Memory usage threshold (percentage) to trigger caching behavior.

    num_workers (int):

    Number of workers for DataLoader and post-processing.

    output_file (str):

    Filename for saving output (e.g., “.zarr” or “.db”).

    output_resolutions (Resolution):

    Resolution used for writing output predictions.

    patch_output_shape (tuple[int, int]):

    Shape of output patches (height, width).

    return_labels (bool):

    Whether to return labels with predictions.

    return_predictions (tuple(bool, …):

    Whether to return array predictions for individual tasks.

    return_probabilities (bool):

    Whether to return per-class probabilities.

    scale_factor (tuple[float, float]):

    Scale factor for annotations (model_mpp / slide_mpp). Used to convert coordinates to baseline resolution.

    stride_shape (tuple[int, int]):

    Stride used during WSI processing. Defaults to patch_input_shape if not provided.

    verbose (bool):

    Whether to enable verbose logging.

  • self (MultiTaskSegmentor)

Returns:

A task-organized dictionary of WSI-scale outputs. For each task (e.g., "semantic", "instance"), typical entries include:

  • "predictions" (da.Array or np.ndarray, optional): Full-resolution task prediction map, present only where enabled by return_predictions.

  • Additional task-specific keys (e.g., "info_dict", per-instance dictionaries, contours, classes, probabilities).

The set of keys and their exact shapes/types are determined by the model’s postproc_func.

Return type:

dict

Notes

  • Full-WSI mode is selected when probability maps are not Zarr-backed; otherwise tile mode is used.

  • Tile mode uses model-specific merging of instances across tile boundaries and may write intermediate arrays under a .zarr group next to save_path.

  • Probability maps themselves are not modified here; this method produces task-centric outputs from them. Use save_predictions to persist results as dict, zarr, or annotationstore.

run(images, *, masks=None, input_resolutions=None, patch_input_shape=None, ioconfig=None, patch_mode=True, save_dir=None, overwrite=False, output_type='dict', **kwargs)[source]

Run the MultiTaskSegmentor engine on input images.

This method orchestrates the full inference pipeline, including preprocessing, model inference, post-processing, and saving results. It supports both patch-level and whole slide image (WSI) modes.

Parameters:
  • images (list[PathLike | WSIReader] | np.ndarray) – Input images or patches. Can be a list of file paths, WSIReader objects, or a NumPy array of image patches.

  • masks (list[PathLike] | np.ndarray | None) – Optional masks for WSI processing. Only used when patch_mode is False.

  • input_resolutions (list[dict[Units, Resolution]] | None) – Resolution settings for input heads. Supported units are level, power, and mpp. Keys should be “units” and “resolution”, e.g., [{“units”: “mpp”, “resolution”: 0.25}]. See WSIReader for details.

  • patch_input_shape (IntPair | None) – Shape of input patches (height, width), requested at read resolution. Must be positive.

  • ioconfig (IOSegmentorConfig | None) – IO configuration for patch extraction and resolution.

  • patch_mode (bool) – Whether to treat input as patches (True) or WSIs (False). Default is True.

  • save_dir (PathLike | None) – Directory to save output files. Required for WSI mode.

  • overwrite (bool) – Whether to overwrite existing output files. Default is False.

  • output_type (str) – Desired output format: “dict”, “zarr”, or “annotationstore”. Default is “dict”.

  • **kwargs (MultiTaskSegmentorRunParams) –

    Additional runtime parameters to configure segmentation.

    Optional Keys:
    auto_get_mask (bool):

    Automatically generate segmentation masks using wsireader.tissue_mask() during processing.

    batch_size (int):

    Number of image patches per forward pass.

    class_dict (dict):

    Mapping of classification outputs to class names.

    device (str):

    Device to run the model on (e.g., “cpu”, “cuda”).

    labels (list):

    Optional labels for input images. Only a single label per image is supported.

    memory_threshold (int):

    Memory usage threshold (percentage) to trigger caching behavior.

    num_workers (int):

    Number of workers for DataLoader and post-processing.

    output_file (str):

    Filename for saving output (e.g., “.zarr” or “.db”).

    output_resolutions (Resolution):

    Resolution used for writing output predictions.

    patch_output_shape (tuple[int, int]):

    Shape of output patches (height, width).

    return_labels (bool):

    Whether to return labels with predictions.

    return_predictions (tuple(bool, …):

    Whether to return array predictions for individual tasks.

    return_probabilities (bool):

    Whether to return per-class probabilities.

    scale_factor (tuple[float, float]):

    Scale factor for annotations (model_mpp / slide_mpp). Used to convert coordinates to baseline resolution.

    stride_shape (tuple[int, int]):

    Stride used during WSI processing. Defaults to patch_input_shape if not provided.

    verbose (bool):

    Whether to enable verbose logging.

  • self (MultiTaskSegmentor)

Returns:

  • If patch_mode is True: returns predictions or path to saved output.

  • If patch_mode is False: returns a dictionary mapping each WSI to its output path.

Return type:

AnnotationStore | Path | str | dict | list[Path]

Examples

>>> wsis = ['wsi1.svs', 'wsi2.svs']
>>> image_patches = [np.ndarray, np.ndarray]
>>> mtsegmentor = MultiTaskSegmentor(model="hovernet_fast-pannuke")
>>> output = mtsegmentor.run(image_patches, patch_mode=True)
>>> output
... "/path/to/Output.db"
>>> output = mtsegmentor.run(
...     image_patches,
...     patch_mode=True,
...     output_type="zarr"
... )
>>> output
... "/path/to/Output.zarr"
>>> output = mtsegmentor.run(wsis, patch_mode=False)
>>> output.keys()
... ['wsi1.svs', 'wsi2.svs']
>>> output['wsi1.svs']
... "/path/to/wsi1.db"
save_predictions(processed_predictions, output_type, save_path=None, **kwargs)[source]

Save model predictions to disk or return them in memory.

Depending on output_type, this method either:
  • returns a Python dictionary ("dict"),

  • writes a Zarr group to disk and returns the path ("zarr"), or

  • writes one or more SQLite-backed AnnotationStore .db files and returns the resulting path(s) ("annotationstore").

For multitask outputs, this function also:
  • Preserves task separation when saving to Zarr (one group per task).

  • Optionally saves raw probability maps if return_probabilities=True (as Zarr only; probabilities cannot be written to AnnotationStore).

  • Merges per-task keys for saving to AnnotationStore, including optional coordinates to establish slide origin.

Parameters:
  • processed_predictions (dict) –

    Task-organized dictionary produced by post-processing (e.g. from post_process_patches or post_process_wsi). For multitask models this typically includes:

    • "probabilities" (optional): list[da.Array] of WSI maps, present if preserved for saving.

    • Per-task sub-dicts (e.g., "semantic", "instance"), each containing task-specific arrays/metadata such as "predictions", "info_dict", etc.

    • "coordinates" (optional): Dask/NumPy array used to set spatial origin when saving vector outputs.

  • output_type (str) – Desired output format. Supported values are: "dict", "zarr", or "annotationstore" (case-sensitive).

  • save_path (Path | None) – Base filesystem path for file outputs. Required for "zarr" and "annotationstore". For Zarr, a save_path.with_suffix(".zarr") group is used. For AnnotationStore, .db files are written (one per image in patch mode, one per WSI in WSI mode). Ignored when output_type="dict".

  • **kwargs (MultiTaskSegmentorRunParams) –

    Additional runtime parameters to configure segmentation.

    Optional Keys:
    auto_get_mask (bool):

    Automatically generate segmentation masks using wsireader.tissue_mask() during processing.

    batch_size (int):

    Number of image patches per forward pass.

    class_dict (dict):

    Mapping of classification outputs to class names.

    device (str):

    Device to run the model on (e.g., “cpu”, “cuda”).

    labels (list):

    Optional labels for input images. Only a single label per image is supported.

    memory_threshold (int):

    Memory usage threshold (percentage) to trigger caching behavior.

    num_workers (int):

    Number of workers for DataLoader and post-processing.

    output_file (str):

    Filename for saving output (e.g., “.zarr” or “.db”).

    output_resolutions (Resolution):

    Resolution used for writing output predictions.

    patch_output_shape (tuple[int, int]):

    Shape of output patches (height, width).

    return_labels (bool):

    Whether to return labels with predictions.

    return_predictions (tuple(bool, …):

    Whether to return array predictions for individual tasks.

    return_probabilities (bool):

    Whether to return per-class probabilities.

    scale_factor (tuple[float, float]):

    Scale factor for annotations (model_mpp / slide_mpp). Used to convert coordinates to baseline resolution.

    stride_shape (tuple[int, int]):

    Stride used during WSI processing. Defaults to patch_input_shape if not provided.

    verbose (bool):

    Whether to enable verbose logging.

  • self (MultiTaskSegmentor)

Returns:

  • If output_type == "dict":

    Returns the (possibly simplified) prediction dictionary. For a single task, the task level is flattened.

  • If output_type == "zarr":

    Returns the Path to the saved .zarr group.

  • If output_type == "annotationstore":

    Returns a list of paths to saved .db files (patch mode), or a single path / store handle for WSI mode. If probability maps were requested for saving, the Zarr path holding those maps may also be included.

Return type:

dict | AnnotationStore | Path | list[Path]

Raises:

TypeError – If an unsupported output_type is provided.

Notes

  • For "dict" and "zarr", saving is delegated to _save_predictions_as_dict_zarr to keep behavior aligned across engines.

  • When output_type == "annotationstore", arrays are first computed (via a Zarr/dict pass) to obtain concrete NumPy payloads suitable for vector export, after which per-task stores are written using _save_predictions_as_annotationstore.

  • If return_probabilities=True, probability maps are written only to Zarr, never to AnnotationStore. A guidance message is logged describing how to visualize heatmaps (e.g., converting to OME-TIFF).