Source code for tiatoolbox.models.architecture.kongnet

"""KongNet Nuclei Detection Model Architecture [1].

This module defines the KongNet model for nuclei detection and classification
in digital pathology. It implements a multi-head encoder decoder architecture
with an EfficientNetV2-L encoder. The model is designed to detect and classify
nuclei in whole slide images (WSIs).

KongNet achieved 1st on track 1 and 2nd on track 2 during the MONKEY Challenge [2].
KongNet achieved 1st place in the 2025 MIDOG Challenge [3].
KongNet ranked among the top three in the PUMA Challenge [4].
KongNet achieved SOTA detection performance on PanNuke [5] and CoNIC [6] datasets.

Please cite the paper [1], if you use this model.

Pretrained Models:
-----------------
    - KongNet_MONKEY_1:
        MONKEY Challenge model.
    - KongNet_Det_MIDOG_1:
        MIDOG Challenge lightweight detection model.
    - KongNet_PUMA_T1_3:
        PUMA Challenge model for track 1.
    - KongNet_PUMA_T2_3:
        PUMA Challenge model for track 2.
    - KongNet_CoNIC_1:
        CoNIC model.
    - KongNet_PanNuke_1:
        PanNuke model.

Key Components:
---------------
- TimmEncoderFixed: Encoder module using TIMM models with fixed drop_path_rate handling.
- SubPixelUpsample: Sub-pixel upsampling module using PixelShuffle.
- DecoderBlock: U-Net style decoder block with attention mechanisms.
- KongNetDecoder: U-Net style decoder with multiple decoder blocks.
- KongNet: Multi-head segmentation model with shared encoder and multiple decoders.

Features:
---------
- Multi-head architecture for accurate nuclei detection and classification.
- Efficient inference pipeline for batch processing.

Example:
    >>> from tiatoolbox.models.engine.nucleus_detector import NucleusDetector
    >>> detector = NucleusDetector(model="KongNet_CoNIC_1")
    >>> results = detector.run(
    ...     ["/example_wsi.svs"],
    ...     masks=None,
    ...     auto_get_mask=False,
    ...     patch_mode=False,
    ...     save_dir=Path("/KongNet_CoNIC/"),
    ...     output_type="annotationstore",
    ... )

References:
    [1] Lv, Jiaqi et al., "KongNet: A Multi-headed Deep Learning Model for Detection
    and Classification of Nuclei in Histopathology Images.", 2025,
    arXiv preprint arXiv:2510.23559., URL: https://arxiv.org/abs/2510.23559

    [2] L. Studer, “Structured description of the monkey challenge,” Sept. 2024.

    [3] J. Ammeling, M. Aubreville, S. Banerjee, C. A. Bertram, K. Breininger,
    D. Hirling, P. Horvath, N. Stathonikos, and M. Veta, “Mitosis domain
    generalization challenge 2025,” Mar. 2025.

    [4] M. Schuiveling, H. Liu, D. Eek, G. Breimer, K. Suijkerbuijk, W. Blokx,
    and M. Veta, “A novel dataset for nuclei and tissue segmentation in
    melanoma with baseline nuclei segmentation and tissue segmentation
    benchmarks,” GigaScience, vol. 14, 01 2025.

    [5] J. Gamper, N. A. Koohbanani, K. Benes, S. Graham, M. Jahanifar,
    S. A. Khurram, A. Azam, K. Hewitt, and N. Rajpoot, “Pannuke dataset
    extension, insights and baselines,” 2020.

    [6]  S. Graham et al., “Conic challenge: Pushing the frontiers of nuclear detection,
    segmentation, classification and counting,” Medical Image Analysis,
    vol. 92, p. 103047, 2024.

"""

from __future__ import annotations

from typing import TYPE_CHECKING, Any

import numpy as np
import timm
import torch
from torch import nn
from torchvision.ops import Conv2dNormActivation

from tiatoolbox.models.architecture.utils import (
    AttentionModule,
    SegmentationHead,
    nms_on_detection_maps,
    peak_detection_map_overlap,
)
from tiatoolbox.models.models_abc import ModelABC

