Source code for tiatoolbox.models.engine.semantic_segmentor

"""Semantic Segmentation Engine for Whole Slide Images (WSIs) using TIAToolbox.

This module defines the `SemanticSegmentor` class, which extends the `PatchPredictor`
engine to support semantic segmentation workflows on digital pathology images.
It leverages deep learning models from TIAToolbox to perform patch-level and
WSI-level inference, and includes utilities for preprocessing, postprocessing,
and saving predictions in various formats.

Key Components:
---------------
Classes:
- SemanticSegmentorRunParams:
    Configuration parameters for controlling runtime behavior during segmentation.
- SemanticSegmentor:
    Core engine for performing semantic segmentation on image patches or WSIs.

Functions:
- concatenate_none:
    Concatenate arrays while gracefully handling None values.
- merge_horizontal:
    Incrementally merge horizontal patches and update location arrays.
- save_to_cache:
    Save intermediate canvas and count arrays to Zarr cache.
- merge_vertical_chunkwise:
    Merge vertically chunked canvas and count arrays into a probability map.
- store_probabilities:
    Store computed probability data in Zarr or Dask arrays.
- prepare_full_batch:
    Align patch-level predictions with global output locations.

Example:
>>> from tiatoolbox.models.engine.semantic_segmentor import SemanticSegmentor
>>> segmentor = SemanticSegmentor(model="fcn_resnet50_unet-bcss")
>>> wsis = ["slide1.svs", "slide2.svs"]
>>> output = segmentor.run(wsis, patch_mode=False)
>>>
>>> patches = [np.ndarray, np.ndarray]
>>> segmentor = SemanticSegmentor(model="fcn_resnet50_unet-bcss")
>>> output = segmentor.run(patches, patch_mode=True, output_type="dict")

Notes:
------
- Supports both patch-based and WSI-based segmentation.
- Compatible with TIAToolbox pretrained models and custom PyTorch models.
- Outputs can be saved as dictionaries, Zarr arrays, or AnnotationStore databases.
- Includes memory-aware caching and efficient merging strategies for large-scale
  inference.

"""

from __future__ import annotations

import gc
import shutil
import tempfile
from pathlib import Path
from typing import TYPE_CHECKING

import dask.array as da
import numpy as np
import psutil
import torch
import zarr
from tqdm.auto import tqdm
from typing_extensions import Unpack

from tiatoolbox import logger
from tiatoolbox.models.dataset.dataset_abc import WSIPatchDataset
from tiatoolbox.utils.misc import (
    dict_to_store_semantic_segmentor,
    update_tqdm_desc,
)
from tiatoolbox.wsicore.wsireader import WSIReader, is_zarr

from .patch_predictor import PatchPredictor, PredictorRunParams

if TYPE_CHECKING:  # pragma: no cover
    import os

    from torch.utils.data import DataLoader

    from tiatoolbox.annotation import AnnotationStore
    from tiatoolbox.models.engine.io_config import IOSegmentorConfig
    from tiatoolbox.models.models_abc import ModelABC
    from tiatoolbox.type_hints import IntPair, Resolution, Units


