KongNet

class KongNet(num_heads, num_channels_per_head, target_channels, min_distance, threshold_abs, tile_shape=(2048, 2048), *, wide_decoder=False, class_dict=None)[source]

KongNet: Multi-head nuclei detection model.

This module defines the KongNet model for nuclei detection and classification in digital pathology. It implements a multi-head encoder decoder architecture with an EfficientNetV2-L encoder. The model is designed to detect and classify nuclei in whole slide images (WSIs). Please cite the paper [1], if you use this model.

KongNet detection performance (FROC) on the MONKEY Challenge Final Leaderboard [2]

Model name

Overall Inflammatory

Lymphocytes

Monocytes

KongNet_MONKEY_1

0.3930

0.4624

0.2392

KongNet detection performance (F1) on the MIDOG 2025 Challenge Final Leaderboard [3]

Model name

Mitotic Figures

KongNet_Det_MIDOG_1

0.7400

KongNet detection performance (F1) on the PUMA Challenge Final Leaderboard Track 1 [4]

Model name

Tumour Cells

Lymphocytes

Other

KongNet_PUMA_T1_3

0.7948

0.6746

0.4704

KongNet detection performance (F1) on the PUMA Challenge Final Leaderboard Track 2 [4]

Model name

Tumour Cells

Stroma Cells

Apoptotic Cells

Epithelium Cells

Histiocytes

Lymphocytes

Neutrophils

Endothelial Cells

Melanophages

Plasma Cells

KongNet_PUMA_T1_3

0.7952

0.2927

0.1170

0.0707

0.2154

0.6642

0.0361

0.2123

0.1931

0.0595

KongNet detection performance (F1) on the PanNuke Dataset [5]

Model name

Overall

Neoplastic Cells

Inflammatory Cells

Epithelial Cells

Connective Cells

Dead Cells

KongNet_CoNIC_1

0.84

0.71

0.72

0.65

0.70

0.59

KongNet detection performance (F1) on the CoNIC Dataset [6]

Model name

Neutrophils

Epithelial Cells

Lymphocytes

Plasma Cells

Eosinophils

Connective Cells

KongNet_CoNIC_1

0.510

0.818

0.707

0.596

0.591

0.695

encoder

Encoder module (e.g., TimmEncoderFixed)

decoders

List of decoder modules (KongNetDecoder)

heads

List of segmentation head modules (SegmentationHead)

min_distance

Minimum distance between peaks in post-processing

threshold_abs

Absolute threshold for peak detection in post-processing

target_channels

List of target channel indices for post-processing

class_dict

Optional dictionary mapping class names to indices

tile_shape

Tile shape for post-processing with dask

Example

>>> from tiatoolbox.models.engine.nucleus_detector import NucleusDetector
>>> detector = NucleusDetector(model="KongNet_CoNIC_1")
>>> results = detector.run(
...     ["/example_wsi.svs"],
...     masks=None,
...     auto_get_mask=False,
...     patch_mode=False,
...     save_dir=Path("/KongNet_CoNIC/"),
...     output_type="annotationstore",
... )

References

[1] Lv, Jiaqi et al., “KongNet: A Multi-headed Deep Learning Model for Detection and Classification of Nuclei in Histopathology Images.”, 2025, arXiv preprint arXiv:2510.23559., URL: https://arxiv.org/abs/2510.23559

[2] L. Studer, “Structured description of the monkey challenge,” Sept. 2024.

[3] J. Ammeling, M. Aubreville, S. Banerjee, C. A. Bertram, K. Breininger, D. Hirling, P. Horvath, N. Stathonikos, and M. Veta, “Mitosis domain generalization challenge 2025,” Mar. 2025.

[4] M. Schuiveling, H. Liu, D. Eek, G. Breimer, K. Suijkerbuijk, W. Blokx, and M. Veta, “A novel dataset for nuclei and tissue segmentation in melanoma with baseline nuclei segmentation and tissue segmentation benchmarks,” GigaScience, vol. 14, 01 2025.

