Source code for tiatoolbox.models.architecture.vanilla

"""Define vanilla CNNs with torch backbones, mainly for patch classification."""

from __future__ import annotations

from typing import TYPE_CHECKING

import timm
import torch
import torchvision.models as torch_models
from timm.layers import SwiGLUPacked
from torch import nn

from tiatoolbox.models.architecture.utils import argmax_last_axis
from tiatoolbox.models.models_abc import ModelABC

if TYPE_CHECKING:  # pragma: no cover
    import numpy as np
    from torchvision.models import WeightsEnum

torch_cnn_backbone_dict = {
    "alexnet": torch_models.alexnet,
    "resnet18": torch_models.resnet18,
    "resnet34": torch_models.resnet34,
    "resnet50": torch_models.resnet50,
    "resnet101": torch_models.resnet101,
    "resnext50_32x4d": torch_models.resnext50_32x4d,
    "resnext101_32x8d": torch_models.resnext101_32x8d,
    "wide_resnet50_2": torch_models.wide_resnet50_2,
    "wide_resnet101_2": torch_models.wide_resnet101_2,
    "densenet121": torch_models.densenet121,
    "densenet161": torch_models.densenet161,
    "densenet169": torch_models.densenet169,
    "densenet201": torch_models.densenet201,
    "inception_v3": torch_models.inception_v3,
    "googlenet": torch_models.googlenet,
    "mobilenet_v2": torch_models.mobilenet_v2,
    "mobilenet_v3_large": torch_models.mobilenet_v3_large,
    "mobilenet_v3_small": torch_models.mobilenet_v3_small,
}

timm_arch_dict = {
    # UNI tile encoder: https://huggingface.co/MahmoodLab/UNI
    "UNI": {
        "model": "hf-hub:MahmoodLab/UNI",
        "init_values": 1e-5,
        "dynamic_img_size": True,
    },
    # Prov-GigaPath tile encoder: https://huggingface.co/prov-gigapath/prov-gigapath
    "prov-gigapath": {"model": "hf_hub:prov-gigapath/prov-gigapath"},
    # H-Optimus-0 tile encoder: https://huggingface.co/bioptimus/H-optimus-0
    "H-optimus-0": {
        "model": "hf-hub:bioptimus/H-optimus-0",
        "init_values": 1e-5,
        "dynamic_img_size": False,
    },
    # H-Optimus-1 tile encoder: https://huggingface.co/bioptimus/H-optimus-1
    "H-optimus-1": {
        "model": "hf-hub:bioptimus/H-optimus-1",
        "init_values": 1e-5,
        "dynamic_img_size": False,
    },
    # HO-mini tile encoder: https://huggingface.co/bioptimus/H0-mini
    "H0-mini": {
        "model": "hf-hub:bioptimus/H0-mini",
        "init_values": 1e-5,
        "dynamic_img_size": False,
        "mlp_layer": timm.layers.SwiGLUPacked,
        "act_layer": torch.nn.SiLU,
    },
    # UNI2-h tile encoder: https://huggingface.co/MahmoodLab/UNI2-h
    "UNI2": {
        "model": "hf-hub:MahmoodLab/UNI2-h",
        "img_size": 224,
        "patch_size": 14,
        "depth": 24,
        "num_heads": 24,
        "init_values": 1e-5,
        "embed_dim": 1536,
        "mlp_ratio": 2.66667 * 2,
        "num_classes": 0,
        "no_embed_class": True,
        "mlp_layer": timm.layers.SwiGLUPacked,
        "act_layer": torch.nn.SiLU,
        "reg_tokens": 8,
        "dynamic_img_size": True,
    },
    # Virchow tile encoder: https://huggingface.co/paige-ai/Virchow
    "Virchow": {
        "model": "hf_hub:paige-ai/Virchow",
        "mlp_layer": SwiGLUPacked,
        "act_layer": torch.nn.SiLU,
    },
    # Virchow2 tile encoder: https://huggingface.co/paige-ai/Virchow2
    "Virchow2": {
        "model": "hf_hub:paige-ai/Virchow2",
        "mlp_layer": SwiGLUPacked,
        "act_layer": torch.nn.SiLU,
    },
    # Kaiko tile encoder:
    # https://huggingface.co/1aurent/vit_large_patch14_reg4_224.kaiko_ai_towards_large_pathology_fms
    "kaiko": {
        "model": (
            "hf_hub:1aurent/"
            "vit_large_patch14_reg4_224.kaiko_ai_towards_large_pathology_fms"
        ),
        "dynamic_img_size": True,
    },
}


