multi_task_segmentor

Multi-task segmentation engine for computational pathology.

This module implements the MultiTaskSegmentor and supporting utilities to run multi-head segmentation models (e.g., HoVerNet/HoVerNetplus-style architectures) on histology images in both patch and whole slide image (WSI) workflows. It provides consistent orchestration for data loading, model invocation, tiled stitching, memory-aware caching, post-processing per task, and saving outputs to in-memory dictionaries, Zarr, or AnnotationStore (.db).

Overview
  • Patch mode: infer_patches runs a model on batches of image patches, producing one probability/logit tensor per model head. post_process_patches converts these into task-specific outputs (e.g., semantic maps, instances).

  • WSI mode: infer_wsi iterates over all WSI patches, assembles head outputs via horizontal row-merge and vertical normalization, and returns WSI-scale probability maps per head. post_process_wsi consumes these maps in either full-WSI or tile mode (for Zarr-backed arrays), deriving task-centric outputs and merging instances across tile boundaries.

  • Memory awareness: intermediate accumulators spill to Zarr automatically once usage exceeds a configurable memory_threshold, enabling processing of very large slides on limited RAM.

Key Classes
MultiTaskSegmentor

Core engine for multi-head segmentation. Extends SemanticSegmentor to run models with multiple output heads and to produce task-centric predictions after post-processing. Supports patch and WSI workflows, dict/Zarr/AnnotationStore outputs, and device/batch/stride configuration.

MultiTaskSegmentorRunParams

TypedDict of runtime parameters used across the engine. Extends SemanticSegmentorRunParams with additional multitask option: return_predictions.

Important Functions
infer_patches(dataloader, *, return_coordinates=False) -> dict

Run model on a collection of patches; returns per-head probabilities as Dask arrays and optionally patch coordinates.

infer_wsi(dataloader, save_path, **kwargs) -> dict

Run model on a WSI via patch extraction and incremental stitching, with optional Zarr caching when memory pressure is high.

post_process_patches(raw_predictions, **kwargs) -> dict

Apply the model’s post-processing per patch and reorganize results into a task-centric dictionary (e.g., “semantic”, “instance”).

post_process_wsi(raw_predictions, save_path, **kwargs) -> dict

Convert WSI-scale head maps into task-specific outputs, either in memory (full-WSI) or via tile-mode with instance de-duplication across tile boundaries.

save_predictions(processed_predictions, output_type, save_path=None, **kwargs)

Persist results as dict, zarr, or annotationstore. Probability maps are saved to Zarr; vector outputs are written to AnnotationStore.

Helper utilities
  • build_post_process_raw_predictions(…)

    Group per-image outputs by task and normalize array/dict payloads.

  • prepare_multitask_full_batch(…)

    Align a batch’s predictions to global output indices and pad the tail.

  • merge_multitask_horizontal(…)

    Row-wise stitching of patch predictions for each head.

  • save_multitask_to_cache(…)

    Spill accumulated row blocks (canvas/count) to Zarr.

  • merge_multitask_vertical_chunkwise(…)

    Normalize and merge rows vertically into final WSI probability maps.

  • dict_to_store(…)

    Convert polygonal task predictions to AnnotationStore records.

Inputs and Outputs
  • Inputs: lists of file paths or WSIReader instances (WSI mode), or np.ndarray patches (NHWC) in patch mode. Optional masks and IO configs control extraction resolution, patch/tile shapes, and stride.

  • Raw outputs: per-head probability maps/logits as Dask arrays (patch- or WSI-scale).

  • Post-processed outputs: task-centric dictionaries (e.g., instance tables, semantic predictions), optionally including full-resolution prediction arrays if requested via return_predictions.

  • Saved outputs:
    • dict: in-memory Python structures

    • zarr: hierarchical arrays (optionally with probability maps)

    • annotationstore: SQLite-backed vector annotations (.db)

Examples

Patch-mode prediction:
>>> patches = [np.ndarray, np.ndarray]  # NHWC
>>> mt = MultiTaskSegmentor(model="hovernetplus-oed", device="cuda")
>>> out = mt.run(patches, patch_mode=True, output_type="dict")
WSI-mode prediction with Zarr caching and AnnotationStore output:
>>> wsis = [Path("slide1.svs"), Path("slide2.svs")]
>>> mt = MultiTaskSegmentor(model="hovernet_fast-pannuke", device="cuda")
>>> out = mt.run(
...     wsis,
...     patch_mode=False,
...     save_dir=Path("outputs/"),
...     output_type="annotationstore",
...     memory_threshold=80,
...     auto_get_mask=True,
...     overwrite=True,
... )

Notes

  • The engine infers the number of model heads from the first infer_batch call and maintains per-head arrays throughout merging.

  • Probability normalization is performed during the final vertical merge (row accumulation divided by row counts).

  • Probability maps are not written to AnnotationStore; use Zarr to persist them and convert to OME-TIFF separately if needed for visualization.

Functions

dict_to_json_store

Write polygonal multitask predictions into an QuPath JSON or AnnotationStore.

merge_multitask_horizontal

Merge horizontally a run of patch outputs into per-head row blocks.

merge_multitask_vertical_chunkwise

Merge horizontally stitched row blocks into final WSI probability maps.

prepare_multitask_full_batch

Align patch predictions to the global output index and pad to cover gaps.

retrieve_sel_uids

Helper to retrieved selected instance uids.

save_multitask_to_cache

Write accumulated horizontal row blocks to a Zarr cache on disk.

Classes

DaskDelayedJSONStore

Compute and write TIAToolbox annotations using batched Dask Delayed tasks.

MultiTaskSegmentor

MultiTask segmentation engine to run models like hovernet and hovernetplus.

MultiTaskSegmentorRunParams

Runtime parameters for configuring the MultiTaskSegmentor.run() method.