[5] J. Gamper, N. A. Koohbanani, K. Benes, S. Graham, M. Jahanifar, S. A. Khurram, A. Azam, K. Hewitt, and N. Rajpoot, “Pannuke dataset extension, insights and baselines,” 2020.

[6] S. Graham et al., “Conic challenge: Pushing the frontiers of nuclear detection, segmentation, classification and counting,” Medical Image Analysis, vol. 92, p. 103047, 2024.

Initialize KongNet model.

Parameters:
  • num_heads (int) – Number of decoder heads.

  • num_channels_per_head (list[int]) – List specifying number of output channels for each head.

  • target_channels (list[int]) – List of target channel indices for post-processing.

  • min_distance (int) – Minimum distance between peaks in post-processing.

  • threshold_abs (float) – Absolute threshold for peak detection in post-processing.

  • tile_shape (IntPair) – Tile shape for post-processing with dask. Defaults to (2048, 2048).

  • wide_decoder (bool) – Whether to use a wider decoder architecture. Defaults to False.

  • class_dict (dict | None) – Optional dictionary mapping class names to indices. Defaults to None.

Methods

forward

Forward pass through the model.

infer_batch

Run inference on a batch of images.

load_state_dict

Load state dict with support for wrapped models.

postproc

KongNet post-processing function.

preproc

Preprocess input image for inference.

Attributes

training

forward(x, *args, **kwargs)[source]

Forward pass through the model.

Parameters:
  • x (torch.Tensor) – Input tensor of shape (B, C, H, W)

  • *args (tuple) – Additional positional arguments (unused).

  • **kwargs (dict) – Additional keyword arguments (unused).

  • self (KongNet)

Returns:

Concatenated output from all heads of shape

(B, sum(num_channels_per_head), H, W)

Return type:

torch.Tensor

static infer_batch(model, batch_data, *, device)[source]

Run inference on a batch of images.

Transfers the model and input batch to the specified device, performs forward pass, and returns probability maps.

Parameters:
  • model (torch.nn.Module) – PyTorch model instance.

  • batch_data (torch.Tensor) – Batch of input images in NHWC format.

  • device (str) – Device for inference (e.g., “cpu” or “cuda”).

Returns:

Inference results as a NumPy array of shape (N, H, W, C).

Return type:

np.ndarray

Example

>>> batch = torch.randn(4, 256, 256, 3)
>>> probs = KongNet.infer_batch(model, batch, device="cpu")
>>> probs.shape
(4, 256, 256, len(model.target_channels))
load_state_dict(state_dict, *, strict=True, assign=False)[source]

Load state dict with support for wrapped models.

Parameters:
Return type:

nn.Module

postproc(block, min_distance=None, threshold_abs=None, threshold_rel=None, block_info=None, depth_h=0, depth_w=0)[source]

KongNet post-processing function.

Builds a processed mask per input channel, runs peak_local_max then writes 1.0 at peak pixels.

Returns same spatial shape as the input block

Parameters:
  • block (np.ndarray) – shape (H, W, C).

  • min_distance (int | None) – The minimal allowed distance separating peaks.

  • threshold_abs (float | None) – Minimum intensity of peaks.

  • threshold_rel (float | None) – Minimum intensity of peaks.

  • block_info (dict | None) – Dask block info dict. Only used when called from dask.array.map_overlap.

  • depth_h (int) – Halo size in pixels for height (rows). Only used when it’s called from dask.array.map_overlap.

  • depth_w (int) – Halo size in pixels for width (cols). Only used when it’s called from dask.array.map_overlap.

  • self (KongNet)

Returns:

NumPy array (H, W, C) with 1.0 at peaks, 0 elsewhere.

Return type:

out

static preproc(image)[source]

Preprocess input image for inference.

Applies ImageNet normalization to the input image.

Parameters:

image (np.ndarray) – Input image as a NumPy array of shape (H, W, C) in uint8 format.

Returns:

Preprocessed image normalized to ImageNet statistics.

Return type:

np.ndarray

Example

>>> img = np.random.randint(0, 255, (256, 256, 3), dtype=np.uint8)
>>> processed = KongNet.preproc(img)
>>> processed.shape
... (256, 256, 3)