patch_first_conv¶

patch_first_conv(model, new_in_channels, default_in_channels=3, *, pretrained=True)[source]¶

Update the first convolution layer for a new input channel size.

This function updates the first convolutional layer of a model to handle arbitrary input channels. It optionally reuses pretrained weights or initializes weights randomly.

Parameters:
  • model (nn.Module) – The neural network model whose first convolution layer will be patched.

  • new_in_channels (int) – Number of input channels for the new first layer.

  • default_in_channels (int) – Original number of input channels. Defaults to 3.

  • pretrained (bool) – Whether to reuse pretrained weights. Defaults to True.

Return type:

None

Notes

  • If new_in_channels == 1 or 2 → reuse original weights.

  • If new_in_channels > 3 → initialize weights using Kaiming normal.

Example

>>> patch_first_conv(model, new_in_channels=1, pretrained=True)