EfficientNetBaseEncoder¶
- class EfficientNetBaseEncoder(stage_idxs, out_channels, depth=5, output_stride=32, **kwargs)[source]¶
Base class for EfficientNet encoder.
Combines EfficientNet backbone from timm with encoder-specific functionality for feature extraction in segmentation and classification tasks.
- Features:
Supports configurable depth and output stride.
Provides intermediate feature maps for multi-scale processing.
Removes classifier for encoder-only usage.
- Raises:
ValueError – If depth is not in range [1, 5].
- Parameters:
Example
>>> encoder = EfficientNetBaseEncoder( ... stage_idxs=[2, 3, 5], ... out_channels=[3, 32, 24, 40, 112, 320], ... depth=5, ... output_stride=32 ... ) >>> x = torch.randn(1, 3, 224, 224) >>> features = encoder(x) >>> [f.shape for f in features] ... [torch.Size([1, 3, 224, 224]), torch.Size([1, 32, 112, 112]), ...]
Initialize EfficientNetBaseEncoder.
- Parameters:
stage_idxs (list[int]) – Indices of stages for feature extraction.
out_channels (list[int]) – Output channels for each depth level.
depth (int) – Encoder depth (1-5). Defaults to 5.
output_stride (int) – Output stride of encoder. Defaults to 32.
**kwargs (dict[str, Any]) – Additional keyword arguments for EfficientNet initialization.
- Raises:
ValueError – If depth is not in range [1, 5].
Methods
Forward pass through EfficientNet encoder.
Return encoder stages for dilation modification.
Load state dictionary, excluding classifier weights.
Attributes
training- forward(x)[source]¶
Forward pass through EfficientNet encoder.
Extracts feature maps from multiple stages of the encoder for use in decoder networks or multi-scale processing.
- Parameters:
x (torch.Tensor) – Input tensor of shape (N, C, H, W).
- Returns:
List of feature maps from different encoder depths.
- Return type:
Example
>>> x = torch.randn(1, 3, 224, 224) >>> features = encoder(x) >>> len(features) ... 6
- get_stages()[source]¶
Return encoder stages for dilation modification.
Provides mapping of output strides to corresponding module sequences, enabling conversion to dilated versions for segmentation tasks.
- Returns:
Dictionary mapping output stride to module sequences.
- Return type:
dict[int, Sequence[torch.nn.Module]]
Example
>>> stages = encoder.get_stages() >>> print(stages.keys()) ... dict_keys([16, 32])
- load_state_dict(state_dict, **kwargs)[source]¶
Load state dictionary, excluding classifier weights.
Removes classifier weights from the state dictionary before loading, as the encoder does not include a classification head.
- Parameters:
- Returns:
Result of parent class load_state_dict method.
- Return type:
torch.nn.modules.module._IncompatibleKeys
Example
>>> encoder.load_state_dict(torch.load("efficientnet_weights.pth"))