EfficientNetBaseEncoder¶

class EfficientNetBaseEncoder(stage_idxs, out_channels, depth=5, output_stride=32, **kwargs)[source]¶

Base class for EfficientNet encoder.

Combines EfficientNet backbone from timm with encoder-specific functionality for feature extraction in segmentation and classification tasks.

Features:
  • Supports configurable depth and output stride.

  • Provides intermediate feature maps for multi-scale processing.

  • Removes classifier for encoder-only usage.

Raises:

ValueError – If depth is not in range [1, 5].

Parameters:

Example

>>> encoder = EfficientNetBaseEncoder(
...     stage_idxs=[2, 3, 5],
...     out_channels=[3, 32, 24, 40, 112, 320],
...     depth=5,
...     output_stride=32
... )
>>> x = torch.randn(1, 3, 224, 224)
>>> features = encoder(x)
>>> [f.shape for f in features]
... [torch.Size([1, 3, 224, 224]), torch.Size([1, 32, 112, 112]), ...]

Initialize EfficientNetBaseEncoder.

Parameters:
  • stage_idxs (list[int]) – Indices of stages for feature extraction.

  • out_channels (list[int]) – Output channels for each depth level.

  • depth (int) – Encoder depth (1-5). Defaults to 5.

  • output_stride (int) – Output stride of encoder. Defaults to 32.

  • **kwargs (dict[str, Any]) – Additional keyword arguments for EfficientNet initialization.

Raises:

ValueError – If depth is not in range [1, 5].

Methods

forward

Forward pass through EfficientNet encoder.

get_stages

Return encoder stages for dilation modification.

load_state_dict

Load state dictionary, excluding classifier weights.

Attributes

training

forward(x)[source]¶

Forward pass through EfficientNet encoder.

Extracts feature maps from multiple stages of the encoder for use in decoder networks or multi-scale processing.

Parameters:

x (torch.Tensor) – Input tensor of shape (N, C, H, W).

Returns:

List of feature maps from different encoder depths.

Return type:

list[torch.Tensor]

Example

>>> x = torch.randn(1, 3, 224, 224)
>>> features = encoder(x)
>>> len(features)
... 6
get_stages()[source]¶

Return encoder stages for dilation modification.

Provides mapping of output strides to corresponding module sequences, enabling conversion to dilated versions for segmentation tasks.

Returns:

Dictionary mapping output stride to module sequences.

Return type:

dict[int, Sequence[torch.nn.Module]]

Example

>>> stages = encoder.get_stages()
>>> print(stages.keys())
... dict_keys([16, 32])
load_state_dict(state_dict, **kwargs)[source]¶

Load state dictionary, excluding classifier weights.

Removes classifier weights from the state dictionary before loading, as the encoder does not include a classification head.

Parameters:
  • state_dict (Mapping[str, Any]) – State dictionary to load.

  • **kwargs (bool) – Additional keyword arguments for load_state_dict.

Returns:

Result of parent class load_state_dict method.

Return type:

torch.nn.modules.module._IncompatibleKeys

Example

>>> encoder.load_state_dict(torch.load("efficientnet_weights.pth"))