Source code for tiatoolbox.models.architecture.timm_efficientnet

"""EfficientNet Encoder Implementation using timm.

This module provides an implementation of EfficientNet-based encoders for use in
semantic segmentation and other computer vision tasks. It leverages the `timm`
library for model components and adds encoder-specific functionality such as
custom input channels, dilation support, and configurable scaling parameters.

Key Components:
---------------
- patch_first_conv:
    Utility to modify the first convolution layer for arbitrary input channels.
- replace_strides_with_dilation:
    Utility to convert strides into dilations for atrous convolutions.
- EncoderMixin:
    Mixin class adding encoder-specific features like output channels and stride.
- EfficientNetBaseEncoder:
    Base encoder combining EfficientNet backbone with encoder functionality.
- EfficientNetEncoder:
    Configurable EfficientNet encoder supporting depth and channel scaling.
- timm_efficientnet_encoders:
    Dictionary of available EfficientNet encoder configurations and pretrained settings.

Features:
---------
- Supports arbitrary input channels (e.g., grayscale or multi-channel images).
- Allows conversion to dilated versions for semantic segmentation.
- Provides pretrained weights from multiple sources (ImageNet, AdvProp, Noisy Student).
- Implements scaling rules for EfficientNet architecture.

Example:
    >>> from tiatoolbox.models.architecture.timm_efficientnet import EfficientNetEncoder
    >>> encoder = EfficientNetEncoder(
    ...     stage_idxs=[2, 3, 5],
    ...     out_channels=[3, 32, 24, 40, 112, 320],
    ...     channel_multiplier=1.0,
    ...     depth_multiplier=1.0,
    ...     drop_rate=0.2
    ... )
    >>> x = torch.randn(1, 3, 224, 224)
    >>> features = encoder(x)
    >>> [f.shape for f in features]
    [torch.Size([1, 3, 224, 224]), torch.Size([1, 32, 112, 112]), ...]

References:
    - Tan, Mingxing, and Quoc V. Le. "EfficientNet: Rethinking Model Scaling for
      Convolutional Neural Networks." arXiv preprint arXiv:1905.11946 (2019).
      URL: https://arxiv.org/abs/1905.11946

"""

from __future__ import annotations

from functools import partial
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:  # pragma: no cover
    from collections.abc import Mapping, Sequence

import torch
from timm.layers.activations import Swish
from timm.models._efficientnet_builder import decode_arch_def, round_channels
from timm.models.efficientnet import EfficientNet
from torch import nn

MAX_DEPTH = 5
MIN_DEPTH = 1
DEFAULT_IN_CHANNELS = 3