[docs] class SemanticSegmentorRunParams(PredictorRunParams, total=False): """Runtime parameters for configuring the `SemanticSegmentor.run()` method. This class extends `PredictorRunParams`, which itself extends `EngineABCRunParams`, and adds parameters specific to semantic segmentation workflows. Attributes: auto_get_mask (bool): Whether to automatically generate segmentation masks using `wsireader.tissue_mask()` during processing. batch_size (int): Number of image patches to feed to the model in a forward pass. class_dict (dict): Optional dictionary mapping classification outputs to class names. device (str): Device to run the model on (e.g., "cpu", "cuda"). labels (list): Optional labels for input images. Only a single label per image is supported. memory_threshold (int): Memory usage threshold (in percentage) to trigger caching behavior. num_workers (int): Number of workers used in DataLoader. output_file (str): Output file name for saving results (e.g., .zarr or .db). output_resolutions (Resolution): Resolution used for writing output predictions. patch_output_shape (tuple[int, int]): Shape of output patches (height, width). return_labels (bool): Whether to return labels with predictions. return_probabilities (bool): Whether to return per-class probabilities. scale_factor (tuple[float, float]): Scale factor for converting annotations to baseline resolution. Typically model_mpp / slide_mpp. stride_shape (tuple[int, int]): Stride used during WSI processing. Defaults to patch_input_shape. verbose (bool): Whether to output logging information. """ patch_output_shape: tuple[int, int] output_resolutions: list[dict[Units, Resolution]]
[docs] class SemanticSegmentor(PatchPredictor): r"""Semantic segmentation engine for digital histology images. This class extends `PatchPredictor` to support semantic segmentation tasks using pretrained or custom models from TIAToolbox. It supports both patch-level and whole slide image (WSI) processing, and provides utilities for merging, post-processing, and saving predictions. Performance: The TIAToolbox model `fcn_resnet50_unet-bcss` achieves the following results on the BCSS dataset: .. list-table:: Semantic segmentation performance on the BCSS dataset :widths: 15 15 15 15 15 15 15 :header-rows: 1 * - - Tumour - Stroma - Inflammatory - Necrosis - Other - All * - Amgad et al. - 0.851 - 0.800 - 0.712 - 0.723 - 0.666 - 0.750 * - TIAToolbox - 0.885 - 0.825 - 0.761 - 0.765 - 0.581 - 0.763 Args: model (str | ModelABC): A PyTorch model instance or name of a pretrained model from TIAToolbox. The user can request pretrained models from the toolbox model zoo using the list of pretrained models available at this `link <https://tia-toolbox.readthedocs.io/en/latest/pretrained.html>`_ By default, the corresponding pretrained weights will also be downloaded. However, you can override with your own set of weights using the `weights` parameter. Default is `None`. batch_size (int): Number of image patches processed per forward pass. Default is 8. num_workers (int): Number of workers for data loading. Default is 0. weights (str | Path | None): Path to model weights. If None, default weights are used. >>> engine = SemanticSegmentor( ... model="pretrained-model", ... weights="/path/to/pretrained-local-weights.pth" ... ) device (str): Device to run the model on (e.g., "cpu", "cuda"). Default is "cpu". verbose (bool): Whether to enable verbose logging. Default is True. Attributes: images (list[str | Path] | np.ndarray): Input image patches or WSI paths. masks (list[str | Path] | np.ndarray): Optional tissue masks for WSI processing. These are only utilized when patch_mode is False. If not provided, then a tissue mask will be automatically generated for whole slide images. patch_mode (bool): Whether input is treated as patches (`True`) or WSIs (`False`). model (ModelABC): Loaded PyTorch model. ioconfig (ModelIOConfigABC): IO configuration for patch extraction and resolution. return_labels (bool): Whether to include labels in the output. input_resolutions (list[dict]): Resolution settings for model input. Supported units are `level`, `power` and `mpp`. Keys should be "units" and "resolution" e.g., [{"units": "mpp", "resolution": 0.25}]. Please see :class:`WSIReader` for details. patch_input_shape (tuple[int, int]): Shape of input patches (height, width). Patches are at requested read resolution, not with respect to level 0, and must be positive. stride_shape (tuple[int, int]): Stride used during patch extraction. Stride is at requested read resolution, not with respect to level 0, and must be positive. If not provided, `stride_shape=patch_input_shape`. labels (list | None): Optional labels for input images. Only a single label per image is supported. drop_keys (list): Keys to exclude from model output. output_type (str): Format of output ("dict", "zarr", "qupath", "annotationstore"). output_locations (list | None): Coordinates of output patches used during WSI processing. Examples: >>> # list of 2 image patches as input >>> wsis = ['path/img.svs', 'path/img.svs'] >>> segmentor = SemanticSegmentor(model="fcn-tissue_mask") >>> output = segmentor.run(wsis, patch_mode=False) >>> # array of list of 2 image patches as input >>> image_patches = [np.ndarray, np.ndarray] >>> segmentor = SemanticSegmentor(model="fcn-tissue_mask") >>> output = segmentor.run(image_patches, patch_mode=True) >>> # list of 2 image patch files as input >>> data = ['path/img.png', 'path/img.png'] >>> segmentor = SemanticSegmentor(model="fcn-tissue_mask") >>> output = segmentor.run(data, patch_mode=False) >>> # list of 2 image tile files as input >>> tile_file = ['path/tile1.png', 'path/tile2.png'] >>> segmentor = SemanticSegmentor(model="fcn-tissue_mask") >>> output = segmentor.run(tile_file, patch_mode=False) >>> # list of 2 wsi files as input >>> wsis = ['path/wsi1.svs', 'path/wsi2.svs'] >>> segmentor = SemanticSegmentor(model="resnet18-kather100k") >>> output = segmentor.run(wsis, patch_mode=False) References: [1] Amgad M, Elfandy H, ..., Gutman DA, Cooper LAD. Structured crowdsourcing enables convolutional segmentation of histology images. Bioinformatics 2019. doi: 10.1093/bioinformatics/btz083 """ def __init__( self: SemanticSegmentor, model: str | ModelABC, batch_size: int = 8, num_workers: int = 0, weights: str | Path | None = None, *, device: str = "cpu", verbose: bool = True, ) -> None: """Initialize :class:`SemanticSegmentor`. Args: model (str | ModelABC): A PyTorch model instance or name of a pretrained model from TIAToolbox. If a string is provided, the corresponding pretrained weights will be downloaded unless overridden via `weights`. batch_size (int): Number of image patches processed per forward pass. Default is 8. num_workers (int): Number of workers for data loading. Default is 0. weights (str | Path | None): Path to model weights. If None, default weights are used. device (str): Device to run the model on (e.g., "cpu", "cuda"). Default is "cpu". verbose (bool): Whether to enable verbose logging. Default is True. """ super().__init__( model=model, batch_size=batch_size, num_workers=num_workers, weights=weights, device=device, verbose=verbose, ) self.output_locations: list | None = None
[docs] def get_dataloader( self: SemanticSegmentor, images: str | Path | list[str | Path] | np.ndarray, masks: Path | None = None, labels: list | None = None, ioconfig: IOSegmentorConfig | None = None, *, patch_mode: bool = True, auto_get_mask: bool = True, ) -> torch.utils.data.DataLoader: """Pre-process images and masks and return a DataLoader for inference. This method prepares the dataset and returns a PyTorch DataLoader for either patch-based or WSI-based semantic segmentation. It overrides the base method to support additional WSI-specific logic, including patch output shape and output location tracking. Args: images (str | Path | list[str | Path] | np.ndarray): Input images. Can be a list of file paths or a NumPy array of image patches in NHWC format. masks (Path | None): Optional tissue masks for WSI processing. Only used when `patch_mode` is False. labels (list | None): Optional labels for input images. Only one label per image is supported. ioconfig (IOSegmentorConfig | None): IO configuration for patch extraction and resolution. patch_mode (bool): Whether to treat input as patches (`True`) or WSIs (`False`). auto_get_mask (bool): Whether to automatically generate a tissue mask using `wsireader.tissue_mask()` when `patch_mode` is False. If `True`, only tissue regions are processed. If `False`, all patches are processed. Default is `True`. Returns: torch.utils.data.DataLoader: A PyTorch DataLoader configured for semantic segmentation inference. """ # Overwrite when patch_mode is False. if not patch_mode: dataset = WSIPatchDataset( input_img=images, mask_path=masks, patch_input_shape=ioconfig.patch_input_shape, patch_output_shape=ioconfig.patch_output_shape, stride_shape=ioconfig.stride_shape, resolution=ioconfig.input_resolutions[0]["resolution"], units=ioconfig.input_resolutions[0]["units"], auto_get_mask=auto_get_mask, ) dataset.preproc_func = self._get_model_attr("preproc_func") self.output_locations = dataset.outputs # preprocessing must be defined with the dataset return torch.utils.data.DataLoader( dataset, num_workers=self.num_workers, batch_size=self.batch_size, drop_last=False, shuffle=False, ) return super().get_dataloader( images=images, masks=masks, labels=labels, ioconfig=ioconfig, patch_mode=patch_mode, )
[docs] def infer_wsi( self: SemanticSegmentor, dataloader: DataLoader, save_path: Path, **kwargs: Unpack[SemanticSegmentorRunParams], ) -> dict[str, da.Array]: """Perform model inference on a whole slide image (WSI). This method processes a WSI using the provided DataLoader, merges patch-level predictions into a full-resolution canvas, and returns the aggregated output. It supports memory-aware caching and optional inclusion of coordinates and labels. Args: dataloader (DataLoader): PyTorch DataLoader configured for WSI processing. save_path (Path): Path to save the intermediate output. The intermediate output is saved in a Zarr file. **kwargs (SemanticSegmentorRunParams): Additional runtime parameters to configure segmentation. Optional Keys: auto_get_mask (bool): Automatically generate segmentation masks using `wsireader.tissue_mask()` during processing. batch_size (int): Number of image patches per forward pass. class_dict (dict): Mapping of classification outputs to class names. device (str): Device to run the model on (e.g., "cpu", "cuda"). labels (list): Optional labels for input images. Only a single label per image is supported. memory_threshold (int): Memory usage threshold (percentage) to trigger caching behavior. num_workers (int): Number of workers for DataLoader and post-processing. output_file (str): Filename for saving output (e.g., ".zarr" or ".db"). output_resolutions (Resolution): Resolution used for writing output predictions. patch_output_shape (tuple[int, int]): Shape of output patches (height, width). return_labels (bool): Whether to return labels with predictions. return_probabilities (bool): Whether to return per-class probabilities. scale_factor (tuple[float, float]): Scale factor for annotations (model_mpp / slide_mpp). Used to convert coordinates to baseline resolution. stride_shape (tuple[int, int]): Stride used during WSI processing. Defaults to `patch_input_shape` if not provided. verbose (bool): Whether to enable verbose logging. Returns: dict[str, dask.array.Array]: Dictionary containing merged prediction results: - "probabilities": Full-resolution probability map. - "coordinates": Patch coordinates. - "labels": Ground truth labels (if `return_labels` is True). """ # Default Memory threshold percentage is 80. memory_threshold = kwargs.get("memory_threshold", 80) keys = ["probabilities", "coordinates"] coordinates = [] # Main output dictionary raw_predictions = dict( zip(keys, [da.empty(shape=(0, 0))] * len(keys), strict=False) ) # Inference loop tqdm_loop = tqdm( dataloader, leave=False, desc="Inferring patches", disable=not self.verbose, ) canvas_np, output_locs_y_ = None, None canvas, count, output_locs = None, None, None canvas_zarr, count_zarr = None, None full_output_locs = ( dataloader.dataset.full_outputs if hasattr(dataloader.dataset, "full_outputs") else dataloader.dataset.outputs ) infer_batch = self._get_model_attr("infer_batch") for batch_idx, batch_data in enumerate(tqdm_loop): batch_output = infer_batch( self.model, batch_data["image"], device=self.device, ) batch_locs = batch_data["output_locs"].numpy() # Interpolate outputs for masked regions full_batch_output, full_output_locs, output_locs = prepare_full_batch( batch_output, batch_locs, full_output_locs, output_locs, canvas_np=canvas_np, save_path=save_path.with_name("full_batch_tmp"), memory_threshold=memory_threshold, is_last=(batch_idx == (len(dataloader) - 1)), ) canvas_np = concatenate_none(old_arr=canvas_np, new_arr=full_batch_output) # Determine if dataloader is moved to next row of patches change_indices = np.where(np.diff(output_locs[:, 1]) != 0)[0] + 1 # If a row of patches has been processed. if change_indices.size > 0: canvas, count, canvas_np, output_locs, output_locs_y_ = ( merge_horizontal( canvas, count, output_locs_y_, canvas_np, output_locs, change_indices, ) ) vm = psutil.virtual_memory() used_percent = vm.percent # Use currently available memory (not the initial snapshot) to # decide when to spill intermediate results. canvas_used_percent = (canvas.nbytes / max(vm.available, 1)) * 100 if ( used_percent > memory_threshold or canvas_used_percent > memory_threshold ): used_percent = ( canvas_used_percent if (canvas_used_percent > memory_threshold) else used_percent ) msg = ( f"Current Memory usage: {used_percent} % " f"exceeds specified threshold: {memory_threshold}. " f"Saving intermediate results to disk." ) update_tqdm_desc(tqdm_loop=tqdm_loop, desc=msg) # Flush data in Memory and clear dask graph canvas_zarr, count_zarr = save_to_cache( canvas, count, canvas_zarr, count_zarr, save_path=save_path, verbose=self.verbose, ) canvas, count = None, None gc.collect() update_tqdm_desc(tqdm_loop=tqdm_loop, desc="Inferring patches") coordinates.append( da.from_array( self._get_coordinates(batch_data), ) ) canvas, count, _, _, output_locs_y_ = merge_horizontal( canvas, count, output_locs_y_, canvas_np, output_locs, change_indices=[len(output_locs)], ) zarr_group = None if canvas_zarr is not None: canvas_zarr, count_zarr = save_to_cache( canvas, count, canvas_zarr, count_zarr, verbose=self.verbose, ) # Wrap zarr in dask array canvas = da.from_zarr(canvas_zarr, chunks=canvas_zarr.chunks) count = da.from_zarr(count_zarr, chunks=count_zarr.chunks) zarr_group = zarr.open(canvas_zarr.store.path, mode="a") output_shape = get_wsi_output_shape(dataloader.dataset) # Final vertical merge raw_predictions["probabilities"] = merge_vertical_chunkwise( canvas, count, output_locs_y_, zarr_group, save_path, memory_threshold, output_shape=output_shape, verbose=self.verbose, ) raw_predictions["coordinates"] = da.concatenate(coordinates, axis=0) if save_path.with_name("full_batch_tmp").exists(): shutil.rmtree(save_path.with_name("full_batch_tmp")) return raw_predictions
[docs] def save_predictions( self: SemanticSegmentor, processed_predictions: dict, output_type: str, save_path: Path | None = None, **kwargs: Unpack[SemanticSegmentorRunParams], ) -> dict | AnnotationStore | Path | list[Path]: """Save semantic segmentation predictions to disk or return them in memory. This method saves predictions in one of the supported formats: - "dict": returns predictions as a Python dictionary. - "zarr": saves predictions as a Zarr group and returns the path. - "annotationstore": converts predictions to an AnnotationStore (.db file). If `patch_mode` is True, predictions are saved per image. If False, predictions are merged and saved as a single output. Args: processed_predictions (dict): Dictionary containing processed model predictions. output_type (str): Desired output format: "dict", "zarr", "qupath" or "annotationstore". save_path (Path | None): Path to save the output file. Required for "zarr", "qupath" and "annotationstore". **kwargs (SemanticSegmentorRunParams): Additional runtime parameters to configure segmentation. Optional Keys: auto_get_mask (bool): Automatically generate segmentation masks using `wsireader.tissue_mask()` during processing. batch_size (int): Number of image patches per forward pass. class_dict (dict): Mapping of classification outputs to class names. device (str): Device to run the model on (e.g., "cpu", "cuda"). labels (list): Optional labels for input images. Only a single label per image is supported. memory_threshold (int): Memory usage threshold (percentage) to trigger caching behavior. num_workers (int): Number of workers for DataLoader and post-processing. output_file (str): Filename for saving output (e.g., ".zarr" or ".db"). output_resolutions (Resolution): Resolution used for writing output predictions. patch_output_shape (tuple[int, int]): Shape of output patches (height, width). return_labels (bool): Whether to return labels with predictions. return_probabilities (bool): Whether to return per-class probabilities. scale_factor (tuple[float, float]): Scale factor for annotations (model_mpp / slide_mpp). Used to convert coordinates to baseline resolution. stride_shape (tuple[int, int]): Stride used during WSI processing. Defaults to `patch_input_shape` if not provided. verbose (bool): Whether to enable verbose logging. Returns: dict | AnnotationStore | Path | list[Path]: - If output_type is "dict": returns predictions as a dictionary. - If output_type is "zarr": returns path to saved Zarr file. - If output_type is "qupath": returns QuPath JSON or path or list of paths to .json file. - If output_type is "annotationstore": returns AnnotationStore or path or list of paths to .db file. """ # Conversion to annotationstore uses a different function for SemanticSegmentor if output_type.lower() not in ["qupath", "annotationstore"]: return super().save_predictions( processed_predictions, output_type, save_path=save_path, **kwargs ) return_probabilities = kwargs.get("return_probabilities", False) output_type_ = ( "zarr" if is_zarr(save_path.with_suffix(".zarr")) or return_probabilities else "dict" ) processed_predictions = super().save_predictions( processed_predictions, output_type=output_type_, save_path=save_path.with_suffix(".zarr"), **kwargs, ) if isinstance(processed_predictions, Path): processed_predictions = zarr.open(str(processed_predictions), mode="r") # scale_factor set from kwargs scale_factor = kwargs.get("scale_factor", (1.0, 1.0)) # class_dict set from kwargs class_dict = kwargs.get("class_dict", self._get_model_attr("class_dict")) # Need to add support for zarr conversion. save_paths = [] suffix = ".json" if output_type.lower() == "qupath" else ".db" msg = f"Saving predictions as {output_type} in {suffix} format." logger.info(msg) if self.patch_mode: for i, predictions in enumerate(processed_predictions["predictions"]): if isinstance(self.images[i], Path): output_path = save_path.parent / (self.images[i].stem + suffix) else: output_path = save_path.parent / (str(i) + suffix) out_file = dict_to_store_semantic_segmentor( patch_output={"predictions": predictions}, scale_factor=scale_factor, output_type=output_type, class_dict=class_dict, save_path=output_path, verbose=self.verbose, ) save_paths.append(out_file) else: out_file = dict_to_store_semantic_segmentor( patch_output=processed_predictions, scale_factor=scale_factor, output_type=output_type, class_dict=class_dict, save_path=save_path.with_suffix(suffix), verbose=self.verbose, ) save_paths = out_file if return_probabilities: msg = ( f"Probability maps cannot be saved as AnnotationStore or JSON. " f"To visualise heatmaps in TIAToolbox Visualization tool," f"convert heatmaps in {save_path} to ome.tiff using" f"tiatoolbox.utils.misc.write_probability_heatmap_as_ome_tiff." ) logger.info(msg) elif save_path.with_suffix(".zarr").exists(): shutil.rmtree(save_path.with_suffix(".zarr")) return save_paths
def _update_run_params( self: SemanticSegmentor, images: list[os.PathLike | Path | WSIReader] | np.ndarray, masks: list[os.PathLike | Path] | np.ndarray | None = None, input_resolutions: list[dict[Units, Resolution]] | None = None, patch_input_shape: tuple[int, int] | None = None, save_dir: os.PathLike | Path | None = None, ioconfig: IOSegmentorConfig | None = None, output_type: str = "dict", *, overwrite: bool = False, patch_mode: bool, **kwargs: Unpack[SemanticSegmentorRunParams], ) -> Path | None: """Update runtime parameters for the SemanticSegmentor engine. This method sets internal attributes such as caching, batch size, IO configuration, and output format based on user input and keyword arguments. It also configures whether to include probabilities in the output. Args: images (list[PathLike | WSIReader] | np.ndarray): Input images or patches. masks (list[PathLike] | np.ndarray | None): Optional masks for WSI processing. input_resolutions (list[dict[Units, Resolution]] | None): Resolution settings for input heads. Supported units are `level`, `power`, and `mpp`. Keys should be "units" and "resolution", e.g., [{"units": "mpp", "resolution": 0.25}]. See :class:`WSIReader` for details. patch_input_shape (IntPair | None): Shape of input patches (height, width), requested at read resolution. Must be positive. save_dir (PathLike | None): Directory to save output files. Required for WSI mode. ioconfig (ModelIOConfigABC | None): IO configuration for patch extraction and resolution. output_type (str): Desired output format: "dict", "zarr", "qupath", or "annotationstore". overwrite (bool): Whether to overwrite existing output files. Default is False. patch_mode (bool): Whether to treat input as patches (`True`) or WSIs (`False`). **kwargs (SemanticSegmentorRunParams): Additional runtime parameters to configure segmentation. Optional Keys: auto_get_mask (bool): Automatically generate segmentation masks using `wsireader.tissue_mask()` during processing. batch_size (int): Number of image patches per forward pass. class_dict (dict): Mapping of classification outputs to class names. device (str): Device to run the model on (e.g., "cpu", "cuda"). labels (list): Optional labels for input images. Only a single label per image is supported. memory_threshold (int): Memory usage threshold (percentage) to trigger caching behavior. num_workers (int): Number of workers for DataLoader and post-processing. output_file (str): Filename for saving output (e.g., ".zarr" or ".db"). output_resolutions (Resolution): Resolution used for writing output predictions. patch_output_shape (tuple[int, int]): Shape of output patches (height, width). return_labels (bool): Whether to return labels with predictions. return_probabilities (bool): Whether to return per-class probabilities. scale_factor (tuple[float, float]): Scale factor for annotations (model_mpp / slide_mpp). Used to convert coordinates to baseline resolution. stride_shape (tuple[int, int]): Stride used during WSI processing. Defaults to `patch_input_shape` if not provided. verbose (bool): Whether to enable verbose logging. Returns: Path | None: Path to the save directory if applicable, otherwise None. Raises: ValueError: If `labels` are requested for WSI processing. """ return_labels = kwargs.get("return_labels") if return_labels and not patch_mode: msg = "`return_labels` is not supported when `patch_mode` is False." raise ValueError(msg) return super()._update_run_params( images=images, masks=masks, input_resolutions=input_resolutions, patch_input_shape=patch_input_shape, save_dir=save_dir, ioconfig=ioconfig, overwrite=overwrite, patch_mode=patch_mode, output_type=output_type, **kwargs, )
[docs] def run( self: SemanticSegmentor, images: list[os.PathLike | Path | WSIReader] | np.ndarray, *, masks: list[os.PathLike | Path] | np.ndarray | None = None, input_resolutions: list[dict[Units, Resolution]] | None = None, patch_input_shape: IntPair | None = None, ioconfig: IOSegmentorConfig | None = None, patch_mode: bool = True, save_dir: os.PathLike | Path | None = None, overwrite: bool = False, output_type: str = "dict", **kwargs: Unpack[SemanticSegmentorRunParams], ) -> AnnotationStore | Path | str | dict | list[Path]: """Run the semantic segmentation engine on input images. This method orchestrates the full inference pipeline, including preprocessing, model inference, post-processing, and saving results. It supports both patch-level and whole slide image (WSI) modes. Args: images (list[PathLike | WSIReader] | np.ndarray): Input images or patches. Can be a list of file paths, WSIReader objects, or a NumPy array of image patches. masks (list[PathLike] | np.ndarray | None): Optional masks for WSI processing. Only used when `patch_mode` is False. input_resolutions (list[dict[Units, Resolution]] | None): Resolution settings for input heads. Supported units are `level`, `power`, and `mpp`. Keys should be "units" and "resolution", e.g., [{"units": "mpp", "resolution": 0.25}]. See :class:`WSIReader` for details. patch_input_shape (IntPair | None): Shape of input patches (height, width), requested at read resolution. Must be positive. ioconfig (IOSegmentorConfig | None): IO configuration for patch extraction and resolution. patch_mode (bool): Whether to treat input as patches (`True`) or WSIs (`False`). Default is True. save_dir (PathLike | None): Directory to save output files. Required for WSI mode. overwrite (bool): Whether to overwrite existing output files. Default is False. output_type (str): Desired output format: "dict", "zarr", "qupath", or "annotationstore". Default is "dict". **kwargs (SemanticSegmentorRunParams): Additional runtime parameters to configure segmentation. Optional Keys: auto_get_mask (bool): Automatically generate segmentation masks using `wsireader.tissue_mask()` during processing. batch_size (int): Number of image patches per forward pass. class_dict (dict): Mapping of classification outputs to class names. device (str): Device to run the model on (e.g., "cpu", "cuda"). labels (list): Optional labels for input images. Only a single label per image is supported. memory_threshold (int): Memory usage threshold (percentage) to trigger caching behavior. num_workers (int): Number of workers for DataLoader and post-processing. output_file (str): Filename for saving output (e.g., ".zarr" or ".db"). output_resolutions (Resolution): Resolution used for writing output predictions. patch_output_shape (tuple[int, int]): Shape of output patches (height, width). return_labels (bool): Whether to return labels with predictions. return_probabilities (bool): Whether to return per-class probabilities. scale_factor (tuple[float, float]): Scale factor for annotations (model_mpp / slide_mpp). Used to convert coordinates to baseline resolution. stride_shape (tuple[int, int]): Stride used during WSI processing. Defaults to `patch_input_shape` if not provided. verbose (bool): Whether to enable verbose logging. Returns: AnnotationStore | Path | str | dict | list[Path]: - If `patch_mode` is True: returns predictions or path to saved output. - If `patch_mode` is False: returns a dictionary mapping each WSI to its output path. Examples: >>> wsis = ['wsi1.svs', 'wsi2.svs'] >>> image_patches = [np.ndarray, np.ndarray] >>> segmentor = SemanticSegmentor(model="fcn-tissue_mask") >>> output = segmentor.run(image_patches, patch_mode=True) >>> output ... "/path/to/Output.db" >>> output = segmentor.run( ... image_patches, ... patch_mode=True, ... output_type="zarr" ... ) >>> output ... "/path/to/Output.zarr" >>> output = segmentor.run(wsis, patch_mode=False) >>> output.keys() ... ['wsi1.svs', 'wsi2.svs'] >>> output['wsi1.svs'] ... "/path/to/wsi1.db" """ return super().run( images=images, masks=masks, input_resolutions=input_resolutions, patch_input_shape=patch_input_shape, ioconfig=ioconfig, patch_mode=patch_mode, save_dir=save_dir, overwrite=overwrite, output_type=output_type, **kwargs, )
[docs] def concatenate_none( old_arr: np.ndarray | da.Array | None, new_arr: np.ndarray | da.Array, ) -> np.ndarray | da.Array: """Concatenate arrays, handling None values gracefully. This utility function concatenates `new_arr` to `old_arr` along the first axis. If `old_arr` is None, it returns `new_arr` directly. Supports both NumPy and Dask arrays. Args: old_arr (np.ndarray | da.Array | None): Existing array to append to. Can be None. new_arr (np.ndarray | da.Array): New array to append. Returns: np.ndarray | da.Array: Concatenated array of the same type as `new_arr`. """ if isinstance(new_arr, np.ndarray): return ( new_arr if old_arr is None else np.concatenate((old_arr, new_arr), axis=0) ) return new_arr if old_arr is None else da.concatenate([old_arr, new_arr], axis=0)
[docs] def merge_batch_to_canvas( blocks: np.ndarray, output_locations: np.ndarray, merged_shape: tuple[int, int, int], ) -> tuple[np.ndarray, np.ndarray]: """Merge patch-level predictions into a single canvas. This function aggregates overlapping patch predictions into a unified output canvas and maintains a count map to normalize overlapping regions. Args: blocks (np.ndarray): Array of predicted blocks with shape (N, H, W, C), where N is the number of patches. output_locations (np.ndarray): Array of coordinates for each block in the format [start_x, start_y, end_x, end_y] with shape (N, 4). merged_shape (tuple[int, int, int]): Shape of the final merged canvas (H, W, C). Returns: tuple[np.ndarray, np.ndarray]: - canvas: Merged prediction map of shape (H, W, C). - count: Count map indicating how many times each pixel was updated, shape (H, W). """ # Ensure we operate on NumPy to avoid Dask out-parameter issues when merging. if not isinstance(blocks, np.ndarray): blocks = np.asarray(blocks) if not isinstance(output_locations, np.ndarray): output_locations = np.asarray(output_locations) canvas = np.zeros(merged_shape, dtype=blocks.dtype) count = np.zeros((*merged_shape[:2], 1), dtype=np.uint8) for i, block in enumerate(blocks): xs, ys, xe, ye = output_locations[i] if not np.any(block): continue # To deal with edge cases canvas[0 : ye - ys, xs:xe, :] += block[0 : ye - ys, 0 : xe - xs, :] count[0 : ye - ys, xs:xe, 0] += 1 return canvas, count
[docs] def merge_horizontal( canvas: None | da.Array, count: None | da.Array, output_locs_y: np.ndarray, canvas_np: np.ndarray, output_locs: np.ndarray, change_indices: np.ndarray | list[int], ) -> tuple[da.Array, da.Array, np.ndarray, np.ndarray, np.ndarray]: """Merge horizontal patches incrementally for each row of patches. This function processes segments of NumPy patch arrays (`canvas_np`, `count_np`, `output_locs`) based on `change_indices`, merging them horizontally and appending the results to Dask arrays. It also updates the vertical output locations (`output_locs_y_`) for downstream vertical merging. Args: canvas (None | da.Array): Existing Dask array for canvas data, or None if uninitialized. count (None | da.Array): Existing Dask array for count data, or None if uninitialized. output_locs_y (np.ndarray): Array tracking vertical output locations for merged patches. canvas_np (np.ndarray): NumPy array of canvas patches to be merged. output_locs (np.ndarray): Array of output locations for each patch. change_indices (np.ndarray | list[np.ndarray]): Indices indicating where to flush and merge patches. Returns: tuple: Updated canvas and count Dask arrays, along with remaining canvas_np, count_np, output_locs, and output_locs_y_ arrays after processing. """ start_idx = 0 for c_idx in change_indices: output_locs_ = output_locs[: c_idx - start_idx] canvas_np_ = canvas_np[: c_idx - start_idx] # Compute span only for the current row to avoid allocating a canvas # covering the entire slide width. batch_xs = np.min(output_locs_[:, 0], axis=0) batch_xe = np.max(output_locs_[:, 2], axis=0) merged_shape = (canvas_np_.shape[1], batch_xe - batch_xs, canvas_np.shape[3]) canvas_merge, count_merge = merge_batch_to_canvas( blocks=canvas_np_, output_locations=output_locs_, merged_shape=merged_shape, ) canvas_merge = da.from_array(canvas_merge, chunks=canvas_merge.shape) count_merge = da.from_array(count_merge, chunks=count_merge.shape) canvas = concatenate_none(old_arr=canvas, new_arr=canvas_merge) count = concatenate_none(old_arr=count, new_arr=count_merge) output_locs_y = concatenate_none( old_arr=output_locs_y, new_arr=output_locs_[:, (1, 3)] ) canvas_np = canvas_np[c_idx - start_idx :] output_locs = output_locs[c_idx - start_idx :] start_idx = c_idx return canvas, count, canvas_np, output_locs, output_locs_y
[docs] def save_to_cache( canvas: da.Array, count: da.Array, canvas_zarr: zarr.Array, count_zarr: zarr.Array, save_path: str | Path = "temp.zarr", zarr_dataset_name: tuple[str, str] = ("canvas", "count"), *, verbose: bool = True, ) -> tuple[zarr.Array, zarr.Array]: """Incrementally save computed canvas and count arrays to Zarr cache. This function computes the given Dask arrays (`canvas` and `count`) row-chunks one at a time to avoid materializing the full dask arrays in memory. If the datasets do not exist, they are created using the chunk shapes from the first block. Args: canvas (da.Array): Dask array representing image or feature data. count (da.Array): Dask array representing count or normalization data. canvas_zarr (zarr.Array): Existing Zarr dataset for canvas data. If None, a new one is created. count_zarr (zarr.Array): Existing Zarr dataset for count data. If None, a new one is created. save_path (str | Path): Path to the Zarr group for saving datasets. Defaults to "temp.zarr". zarr_dataset_name (tuple[str, str]): Tuple of name for zarr dataset to save canvas and count. Defaults to ("canvas", "count"). verbose (bool): Whether to display progress bar. Returns: tuple[zarr.Array, zarr.Array]: Updated Zarr datasets for canvas and count arrays. """ chunk0 = canvas.chunks[0][0] if canvas_zarr is None: zarr_group = zarr.open(str(save_path), mode="a") # Peek first block shapes to initialise datasets without computing all rows. # Blocks are 3D: (row_chunk, col_chunk, channel_chunk). Grab the first. first_canvas_block = canvas.blocks[0, 0, 0].compute() first_count_block = count.blocks[0, 0, 0].compute() canvas_zarr = zarr_group.create_dataset( name=zarr_dataset_name[0], # Append along axis 0 (height); keep width/channels fixed. shape=(0, *first_canvas_block.shape[1:]), chunks=(chunk0, *first_canvas_block.shape[1:]), dtype=first_canvas_block.dtype, overwrite=True, ) count_zarr = zarr_group.create_dataset( name=zarr_dataset_name[1], shape=(0, *first_count_block.shape[1:]), dtype=first_count_block.dtype, chunks=(chunk0, *first_count_block.shape[1:]), overwrite=True, ) # We already computed the first block; store it and start from the next. canvas_zarr.resize((first_canvas_block.shape[0], *canvas_zarr.shape[1:])) canvas_zarr[-first_canvas_block.shape[0] :] = first_canvas_block count_zarr.resize((first_count_block.shape[0], *count_zarr.shape[1:])) count_zarr[-first_count_block.shape[0] :] = first_count_block start_idx = 1 else: start_idx = 0 # Append remaining blocks one-at-a-time to limit peak memory. num_blocks = canvas.numblocks[0] tqdm_loop = tqdm( range(start_idx, num_blocks), leave=False, desc="Memory Overload, Spilling to disk", disable=not verbose, ) for block_idx in tqdm_loop: canvas_block = canvas.blocks[block_idx, 0, 0].compute() count_block = count.blocks[block_idx, 0, 0].compute() canvas_zarr.resize( (canvas_zarr.shape[0] + canvas_block.shape[0], *canvas_zarr.shape[1:]) ) canvas_zarr[-canvas_block.shape[0] :] = canvas_block count_zarr.resize( (count_zarr.shape[0] + count_block.shape[0], *count_zarr.shape[1:]) ) count_zarr[-count_block.shape[0] :] = count_block return canvas_zarr, count_zarr
[docs] def get_wsi_output_shape(dataset: object) -> tuple[int, int] | None: """Return WSI output shape as (height, width) for the dataset if available.""" wsi_shape = getattr(dataset, "wsi_shape", None) if wsi_shape is None: has_meta = all( hasattr(dataset, attr) for attr in ("img_path", "resolution", "units") ) if has_meta: try: reader = getattr(dataset, "reader", None) if reader is None: reader = WSIReader.open(dataset.img_path) wsi_shape = reader.slide_dimensions( resolution=dataset.resolution, units=dataset.units ) except (AttributeError, OSError, TypeError, ValueError): msg = "WSI output shape is not recognizable. Please verify outputs." logger.info(msg) return None else: msg = "No metadata found in dataset. Please verify outputs." logger.warning(msg) return None return int(wsi_shape[1]), int(wsi_shape[0])
[docs] def merge_vertical_chunkwise( canvas: da.Array, count: da.Array, output_locs_y_: np.ndarray, zarr_group: zarr.Group, save_path: Path, memory_threshold: int = 80, output_shape: tuple[int, int] | None = None, *, verbose: bool = True, ) -> da.Array: """Merge vertically chunked canvas and count arrays into a single probability map. This function processes vertically stacked image blocks (`canvas`) and their associated count arrays to compute normalized probabilities. It handles overlapping regions between chunks by applying seam folding and trimming halos to ensure smooth transitions. If a Zarr group is provided, the result is stored incrementally. Args: canvas (da.Array): Dask array containing image data split into vertical chunks. count (da.Array): Dask array containing count data corresponding to the canvas. output_locs_y_ (np.ndarray): Array of shape (N, 2) specifying vertical output locations for each chunk, used to compute overlaps. zarr_group (zarr.Group): Zarr group to store the merged probability dataset. save_path (Path): Path to save the intermediate output. The intermediate output is saved in a Zarr file. memory_threshold (int): Memory usage threshold (in percentage) to trigger caching behavior. output_shape (tuple[int, int] | None): Optional target output shape as (height, width). If provided, merged probabilities are clipped to this shape before being accumulated or written to Zarr. verbose (bool): Whether to display progress bar. Returns: da.Array: A merged Dask array of normalized probabilities, either loaded from Zarr or constructed in memory. """ y0s, y1s = np.unique(output_locs_y_[:, 0]), np.unique(output_locs_y_[:, 1]) overlaps = np.append(y1s[:-1] - y0s[1:], 0) num_chunks = canvas.numblocks[0] probabilities_zarr, probabilities_da = None, None chunk_shape = tuple(chunk[0] for chunk in canvas.chunks) written_height = 0 tqdm_loop = tqdm( overlaps, leave=False, desc="Merging rows", disable=not verbose, ) used_percent = 0 curr_chunk = canvas.blocks[0, 0].compute() curr_count = count.blocks[0, 0].compute() next_chunk = canvas.blocks[1, 0].compute() if num_chunks > 1 else None next_count = count.blocks[1, 0].compute() if num_chunks > 1 else None probabilities = np.empty(0) for i, overlap in enumerate(tqdm_loop): if next_chunk is not None and overlap > 0: curr_chunk[-overlap:] += next_chunk[:overlap] curr_count[-overlap:] += next_count[:overlap] # Normalize curr_count = np.where(curr_count == 0, 1, curr_count) probabilities = curr_chunk / curr_count.astype(np.float32) probabilities, written_height, should_stop = clip_probabilities_to_shape( probabilities=probabilities, output_shape=output_shape, written_height=written_height, ) if should_stop: break probabilities_zarr, probabilities_da = store_probabilities( probabilities=probabilities, chunk_shape=chunk_shape, probabilities_zarr=probabilities_zarr, probabilities_da=probabilities_da, zarr_group=zarr_group, ) if probabilities_da is not None: vm = psutil.virtual_memory() used_percent = (probabilities_da.nbytes / vm.free) * 100 if probabilities_zarr is None and used_percent > memory_threshold: desc = tqdm_loop.desc if hasattr(tqdm_loop, "desc") else "" msg = ( f"Current Memory usage: {used_percent} % " f"exceeds specified threshold: {memory_threshold}. " f"Saving intermediate results to disk." ) update_tqdm_desc(tqdm_loop=tqdm_loop, desc=msg) zarr_group = zarr.open(str(save_path), mode="a") probabilities_zarr = zarr_group.create_dataset( name="probabilities", shape=probabilities_da.shape, chunks=(chunk_shape[0], *probabilities.shape[1:]), dtype=probabilities.dtype, overwrite=True, ) probabilities_zarr[:] = probabilities_da.compute() probabilities_da = None update_tqdm_desc(tqdm_loop=tqdm_loop, desc=desc) if next_chunk is not None: curr_chunk, curr_count = next_chunk[overlap:], next_count[overlap:] if i + 2 < num_chunks: next_chunk = canvas.blocks[i + 2, 0].compute() next_count = count.blocks[i + 2, 0].compute() else: next_chunk, next_count = None, None if probabilities_zarr: return _get_probabilities_da_from_zarr( zarr_group=zarr_group, probabilities_zarr=probabilities_zarr, chunk_shape=chunk_shape, probabilities=probabilities, ) return probabilities_da
[docs] def clip_probabilities_to_shape( probabilities: np.ndarray, output_shape: tuple[int, int] | None, written_height: int, ) -> tuple[np.ndarray, int, bool]: """Clip probability chunk to target output shape and track written height.""" if output_shape is None: return probabilities, written_height, False target_height, target_width = map(int, output_shape) remaining_height = target_height - written_height if remaining_height <= 0: return probabilities[:0], written_height, True clipped = probabilities[:remaining_height, :target_width, ...] if clipped.shape[0] == 0: return clipped, written_height, True return clipped, written_height + clipped.shape[0], False
def _get_probabilities_da_from_zarr( zarr_group: zarr.Group, probabilities_zarr: zarr.Array, chunk_shape: tuple, probabilities: zarr.Array | np.ndarray, ) -> da.Array: """Helper function to return dask array after probabilities have been merged.""" if "canvas" in zarr_group: del zarr_group["canvas"] if "count" in zarr_group: del zarr_group["count"] return da.from_zarr( probabilities_zarr, chunks=(chunk_shape[0], *probabilities.shape[1:]) )
[docs] def store_probabilities( probabilities: np.ndarray, chunk_shape: tuple[int, ...], probabilities_zarr: zarr.Array | None, probabilities_da: da.Array | None, zarr_group: zarr.Group | None, name: str = "probabilities", ) -> tuple[zarr.Array | None, da.Array | None]: """Store computed probability data into a Zarr dataset or accumulate in memory. If a Zarr group is provided, the function appends the given probability array to the 'probabilities' dataset, resizing as needed. Otherwise, it concatenates the array into an existing Dask array for in-memory accumulation. Args: probabilities (np.ndarray): Computed probability array to store. chunk_shape (tuple[int, ...]): Chunk shape used for Zarr dataset creation. probabilities_zarr (zarr.Array | None): Existing Zarr dataset, or None to initialize. probabilities_da (da.Array | None): Existing Dask array for in-memory accumulation. zarr_group (zarr.Group | None): Zarr group used to create or access the dataset. name (str): Name to create Zarr dataset. Returns: tuple[zarr.Array | None, da.Array | None]: Updated Zarr dataset and/or Dask array. """ if zarr_group is not None: if probabilities_zarr is None: probabilities_zarr = zarr_group.create_dataset( name=name, shape=(0, *probabilities.shape[1:]), chunks=(chunk_shape[0], *probabilities.shape[1:]), dtype=probabilities.dtype, ) probabilities_zarr.resize( ( probabilities_zarr.shape[0] + probabilities.shape[0], *probabilities_zarr.shape[1:], ) ) probabilities_zarr[-probabilities.shape[0] :] = probabilities else: probabilities_da = concatenate_none( old_arr=probabilities_da, new_arr=da.from_array( probabilities, chunks=(chunk_shape[0], *probabilities.shape[1:]) ), ) return probabilities_zarr, probabilities_da
[docs] def prepare_full_batch( batch_output: np.ndarray, batch_locs: np.ndarray, full_output_locs: np.ndarray, output_locs: np.ndarray, canvas_np: np.ndarray | zarr.Array | None = None, save_path: Path | str = "temp_fullbatch", memory_threshold: int = 80, *, is_last: bool, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """Prepare full-sized output and count arrays for a batch of patch predictions. This function aligns patch-level predictions with global output locations when a mask (e.g., auto_get_mask) is applied. It initializes full-sized arrays and fills them using matched indices. If the batch is the last in the sequence, it pads the arrays to cover remaining locations. Args: batch_output (np.ndarray): Patch-level model predictions of shape (N, H, W, C). batch_locs (np.ndarray): Output locations corresponding to `batch_output`. full_output_locs (np.ndarray): Remaining global output locations to be matched. output_locs (np.ndarray): Accumulated output location array across batches. canvas_np (np.ndarray | zarr.Array | None): Accumulated canvas array from previous batches. Used to check total memory footprint when deciding numpy vs zarr. save_path (Path | str): Path to a directory; a unique temp subfolder will be created within it to store the temporary full-batch zarr for this batch. memory_threshold (int): Memory usage threshold (in percentage) to trigger caching behavior. is_last (bool): Flag indicating whether this is the final batch. Returns: tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: - full_batch_output: Full-sized output array with predictions placed. - full_output_locs: Updated remaining global output locations. - output_locs: Updated accumulated output locations. """ # Map batch locations back to indices in the full output grid. # Use a dict to avoid allocating a huge dense array when locations are sparse. full_output_dict = {tuple(row): i for i, row in enumerate(full_output_locs)} matches = np.array([full_output_dict[tuple(row)] for row in batch_locs]) total_size = int(np.max(matches).astype(np.uint32)) + 1 sample_shape = batch_output.shape[1:] # Calculate final size including potential padding final_size = total_size if is_last and len(full_output_locs): final_size += len(full_output_locs) # Check if array will fit in available memory # Consider BOTH: new array size AND accumulated canvas_np size array_bytes = final_size * np.prod(sample_shape) * batch_output.dtype.itemsize canvas_bytes = canvas_np.nbytes if canvas_np is not None else 0 total_bytes = array_bytes + canvas_bytes vm = psutil.virtual_memory() # During concatenation, we temporarily need: # - existing canvas_np (canvas_bytes) # - new full_batch_output (array_bytes) # - concatenated result (canvas_bytes + array_bytes) # Total peak = 2 * (canvas_bytes + array_bytes) peak_bytes = 2 * total_bytes memory_available = vm.available * (memory_threshold / 100) use_numpy = peak_bytes < memory_available if use_numpy: # Array fits safely in RAM, use numpy for better performance full_batch_output = np.zeros( shape=(final_size, *sample_shape), dtype=batch_output.dtype, ) else: save_path_dir = Path(save_path) save_path_dir.mkdir(parents=True, exist_ok=True) temp_dir = Path( tempfile.mkdtemp(prefix="full_batch_tmp_", dir=str(save_path_dir)) ) store = zarr.DirectoryStore(str(temp_dir)) full_batch_output = zarr.zeros( shape=(total_size, *sample_shape), chunks=(len(batch_output), *sample_shape), dtype=batch_output.dtype, store=store, overwrite=True, ) # Place matching outputs using matching indices full_batch_output[matches] = batch_output output_locs = concatenate_none( old_arr=output_locs, new_arr=full_output_locs[:total_size] ) full_output_locs = full_output_locs[total_size:] if is_last and len(full_output_locs): pad_len = len(full_output_locs) if not use_numpy: # Resize zarr array to accommodate padding full_batch_output.resize(total_size + pad_len, *sample_shape) # For numpy, array is already pre-allocated to final_size full_batch_output[-pad_len:] = 0 output_locs = concatenate_none(old_arr=output_locs, new_arr=full_output_locs) full_output_locs = np.empty( (0, batch_locs.shape[1]), dtype=full_output_locs.dtype ) return full_batch_output, full_output_locs, output_locs