UnetPlusPlusDecoder¶

class UnetPlusPlusDecoder(encoder_channels, decoder_channels, n_blocks=5)[source]¶

UNet++ decoder with dense skip connections.

This class implements the decoder portion of the UNet++ architecture. It reconstructs high-resolution feature maps from encoder outputs using multiple decoder blocks and dense connections between intermediate layers.

Raises:

ValueError – If the number of decoder blocks does not match the length of decoder_channels.

Parameters:
  • encoder_channels (Sequence[int])

  • decoder_channels (Sequence[int])

  • n_blocks (int)

blocks¶

Dictionary of decoder blocks organized by depth and layer index.

Type:

nn.ModuleDict

center¶

Center block (currently Identity).

Type:

nn.Module

depth¶

Depth of the decoder network.

Type:

int

Example

>>> decoder = UnetPlusPlusDecoder(
...     encoder_channels=[3, 32, 64, 128, 256, 512],
...     decoder_channels=[256, 128, 64, 32, 16],
...     n_blocks=5
... )
>>> # Generate dummy feature maps for testing
>>> features = [
...     torch.randn(1, c, 64 // (2**i), 64 // (2**i))
...     for i, c in enumerate([3, 32, 64, 128, 256, 512])
... ]
>>> output = decoder(features)
>>> output.shape
... torch.Size([1, 16, 64, 64])

Initialize UnetPlusPlusDecoder.

Sets up the decoder blocks and dense connections for UNet++ architecture.

Parameters:
  • encoder_channels (Sequence[int]) – List of channel sizes from the encoder stages.

  • decoder_channels (Sequence[int]) – List of channel sizes for each decoder block.

  • n_blocks (int) – Number of decoder blocks. Defaults to 5.

Raises:

ValueError – If n_blocks does not match the length of decoder_channels.

Methods

forward

Forward pass through UNet++ decoder.

Attributes

training

forward(features)[source]¶

Forward pass through UNet++ decoder.

Reconstructs high-resolution feature maps from encoder outputs using dense skip connections and multiple decoder blocks.

Parameters:

features (list[torch.Tensor]) – List of feature maps from the encoder, ordered from shallow to deep.

Returns:

Decoded output tensor with spatial resolution restored.

Return type:

torch.Tensor