def _get_architecture(
    arch_name: str,
    weights: str | WeightsEnum | None = None,
    **kwargs: dict,
) -> list[nn.Sequential, ...] | nn.Sequential:
    """Retrieve a CNN model architecture.

    This function fetches a Convolutional Neural Network (CNN) model architecture,
    either predefined in torchvision or custom-made within tiatoolbox, for
    patch classification tasks.

    Args:
        arch_name (str):
            Name of the architecture (e.g. 'resnet50', 'alexnet').
        weights (str, WeightsEnum, or None):
            Pretrained torchvision model weights to use (get_model_weights).
            Default is None to avoid downloading ImageNet weights.
            To initiate the models with ImageNet weights, use "DEFAULT".
        **kwargs (dict):
            Key-word arguments.

    Returns:
        list[nn.Sequential, ...] | nn.Sequential:
            A list of PyTorch network layers wrapped with `nn.Sequential`.

    Raises:
        ValueError:
            If `arch_name` is not supported.

    Example:
        >>> model = _get_architecture("resnet18")
        >>> print(model)

    """
    if arch_name not in torch_cnn_backbone_dict:
        msg = f"Backbone `{arch_name}` is not supported."
        raise ValueError(msg)

    creator = torch_cnn_backbone_dict[arch_name]
    if "inception_v3" in arch_name or "googlenet" in arch_name:
        model = creator(weights=weights, aux_logits=False, num_classes=1000)
        return nn.Sequential(*list(model.children())[:-3])

    model = creator(weights=weights, **kwargs)
    # Unroll all the definition and strip off the final GAP and FCN
    if "resnet" in arch_name or "resnext" in arch_name:
        return nn.Sequential(*list(model.children())[:-2])
    if "densenet" in arch_name:
        return model.features
    if "alexnet" in arch_name:
        return model.features

    return model.features


def _get_timm_architecture(
    arch_name: str,
    *,
    pretrained: bool,
) -> list[nn.Sequential, ...] | nn.Sequential:
    """Retrieve a timm model architecture.

    This function fetches a model architecture from the timm library, specifically for
    pathology-related tasks.

    Args:
        arch_name (str):
            Name of the architecture (e.g. 'UNI', 'UN2', 'H-optimus-0',
            'efficientnet_b0', etc.).
        pretrained (bool, keyword-only):
            Whether to load pretrained weights.

    Returns:
        list[nn.Sequential, ...] | nn.Sequential:
            A ready-to-use timm model.

    Raises:
        ValueError:
            If the backbone architecture `arch_name` is not supported.

    Example:
        >>> model = _get_timm_architecture("UNI", pretrained=True)
        >>> print(model)

    """
    if arch_name in timm_arch_dict:  # pragma: no cover
        # Coverage skipped timm API is tested using efficient U-Net.
        return timm.create_model(
            timm_arch_dict[arch_name].pop("model"),
            pretrained=pretrained,
            **timm_arch_dict[arch_name],
        )

    if arch_name in timm.list_models():
        model = timm.create_model(arch_name, pretrained=pretrained)
        return nn.Sequential(*list(model.children())[:-1])

    msg = f"Backbone {arch_name} not supported. "
    raise ValueError(msg)


def _infer_batch(
    model: nn.Module,
    batch_data: torch.Tensor,
    device: str,
) -> np.ndarray:
    """Run inference on an input batch.

    Contains logic for forward operation as well as i/o aggregation.

    Args:
        model (nn.Module):
            PyTorch defined model.
        batch_data (torch.Tensor):
            A batch of data generated by
            `torch.utils.data.DataLoader`.
        device (str):
            Transfers model to the specified device. Default is "cpu".

    Returns:
        dict[str, np.ndarray]:
            The model predictions as a NumPy array.

    Example:
        >>> output = _infer_batch(model, batch_data, "cuda")
        >>> print(output)

    """
    img_patches_device = batch_data.to(device=device).type(
        torch.float32,
    )  # to NCHW
    img_patches_device = img_patches_device.permute(0, 3, 1, 2).contiguous()

    # Inference mode
    model.eval()
    # Do not compute the gradient (not training)
    with torch.inference_mode():
        output = model(img_patches_device)
    # Output should be a single tensor or scalar
    return output.cpu().numpy()


