SemanticSegmentor¶
- class SemanticSegmentor(model, batch_size=8, num_workers=0, weights=None, *, device='cpu', verbose=True)[source]¶
Semantic segmentation engine for digital histology images.
This class extends PatchPredictor to support semantic segmentation tasks using pretrained or custom models from TIAToolbox. It supports both patch-level and whole slide image (WSI) processing, and provides utilities for merging, post-processing, and saving predictions.
- Performance:
The TIAToolbox model fcn_resnet50_unet-bcss achieves the following results on the BCSS dataset:
Semantic segmentation performance on the BCSS dataset¶ Tumour
Stroma
Inflammatory
Necrosis
Other
All
Amgad et al.
0.851
0.800
0.712
0.723
0.666
0.750
TIAToolbox
0.885
0.825
0.761
0.765
0.581
0.763
- Parameters:
model (str | ModelABC) – A PyTorch model instance or name of a pretrained model from TIAToolbox. The user can request pretrained models from the toolbox model zoo using the list of pretrained models available at this link By default, the corresponding pretrained weights will also be downloaded. However, you can override with your own set of weights using the weights parameter. Default is None.
batch_size (int) – Number of image patches processed per forward pass. Default is 8.
num_workers (int) – Number of workers for data loading. Default is 0.
weights (str | Path | None) –
Path to model weights. If None, default weights are used.
>>> engine = SemanticSegmentor( ... model="pretrained-model", ... weights="/path/to/pretrained-local-weights.pth" ... )
device (str) – Device to run the model on (e.g., “cpu”, “cuda”). Default is “cpu”.
verbose (bool) – Whether to enable verbose logging. Default is True.
- masks¶
Optional tissue masks for WSI processing. These are only utilized when patch_mode is False. If not provided, then a tissue mask will be automatically generated for whole slide images.
- ioconfig¶
IO configuration for patch extraction and resolution.
- Type:
- input_resolutions¶
Resolution settings for model input. Supported units are level, power and mpp. Keys should be “units” and “resolution” e.g., [{“units”: “mpp”, “resolution”: 0.25}]. Please see
WSIReaderfor details.
- patch_input_shape¶
Shape of input patches (height, width). Patches are at requested read resolution, not with respect to level 0, and must be positive.
- stride_shape¶
Stride used during patch extraction. Stride is at requested read resolution, not with respect to level 0, and must be positive. If not provided, stride_shape=patch_input_shape.
- labels¶
Optional labels for input images. Only a single label per image is supported.
- Type:
list | None
Examples
>>> # list of 2 image patches as input >>> wsis = ['path/img.svs', 'path/img.svs'] >>> segmentor = SemanticSegmentor(model="fcn-tissue_mask") >>> output = segmentor.run(wsis, patch_mode=False)
>>> # array of list of 2 image patches as input >>> image_patches = [np.ndarray, np.ndarray] >>> segmentor = SemanticSegmentor(model="fcn-tissue_mask") >>> output = segmentor.run(image_patches, patch_mode=True)
>>> # list of 2 image patch files as input >>> data = ['path/img.png', 'path/img.png'] >>> segmentor = SemanticSegmentor(model="fcn-tissue_mask") >>> output = segmentor.run(data, patch_mode=False)
>>> # list of 2 image tile files as input >>> tile_file = ['path/tile1.png', 'path/tile2.png'] >>> segmentor = SemanticSegmentor(model="fcn-tissue_mask") >>> output = segmentor.run(tile_file, patch_mode=False)
>>> # list of 2 wsi files as input >>> wsis = ['path/wsi1.svs', 'path/wsi2.svs'] >>> segmentor = SemanticSegmentor(model="resnet18-kather100k") >>> output = segmentor.run(wsis, patch_mode=False)
References
[1] Amgad M, Elfandy H, …, Gutman DA, Cooper LAD. Structured crowdsourcing enables convolutional segmentation of histology images. Bioinformatics 2019. doi: 10.1093/bioinformatics/btz083
Initialize
SemanticSegmentor.- Parameters:
model (str | ModelABC) – A PyTorch model instance or name of a pretrained model from TIAToolbox. If a string is provided, the corresponding pretrained weights will be downloaded unless overridden via weights.
batch_size (int) – Number of image patches processed per forward pass. Default is 8.
num_workers (int) – Number of workers for data loading. Default is 0.
weights (str | Path | None) – Path to model weights. If None, default weights are used.
device (str) – Device to run the model on (e.g., “cpu”, “cuda”). Default is “cpu”.
verbose (bool) – Whether to enable verbose logging. Default is True.
Methods
Pre-process images and masks and return a DataLoader for inference.
Perform model inference on a whole slide image (WSI).
Run the semantic segmentation engine on input images.
Save semantic segmentation predictions to disk or return them in memory.
- get_dataloader(images, masks=None, labels=None, ioconfig=None, *, patch_mode=True, auto_get_mask=True)[source]¶
Pre-process images and masks and return a DataLoader for inference.
This method prepares the dataset and returns a PyTorch DataLoader for either patch-based or WSI-based semantic segmentation. It overrides the base method to support additional WSI-specific logic, including patch output shape and output location tracking.
- Parameters:
images (str | Path | list[str | Path] | np.ndarray) – Input images. Can be a list of file paths or a NumPy array of image patches in NHWC format.
masks (Path | None) – Optional tissue masks for WSI processing. Only used when patch_mode is False.
labels (list | None) – Optional labels for input images. Only one label per image is supported.
ioconfig (IOSegmentorConfig | None) – IO configuration for patch extraction and resolution.
patch_mode (bool) – Whether to treat input as patches (True) or WSIs (False).
auto_get_mask (bool) – Whether to automatically generate a tissue mask using wsireader.tissue_mask() when patch_mode is False. If True, only tissue regions are processed. If False, all patches are processed. Default is True.
self (SemanticSegmentor)
- Returns:
A PyTorch DataLoader configured for semantic segmentation inference.
- Return type:
- infer_wsi(dataloader, save_path, **kwargs)[source]¶
Perform model inference on a whole slide image (WSI).
This method processes a WSI using the provided DataLoader, merges patch-level predictions into a full-resolution canvas, and returns the aggregated output. It supports memory-aware caching and optional inclusion of coordinates and labels.
- Parameters:
dataloader (DataLoader) – PyTorch DataLoader configured for WSI processing.
save_path (Path) – Path to save the intermediate output. The intermediate output is saved in a Zarr file.
**kwargs (SemanticSegmentorRunParams) –
Additional runtime parameters to configure segmentation.
- Optional Keys:
- auto_get_mask (bool):
Automatically generate segmentation masks using wsireader.tissue_mask() during processing.
- batch_size (int):
Number of image patches per forward pass.
- class_dict (dict):
Mapping of classification outputs to class names.
- device (str):
Device to run the model on (e.g., “cpu”, “cuda”).
- labels (list):
Optional labels for input images. Only a single label per image is supported.
- memory_threshold (int):
Memory usage threshold (percentage) to trigger caching behavior.
- num_workers (int):
Number of workers for DataLoader and post-processing.
- output_file (str):
Filename for saving output (e.g., “.zarr” or “.db”).
- output_resolutions (Resolution):
Resolution used for writing output predictions.
- patch_output_shape (tuple[int, int]):
Shape of output patches (height, width).
- return_labels (bool):
Whether to return labels with predictions.
- return_probabilities (bool):
Whether to return per-class probabilities.
- scale_factor (tuple[float, float]):
Scale factor for annotations (model_mpp / slide_mpp). Used to convert coordinates to baseline resolution.
- stride_shape (tuple[int, int]):
Stride used during WSI processing. Defaults to patch_input_shape if not provided.
- verbose (bool):
Whether to enable verbose logging.
self (SemanticSegmentor)
- Returns:
Dictionary containing merged prediction results: - “probabilities”: Full-resolution probability map. - “coordinates”: Patch coordinates. - “labels”: Ground truth labels (if return_labels is True).
- Return type:
- run(images, *, masks=None, input_resolutions=None, patch_input_shape=None, ioconfig=None, patch_mode=True, save_dir=None, overwrite=False, output_type='dict', **kwargs)[source]¶
Run the semantic segmentation engine on input images.
This method orchestrates the full inference pipeline, including preprocessing, model inference, post-processing, and saving results. It supports both patch-level and whole slide image (WSI) modes.
- Parameters:
images (list[PathLike | WSIReader] | np.ndarray) – Input images or patches. Can be a list of file paths, WSIReader objects, or a NumPy array of image patches.
masks (list[PathLike] | np.ndarray | None) – Optional masks for WSI processing. Only used when patch_mode is False.
input_resolutions (list[dict[Units, Resolution]] | None) – Resolution settings for input heads. Supported units are level, power, and mpp. Keys should be “units” and “resolution”, e.g., [{“units”: “mpp”, “resolution”: 0.25}]. See
WSIReaderfor details.patch_input_shape (IntPair | None) – Shape of input patches (height, width), requested at read resolution. Must be positive.
ioconfig (IOSegmentorConfig | None) – IO configuration for patch extraction and resolution.
patch_mode (bool) – Whether to treat input as patches (True) or WSIs (False). Default is True.
save_dir (PathLike | None) – Directory to save output files. Required for WSI mode.
overwrite (bool) – Whether to overwrite existing output files. Default is False.
output_type (str) – Desired output format: “dict”, “zarr”, “qupath”, or “annotationstore”. Default is “dict”.
**kwargs (SemanticSegmentorRunParams) –
Additional runtime parameters to configure segmentation.
- Optional Keys:
- auto_get_mask (bool):
Automatically generate segmentation masks using wsireader.tissue_mask() during processing.
- batch_size (int):
Number of image patches per forward pass.
- class_dict (dict):
Mapping of classification outputs to class names.
- device (str):
Device to run the model on (e.g., “cpu”, “cuda”).
- labels (list):
Optional labels for input images. Only a single label per image is supported.
- memory_threshold (int):
Memory usage threshold (percentage) to trigger caching behavior.
- num_workers (int):
Number of workers for DataLoader and post-processing.
- output_file (str):
Filename for saving output (e.g., “.zarr” or “.db”).
- output_resolutions (Resolution):
Resolution used for writing output predictions.
- patch_output_shape (tuple[int, int]):
Shape of output patches (height, width).
- return_labels (bool):
Whether to return labels with predictions.
- return_probabilities (bool):
Whether to return per-class probabilities.
- scale_factor (tuple[float, float]):
Scale factor for annotations (model_mpp / slide_mpp). Used to convert coordinates to baseline resolution.
- stride_shape (tuple[int, int]):
Stride used during WSI processing. Defaults to patch_input_shape if not provided.
- verbose (bool):
Whether to enable verbose logging.
self (SemanticSegmentor)
- Returns:
If patch_mode is True: returns predictions or path to saved output.
If patch_mode is False: returns a dictionary mapping each WSI to its output path.
- Return type:
AnnotationStore | Path | str | dict | list[Path]
Examples
>>> wsis = ['wsi1.svs', 'wsi2.svs'] >>> image_patches = [np.ndarray, np.ndarray] >>> segmentor = SemanticSegmentor(model="fcn-tissue_mask") >>> output = segmentor.run(image_patches, patch_mode=True) >>> output ... "/path/to/Output.db"
>>> output = segmentor.run( ... image_patches, ... patch_mode=True, ... output_type="zarr" ... ) >>> output ... "/path/to/Output.zarr"
>>> output = segmentor.run(wsis, patch_mode=False) >>> output.keys() ... ['wsi1.svs', 'wsi2.svs'] >>> output['wsi1.svs'] ... "/path/to/wsi1.db"
- save_predictions(processed_predictions, output_type, save_path=None, **kwargs)[source]¶
Save semantic segmentation predictions to disk or return them in memory.
This method saves predictions in one of the supported formats: - “dict”: returns predictions as a Python dictionary. - “zarr”: saves predictions as a Zarr group and returns the path. - “annotationstore”: converts predictions to an AnnotationStore (.db file).
If patch_mode is True, predictions are saved per image. If False, predictions are merged and saved as a single output.
- Parameters:
processed_predictions (dict) – Dictionary containing processed model predictions.
output_type (str) – Desired output format: “dict”, “zarr”, “qupath” or “annotationstore”.
save_path (Path | None) – Path to save the output file. Required for “zarr”, “qupath” and “annotationstore”.
**kwargs (SemanticSegmentorRunParams) –
Additional runtime parameters to configure segmentation.
- Optional Keys:
- auto_get_mask (bool):
Automatically generate segmentation masks using wsireader.tissue_mask() during processing.
- batch_size (int):
Number of image patches per forward pass.
- class_dict (dict):
Mapping of classification outputs to class names.
- device (str):
Device to run the model on (e.g., “cpu”, “cuda”).
- labels (list):
Optional labels for input images. Only a single label per image is supported.
- memory_threshold (int):
Memory usage threshold (percentage) to trigger caching behavior.
- num_workers (int):
Number of workers for DataLoader and post-processing.
- output_file (str):
Filename for saving output (e.g., “.zarr” or “.db”).
- output_resolutions (Resolution):
Resolution used for writing output predictions.
- patch_output_shape (tuple[int, int]):
Shape of output patches (height, width).
- return_labels (bool):
Whether to return labels with predictions.
- return_probabilities (bool):
Whether to return per-class probabilities.
- scale_factor (tuple[float, float]):
Scale factor for annotations (model_mpp / slide_mpp). Used to convert coordinates to baseline resolution.
- stride_shape (tuple[int, int]):
Stride used during WSI processing. Defaults to patch_input_shape if not provided.
- verbose (bool):
Whether to enable verbose logging.
self (SemanticSegmentor)
- Returns:
If output_type is “dict”: returns predictions as a dictionary.
If output_type is “zarr”: returns path to saved Zarr file.
If output_type is “qupath”: returns QuPath JSON or path or list of paths to .json file.
If output_type is “annotationstore”: returns AnnotationStore or path or list of paths to .db file.
- Return type:
dict | AnnotationStore | Path | list[Path]