[docs] def patch_first_conv( model: nn.Module, new_in_channels: int, default_in_channels: int = 3, *, pretrained: bool = True, ) -> None: """Update the first convolution layer for a new input channel size. This function updates the first convolutional layer of a model to handle arbitrary input channels. It optionally reuses pretrained weights or initializes weights randomly. Args: model (nn.Module): The neural network model whose first convolution layer will be patched. new_in_channels (int): Number of input channels for the new first layer. default_in_channels (int): Original number of input channels. Defaults to 3. pretrained (bool): Whether to reuse pretrained weights. Defaults to True. Notes: - If `new_in_channels` == 1 or 2 → reuse original weights. - If `new_in_channels` > 3 → initialize weights using Kaiming normal. Example: >>> patch_first_conv(model, new_in_channels=1, pretrained=True) """ # get first conv conv_module: nn.Conv2d | None = None for module in model.modules(): if isinstance(module, nn.Conv2d) and module.in_channels == default_in_channels: conv_module = module break if conv_module is None: return weight = conv_module.weight.detach() conv_module.in_channels = new_in_channels if not pretrained: conv_module.weight = nn.parameter.Parameter( torch.Tensor( conv_module.out_channels, new_in_channels // conv_module.groups, *conv_module.kernel_size, ) ) conv_module.reset_parameters() elif new_in_channels == 1: new_weight = weight.sum(1, keepdim=True) conv_module.weight = nn.parameter.Parameter(new_weight) else: new_weight = torch.Tensor( conv_module.out_channels, new_in_channels // conv_module.groups, *conv_module.kernel_size, ) for i in range(new_in_channels): new_weight[:, i] = weight[:, i % default_in_channels] new_weight = new_weight * (default_in_channels / new_in_channels) conv_module.weight = nn.parameter.Parameter(new_weight)
[docs] def replace_strides_with_dilation(module: nn.Module, dilation_rate: int) -> None: """Replace strides with dilation in Conv2d layers. Converts convolutional layers to use dilation instead of stride, enabling atrous convolutions for semantic segmentation tasks. Args: module (nn.Module): Module containing Conv2d layers to patch. dilation_rate (int): Dilation rate to apply to all Conv2d layers. Example: >>> replace_strides_with_dilation(model, dilation_rate=2) """ for mod in module.modules(): if isinstance(mod, nn.Conv2d): mod.stride = (1, 1) mod.dilation = (dilation_rate, dilation_rate) kh, _ = mod.kernel_size mod.padding = ((kh // 2) * dilation_rate, (kh // 2) * dilation_rate) # Workaround for EfficientNet if hasattr(mod, "static_padding"): mod.static_padding = nn.Identity() # type: ignore[attr-defined]
[docs] class EncoderMixin: """Mixin class adding encoder-specific functionality. Provides methods for: - Managing output channels for encoder feature maps. - Patching the first convolution for arbitrary input channels. - Converting encoder to dilated version for segmentation tasks. Attributes: _depth (int): Encoder depth (number of stages). _in_channels (int): Number of input channels. _output_stride (int): Output stride of the encoder. _out_channels (list[int]): List of output channel dimensions for each depth level. Example: >>> encoder = EncoderMixin() >>> encoder.set_in_channels(1) """ _is_torch_scriptable = True _is_torch_exportable = True _is_torch_compilable = True def __init__(self) -> None: """Initialize EncoderMixin with default parameters. Sets default values for encoder depth, input channels, output stride, and output channel list. """ self._depth = 5 self._in_channels = 3 self._output_stride = 32 self._out_channels: list[int] = [] @property def out_channels(self) -> list[int]: """Return output channel dimensions for encoder feature maps. Returns: list[int]: List of output channel dimensions for each depth level. Example: >>> encoder.out_channels ... [3, 32, 64, 128, 256, 512] """ return self._out_channels[: self._depth + 1] @property def output_stride(self) -> int: """Return the effective output stride of the encoder. The output stride is the minimum of the configured stride and 2^depth. Returns: int: Effective output stride. Example: >>> encoder.output_stride ... 32 """ return min(self._output_stride, 2**self._depth)
[docs] def set_in_channels(self, in_channels: int, *, pretrained: bool = True) -> None: """Update the encoder to accept a different number of input channels. Args: in_channels (int): Number of input channels. pretrained (bool): Whether to use pretrained weights. Defaults to True. Example: >>> encoder.set_in_channels(1, pretrained=False) """ if in_channels == DEFAULT_IN_CHANNELS: return self._in_channels = in_channels if self._out_channels[0] == DEFAULT_IN_CHANNELS: self._out_channels = [in_channels, *self._out_channels[1:]] # Type ignore needed because self is a mixin that will be used with nn.Module patch_first_conv(model=self, new_in_channels=in_channels, pretrained=pretrained) # type: ignore[arg-type]
[docs] def get_stages(self) -> dict[int, Sequence[torch.nn.Module]]: """Return encoder stages for dilation modification. This method should be overridden by subclasses to provide stage mappings for converting strides to dilations. Returns: dict[int, Sequence[torch.nn.Module]]: Dictionary mapping output stride to corresponding module sequences. Raises: NotImplementedError: If the method is not implemented by the subclass. Example: >>> stages = encoder.get_stages() """ raise NotImplementedError
[docs] def make_dilated(self, output_stride: int) -> None: """Convert encoder to a dilated version for segmentation. Args: output_stride (int): Target output stride (must be 8 or 16). Raises: ValueError: If `output_stride` is not 8 or 16. Example: >>> encoder.make_dilated(output_stride=16) """ if output_stride not in [8, 16]: msg = f"Output stride should be 16 or 8, got {output_stride}." raise ValueError(msg) stages = self.get_stages() for stage_stride, stage_modules in stages.items(): if stage_stride <= output_stride: continue dilation_rate = stage_stride // output_stride for module in stage_modules: replace_strides_with_dilation(module, dilation_rate)
[docs] def get_efficientnet_kwargs( channel_multiplier: float = 1.0, depth_multiplier: float = 1.0, drop_rate: float = 0.2, ) -> dict[str, Any]: """Generate configuration parameters for EfficientNet. Args: channel_multiplier (float): Multiplier for number of channels per layer. Defaults to 1.0. depth_multiplier (float): Multiplier for number of repeats per stage. Defaults to 1.0. drop_rate (float): Dropout rate. Defaults to 0.2. Reference implementation: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py Paper: https://arxiv.org/abs/1905.11946 EfficientNet parameters: - 'efficientnet-b0': (1.0, 1.0, 224, 0.2) - 'efficientnet-b1': (1.0, 1.1, 240, 0.2) - 'efficientnet-b2': (1.1, 1.2, 260, 0.3) - 'efficientnet-b3': (1.2, 1.4, 300, 0.3) - 'efficientnet-b4': (1.4, 1.8, 380, 0.4) - 'efficientnet-b5': (1.6, 2.2, 456, 0.4) - 'efficientnet-b6': (1.8, 2.6, 528, 0.5) - 'efficientnet-b7': (2.0, 3.1, 600, 0.5) - 'efficientnet-b8': (2.2, 3.6, 672, 0.5) - 'efficientnet-l2': (4.3, 5.3, 800, 0.5) Args: channel_multiplier: Multiplier to number of channels per layer. Defaults to 1.0. depth_multiplier: Multiplier to number of repeats per stage. Defaults to 1.0. drop_rate: Dropout rate. Defaults to 0.2. Returns: dict[str, Any]: Dictionary containing EfficientNet configuration parameters Example: >>> kwargs = get_efficientnet_kwargs( ... channel_multiplier=1.2, ... depth_multiplier=1.4, ... ) """ arch_def = [ ["ds_r1_k3_s1_e1_c16_se0.25"], ["ir_r2_k3_s2_e6_c24_se0.25"], ["ir_r2_k5_s2_e6_c40_se0.25"], ["ir_r3_k3_s2_e6_c80_se0.25"], ["ir_r3_k5_s1_e6_c112_se0.25"], ["ir_r4_k5_s2_e6_c192_se0.25"], ["ir_r1_k3_s1_e6_c320_se0.25"], ] return { "block_args": decode_arch_def(arch_def, depth_multiplier), "num_features": round_channels(1280, channel_multiplier, 8, None), "stem_size": 32, "round_chs_fn": partial(round_channels, multiplier=channel_multiplier), "act_layer": Swish, "drop_rate": drop_rate, "drop_path_rate": 0.2, }
[docs] class EfficientNetBaseEncoder(EfficientNet, EncoderMixin): """Base class for EfficientNet encoder. Combines EfficientNet backbone from `timm` with encoder-specific functionality for feature extraction in segmentation and classification tasks. Features: - Supports configurable depth and output stride. - Provides intermediate feature maps for multi-scale processing. - Removes classifier for encoder-only usage. Raises: ValueError: If `depth` is not in range [1, 5]. Example: >>> encoder = EfficientNetBaseEncoder( ... stage_idxs=[2, 3, 5], ... out_channels=[3, 32, 24, 40, 112, 320], ... depth=5, ... output_stride=32 ... ) >>> x = torch.randn(1, 3, 224, 224) >>> features = encoder(x) >>> [f.shape for f in features] ... [torch.Size([1, 3, 224, 224]), torch.Size([1, 32, 112, 112]), ...] """ def __init__( self, stage_idxs: list[int], out_channels: list[int], depth: int = 5, output_stride: int = 32, **kwargs: dict[str, Any], ) -> None: """Initialize EfficientNetBaseEncoder. Args: stage_idxs (list[int]): Indices of stages for feature extraction. out_channels (list[int]): Output channels for each depth level. depth (int): Encoder depth (1-5). Defaults to 5. output_stride (int): Output stride of encoder. Defaults to 32. **kwargs (dict[str, Any]): Additional keyword arguments for EfficientNet initialization. Raises: ValueError: If `depth` is not in range [1, 5]. """ if depth > MAX_DEPTH or depth < MIN_DEPTH: msg = f"{self.__class__.__name__} depth should be in range \ [1, 5], got {depth}" raise ValueError(msg) super().__init__(**kwargs) self._stage_idxs = stage_idxs self._depth = depth self._in_channels = 3 self._out_channels = out_channels self._output_stride = output_stride del self.classifier
[docs] def get_stages(self) -> dict[int, Sequence[torch.nn.Module]]: """Return encoder stages for dilation modification. Provides mapping of output strides to corresponding module sequences, enabling conversion to dilated versions for segmentation tasks. Returns: dict[int, Sequence[torch.nn.Module]]: Dictionary mapping output stride to module sequences. Example: >>> stages = encoder.get_stages() >>> print(stages.keys()) ... dict_keys([16, 32]) """ return { 16: [self.blocks[self._stage_idxs[1] : self._stage_idxs[2]]], # type: ignore[attr-defined] 32: [self.blocks[self._stage_idxs[2] :]], # type: ignore[attr-defined] }
[docs] def forward(self, x: torch.Tensor) -> list[torch.Tensor]: """Forward pass through EfficientNet encoder. Extracts feature maps from multiple stages of the encoder for use in decoder networks or multi-scale processing. Args: x (torch.Tensor): Input tensor of shape (N, C, H, W). Returns: list[torch.Tensor]: List of feature maps from different encoder depths. Example: >>> x = torch.randn(1, 3, 224, 224) >>> features = encoder(x) >>> len(features) ... 6 """ features = [x] if self._depth >= 1: x = self.conv_stem(x) # type: ignore[attr-defined] x = self.bn1(x) # type: ignore[attr-defined] features.append(x) if self._depth >= 2: # noqa: PLR2004 x = self.blocks[0](x) # type: ignore[attr-defined] x = self.blocks[1](x) # type: ignore[attr-defined] features.append(x) if self._depth >= 3: # noqa: PLR2004 x = self.blocks[2](x) # type: ignore[attr-defined] features.append(x) if self._depth >= 4: # noqa: PLR2004 x = self.blocks[3](x) # type: ignore[attr-defined] x = self.blocks[4](x) # type: ignore[attr-defined] features.append(x) if self._depth >= 5: # noqa: PLR2004 x = self.blocks[5](x) # type: ignore[attr-defined] x = self.blocks[6](x) # type: ignore[attr-defined] features.append(x) return features
[docs] def load_state_dict( self, state_dict: Mapping[str, Any], **kwargs: bool, ) -> torch.nn.modules.module._IncompatibleKeys: """Load state dictionary, excluding classifier weights. Removes classifier weights from the state dictionary before loading, as the encoder does not include a classification head. Args: state_dict (Mapping[str, Any]): State dictionary to load. **kwargs (bool): Additional keyword arguments for `load_state_dict`. Returns: torch.nn.modules.module._IncompatibleKeys: Result of parent class `load_state_dict` method. Example: >>> encoder.load_state_dict(torch.load("efficientnet_weights.pth")) """ # Create a mutable copy of the state dict to modify state_dict_copy = dict(state_dict) state_dict_copy.pop("classifier.bias", None) state_dict_copy.pop("classifier.weight", None) return super().load_state_dict(state_dict_copy, **kwargs)
[docs] class EfficientNetEncoder(EfficientNetBaseEncoder): """EfficientNet encoder with configurable scaling parameters. This class extends `EfficientNetBaseEncoder` to provide scaling options for depth and channel multipliers, enabling flexible encoder configurations for segmentation and classification tasks. Features: - Supports depth and channel scaling. - Provides pretrained weights for multiple variants. - Outputs multi-scale feature maps for downstream tasks. Example: >>> encoder = EfficientNetEncoder( ... stage_idxs=[2, 3, 5], ... out_channels=[3, 32, 24, 40, 112, 320], ... channel_multiplier=1.0, ... depth_multiplier=1.0, ... drop_rate=0.2 ... ) >>> x = torch.randn(1, 3, 224, 224) >>> features = encoder(x) >>> [f.shape for f in features] ... [torch.Size([1, 3, 224, 224]), torch.Size([1, 32, 112, 112]), ...] """ def __init__( self, stage_idxs: list[int], out_channels: list[int], depth: int = 5, channel_multiplier: float = 1.0, depth_multiplier: float = 1.0, drop_rate: float = 0.2, output_stride: int = 32, ) -> None: """Initialize EfficientNetEncoder. Creates an EfficientNet encoder with configurable scaling parameters for depth and channel multipliers. Args: stage_idxs (list[int]): Indices of stages for feature extraction. out_channels (list[int]): Output channels for each depth level. depth (int): Encoder depth (1-5). Defaults to 5. channel_multiplier (float): Channel scaling factor. Defaults to 1.0. depth_multiplier (float): Depth scaling factor. Defaults to 1.0. drop_rate (float): Dropout rate. Defaults to 0.2. output_stride (int): Output stride of encoder. Defaults to 32. """ kwargs = get_efficientnet_kwargs( channel_multiplier, depth_multiplier, drop_rate ) super().__init__( stage_idxs=stage_idxs, depth=depth, out_channels=out_channels, output_stride=output_stride, **kwargs, )
timm_efficientnet_encoders = { "timm-efficientnet-b0": { "encoder": EfficientNetEncoder, "pretrained_settings": { "imagenet": { "repo_id": "smp-hub/timm-efficientnet-b0.imagenet", "revision": "8419e9cc19da0b68dcd7bb12f19b7c92407ad7c4", }, "advprop": { "repo_id": "smp-hub/timm-efficientnet-b0.advprop", "revision": "a5870af2d24ce79e0cc7fae2bbd8e0a21fcfa6d8", }, "noisy-student": { "repo_id": "smp-hub/timm-efficientnet-b0.noisy-student", "revision": "bea8b0ff726a50e48774d2d360c5fb1ac4815836", }, }, "params": { "out_channels": [3, 32, 24, 40, 112, 320], "stage_idxs": [2, 3, 5], "channel_multiplier": 1.0, "depth_multiplier": 1.0, "drop_rate": 0.2, }, }, }