[docs] class CNNModel(ModelABC): """Retrieve the model backbone and attach an extra FCN to perform classification. This class initializes a Convolutional Neural Network (CNN) model with a specified backbone and attaches a fully connected layer for classification tasks. Args: backbone (str): Name of the CNN model backbone (e.g., "resnet18", "densenet121"). num_classes (int): Number of classes output by model. Defaults to 1. Attributes: num_classes (int): Number of classes output by the model. feat_extract (nn.Module): Backbone CNN model. pool (nn.Module): Type of pooling applied after feature extraction. classifier (nn.Module): Linear classifier module used to map the features to the output. Example: >>> model = CNNModel("resnet18", num_classes=2) >>> output = model(torch.randn(1, 3, 224, 224)) >>> print(output.shape) """ def __init__(self: CNNModel, backbone: str, num_classes: int = 1) -> None: """Initialize :class:`CNNModel`.""" super().__init__() self.num_classes = num_classes # By default pretrained weights are not downloaded self.feat_extract = _get_architecture(backbone) self.pool = nn.AdaptiveAvgPool2d((1, 1)) # Best way to retrieve channel dynamically is passing a small forward pass prev_num_ch = self.feat_extract(torch.rand([2, 3, 96, 96])).shape[1] self.classifier = nn.Linear(prev_num_ch, num_classes) # pylint: disable=W0221 # because abc is generic, this is actual definition
[docs] def forward(self: CNNModel, imgs: torch.Tensor) -> torch.Tensor: """Pass input data through the model. Args: imgs (torch.Tensor): Model input. Returns: torch.Tensor: The output logits after passing through the model. """ feat = self.feat_extract(imgs) gap_feat = self.pool(feat) gap_feat = torch.flatten(gap_feat, 1) logit = self.classifier(gap_feat) return torch.softmax(logit, -1)
[docs] @staticmethod def postproc(image: np.ndarray) -> np.ndarray: """Define the post-processing of this class of model. This simply applies argmax along last axis of the input. Args: image (np.ndarray): The input image array. Returns: np.ndarray: The post-processed image array. """ return argmax_last_axis(image=image)
[docs] @staticmethod def infer_batch( model: nn.Module, batch_data: torch.Tensor, device: str = "cpu", ) -> np.ndarray: """Run inference on an input batch. Contains logic for forward operation as well as i/o aggregation. Args: model (nn.Module): PyTorch defined model. batch_data (torch.Tensor): A batch of data generated by `torch.utils.data.DataLoader`. device (str): Transfers model to the specified device. Default is "cpu". Example: >>> output = _infer_batch(model, batch_data, "cuda") >>> print(output) """ return _infer_batch(model=model, batch_data=batch_data, device=device)
[docs] class TimmModel(ModelABC): """Retrieve the tile encoder from timm. This is a wrapper for pretrained models within timm. Args: backbone (str): Model name. Currently, the tool supports following model names and their default associated weights from timm. - "efficientnet_b{i}" for i in [0, 1, ..., 7] - "UNI" - "prov-gigapath" - "UNI2" - "Virchow" - "Virchow2" - "kaiko" - "H-optimus-0" - "H-optimus-1" - "H0-mini" num_classes (int): Number of classes output by model. pretrained (bool, keyword-only): Whether to load pretrained weights. Attributes: num_classes (int): Number of classes output by the model. pretrained (bool): Whether to load pretrained weights. feat_extract (nn.Module): Backbone Timm model. classifier (nn.Module): Linear classifier module used to map the features to the output. Example: >>> model = TimmModel("UNI", pretrained=True) >>> output = model(torch.randn(1, 3, 224, 224)) >>> print(output.shape) """ def __init__( self: TimmModel, backbone: str, num_classes: int = 1, *, pretrained: bool, ) -> None: """Initialize :class:`TimmModel`.""" super().__init__() self.pretrained = pretrained self.num_classes = num_classes self.feat_extract = _get_timm_architecture( arch_name=backbone, pretrained=pretrained ) # Best way to retrieve channel dynamically is passing a small forward pass prev_num_ch = self.feat_extract(torch.rand([2, 3, 224, 224])).shape[1] self.classifier = nn.Linear(prev_num_ch, num_classes) # pylint: disable=W0221 # because abc is generic, this is actual definition
[docs] def forward(self: TimmModel, imgs: torch.Tensor) -> torch.Tensor: """Pass input data through the model. Args: imgs (torch.Tensor): Model input. Returns: torch.Tensor: The output logits after passing through the model. """ feat = self.feat_extract(imgs) feat = torch.flatten(feat, 1) logit = self.classifier(feat) return torch.softmax(logit, -1)
[docs] @staticmethod def postproc(image: np.ndarray) -> np.ndarray: """Define the post-processing of this class of model. This simply applies argmax along last axis of the input. Args: image (np.ndarray): The input image array. Returns: np.ndarray: The post-processed image array. """ return argmax_last_axis(image=image)
[docs] @staticmethod def infer_batch( model: nn.Module, batch_data: torch.Tensor, device: str, ) -> np.ndarray: """Run inference on an input batch. Contains logic for forward operation as well as i/o aggregation. Args: model (nn.Module): PyTorch defined model. batch_data (torch.Tensor): A batch of data generated by `torch.utils.data.DataLoader`. device (str): Transfers model to the specified device. Returns: np.ndarray: The model predictions as a NumPy array. Example: >>> output = _infer_batch(model, batch_data, "cuda") >>> print(output) """ return _infer_batch(model=model, batch_data=batch_data, device=device)
[docs] class CNNBackbone(ModelABC): """Retrieve the model backbone and strip the classification layer. This is a wrapper for pretrained models within pytorch. Args: backbone (str): Model name. Currently, the tool supports following model names and their default associated weights from pytorch. - "alexnet" - "resnet18" - "resnet34" - "resnet50" - "resnet101" - "resnext50_32x4d" - "resnext101_32x8d" - "wide_resnet50_2" - "wide_resnet101_2" - "densenet121" - "densenet161" - "densenet169" - "densenet201" - "inception_v3" - "googlenet" - "mobilenet_v2" - "mobilenet_v3_large" - "mobilenet_v3_small" Attributes: feat_extract (nn.Module): Backbone CNN model. pool (nn.Module): Type of pooling applied after feature extraction. Examples: >>> # Creating resnet50 architecture from default pytorch >>> # without the classification layer with its associated >>> # weights loaded >>> model = CNNBackbone(backbone="resnet50") >>> model.eval() # set to evaluation mode >>> # dummy sample in NHWC form >>> samples = torch.rand(4, 3, 512, 512) >>> features = model(samples) >>> features.shape # features after global average pooling torch.Size([4, 2048]) """ def __init__(self: CNNBackbone, backbone: str) -> None: """Initialize :class:`CNNBackbone`.""" super().__init__() # By default pretrained weights are not downloaded self.feat_extract = _get_architecture(backbone) self.pool = nn.AdaptiveAvgPool2d((1, 1)) # pylint: disable=W0221 # because abc is generic, this is actual definition
[docs] def forward(self: CNNBackbone, imgs: torch.Tensor) -> torch.Tensor: """Pass input data through the model. Args: imgs (torch.Tensor): Model input. Returns: torch.Tensor: The extracted features. """ feat = self.feat_extract(imgs) gap_feat = self.pool(feat) return torch.flatten(gap_feat, 1)
[docs] @staticmethod def infer_batch( model: nn.Module, batch_data: torch.Tensor, device: str, ) -> list[np.ndarray]: """Run inference on an input batch. Contains logic for forward operation as well as i/o aggregation. Args: model (nn.Module): PyTorch defined model. batch_data (torch.Tensor): A batch of data generated by `torch.utils.data.DataLoader`. device (str): Transfers model to the specified device. Default is "cpu". Returns: list[np.ndarray]: list of dictionary values with numpy arrays. Example: >>> output = CNNBackbone.infer_batch(model, batch_data, "cuda") >>> print(output) """ return [_infer_batch(model=model, batch_data=batch_data, device=device)]
[docs] class TimmBackbone(ModelABC): """Retrieve tile encoders from timm. This is a wrapper for pretrained models within timm. Args: backbone (str): Model name. Supported model names include: - "efficientnet_b{i}" for i in [0, 1, ..., 7] - "UNI" - "prov-gigapath" - "UNI2" - "Virchow" - "Virchow2" - "kaiko" - "H-optimus-0" - "H-optimus-1" - "H0-mini" pretrained (bool, keyword-only): Whether to load pretrained weights. Attributes: feat_extract (nn.Module): Backbone timm model. Examples: >>> # Creating UNI tile encoder >>> model = TimmBackbone(backbone="UNI", pretrained=True) >>> model.eval() # set to evaluation mode >>> # dummy sample in NHWC form >>> samples = torch.rand(4, 3, 224, 224) >>> features = model(samples) >>> features.shape # feature vector torch.Size([4, 1024]) """ def __init__(self: TimmBackbone, backbone: str, *, pretrained: bool) -> None: """Initialize :class:`TimmBackbone`.""" super().__init__() self.pretrained = pretrained self.feat_extract = _get_timm_architecture( arch_name=backbone, pretrained=pretrained ) # pylint: disable=W0221 # because abc is generic, this is actual definition
[docs] def forward(self: TimmBackbone, imgs: torch.Tensor) -> torch.Tensor: """Pass input data through the model. Args: imgs (torch.Tensor): Model input. Returns: torch.Tensor: The extracted features. """ feats = self.feat_extract(imgs) return torch.flatten(feats, 1)
[docs] @staticmethod def infer_batch( model: nn.Module, batch_data: torch.Tensor, device: str, ) -> list[np.ndarray]: """Run inference on an input batch. Contains logic for forward operation as well as i/o aggregation. Args: model (nn.Module): PyTorch defined model. batch_data (torch.Tensor): A batch of data generated by `torch.utils.data.DataLoader`. device (str): Transfers model to the specified device. Default is "cpu". Returns: list[np.ndarray]: list of dictionary values with numpy arrays. Example: >>> output = TimmBackbone.infer_batch(model, batch_data, "cuda") >>> print(output) """ return [_infer_batch(model=model, batch_data=batch_data, device=device)]