patch_first_conv¶
- patch_first_conv(model, new_in_channels, default_in_channels=3, *, pretrained=True)[source]¶
Update the first convolution layer for a new input channel size.
This function updates the first convolutional layer of a model to handle arbitrary input channels. It optionally reuses pretrained weights or initializes weights randomly.
- Parameters:
model (nn.Module) – The neural network model whose first convolution layer will be patched.
new_in_channels (int) – Number of input channels for the new first layer.
default_in_channels (int) – Original number of input channels. Defaults to 3.
pretrained (bool) – Whether to reuse pretrained weights. Defaults to True.
- Return type:
None
Notes
If new_in_channels == 1 or 2 → reuse original weights.
If new_in_channels > 3 → initialize weights using Kaiming normal.
Example
>>> patch_first_conv(model, new_in_channels=1, pretrained=True)