multi_task_segmentor¶
Multi-task segmentation engine for computational pathology.
This module implements the MultiTaskSegmentor and supporting utilities to
run multi-head segmentation models (e.g., HoVerNet/HoVerNetplus-style architectures)
on histology images in both patch and whole slide image (WSI) workflows.
It provides consistent orchestration for data loading, model invocation, tiled
stitching, memory-aware caching, post-processing per task, and saving outputs
to in-memory dictionaries, Zarr, or AnnotationStore (.db).
- Overview
Patch mode: infer_patches runs a model on batches of image patches, producing one probability/logit tensor per model head. post_process_patches converts these into task-specific outputs (e.g., semantic maps, instances).
WSI mode: infer_wsi iterates over all WSI patches, assembles head outputs via horizontal row-merge and vertical normalization, and returns WSI-scale probability maps per head. post_process_wsi consumes these maps in either full-WSI or tile mode (for Zarr-backed arrays), deriving task-centric outputs and merging instances across tile boundaries.
Memory awareness: intermediate accumulators spill to Zarr automatically once usage exceeds a configurable memory_threshold, enabling processing of very large slides on limited RAM.
- Key Classes
- MultiTaskSegmentor
Core engine for multi-head segmentation. Extends
SemanticSegmentorto run models with multiple output heads and to produce task-centric predictions after post-processing. Supports patch and WSI workflows, dict/Zarr/AnnotationStore outputs, and device/batch/stride configuration.- MultiTaskSegmentorRunParams
TypedDict of runtime parameters used across the engine. Extends
SemanticSegmentorRunParamswith additional multitask option: return_predictions.
- Important Functions
- infer_patches(dataloader, *, return_coordinates=False) -> dict
Run model on a collection of patches; returns per-head probabilities as Dask arrays and optionally patch coordinates.
- infer_wsi(dataloader, save_path, **kwargs) -> dict
Run model on a WSI via patch extraction and incremental stitching, with optional Zarr caching when memory pressure is high.
- post_process_patches(raw_predictions, **kwargs) -> dict
Apply the model’s post-processing per patch and reorganize results into a task-centric dictionary (e.g., “semantic”, “instance”).
- post_process_wsi(raw_predictions, save_path, **kwargs) -> dict
Convert WSI-scale head maps into task-specific outputs, either in memory (full-WSI) or via tile-mode with instance de-duplication across tile boundaries.
- save_predictions(processed_predictions, output_type, save_path=None, **kwargs)
Persist results as dict, zarr, or annotationstore. Probability maps are saved to Zarr; vector outputs are written to AnnotationStore.
- Helper utilities
- build_post_process_raw_predictions(…)
Group per-image outputs by task and normalize array/dict payloads.
- prepare_multitask_full_batch(…)
Align a batch’s predictions to global output indices and pad the tail.
- merge_multitask_horizontal(…)
Row-wise stitching of patch predictions for each head.
- save_multitask_to_cache(…)
Spill accumulated row blocks (canvas/count) to Zarr.
- merge_multitask_vertical_chunkwise(…)
Normalize and merge rows vertically into final WSI probability maps.
- dict_to_store(…)
Convert polygonal task predictions to
AnnotationStorerecords.
- Inputs and Outputs
Inputs: lists of file paths or
WSIReaderinstances (WSI mode), or np.ndarray patches (NHWC) in patch mode. Optional masks and IO configs control extraction resolution, patch/tile shapes, and stride.Raw outputs: per-head probability maps/logits as Dask arrays (patch- or WSI-scale).
Post-processed outputs: task-centric dictionaries (e.g., instance tables, semantic predictions), optionally including full-resolution prediction arrays if requested via return_predictions.
- Saved outputs:
dict: in-memory Python structures
zarr: hierarchical arrays (optionally with probability maps)
annotationstore: SQLite-backed vector annotations (.db)
Examples
- Patch-mode prediction:
>>> patches = [np.ndarray, np.ndarray] # NHWC >>> mt = MultiTaskSegmentor(model="hovernetplus-oed", device="cuda") >>> out = mt.run(patches, patch_mode=True, output_type="dict")
- WSI-mode prediction with Zarr caching and AnnotationStore output:
>>> wsis = [Path("slide1.svs"), Path("slide2.svs")] >>> mt = MultiTaskSegmentor(model="hovernet_fast-pannuke", device="cuda") >>> out = mt.run( ... wsis, ... patch_mode=False, ... save_dir=Path("outputs/"), ... output_type="annotationstore", ... memory_threshold=80, ... auto_get_mask=True, ... overwrite=True, ... )
Notes
The engine infers the number of model heads from the first infer_batch call and maintains per-head arrays throughout merging.
Probability normalization is performed during the final vertical merge (row accumulation divided by row counts).
Probability maps are not written to AnnotationStore; use Zarr to persist them and convert to OME-TIFF separately if needed for visualization.
Functions
Write polygonal multitask predictions into an QuPath JSON or AnnotationStore. |
|
Merge horizontally a run of patch outputs into per-head row blocks. |
|
Merge horizontally stitched row blocks into final WSI probability maps. |
|
Align patch predictions to the global output index and pad to cover gaps. |
|
Helper to retrieved selected instance uids. |
|
Write accumulated horizontal row blocks to a Zarr cache on disk. |
Classes
Compute and write TIAToolbox annotations using batched Dask Delayed tasks. |
|
MultiTask segmentation engine to run models like hovernet and hovernetplus. |
|
Runtime parameters for configuring the MultiTaskSegmentor.run() method. |