if TYPE_CHECKING:  # pragma: no cover
    from collections.abc import Mapping

    from tiatoolbox.type_hints import IntPair


[docs] class TimmEncoderFixed(nn.Module): """Fixed version of TIMM encoder that handles drop_path_rate parameter properly. This encoder wraps TIMM models to provide consistent feature extraction interface for segmentation tasks. It extracts features at multiple scales from the encoder backbone. """ def __init__( self, name: str, in_channels: int = 3, depth: int = 5, output_stride: int = 32, drop_rate: float = 0.5, drop_path_rate: float | None = 0.0, *, pretrained: bool = True, ) -> None: """Initialize TimmEncoderFixed. Args: name (str): Name of the TIMM model to use as backbone. in_channels (int): Number of input channels. Default is 3. depth (int): Number of encoder stages to extract features from. Default is 5. output_stride (int): Output stride of the encoder. Default is 32. drop_rate (float): Dropout rate. Default is 0.5. drop_path_rate (float | None): Drop path rate of the encoder. Default is 0.0. pretrained (bool): Whether to use pretrained weights. Default is True. """ super().__init__() if drop_path_rate is None: kwargs = { "in_chans": in_channels, "features_only": True, "pretrained": pretrained, "out_indices": tuple(range(depth)), "drop_rate": drop_rate, } else: kwargs = { "in_chans": in_channels, "features_only": True, "pretrained": pretrained, "out_indices": tuple(range(depth)), "drop_rate": drop_rate, "drop_path_rate": drop_path_rate, } self.model = timm.create_model(name, **kwargs) self._in_channels = in_channels self._out_channels = [in_channels, *self.model.feature_info.channels()] self._depth = depth self._output_stride = output_stride
[docs] def forward(self, x: torch.Tensor) -> list[torch.Tensor]: """Forward pass through the encoder. Args: x (torch.Tensor): Input tensor of shape (B, C, H, W) Returns: list[torch.Tensor]: List of feature tensors at different scales, including the input as the first element """ features = self.model(x) return [x, *features]
@property def out_channels(self) -> list[int]: """Get output channels for each feature level. Returns: list[int]: Number of channels at each feature level """ return self._out_channels @property def output_stride(self) -> int: """Get the output stride of the encoder. Returns: int: Output stride value """ return min(self._output_stride, 2**self._depth)
[docs] class SubPixelUpsample(nn.Module): """Sub-pixel upsampling module using PixelShuffle. This module performs upsampling using sub-pixel convolution (PixelShuffle) which is more efficient than transposed convolution and produces better results. Args: in_channels (int): Number of input channels out_channels (int): Number of output channels upscale_factor (int): Factor to increase spatial resolution. Default: 2 """ def __init__( self, in_channels: int, out_channels: int, upscale_factor: int = 2 ) -> None: """Initialize SubPixelUpsample. Args: in_channels (int): Number of input channels out_channels (int): Number of output channels upscale_factor (int): Factor to increase spatial resolution. Default is 2. """ super().__init__() self.conv1 = Conv2dNormActivation( in_channels, out_channels * upscale_factor**2, kernel_size=1, norm_layer=nn.BatchNorm2d, activation_layer=nn.SiLU, ) self.pixel_shuffle = nn.PixelShuffle(upscale_factor) self.conv2 = Conv2dNormActivation( out_channels, out_channels, kernel_size=3, padding=1, norm_layer=nn.BatchNorm2d, activation_layer=nn.SiLU, )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass through sub-pixel upsampling. Args: x (torch.Tensor): Input tensor of shape (B, C, H, W) Returns: torch.Tensor: Upsampled tensor of shape (B, out_channels, H*upscale_factor, W*upscale_factor) """ x = self.conv1(x) x = self.pixel_shuffle(x) return self.conv2(x)
[docs] class DecoderBlock(nn.Module): """Decoder block with upsampling, skip connection, and attention. This block performs upsampling of the input features, concatenates with skip connections from the encoder, applies attention mechanisms, and processes through convolutions. Args: in_channels (int): Number of input channels skip_channels (int): Number of channels from skip connection out_channels (int): Number of output channels attention_type (str): Type of attention mechanism. Default: 'scse'. """ def __init__( self, in_channels: int, skip_channels: int, out_channels: int, attention_type: str = "scse", ) -> None: """Initialize DecoderBlock. Args: in_channels (int): Number of input channels. skip_channels (int): Number of channels from skip connection. out_channels (int): Number of output channels. attention_type (str): Type of attention mechanism. Default: 'scse'. """ super().__init__() self.up = SubPixelUpsample(in_channels, in_channels, upscale_factor=2) self.conv1 = Conv2dNormActivation( in_channels + skip_channels, out_channels, kernel_size=3, padding=1, norm_layer=nn.BatchNorm2d, activation_layer=nn.SiLU, ) self.attention1 = AttentionModule( name=attention_type, in_channels=in_channels + skip_channels ) self.conv2 = Conv2dNormActivation( out_channels, out_channels, kernel_size=3, padding=1, norm_layer=nn.BatchNorm2d, activation_layer=nn.SiLU, ) self.attention2 = AttentionModule(name=attention_type, in_channels=out_channels)
[docs] def forward( self, x: torch.Tensor, skip: torch.Tensor | None = None ) -> torch.Tensor: """Forward pass through decoder block. Args: x (torch.Tensor): Input tensor to be upsampled skip (Optional[torch.Tensor]): Skip connection tensor from encoder. Default: None Returns: torch.Tensor: Processed output tensor """ x = self.up(x) if skip is not None: x = torch.cat([x, skip], dim=1) x = self.attention1(x) x = self.conv1(x) x = self.conv2(x) return self.attention2(x)
[docs] class CenterBlock(nn.Module): """Center block that applies attention mechanism at the bottleneck. This block is placed at the center of the U-Net architecture (deepest level) to enhance feature representation using attention mechanisms. Args: in_channels (int): Number of input channels """ def __init__(self, in_channels: int) -> None: """Initialize CenterBlock with attention. Args: in_channels (int): Number of input channels. """ super().__init__() self.attention = AttentionModule(name="scse", in_channels=in_channels)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass through center block. Args: x (torch.Tensor): Input tensor. Returns: torch.Tensor: Output tensor with attention applied. """ return self.attention(x)
[docs] class KongNetDecoder(nn.Module): """Decoder module for KongNet architecture. This decoder implements a U-Net style decoder with multiple decoder blocks, attention mechanisms, and optional center block at the bottleneck. Args: encoder_channels (list[int]): Number of channels at each encoder level decoder_channels (Tuple[int, ...]): Number of channels at each decoder level n_blocks (int): Number of decoder blocks. Default: 5 attention_type (str): Type of attention mechanism. Default: 'scse' center (bool): Whether to use center block at bottleneck. Default: True Raises: ValueError: If n_blocks doesn't match length of decoder_channels """ def __init__( self, encoder_channels: list[int], decoder_channels: tuple[int, ...], n_blocks: int = 5, attention_type: str = "scse", *, center: bool = True, ) -> None: """Initialize KongNetDecoder. Args: encoder_channels (list[int]): Number of channels at each encoder level. decoder_channels (Tuple[int, ...]): Number of channels at each decoder level. n_blocks (int): Number of decoder blocks. Default is 5. attention_type (str): Type of attention mechanism to use. Default is 'scse'. center (bool): Whether to include a center block at the bottleneck. Default is True. """ super().__init__() if n_blocks != len(decoder_channels): msg = ( f"The number of blocks {n_blocks} must match the" f" length of decoder_channels {len(decoder_channels)}." ) raise ValueError(msg) # remove first skip with same spatial resolution encoder_channels = encoder_channels[1:] # reverse channels to start from head of encoder encoder_channels = encoder_channels[::-1] # computing blocks input and output channels head_channels = encoder_channels[0] in_channels = [head_channels, *list(decoder_channels[:-1])] skip_channels = [*list(encoder_channels[1:]), 0] out_channels = decoder_channels if center: self.center = CenterBlock(head_channels) else: self.center = nn.Identity() blocks = [ DecoderBlock(in_ch, skip_ch, out_ch, attention_type=attention_type) for in_ch, skip_ch, out_ch in zip( in_channels, skip_channels, out_channels, strict=True ) ] self.blocks = nn.ModuleList(blocks)
[docs] def forward(self, *features: torch.Tensor) -> torch.Tensor: """Forward pass through the decoder. Args: *features: Feature tensors from encoder at different scales Returns: torch.Tensor: Decoded output tensor """ features = features[1:] # remove first skip with same spatial resolution features = features[::-1] # reverse channels to start from head of encoder head = features[0] skips = features[1:] x = self.center(head) for i, decoder_block in enumerate(self.blocks): skip = skips[i] if i < len(skips) else None x = decoder_block(x, skip) return x
[docs] class KongNet(ModelABC): """KongNet: Multi-head nuclei detection model. This module defines the KongNet model for nuclei detection and classification in digital pathology. It implements a multi-head encoder decoder architecture with an EfficientNetV2-L encoder. The model is designed to detect and classify nuclei in whole slide images (WSIs). Please cite the paper [1], if you use this model. .. list-table:: KongNet detection performance (FROC) on the MONKEY Challenge Final Leaderboard [2] :widths: 15 15 15 15 :header-rows: 1 :align: left * - Model name - Overall Inflammatory - Lymphocytes - Monocytes * - KongNet_MONKEY_1 - 0.3930 - 0.4624 - 0.2392 .. list-table:: KongNet detection performance (F1) on the MIDOG 2025 Challenge Final Leaderboard [3] :widths: 15 15 :header-rows: 1 :align: left * - Model name - Mitotic Figures * - KongNet_Det_MIDOG_1 - 0.7400 .. list-table:: KongNet detection performance (F1) on the PUMA Challenge Final Leaderboard Track 1 [4] :widths: 15 15 15 15 :header-rows: 1 :align: left * - Model name - Tumour Cells - Lymphocytes - Other * - KongNet_PUMA_T1_3 - 0.7948 - 0.6746 - 0.4704 .. list-table:: KongNet detection performance (F1) on the PUMA Challenge Final Leaderboard Track 2 [4] :widths: 15 15 15 15 15 15 15 15 15 15 15 :header-rows: 1 :align: left * - Model name - Tumour Cells - Stroma Cells - Apoptotic Cells - Epithelium Cells - Histiocytes - Lymphocytes - Neutrophils - Endothelial Cells - Melanophages - Plasma Cells * - KongNet_PUMA_T1_3 - 0.7952 - 0.2927 - 0.1170 - 0.0707 - 0.2154 - 0.6642 - 0.0361 - 0.2123 - 0.1931 - 0.0595 .. list-table:: KongNet detection performance (F1) on the PanNuke Dataset [5] :widths: 15 15 15 15 15 15 15 :header-rows: 1 :align: left * - Model name - Overall - Neoplastic Cells - Inflammatory Cells - Epithelial Cells - Connective Cells - Dead Cells * - KongNet_CoNIC_1 - 0.84 - 0.71 - 0.72 - 0.65 - 0.70 - 0.59 .. list-table:: KongNet detection performance (F1) on the CoNIC Dataset [6] :widths: 15 15 15 15 15 15 15 :header-rows: 1 :align: left * - Model name - Neutrophils - Epithelial Cells - Lymphocytes - Plasma Cells - Eosinophils - Connective Cells * - KongNet_CoNIC_1 - 0.510 - 0.818 - 0.707 - 0.596 - 0.591 - 0.695 Attributes: encoder: Encoder module (e.g., TimmEncoderFixed) decoders: List of decoder modules (KongNetDecoder) heads: List of segmentation head modules (SegmentationHead) min_distance: Minimum distance between peaks in post-processing threshold_abs: Absolute threshold for peak detection in post-processing target_channels: List of target channel indices for post-processing class_dict: Optional dictionary mapping class names to indices tile_shape: Tile shape for post-processing with dask Example: >>> from tiatoolbox.models.engine.nucleus_detector import NucleusDetector >>> detector = NucleusDetector(model="KongNet_CoNIC_1") >>> results = detector.run( ... ["/example_wsi.svs"], ... masks=None, ... auto_get_mask=False, ... patch_mode=False, ... save_dir=Path("/KongNet_CoNIC/"), ... output_type="annotationstore", ... ) References: [1] Lv, Jiaqi et al., "KongNet: A Multi-headed Deep Learning Model for Detection and Classification of Nuclei in Histopathology Images.", 2025, arXiv preprint arXiv:2510.23559., URL: https://arxiv.org/abs/2510.23559 [2] L. Studer, “Structured description of the monkey challenge,” Sept. 2024. [3] J. Ammeling, M. Aubreville, S. Banerjee, C. A. Bertram, K. Breininger, D. Hirling, P. Horvath, N. Stathonikos, and M. Veta, “Mitosis domain generalization challenge 2025,” Mar. 2025. [4] M. Schuiveling, H. Liu, D. Eek, G. Breimer, K. Suijkerbuijk, W. Blokx, and M. Veta, “A novel dataset for nuclei and tissue segmentation in melanoma with baseline nuclei segmentation and tissue segmentation benchmarks,” GigaScience, vol. 14, 01 2025. [5] J. Gamper, N. A. Koohbanani, K. Benes, S. Graham, M. Jahanifar, S. A. Khurram, A. Azam, K. Hewitt, and N. Rajpoot, “Pannuke dataset extension, insights and baselines,” 2020. [6] S. Graham et al., “Conic challenge: Pushing the frontiers of nuclear detection, segmentation, classification and counting,” Medical Image Analysis, vol. 92, p. 103047, 2024. """ def __init__( self: KongNet, num_heads: int, num_channels_per_head: list[int], target_channels: list[int], min_distance: int, threshold_abs: float, tile_shape: IntPair = (2048, 2048), *, wide_decoder: bool = False, class_dict: dict | None = None, ) -> None: """Initialize KongNet model. Args: num_heads (int): Number of decoder heads. num_channels_per_head (list[int]): List specifying number of output channels for each head. target_channels (list[int]): List of target channel indices for post-processing. min_distance (int): Minimum distance between peaks in post-processing. threshold_abs (float): Absolute threshold for peak detection in post-processing. tile_shape (IntPair): Tile shape for post-processing with dask. Defaults to (2048, 2048). wide_decoder (bool): Whether to use a wider decoder architecture. Defaults to False. class_dict (dict | None): Optional dictionary mapping class names to indices. Defaults to None. """ super().__init__() if len(num_channels_per_head) != num_heads: msg = ( f"Number of decoders {len(num_channels_per_head)}" f" must match number of heads {num_heads}." ) raise ValueError(msg) self.encoder = TimmEncoderFixed( name="tf_efficientnetv2_l.in21k_ft_in1k", in_channels=3, depth=5, output_stride=32, drop_rate=0.5, drop_path_rate=0.25, pretrained=False, ) decoder_channels = (256, 128, 64, 32, 16) if wide_decoder: decoder_channels = (512, 256, 128, 64, 32) decoders = [ KongNetDecoder( encoder_channels=self.encoder.out_channels, decoder_channels=decoder_channels, n_blocks=len(decoder_channels), center=True, attention_type="scse", ) for _ in range(num_heads) ] heads = [ SegmentationHead( in_channels=decoders[i].blocks[-1].conv2[0].out_channels, out_channels=num_channels_per_head[i], # instance channels activation=None, kernel_size=1, ) for i in range(num_heads) ] self.decoders = nn.ModuleList(decoders) self.heads = nn.ModuleList(heads) self.min_distance = min_distance self.threshold_abs = threshold_abs self.target_channels = target_channels self.class_dict = class_dict self.tile_shape = tile_shape
[docs] @staticmethod def preproc(image: np.ndarray) -> np.ndarray: """Preprocess input image for inference. Applies ImageNet normalization to the input image. Args: image (np.ndarray): Input image as a NumPy array of shape (H, W, C) in uint8 format. Returns: np.ndarray: Preprocessed image normalized to ImageNet statistics. Example: >>> img = np.random.randint(0, 255, (256, 256, 3), dtype=np.uint8) >>> processed = KongNet.preproc(img) >>> processed.shape ... (256, 256, 3) """ mean = np.array([0.485, 0.456, 0.406]) std = np.array([0.229, 0.224, 0.225]) return (image / 255.0 - mean) / std
[docs] def forward( # skipcq: PYL-W0613 self: KongNet, x: torch.Tensor, *args: tuple[Any, ...], # noqa: ARG002 **kwargs: dict, # noqa: ARG002 ) -> torch.Tensor: """Forward pass through the model. Args: x (torch.Tensor): Input tensor of shape (B, C, H, W) *args (tuple): Additional positional arguments (unused). **kwargs (dict): Additional keyword arguments (unused). Returns: torch.Tensor: Concatenated output from all heads of shape (B, sum(num_channels_per_head), H, W) """ features = self.encoder(x) decoder_outputs = [decoder(*features) for decoder in self.decoders] segmentation_head_outputs = [] for head, decoder_output in zip(self.heads, decoder_outputs, strict=True): segmentation_head_outputs.append(head(decoder_output)) return torch.cat(segmentation_head_outputs, 1)
[docs] @staticmethod def infer_batch( model: KongNet, batch_data: torch.Tensor, *, device: str, ) -> np.ndarray: """Run inference on a batch of images. Transfers the model and input batch to the specified device, performs forward pass, and returns probability maps. Args: model (torch.nn.Module): PyTorch model instance. batch_data (torch.Tensor): Batch of input images in NHWC format. device (str): Device for inference (e.g., "cpu" or "cuda"). Returns: np.ndarray: Inference results as a NumPy array of shape (N, H, W, C). Example: >>> batch = torch.randn(4, 256, 256, 3) >>> probs = KongNet.infer_batch(model, batch, device="cpu") >>> probs.shape (4, 256, 256, len(model.target_channels)) """ model = model.to(device) model.eval() imgs = batch_data imgs = imgs.to(device).type(torch.float32) imgs = imgs.permute(0, 3, 1, 2) # to NCHW with torch.inference_mode(): logits = model(imgs) target_logits = logits[:, model.target_channels, :, :] probs = torch.nn.functional.sigmoid(target_logits) probs = probs.permute(0, 2, 3, 1) # to NHWC return probs.cpu().numpy()
# skipcq: PYL-W0221 # noqa: ERA001
[docs] def postproc( self: KongNet, block: np.ndarray, min_distance: int | None = None, threshold_abs: float | None = None, threshold_rel: float | None = None, block_info: dict | None = None, depth_h: int = 0, depth_w: int = 0, ) -> np.ndarray: """KongNet post-processing function. Builds a processed mask per input channel, runs peak_local_max then writes 1.0 at peak pixels. Returns same spatial shape as the input block Args: block (np.ndarray): shape (H, W, C). min_distance (int | None): The minimal allowed distance separating peaks. threshold_abs (float | None): Minimum intensity of peaks. threshold_rel (float | None): Minimum intensity of peaks. block_info (dict | None): Dask block info dict. Only used when called from dask.array.map_overlap. depth_h (int): Halo size in pixels for height (rows). Only used when it's called from dask.array.map_overlap. depth_w (int): Halo size in pixels for width (cols). Only used when it's called from dask.array.map_overlap. Returns: out: NumPy array (H, W, C) with 1.0 at peaks, 0 elsewhere. """ min_distance_to_use = ( self.min_distance if min_distance is None else min_distance ) threshold_abs_to_use = ( self.threshold_abs if threshold_abs is None else threshold_abs ) peak_map = peak_detection_map_overlap( block, min_distance=min_distance_to_use, threshold_abs=threshold_abs_to_use, threshold_rel=threshold_rel, block_info=block_info, depth_h=depth_h, depth_w=depth_w, return_probability=True, ) return nms_on_detection_maps( peak_map, min_distance=min_distance_to_use, )
[docs] def load_state_dict( self: KongNet, state_dict: Mapping[str, Any], *, strict: bool = True, assign: bool = False, ) -> nn.Module: """Load state dict with support for wrapped models.""" return super().load_state_dict(state_dict["model"], strict, assign)