"""Define utility layers and operators for models in tiatoolbox."""
from __future__ import annotations
import sys
from typing import TYPE_CHECKING, cast
import numpy as np
import torch
from scipy import ndimage
from skimage.feature import peak_local_max
from skimage.measure import label, regionprops
from torch import nn
from tiatoolbox import logger
if TYPE_CHECKING: # pragma: no cover
from tiatoolbox.models.models_abc import ModelABC
[docs]
def is_torch_compile_compatible() -> bool:
"""Check if the current GPU is compatible with torch-compile.
Returns:
True if current GPU is compatible with torch-compile, False otherwise.
Raises:
Warning if GPU is not compatible with `torch.compile`.
"""
gpu_compatibility = True
if torch.cuda.is_available(): # pragma: no cover
device_cap = torch.cuda.get_device_capability()
if device_cap not in ((7, 0), (8, 0), (9, 0)):
logger.warning(
"GPU is not compatible with torch.compile. "
"Compatible GPUs include NVIDIA V100, A100, and H100. "
"Speedup numbers may be lower than expected.",
stacklevel=2,
)
gpu_compatibility = False
else:
logger.warning(
"No GPU detected or cuda not installed, "
"torch.compile is only supported on selected NVIDIA GPUs. "
"Speedup numbers may be lower than expected.",
stacklevel=2,
)
gpu_compatibility = False
return gpu_compatibility
[docs]
def compile_model(
model: nn.Module | ModelABC | None = None,
*,
mode: str = "default",
) -> torch.nn.Module | ModelABC:
"""A decorator to compile a model using torch-compile.
Args:
model (torch.nn.Module):
Model to be compiled.
mode (str):
Mode to be used for torch-compile. Available modes are:
- `disable` disables torch-compile
- `default` balances performance and overhead
- `reduce-overhead` reduces overhead of CUDA graphs (useful for small
batches)
- `max-autotune` leverages Triton/template based matrix multiplications
on GPUs
- `max-autotune-no-cudagraphs` similar to “max-autotune” but without
CUDA graphs
Returns:
torch.nn.Module or ModelABC:
Compiled model.
"""
if mode == "disable":
return model
# Check if GPU is compatible with torch.compile
gpu_compatibility = is_torch_compile_compatible()
if not gpu_compatibility:
return model
if sys.platform == "win32": # pragma: no cover
msg = (
"`torch.compile` is not supported on Windows. Please see "
"https://github.com/pytorch/pytorch/issues/122094."
)
logger.warning(msg=msg)
return model
if isinstance( # pragma: no cover
model,
torch._dynamo.eval_frame.OptimizedModule, # skipcq: PYL-W0212 # noqa: SLF001
):
logger.info(
("The model is already compiled. ",),
)
return model
return cast("nn.Module", torch.compile(model, mode=mode)) # pragma: no cover
[docs]
def centre_crop(
img: np.ndarray | torch.Tensor,
crop_shape: np.ndarray | torch.Tensor | tuple[int, int],
data_format: str = "NCHW",
) -> np.ndarray | torch.Tensor:
"""A function to center crop image with given crop shape.
Args:
img (:class:`numpy.ndarray`, torch.Tensor):
Input image, should be of 3 channels.
crop_shape (:class:`numpy.ndarray`, torch.Tensor):
The subtracted amount in the form of `[subtracted height,
subtracted width]`.
data_format (str):
Either `"NCHW"` or `"NHWC"`.
Returns:
(:class:`numpy.ndarray`, torch.Tensor):
Cropped image.
"""
if data_format not in ["NCHW", "NHWC"]:
msg = f"Unknown input format `{data_format}`."
raise ValueError(msg)
crop_t: int = int(crop_shape[0] // 2)
crop_b: int = int(crop_shape[0] - crop_t)
crop_l: int = int(crop_shape[1] // 2)
crop_r: int = int(crop_shape[1] - crop_l)
if data_format == "NCHW":
return img[:, :, crop_t:-crop_b, crop_l:-crop_r]
return img[:, crop_t:-crop_b, crop_l:-crop_r, :]
[docs]
def centre_crop_to_shape(
x: np.ndarray | torch.Tensor,
y: np.ndarray | torch.Tensor,
data_format: str = "NCHW",
) -> np.ndarray | torch.Tensor:
"""A function to center crop image to shape.
Centre crop `x` so that `x` has shape of `y` and `y` height and
width must be smaller than `x` height width.
Args:
x (:class:`numpy.ndarray`, torch.Tensor):
Image to be cropped.
y (:class:`numpy.ndarray`, torch.Tensor):
Reference image for getting cropping shape, should be of 3
channels.
data_format:
Either `"NCHW"` or `"NHWC"`.
Returns:
(:class:`numpy.ndarray`, torch.Tensor):
Cropped image.
"""
if data_format not in ["NCHW", "NHWC"]:
msg = f"Unknown input format `{data_format}`."
raise ValueError(msg)
if data_format == "NCHW":
_, _, h1, w1 = x.shape
_, _, h2, w2 = y.shape
else:
_, h1, w1, _ = x.shape
_, h2, w2, _ = y.shape
if h1 <= h2 or w1 <= w2:
raise ValueError(
(
"Height or width of `x` is smaller than `y` ",
f"{[h1, w1]} vs {[h2, w2]}",
),
)
x_shape = x.shape
y_shape = y.shape
if data_format == "NCHW":
crop_shape = (x_shape[2] - y_shape[2], x_shape[3] - y_shape[3])
else:
crop_shape = (x_shape[1] - y_shape[1], x_shape[2] - y_shape[2])
return centre_crop(x, crop_shape, data_format)
[docs]
class UpSample2x(nn.Module):
"""A layer to scale input by a factor of 2.
This layer uses Kronecker product underneath rather than the default
pytorch interpolation.
"""
def __init__(self: UpSample2x) -> None:
"""Initialize :class:`UpSample2x`."""
super().__init__()
# correct way to create constant within module
self.unpool_mat: torch.Tensor
self.register_buffer(
"unpool_mat",
torch.from_numpy(np.ones((2, 2), dtype="float32")),
)
self.unpool_mat.unsqueeze(0)
[docs]
def forward(self: UpSample2x, x: torch.Tensor) -> torch.Tensor:
"""Logic for using layers defined in init.
Args:
x (torch.Tensor):
Input images, the tensor is in the shape of NCHW.
Returns:
torch.Tensor:
Input images upsampled by a factor of 2 via nearest
neighbour interpolation. The tensor is the shape as
NCHW.
"""
input_shape = list(x.shape)
# un-squeeze is the same as expand_dims
# permute is the same as transpose
# view is the same as reshape
x = x.unsqueeze(-1) # bchwx1
mat = self.unpool_mat.unsqueeze(0) # 1xshxsw
ret = torch.tensordot(x, mat, dims=1) # bxcxhxwxshxsw
ret = ret.permute(0, 1, 2, 4, 3, 5)
return ret.reshape((-1, input_shape[1], input_shape[2] * 2, input_shape[3] * 2))
[docs]
class SegmentationHead(nn.Sequential):
"""Segmentation head for UNet++ architecture.
This class defines the final segmentation layer for the UNet++ model.
It applies a convolution followed by optional upsampling and activation
to produce the segmentation output.
Attributes:
conv2d (nn.Conv2d):
Convolutional layer for feature transformation.
upsampling_layer (nn.Module):
Upsampling layer (bilinear interpolation or identity).
activation (nn.Module):
Activation function applied after upsampling.
Example:
>>> head = SegmentationHead(in_channels=64, out_channels=2)
>>> x = torch.randn(1, 64, 128, 128)
>>> output = head(x)
>>> output.shape
... torch.Size([1, 2, 128, 128])
"""
def __init__(
self: SegmentationHead,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
activation: nn.Module | None = None,
upsampling: int = 1,
) -> None:
"""Initialize the SegmentationHead module.
This method sets up the segmentation head by creating a convolutional layer,
an optional upsampling layer, and an activation function. It is typically
used as the final stage in UNet++ architectures for semantic segmentation.
Args:
in_channels (int):
Number of input channels to the segmentation head.
out_channels (int):
Number of output channels (usually equal to the number of classes).
kernel_size (int):
Size of the convolution kernel. Defaults to 3.
activation (nn.Module | None):
Activation function applied after convolution. Defaults to None.
upsampling (int):
Upsampling factor applied to the output. Defaults to 1.
"""
conv2d = nn.Conv2d(
in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2
)
upsampling_layer = (
nn.UpsamplingBilinear2d(scale_factor=upsampling)
if upsampling > 1
else nn.Identity()
)
if activation is None:
activation = nn.Identity()
super().__init__(conv2d, upsampling_layer, activation)
[docs]
class AttentionModule(nn.Module):
"""Attention module to apply attention mechanism on feature maps."""
def __init__(self, name: str | None, in_channels: int, reduction: int = 16) -> None:
"""Initialize the Attention module.
Args:
name (str | None):
Name of the attention mechanism.
Only "scse" is implemented. If None, identity is used.
in_channels (int):
Number of input channels.
reduction (int):
Reduction ratio for channel attention.
"""
super().__init__()
if name is None:
self.attention = nn.Identity()
elif name == "scse":
self.attention = SCSEModule(in_channels=in_channels, reduction=reduction)
else:
msg = f"Attention {name} is not implemented"
raise ValueError(msg)
[docs]
def forward(self: AttentionModule, x: torch.Tensor) -> torch.Tensor:
"""Forward pass of the Attention module.
Args:
x (torch.Tensor):
Input feature map of shape (N, C, H, W).
Returns:
torch.Tensor:
Output feature map after applying attention.
"""
return self.attention(x)
[docs]
class SCSEModule(nn.Module):
"""Spatial and Channel Squeeze & Excitation (SCSE) module."""
def __init__(self, in_channels: int, reduction: int = 16) -> None:
"""Initialize the SCSE module.
Args:
in_channels (int):
Number of input channels.
reduction (int):
Reduction ratio for channel attention.
"""
super().__init__()
self.cSE = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, in_channels // reduction, 1),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels // reduction, in_channels, 1),
nn.Sigmoid(),
)
self.sSE = nn.Sequential(nn.Conv2d(in_channels, 1, 1), nn.Sigmoid())
[docs]
def forward(self: SCSEModule, x: torch.Tensor) -> torch.Tensor:
"""Forward pass of the SCSE module.
Args:
x (torch.Tensor):
Input feature map of shape (N, C, H, W).
Returns:
torch.Tensor:
Output feature map after applying SCSE attention.
"""
return x * self.cSE(x) + x * self.sSE(x)
[docs]
def argmax_last_axis(image: np.ndarray) -> np.ndarray:
"""Define the post-processing of this class of model.
This simply applies argmax along last axis of the input.
Args:
image (np.ndarray):
The input image array.
Returns:
np.ndarray:
The post-processed image array.
"""
return image.argmax(axis=-1)
[docs]
def peak_detection_map_overlap(
block: np.ndarray,
min_distance: int,
threshold_abs: float | None = None,
threshold_rel: float | None = None,
block_info: dict | None = None,
depth_h: int = 0,
depth_w: int = 0,
*,
return_probability: bool = False,
) -> np.ndarray:
"""Post-processing function for peak detection.
Builds a processed mask per input channel. Runs peak_local_max then
writes 1.0 at peak pixels if return_probability is False, otherwise writes
the confidence scores at peak locations.
Can be called from dask.da.map_overlap on a padded NumPy block
(h_pad, w_pad, C) to process large prediction maps in chunks with overlap.
Keeps only centroids whose (row, col) lie in the interior window:
rows [depth_h : depth_h + core_h), cols [depth_w : depth_w + core_w)
Returns same spatial shape as the input block
Args:
block:
NumPy array (H, W, C).
min_distance:
Minimum number of pixels separating peaks.
threshold_abs:
Minimum intensity of peaks. By default, None.
threshold_rel:
Minimum relative intensity of peaks. By default, None.
block_info:
Dask block info dict.
Only used when called from dask.array.map_overlap.
depth_h:
Halo size in pixels for height (rows).
Only used when called from dask.array.map_overlap.
depth_w:
Halo size in pixels for width (cols).
Only used when it's called from dask.array.map_overlap.
return_probability:
If True, returns the confidence scores at peak
locations instead of binary peak map.
Returns:
out:
NumPy array (H, W, C) with 1.0 at peaks, 0 elsewhere
if return_probability is False, otherwise with confidence
scores at peak locations.
"""
block_height, block_width, block_channels = block.shape
# --- derive core (pre-overlap) size for THIS block ---
if block_info is None:
core_h = block_height - 2 * depth_h
core_w = block_width - 2 * depth_w
else:
info = block_info[0]
locs = info["array-location"] # a list of (start, stop) coordinates per axis
core_h = int(locs[0][1] - locs[0][0]) # r1 - r0
core_w = int(locs[1][1] - locs[1][0])
rmin, rmax = depth_h, depth_h + core_h
cmin, cmax = depth_w, depth_w + core_w
out = np.zeros((block_height, block_width, block_channels), dtype=np.float32)
if return_probability:
out_probs = np.zeros(
(block_height, block_width, block_channels), dtype=np.float32
)
for ch in range(block_channels):
probs_map = np.asarray(block[..., ch]) # NumPy 2D view
coords = peak_local_max(
probs_map,
min_distance=min_distance,
threshold_abs=threshold_abs,
threshold_rel=threshold_rel,
exclude_border=False,
)
for r, c in coords:
if (rmin <= r < rmax) and (cmin <= c < cmax):
out[r, c, ch] = 1.0
if return_probability:
labeled_peaks = label(out[..., ch])
peak_stats = regionprops(labeled_peaks, intensity_image=probs_map)
for peak in peak_stats:
centroid = peak["centroid"]
r, c, confidence = (
centroid[0],
centroid[1],
peak["mean_intensity"],
)
out_probs[int(r), int(c), ch] = confidence
return out if not return_probability else out_probs
[docs]
def nms_on_detection_maps(
detection_maps: np.ndarray,
min_distance: int,
) -> np.ndarray:
"""Apply NMS to pre-processed peak maps to handle cross-channel conflicts.
Args:
detection_maps (np.ndarray):
(H, W, C) where pixels are already local peaks.
min_distance (int):
Minimum distance required between ANY detections.
Returns:
np.ndarray:
The filtered maps with cross-channel suppression applied.
"""
# 1. Collapse channels to find the "Global Best" at every spatial location
# Contains the highest probability found across all classes at each pixel.
max_across_channels = np.max(detection_maps, axis=2)
# 2. Handle Spatial Conflicts Across Channels (Global NMS)
filter_size = 2 * min_distance + 1
dilated_global_max = ndimage.maximum_filter(
max_across_channels, size=filter_size, mode="constant", cval=0.0
)
# 3. Create the Keep Mask
# A pixel is kept IF:
# A) It is the max value across its own channels
# B) It is the max value in its spatial neighborhood
# C) It is non-zero
keep_mask = (detection_maps == dilated_global_max[..., None]) & (detection_maps > 0)
# Apply mask
return np.where(keep_mask, detection_maps, 0)