UnetPlusPlusDecoder¶
- class UnetPlusPlusDecoder(encoder_channels, decoder_channels, n_blocks=5)[source]¶
UNet++ decoder with dense skip connections.
This class implements the decoder portion of the UNet++ architecture. It reconstructs high-resolution feature maps from encoder outputs using multiple decoder blocks and dense connections between intermediate layers.
- Raises:
ValueError – If the number of decoder blocks does not match the length of decoder_channels.
- Parameters:
- blocks¶
Dictionary of decoder blocks organized by depth and layer index.
- Type:
nn.ModuleDict
- center¶
Center block (currently Identity).
- Type:
nn.Module
Example
>>> decoder = UnetPlusPlusDecoder( ... encoder_channels=[3, 32, 64, 128, 256, 512], ... decoder_channels=[256, 128, 64, 32, 16], ... n_blocks=5 ... ) >>> # Generate dummy feature maps for testing >>> features = [ ... torch.randn(1, c, 64 // (2**i), 64 // (2**i)) ... for i, c in enumerate([3, 32, 64, 128, 256, 512]) ... ] >>> output = decoder(features) >>> output.shape ... torch.Size([1, 16, 64, 64])
Initialize UnetPlusPlusDecoder.
Sets up the decoder blocks and dense connections for UNet++ architecture.
- Parameters:
- Raises:
ValueError – If n_blocks does not match the length of decoder_channels.
Methods
Forward pass through UNet++ decoder.
Attributes
training- forward(features)[source]¶
Forward pass through UNet++ decoder.
Reconstructs high-resolution feature maps from encoder outputs using dense skip connections and multiple decoder blocks.
- Parameters:
features (list[torch.Tensor]) – List of feature maps from the encoder, ordered from shallow to deep.
- Returns:
Decoded output tensor with spatial resolution restored.
- Return type: