wavedl 1.5.6__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,504 @@
1
+ """
2
+ ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders
3
+ ========================================================================
4
+
5
+ ConvNeXt V2 improves upon V1 by replacing LayerScale with Global Response
6
+ Normalization (GRN), which enhances inter-channel feature competition.
7
+
8
+ **Key Changes from V1**:
9
+ - GRN layer replaces LayerScale
10
+ - Better compatibility with masked autoencoder pretraining
11
+ - Prevents feature collapse in deep networks
12
+
13
+ **Variants**:
14
+ - convnext_v2_tiny: 28M params, depths [3,3,9,3], dims [96,192,384,768]
15
+ - convnext_v2_small: 50M params, depths [3,3,27,3], dims [96,192,384,768]
16
+ - convnext_v2_base: 89M params, depths [3,3,27,3], dims [128,256,512,1024]
17
+ - convnext_v2_tiny_pretrained: 2D only, ImageNet weights
18
+
19
+ **Supports**: 1D, 2D, 3D inputs
20
+
21
+ Reference:
22
+ Woo, S., et al. (2023). ConvNeXt V2: Co-designing and Scaling ConvNets
23
+ with Masked Autoencoders. CVPR 2023.
24
+ https://arxiv.org/abs/2301.00808
25
+
26
+ Author: Ductho Le (ductho.le@outlook.com)
27
+ """
28
+
29
+ from typing import Any
30
+
31
+ import torch
32
+ import torch.nn as nn
33
+
34
+ from wavedl.models._timm_utils import (
35
+ LayerNormNd,
36
+ build_regression_head,
37
+ get_conv_layer,
38
+ get_grn_layer,
39
+ get_pool_layer,
40
+ )
41
+ from wavedl.models.base import BaseModel
42
+ from wavedl.models.registry import register_model
43
+
44
+
45
+ # Type alias for spatial shapes
46
+ SpatialShape = tuple[int] | tuple[int, int] | tuple[int, int, int]
47
+
48
+ __all__ = [
49
+ "ConvNeXtV2Base",
50
+ "ConvNeXtV2BaseLarge",
51
+ "ConvNeXtV2Small",
52
+ "ConvNeXtV2Tiny",
53
+ "ConvNeXtV2TinyPretrained",
54
+ ]
55
+
56
+
57
+ # =============================================================================
58
+ # CONVNEXT V2 BLOCK
59
+ # =============================================================================
60
+
61
+
62
+ class ConvNeXtV2Block(nn.Module):
63
+ """
64
+ ConvNeXt V2 Block with GRN instead of LayerScale.
65
+
66
+ Architecture:
67
+ Input → DwConv → LayerNorm → Linear → GELU → GRN → Linear → Residual
68
+
69
+ The GRN layer is the key difference from V1, placed after the
70
+ dimension-expansion in the MLP, replacing LayerScale.
71
+ """
72
+
73
+ def __init__(
74
+ self,
75
+ dim: int,
76
+ spatial_dim: int,
77
+ drop_path: float = 0.0,
78
+ mlp_ratio: float = 4.0,
79
+ ):
80
+ super().__init__()
81
+ self.spatial_dim = spatial_dim
82
+
83
+ Conv = get_conv_layer(spatial_dim)
84
+ GRN = get_grn_layer(spatial_dim)
85
+
86
+ # Depthwise convolution
87
+ kernel_size = 7
88
+ padding = 3
89
+ self.dwconv = Conv(
90
+ dim, dim, kernel_size=kernel_size, padding=padding, groups=dim
91
+ )
92
+
93
+ # LayerNorm (applied in forward with permutation)
94
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
95
+
96
+ # MLP with expansion
97
+ hidden_dim = int(dim * mlp_ratio)
98
+ self.pwconv1 = nn.Linear(dim, hidden_dim) # Expansion
99
+ self.act = nn.GELU()
100
+ self.grn = GRN(hidden_dim) # GRN after expansion (key V2 change)
101
+ self.pwconv2 = nn.Linear(hidden_dim, dim) # Projection
102
+
103
+ # Stochastic depth
104
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
105
+
106
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
107
+ residual = x
108
+
109
+ # Depthwise conv
110
+ x = self.dwconv(x)
111
+
112
+ # Move channels to last for LayerNorm and Linear layers
113
+ if self.spatial_dim == 1:
114
+ x = x.permute(0, 2, 1) # (B, C, L) -> (B, L, C)
115
+ elif self.spatial_dim == 2:
116
+ x = x.permute(0, 2, 3, 1) # (B, C, H, W) -> (B, H, W, C)
117
+ else: # 3D
118
+ x = x.permute(0, 2, 3, 4, 1) # (B, C, D, H, W) -> (B, D, H, W, C)
119
+
120
+ x = self.norm(x)
121
+ x = self.pwconv1(x)
122
+ x = self.act(x)
123
+
124
+ # Move back to channels-first for GRN
125
+ if self.spatial_dim == 1:
126
+ x = x.permute(0, 2, 1) # (B, L, C) -> (B, C, L)
127
+ elif self.spatial_dim == 2:
128
+ x = x.permute(0, 3, 1, 2) # (B, H, W, C) -> (B, C, H, W)
129
+ else: # 3D
130
+ x = x.permute(0, 4, 1, 2, 3) # (B, D, H, W, C) -> (B, C, D, H, W)
131
+
132
+ # Apply GRN (the key V2 innovation)
133
+ x = self.grn(x)
134
+
135
+ # Move to channels-last for final projection
136
+ if self.spatial_dim == 1:
137
+ x = x.permute(0, 2, 1)
138
+ elif self.spatial_dim == 2:
139
+ x = x.permute(0, 2, 3, 1)
140
+ else:
141
+ x = x.permute(0, 2, 3, 4, 1)
142
+
143
+ x = self.pwconv2(x)
144
+
145
+ # Move back to channels-first
146
+ if self.spatial_dim == 1:
147
+ x = x.permute(0, 2, 1)
148
+ elif self.spatial_dim == 2:
149
+ x = x.permute(0, 3, 1, 2)
150
+ else:
151
+ x = x.permute(0, 4, 1, 2, 3)
152
+
153
+ x = residual + self.drop_path(x)
154
+ return x
155
+
156
+
157
+ class DropPath(nn.Module):
158
+ """Stochastic Depth (drop path) regularization."""
159
+
160
+ def __init__(self, drop_prob: float = 0.0):
161
+ super().__init__()
162
+ self.drop_prob = drop_prob
163
+
164
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
165
+ if self.drop_prob == 0.0 or not self.training:
166
+ return x
167
+
168
+ keep_prob = 1 - self.drop_prob
169
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
170
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
171
+ random_tensor.floor_()
172
+ return x.div(keep_prob) * random_tensor
173
+
174
+
175
+ # =============================================================================
176
+ # CONVNEXT V2 BASE CLASS
177
+ # =============================================================================
178
+
179
+
180
+ class ConvNeXtV2Base(BaseModel):
181
+ """
182
+ ConvNeXt V2 base class for regression.
183
+
184
+ Dimension-agnostic implementation supporting 1D, 2D, and 3D inputs.
185
+ Uses GRN (Global Response Normalization) instead of LayerScale.
186
+ """
187
+
188
+ def __init__(
189
+ self,
190
+ in_shape: SpatialShape,
191
+ out_size: int,
192
+ depths: list[int],
193
+ dims: list[int],
194
+ drop_path_rate: float = 0.0,
195
+ dropout_rate: float = 0.3,
196
+ **kwargs,
197
+ ):
198
+ super().__init__(in_shape, out_size)
199
+
200
+ self.dim = len(in_shape)
201
+ self.depths = depths
202
+ self.dims = dims
203
+
204
+ Conv = get_conv_layer(self.dim)
205
+ Pool = get_pool_layer(self.dim)
206
+
207
+ # Stem: aggressive downsampling (4x stride like ConvNeXt)
208
+ self.stem = nn.Sequential(
209
+ Conv(1, dims[0], kernel_size=4, stride=4),
210
+ LayerNormNd(dims[0], self.dim),
211
+ )
212
+
213
+ # Stochastic depth decay rule
214
+ dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
215
+
216
+ # Build stages
217
+ self.stages = nn.ModuleList()
218
+ self.downsamples = nn.ModuleList()
219
+ cur = 0
220
+
221
+ for i in range(len(depths)):
222
+ # Stage: sequence of ConvNeXt V2 blocks
223
+ stage = nn.Sequential(
224
+ *[
225
+ ConvNeXtV2Block(
226
+ dim=dims[i],
227
+ spatial_dim=self.dim,
228
+ drop_path=dp_rates[cur + j],
229
+ )
230
+ for j in range(depths[i])
231
+ ]
232
+ )
233
+ self.stages.append(stage)
234
+ cur += depths[i]
235
+
236
+ # Downsample between stages (except after last)
237
+ if i < len(depths) - 1:
238
+ downsample = nn.Sequential(
239
+ LayerNormNd(dims[i], self.dim),
240
+ Conv(dims[i], dims[i + 1], kernel_size=2, stride=2),
241
+ )
242
+ self.downsamples.append(downsample)
243
+
244
+ # Global pooling and head
245
+ self.norm = nn.LayerNorm(dims[-1], eps=1e-6)
246
+ self.global_pool = Pool(1)
247
+ self.head = nn.Sequential(
248
+ nn.Dropout(dropout_rate),
249
+ nn.Linear(dims[-1], dims[-1] // 2),
250
+ nn.GELU(),
251
+ nn.Dropout(dropout_rate * 0.5),
252
+ nn.Linear(dims[-1] // 2, out_size),
253
+ )
254
+
255
+ # Initialize weights
256
+ self._init_weights()
257
+
258
+ def _init_weights(self):
259
+ """Initialize weights with truncated normal."""
260
+ for m in self.modules():
261
+ if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)):
262
+ nn.init.trunc_normal_(m.weight, std=0.02)
263
+ if m.bias is not None:
264
+ nn.init.zeros_(m.bias)
265
+ elif isinstance(m, nn.LayerNorm):
266
+ nn.init.ones_(m.weight)
267
+ nn.init.zeros_(m.bias)
268
+
269
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
270
+ """
271
+ Forward pass.
272
+
273
+ Args:
274
+ x: Input tensor (B, 1, *in_shape)
275
+
276
+ Returns:
277
+ Output tensor (B, out_size)
278
+ """
279
+ x = self.stem(x)
280
+
281
+ for i, stage in enumerate(self.stages):
282
+ x = stage(x)
283
+ if i < len(self.downsamples):
284
+ x = self.downsamples[i](x)
285
+
286
+ # Global pooling
287
+ x = self.global_pool(x)
288
+ x = x.flatten(1)
289
+
290
+ # Final norm and head
291
+ x = self.norm(x)
292
+ x = self.head(x)
293
+
294
+ return x
295
+
296
+ @classmethod
297
+ def get_default_config(cls) -> dict[str, Any]:
298
+ return {
299
+ "depths": [3, 3, 9, 3],
300
+ "dims": [96, 192, 384, 768],
301
+ "drop_path_rate": 0.1,
302
+ "dropout_rate": 0.3,
303
+ }
304
+
305
+
306
+ # =============================================================================
307
+ # REGISTERED VARIANTS
308
+ # =============================================================================
309
+
310
+
311
+ @register_model("convnext_v2_tiny")
312
+ class ConvNeXtV2Tiny(ConvNeXtV2Base):
313
+ """
314
+ ConvNeXt V2 Tiny: ~27.9M backbone parameters.
315
+
316
+ Depths [3,3,9,3], Dims [96,192,384,768].
317
+ Supports 1D, 2D, 3D inputs.
318
+
319
+ Example:
320
+ >>> model = ConvNeXtV2Tiny(in_shape=(64, 64), out_size=3)
321
+ >>> x = torch.randn(4, 1, 64, 64)
322
+ >>> out = model(x) # (4, 3)
323
+ """
324
+
325
+ def __init__(self, in_shape: SpatialShape, out_size: int, **kwargs):
326
+ super().__init__(
327
+ in_shape=in_shape,
328
+ out_size=out_size,
329
+ depths=[3, 3, 9, 3],
330
+ dims=[96, 192, 384, 768],
331
+ **kwargs,
332
+ )
333
+
334
+ def __repr__(self) -> str:
335
+ return (
336
+ f"ConvNeXtV2_Tiny({self.dim}D, in_shape={self.in_shape}, "
337
+ f"out_size={self.out_size})"
338
+ )
339
+
340
+
341
+ @register_model("convnext_v2_small")
342
+ class ConvNeXtV2Small(ConvNeXtV2Base):
343
+ """
344
+ ConvNeXt V2 Small: ~49.6M backbone parameters.
345
+
346
+ Depths [3,3,27,3], Dims [96,192,384,768].
347
+ Supports 1D, 2D, 3D inputs.
348
+ """
349
+
350
+ def __init__(self, in_shape: SpatialShape, out_size: int, **kwargs):
351
+ super().__init__(
352
+ in_shape=in_shape,
353
+ out_size=out_size,
354
+ depths=[3, 3, 27, 3],
355
+ dims=[96, 192, 384, 768],
356
+ **kwargs,
357
+ )
358
+
359
+ def __repr__(self) -> str:
360
+ return (
361
+ f"ConvNeXtV2_Small({self.dim}D, in_shape={self.in_shape}, "
362
+ f"out_size={self.out_size})"
363
+ )
364
+
365
+
366
+ @register_model("convnext_v2_base")
367
+ class ConvNeXtV2BaseLarge(ConvNeXtV2Base):
368
+ """
369
+ ConvNeXt V2 Base: ~87.7M backbone parameters.
370
+
371
+ Depths [3,3,27,3], Dims [128,256,512,1024].
372
+ Supports 1D, 2D, 3D inputs.
373
+ """
374
+
375
+ def __init__(self, in_shape: SpatialShape, out_size: int, **kwargs):
376
+ super().__init__(
377
+ in_shape=in_shape,
378
+ out_size=out_size,
379
+ depths=[3, 3, 27, 3],
380
+ dims=[128, 256, 512, 1024],
381
+ **kwargs,
382
+ )
383
+
384
+ def __repr__(self) -> str:
385
+ return (
386
+ f"ConvNeXtV2_Base({self.dim}D, in_shape={self.in_shape}, "
387
+ f"out_size={self.out_size})"
388
+ )
389
+
390
+
391
+ # =============================================================================
392
+ # PRETRAINED VARIANT (2D ONLY)
393
+ # =============================================================================
394
+
395
+
396
+ @register_model("convnext_v2_tiny_pretrained")
397
+ class ConvNeXtV2TinyPretrained(BaseModel):
398
+ """
399
+ ConvNeXt V2 Tiny with ImageNet pretrained weights (2D only).
400
+
401
+ Uses torchvision's ConvNeXt V2 implementation with:
402
+ - Adapted input layer for single-channel input
403
+ - Replaced classifier for regression
404
+
405
+ Args:
406
+ in_shape: (H, W) input shape (2D only)
407
+ out_size: Number of regression targets
408
+ pretrained: Whether to load pretrained weights
409
+ freeze_backbone: Whether to freeze backbone for fine-tuning
410
+ """
411
+
412
+ def __init__(
413
+ self,
414
+ in_shape: tuple[int, int],
415
+ out_size: int,
416
+ pretrained: bool = True,
417
+ freeze_backbone: bool = False,
418
+ dropout_rate: float = 0.3,
419
+ **kwargs,
420
+ ):
421
+ super().__init__(in_shape, out_size)
422
+
423
+ if len(in_shape) != 2:
424
+ raise ValueError(
425
+ f"ConvNeXtV2TinyPretrained requires 2D input (H, W), "
426
+ f"got {len(in_shape)}D"
427
+ )
428
+
429
+ self.pretrained = pretrained
430
+ self.freeze_backbone = freeze_backbone
431
+
432
+ # Try to load from torchvision (if available)
433
+ try:
434
+ from torchvision.models import (
435
+ ConvNeXt_Tiny_Weights,
436
+ convnext_tiny,
437
+ )
438
+
439
+ weights = ConvNeXt_Tiny_Weights.IMAGENET1K_V1 if pretrained else None
440
+ self.backbone = convnext_tiny(weights=weights)
441
+
442
+ # Note: torchvision's ConvNeXt is V1, not V2
443
+ # For true V2, we'd need custom implementation or timm
444
+ # This is a fallback using V1 architecture
445
+
446
+ except ImportError:
447
+ raise ImportError(
448
+ "torchvision is required for pretrained ConvNeXt. "
449
+ "Install with: pip install torchvision"
450
+ )
451
+
452
+ # Adapt input layer (3 channels -> 1 channel)
453
+ self._adapt_input_channels()
454
+
455
+ # Replace classifier with regression head
456
+ # Keep the LayerNorm2d (idx 0) and Flatten (idx 1), only replace Linear (idx 2)
457
+ in_features = self.backbone.classifier[2].in_features
458
+ new_head = build_regression_head(in_features, out_size, dropout_rate)
459
+
460
+ # Build new classifier keeping LayerNorm2d and Flatten
461
+ self.backbone.classifier = nn.Sequential(
462
+ self.backbone.classifier[0], # LayerNorm2d
463
+ self.backbone.classifier[1], # Flatten
464
+ new_head, # Our regression head
465
+ )
466
+
467
+ if freeze_backbone:
468
+ self._freeze_backbone()
469
+
470
+ def _adapt_input_channels(self):
471
+ """Adapt first conv layer for single-channel input."""
472
+ old_conv = self.backbone.features[0][0]
473
+ new_conv = nn.Conv2d(
474
+ 1,
475
+ old_conv.out_channels,
476
+ kernel_size=old_conv.kernel_size,
477
+ stride=old_conv.stride,
478
+ padding=old_conv.padding,
479
+ bias=old_conv.bias is not None,
480
+ )
481
+
482
+ if self.pretrained:
483
+ with torch.no_grad():
484
+ # Average RGB weights for grayscale
485
+ new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
486
+ if old_conv.bias is not None:
487
+ new_conv.bias.copy_(old_conv.bias)
488
+
489
+ self.backbone.features[0][0] = new_conv
490
+
491
+ def _freeze_backbone(self):
492
+ """Freeze all backbone parameters except classifier."""
493
+ for name, param in self.backbone.named_parameters():
494
+ if "classifier" not in name:
495
+ param.requires_grad = False
496
+
497
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
498
+ return self.backbone(x)
499
+
500
+ def __repr__(self) -> str:
501
+ return (
502
+ f"ConvNeXtV2_Tiny_Pretrained(in_shape={self.in_shape}, "
503
+ f"out_size={self.out_size}, pretrained={self.pretrained})"
504
+ )
wavedl/models/densenet.py CHANGED
@@ -11,8 +11,8 @@ Features: feature reuse, efficient gradient flow, compact model.
11
11
  - 3D: Volumetric data, CT/MRI (N, 1, D, H, W) → Conv3d
12
12
 
13
13
  **Variants**:
14
- - densenet121: Standard (121 layers, ~8M params for 2D)
15
- - densenet169: Deeper (169 layers, ~14M params for 2D)
14
+ - densenet121: Standard (121 layers, ~7.0M backbone params for 2D)
15
+ - densenet169: Deeper (169 layers, ~12.5M backbone params for 2D)
16
16
 
17
17
  References:
18
18
  Huang, G., et al. (2017). Densely Connected Convolutional Networks.
@@ -262,7 +262,7 @@ class DenseNet121(DenseNetBase):
262
262
  """
263
263
  DenseNet-121: Standard variant with 121 layers.
264
264
 
265
- ~8M parameters (2D). Good for: Balanced accuracy, efficient training.
265
+ ~7.0M backbone parameters (2D). Good for: Balanced accuracy, efficient training.
266
266
 
267
267
  Args:
268
268
  in_shape: (L,), (H, W), or (D, H, W)
@@ -285,7 +285,7 @@ class DenseNet169(DenseNetBase):
285
285
  """
286
286
  DenseNet-169: Deeper variant with 169 layers.
287
287
 
288
- ~14M parameters (2D). Good for: Higher capacity, more complex patterns.
288
+ ~12.5M backbone parameters (2D). Good for: Higher capacity, more complex patterns.
289
289
 
290
290
  Args:
291
291
  in_shape: (L,), (H, W), or (D, H, W)
@@ -320,7 +320,7 @@ class DenseNet121Pretrained(BaseModel):
320
320
  """
321
321
  DenseNet-121 with ImageNet pretrained weights (2D only).
322
322
 
323
- ~8M parameters. Good for: Transfer learning with efficient feature reuse.
323
+ ~7.0M backbone parameters. Good for: Transfer learning with efficient feature reuse.
324
324
 
325
325
  Args:
326
326
  in_shape: (H, W) image dimensions
@@ -6,9 +6,9 @@ Wrapper around torchvision's EfficientNet with a regression head.
6
6
  Provides optional ImageNet pretrained weights for transfer learning.
7
7
 
8
8
  **Variants**:
9
- - efficientnet_b0: Smallest, fastest (~4.7M params)
10
- - efficientnet_b1: Light (~7.2M params)
11
- - efficientnet_b2: Balanced (~8.4M params)
9
+ - efficientnet_b0: Smallest, fastest (~4.0M backbone params)
10
+ - efficientnet_b1: Light (~6.5M backbone params)
11
+ - efficientnet_b2: Balanced (~7.7M backbone params)
12
12
 
13
13
  **Note**: EfficientNet is 2D-only. For 1D data, use TCN. For 3D data, use ResNet3D.
14
14
 
@@ -110,9 +110,30 @@ class EfficientNetBase(BaseModel):
110
110
  self._freeze_backbone()
111
111
 
112
112
  def _adapt_input_channels(self):
113
- """Modify first conv to handle single-channel input by expanding to 3ch."""
114
- # We'll handle this in forward by repeating channels
115
- pass
113
+ """Modify first conv to accept single-channel input.
114
+
115
+ Instead of expanding 1→3 channels in forward (which triples memory),
116
+ we replace the first conv layer with a 1-channel version and initialize
117
+ weights as the mean of the pretrained RGB filters.
118
+ """
119
+ # EfficientNet stem conv is at: features[0][0]
120
+ old_conv = self.backbone.features[0][0]
121
+ new_conv = nn.Conv2d(
122
+ 1, # Single channel input
123
+ old_conv.out_channels,
124
+ kernel_size=old_conv.kernel_size,
125
+ stride=old_conv.stride,
126
+ padding=old_conv.padding,
127
+ dilation=old_conv.dilation,
128
+ groups=old_conv.groups,
129
+ padding_mode=old_conv.padding_mode,
130
+ bias=old_conv.bias is not None,
131
+ )
132
+ if self.pretrained:
133
+ # Initialize with mean of pretrained RGB weights
134
+ with torch.no_grad():
135
+ new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
136
+ self.backbone.features[0][0] = new_conv
116
137
 
117
138
  def _freeze_backbone(self):
118
139
  """Freeze all backbone parameters except the classifier."""
@@ -130,10 +151,6 @@ class EfficientNetBase(BaseModel):
130
151
  Returns:
131
152
  Output tensor of shape (B, out_size)
132
153
  """
133
- # Expand single channel to 3 channels for pretrained weights
134
- if x.size(1) == 1:
135
- x = x.expand(-1, 3, -1, -1)
136
-
137
154
  return self.backbone(x)
138
155
 
139
156
  @classmethod
@@ -152,7 +169,7 @@ class EfficientNetB0(EfficientNetBase):
152
169
  """
153
170
  EfficientNet-B0: Smallest, most efficient variant.
154
171
 
155
- ~5.3M parameters. Good for: Quick training, limited compute, baseline.
172
+ ~4.0M backbone parameters. Good for: Quick training, limited compute, baseline.
156
173
 
157
174
  Args:
158
175
  in_shape: (H, W) image dimensions
@@ -183,7 +200,7 @@ class EfficientNetB1(EfficientNetBase):
183
200
  """
184
201
  EfficientNet-B1: Slightly larger variant.
185
202
 
186
- ~7.8M parameters. Good for: Better accuracy with moderate compute.
203
+ ~6.5M backbone parameters. Good for: Better accuracy with moderate compute.
187
204
 
188
205
  Args:
189
206
  in_shape: (H, W) image dimensions
@@ -214,7 +231,7 @@ class EfficientNetB2(EfficientNetBase):
214
231
  """
215
232
  EfficientNet-B2: Best balance of size and performance.
216
233
 
217
- ~9.1M parameters. Good for: High accuracy without excessive compute.
234
+ ~7.7M backbone parameters. Good for: High accuracy without excessive compute.
218
235
 
219
236
  Args:
220
237
  in_shape: (H, W) image dimensions
@@ -129,10 +129,37 @@ class EfficientNetV2Base(BaseModel):
129
129
  nn.Linear(regression_hidden // 2, out_size),
130
130
  )
131
131
 
132
- # Optionally freeze backbone for fine-tuning
132
+ # Adapt first conv for single-channel input (3× memory savings vs expand)
133
+ self._adapt_input_channels()
134
+
135
+ # Optionally freeze backbone for fine-tuning (after adaptation so new conv is frozen too)
133
136
  if freeze_backbone:
134
137
  self._freeze_backbone()
135
138
 
139
+ def _adapt_input_channels(self):
140
+ """Modify first conv to accept single-channel input.
141
+
142
+ Instead of expanding 1→3 channels in forward (which triples memory),
143
+ we replace the first conv layer with a 1-channel version and initialize
144
+ weights as the mean of the pretrained RGB filters.
145
+ """
146
+ old_conv = self.backbone.features[0][0]
147
+ new_conv = nn.Conv2d(
148
+ 1, # Single channel input
149
+ old_conv.out_channels,
150
+ kernel_size=old_conv.kernel_size,
151
+ stride=old_conv.stride,
152
+ padding=old_conv.padding,
153
+ dilation=old_conv.dilation,
154
+ groups=old_conv.groups,
155
+ padding_mode=old_conv.padding_mode,
156
+ bias=old_conv.bias is not None,
157
+ )
158
+ if self.pretrained:
159
+ with torch.no_grad():
160
+ new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
161
+ self.backbone.features[0][0] = new_conv
162
+
136
163
  def _freeze_backbone(self):
137
164
  """Freeze all backbone parameters except the classifier."""
138
165
  for name, param in self.backbone.named_parameters():
@@ -144,15 +171,11 @@ class EfficientNetV2Base(BaseModel):
144
171
  Forward pass.
145
172
 
146
173
  Args:
147
- x: Input tensor of shape (B, C, H, W) where C is 1 or 3
174
+ x: Input tensor of shape (B, 1, H, W)
148
175
 
149
176
  Returns:
150
177
  Output tensor of shape (B, out_size)
151
178
  """
152
- # Expand single channel to 3 channels for pretrained weights compatibility
153
- if x.size(1) == 1:
154
- x = x.expand(-1, 3, -1, -1)
155
-
156
179
  return self.backbone(x)
157
180
 
158
181
  @classmethod
@@ -176,7 +199,7 @@ class EfficientNetV2S(EfficientNetV2Base):
176
199
  """
177
200
  EfficientNetV2-S: Small variant, recommended default.
178
201
 
179
- ~21.5M parameters. Best balance of speed and accuracy for most tasks.
202
+ ~20.2M backbone parameters. Best balance of speed and accuracy for most tasks.
180
203
  2× faster training than EfficientNet-B4 with better accuracy.
181
204
 
182
205
  Recommended for:
@@ -217,7 +240,7 @@ class EfficientNetV2M(EfficientNetV2Base):
217
240
  """
218
241
  EfficientNetV2-M: Medium variant for higher accuracy.
219
242
 
220
- ~54.1M parameters. Use when accuracy is more important than speed.
243
+ ~52.9M backbone parameters. Use when accuracy is more important than speed.
221
244
 
222
245
  Recommended for:
223
246
  - Large datasets (>50k samples)
@@ -257,7 +280,7 @@ class EfficientNetV2L(EfficientNetV2Base):
257
280
  """
258
281
  EfficientNetV2-L: Large variant for maximum accuracy.
259
282
 
260
- ~118.5M parameters. Use only with large datasets and sufficient compute.
283
+ ~117.2M backbone parameters. Use only with large datasets and sufficient compute.
261
284
 
262
285
  Recommended for:
263
286
  - Very large datasets (>100k samples)