wavedl 1.5.7__py3-none-any.whl → 1.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. wavedl/__init__.py +1 -1
  2. wavedl/hpo.py +451 -451
  3. wavedl/models/__init__.py +80 -4
  4. wavedl/models/_pretrained_utils.py +366 -0
  5. wavedl/models/base.py +48 -0
  6. wavedl/models/caformer.py +270 -0
  7. wavedl/models/cnn.py +2 -27
  8. wavedl/models/convnext.py +113 -51
  9. wavedl/models/convnext_v2.py +488 -0
  10. wavedl/models/densenet.py +10 -23
  11. wavedl/models/efficientnet.py +6 -6
  12. wavedl/models/efficientnetv2.py +315 -315
  13. wavedl/models/efficientvit.py +398 -0
  14. wavedl/models/fastvit.py +252 -0
  15. wavedl/models/mamba.py +555 -0
  16. wavedl/models/maxvit.py +254 -0
  17. wavedl/models/mobilenetv3.py +295 -295
  18. wavedl/models/regnet.py +406 -406
  19. wavedl/models/resnet.py +19 -61
  20. wavedl/models/resnet3d.py +258 -258
  21. wavedl/models/swin.py +443 -443
  22. wavedl/models/tcn.py +393 -409
  23. wavedl/models/unet.py +2 -6
  24. wavedl/models/unireplknet.py +491 -0
  25. wavedl/models/vit.py +9 -9
  26. wavedl/train.py +1430 -1425
  27. wavedl/utils/config.py +367 -367
  28. wavedl/utils/cross_validation.py +530 -530
  29. wavedl/utils/data.py +39 -6
  30. wavedl/utils/losses.py +216 -216
  31. wavedl/utils/optimizers.py +216 -216
  32. wavedl/utils/schedulers.py +251 -251
  33. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/METADATA +150 -82
  34. wavedl-1.6.1.dist-info/RECORD +46 -0
  35. wavedl-1.5.7.dist-info/RECORD +0 -38
  36. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/LICENSE +0 -0
  37. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/WHEEL +0 -0
  38. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/entry_points.txt +0 -0
  39. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,488 @@
1
+ """
2
+ ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders
3
+ ========================================================================
4
+
5
+ ConvNeXt V2 improves upon V1 by replacing LayerScale with Global Response
6
+ Normalization (GRN), which enhances inter-channel feature competition.
7
+
8
+ **Key Changes from V1**:
9
+ - GRN layer replaces LayerScale
10
+ - Better compatibility with masked autoencoder pretraining
11
+ - Prevents feature collapse in deep networks
12
+
13
+ **Variants**:
14
+ - convnext_v2_tiny: 28M params, depths [3,3,9,3], dims [96,192,384,768]
15
+ - convnext_v2_small: 50M params, depths [3,3,27,3], dims [96,192,384,768]
16
+ - convnext_v2_base: 89M params, depths [3,3,27,3], dims [128,256,512,1024]
17
+ - convnext_v2_tiny_pretrained: 2D only, ImageNet weights
18
+
19
+ **Supports**: 1D, 2D, 3D inputs
20
+
21
+ Reference:
22
+ Woo, S., et al. (2023). ConvNeXt V2: Co-designing and Scaling ConvNets
23
+ with Masked Autoencoders. CVPR 2023.
24
+ https://arxiv.org/abs/2301.00808
25
+
26
+ Author: Ductho Le (ductho.le@outlook.com)
27
+ """
28
+
29
+ from typing import Any
30
+
31
+ import torch
32
+ import torch.nn as nn
33
+
34
+ from wavedl.models._pretrained_utils import (
35
+ LayerNormNd,
36
+ build_regression_head,
37
+ get_conv_layer,
38
+ get_grn_layer,
39
+ get_pool_layer,
40
+ )
41
+ from wavedl.models.base import BaseModel, SpatialShape
42
+ from wavedl.models.registry import register_model
43
+
44
+
45
+ __all__ = [
46
+ "ConvNeXtV2Base",
47
+ "ConvNeXtV2BaseLarge",
48
+ "ConvNeXtV2Small",
49
+ "ConvNeXtV2Tiny",
50
+ "ConvNeXtV2TinyPretrained",
51
+ ]
52
+
53
+
54
+ # =============================================================================
55
+ # CONVNEXT V2 BLOCK
56
+ # =============================================================================
57
+
58
+
59
+ class ConvNeXtV2Block(nn.Module):
60
+ """
61
+ ConvNeXt V2 Block with GRN instead of LayerScale.
62
+
63
+ Architecture:
64
+ Input → DwConv → LayerNorm → Linear → GELU → GRN → Linear → Residual
65
+
66
+ The GRN layer is the key difference from V1, placed after the
67
+ dimension-expansion in the MLP, replacing LayerScale.
68
+ """
69
+
70
+ def __init__(
71
+ self,
72
+ dim: int,
73
+ spatial_dim: int,
74
+ drop_path: float = 0.0,
75
+ mlp_ratio: float = 4.0,
76
+ ):
77
+ super().__init__()
78
+ self.spatial_dim = spatial_dim
79
+
80
+ Conv = get_conv_layer(spatial_dim)
81
+ GRN = get_grn_layer(spatial_dim)
82
+
83
+ # Depthwise convolution
84
+ kernel_size = 7
85
+ padding = 3
86
+ self.dwconv = Conv(
87
+ dim, dim, kernel_size=kernel_size, padding=padding, groups=dim
88
+ )
89
+
90
+ # LayerNorm (applied in forward with permutation)
91
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
92
+
93
+ # MLP with expansion
94
+ hidden_dim = int(dim * mlp_ratio)
95
+ self.pwconv1 = nn.Linear(dim, hidden_dim) # Expansion
96
+ self.act = nn.GELU()
97
+ self.grn = GRN(hidden_dim) # GRN after expansion (key V2 change)
98
+ self.pwconv2 = nn.Linear(hidden_dim, dim) # Projection
99
+
100
+ # Stochastic depth
101
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
102
+
103
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
104
+ residual = x
105
+
106
+ # Depthwise conv
107
+ x = self.dwconv(x)
108
+
109
+ # Move channels to last for LayerNorm and Linear layers
110
+ if self.spatial_dim == 1:
111
+ x = x.permute(0, 2, 1) # (B, C, L) -> (B, L, C)
112
+ elif self.spatial_dim == 2:
113
+ x = x.permute(0, 2, 3, 1) # (B, C, H, W) -> (B, H, W, C)
114
+ else: # 3D
115
+ x = x.permute(0, 2, 3, 4, 1) # (B, C, D, H, W) -> (B, D, H, W, C)
116
+
117
+ x = self.norm(x)
118
+ x = self.pwconv1(x)
119
+ x = self.act(x)
120
+
121
+ # Move back to channels-first for GRN
122
+ if self.spatial_dim == 1:
123
+ x = x.permute(0, 2, 1) # (B, L, C) -> (B, C, L)
124
+ elif self.spatial_dim == 2:
125
+ x = x.permute(0, 3, 1, 2) # (B, H, W, C) -> (B, C, H, W)
126
+ else: # 3D
127
+ x = x.permute(0, 4, 1, 2, 3) # (B, D, H, W, C) -> (B, C, D, H, W)
128
+
129
+ # Apply GRN (the key V2 innovation)
130
+ x = self.grn(x)
131
+
132
+ # Move to channels-last for final projection
133
+ if self.spatial_dim == 1:
134
+ x = x.permute(0, 2, 1)
135
+ elif self.spatial_dim == 2:
136
+ x = x.permute(0, 2, 3, 1)
137
+ else:
138
+ x = x.permute(0, 2, 3, 4, 1)
139
+
140
+ x = self.pwconv2(x)
141
+
142
+ # Move back to channels-first
143
+ if self.spatial_dim == 1:
144
+ x = x.permute(0, 2, 1)
145
+ elif self.spatial_dim == 2:
146
+ x = x.permute(0, 3, 1, 2)
147
+ else:
148
+ x = x.permute(0, 4, 1, 2, 3)
149
+
150
+ x = residual + self.drop_path(x)
151
+ return x
152
+
153
+
154
+ class DropPath(nn.Module):
155
+ """Stochastic Depth (drop path) regularization."""
156
+
157
+ def __init__(self, drop_prob: float = 0.0):
158
+ super().__init__()
159
+ self.drop_prob = drop_prob
160
+
161
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
162
+ if self.drop_prob == 0.0 or not self.training:
163
+ return x
164
+
165
+ keep_prob = 1 - self.drop_prob
166
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
167
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
168
+ random_tensor.floor_()
169
+ return x.div(keep_prob) * random_tensor
170
+
171
+
172
+ # =============================================================================
173
+ # CONVNEXT V2 BASE CLASS
174
+ # =============================================================================
175
+
176
+
177
+ class ConvNeXtV2Base(BaseModel):
178
+ """
179
+ ConvNeXt V2 base class for regression.
180
+
181
+ Dimension-agnostic implementation supporting 1D, 2D, and 3D inputs.
182
+ Uses GRN (Global Response Normalization) instead of LayerScale.
183
+ """
184
+
185
+ def __init__(
186
+ self,
187
+ in_shape: SpatialShape,
188
+ out_size: int,
189
+ depths: list[int],
190
+ dims: list[int],
191
+ drop_path_rate: float = 0.0,
192
+ dropout_rate: float = 0.3,
193
+ **kwargs,
194
+ ):
195
+ super().__init__(in_shape, out_size)
196
+
197
+ self.dim = len(in_shape)
198
+ self.depths = depths
199
+ self.dims = dims
200
+
201
+ Conv = get_conv_layer(self.dim)
202
+ Pool = get_pool_layer(self.dim)
203
+
204
+ # Stem: aggressive downsampling (4x stride like ConvNeXt)
205
+ self.stem = nn.Sequential(
206
+ Conv(1, dims[0], kernel_size=4, stride=4),
207
+ LayerNormNd(dims[0], self.dim),
208
+ )
209
+
210
+ # Stochastic depth decay rule
211
+ dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
212
+
213
+ # Build stages
214
+ self.stages = nn.ModuleList()
215
+ self.downsamples = nn.ModuleList()
216
+ cur = 0
217
+
218
+ for i in range(len(depths)):
219
+ # Stage: sequence of ConvNeXt V2 blocks
220
+ stage = nn.Sequential(
221
+ *[
222
+ ConvNeXtV2Block(
223
+ dim=dims[i],
224
+ spatial_dim=self.dim,
225
+ drop_path=dp_rates[cur + j],
226
+ )
227
+ for j in range(depths[i])
228
+ ]
229
+ )
230
+ self.stages.append(stage)
231
+ cur += depths[i]
232
+
233
+ # Downsample between stages (except after last)
234
+ if i < len(depths) - 1:
235
+ downsample = nn.Sequential(
236
+ LayerNormNd(dims[i], self.dim),
237
+ Conv(dims[i], dims[i + 1], kernel_size=2, stride=2),
238
+ )
239
+ self.downsamples.append(downsample)
240
+
241
+ # Global pooling and head
242
+ self.norm = nn.LayerNorm(dims[-1], eps=1e-6)
243
+ self.global_pool = Pool(1)
244
+ self.head = nn.Sequential(
245
+ nn.Dropout(dropout_rate),
246
+ nn.Linear(dims[-1], dims[-1] // 2),
247
+ nn.GELU(),
248
+ nn.Dropout(dropout_rate * 0.5),
249
+ nn.Linear(dims[-1] // 2, out_size),
250
+ )
251
+
252
+ # Initialize weights
253
+ self._init_weights()
254
+
255
+ def _init_weights(self):
256
+ """Initialize weights with truncated normal."""
257
+ for m in self.modules():
258
+ if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)):
259
+ nn.init.trunc_normal_(m.weight, std=0.02)
260
+ if m.bias is not None:
261
+ nn.init.zeros_(m.bias)
262
+ elif isinstance(m, nn.LayerNorm):
263
+ nn.init.ones_(m.weight)
264
+ nn.init.zeros_(m.bias)
265
+
266
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
267
+ """
268
+ Forward pass.
269
+
270
+ Args:
271
+ x: Input tensor (B, 1, *in_shape)
272
+
273
+ Returns:
274
+ Output tensor (B, out_size)
275
+ """
276
+ x = self.stem(x)
277
+
278
+ for i, stage in enumerate(self.stages):
279
+ x = stage(x)
280
+ if i < len(self.downsamples):
281
+ x = self.downsamples[i](x)
282
+
283
+ # Global pooling
284
+ x = self.global_pool(x)
285
+ x = x.flatten(1)
286
+
287
+ # Final norm and head
288
+ x = self.norm(x)
289
+ x = self.head(x)
290
+
291
+ return x
292
+
293
+ @classmethod
294
+ def get_default_config(cls) -> dict[str, Any]:
295
+ return {
296
+ "depths": [3, 3, 9, 3],
297
+ "dims": [96, 192, 384, 768],
298
+ "drop_path_rate": 0.1,
299
+ "dropout_rate": 0.3,
300
+ }
301
+
302
+
303
+ # =============================================================================
304
+ # REGISTERED VARIANTS
305
+ # =============================================================================
306
+
307
+
308
+ @register_model("convnext_v2_tiny")
309
+ class ConvNeXtV2Tiny(ConvNeXtV2Base):
310
+ """
311
+ ConvNeXt V2 Tiny: ~27.9M backbone parameters.
312
+
313
+ Depths [3,3,9,3], Dims [96,192,384,768].
314
+ Supports 1D, 2D, 3D inputs.
315
+
316
+ Example:
317
+ >>> model = ConvNeXtV2Tiny(in_shape=(64, 64), out_size=3)
318
+ >>> x = torch.randn(4, 1, 64, 64)
319
+ >>> out = model(x) # (4, 3)
320
+ """
321
+
322
+ def __init__(self, in_shape: SpatialShape, out_size: int, **kwargs):
323
+ super().__init__(
324
+ in_shape=in_shape,
325
+ out_size=out_size,
326
+ depths=[3, 3, 9, 3],
327
+ dims=[96, 192, 384, 768],
328
+ **kwargs,
329
+ )
330
+
331
+ def __repr__(self) -> str:
332
+ return (
333
+ f"ConvNeXtV2_Tiny({self.dim}D, in_shape={self.in_shape}, "
334
+ f"out_size={self.out_size})"
335
+ )
336
+
337
+
338
+ @register_model("convnext_v2_small")
339
+ class ConvNeXtV2Small(ConvNeXtV2Base):
340
+ """
341
+ ConvNeXt V2 Small: ~49.6M backbone parameters.
342
+
343
+ Depths [3,3,27,3], Dims [96,192,384,768].
344
+ Supports 1D, 2D, 3D inputs.
345
+ """
346
+
347
+ def __init__(self, in_shape: SpatialShape, out_size: int, **kwargs):
348
+ super().__init__(
349
+ in_shape=in_shape,
350
+ out_size=out_size,
351
+ depths=[3, 3, 27, 3],
352
+ dims=[96, 192, 384, 768],
353
+ **kwargs,
354
+ )
355
+
356
+ def __repr__(self) -> str:
357
+ return (
358
+ f"ConvNeXtV2_Small({self.dim}D, in_shape={self.in_shape}, "
359
+ f"out_size={self.out_size})"
360
+ )
361
+
362
+
363
+ @register_model("convnext_v2_base")
364
+ class ConvNeXtV2BaseLarge(ConvNeXtV2Base):
365
+ """
366
+ ConvNeXt V2 Base: ~87.7M backbone parameters.
367
+
368
+ Depths [3,3,27,3], Dims [128,256,512,1024].
369
+ Supports 1D, 2D, 3D inputs.
370
+ """
371
+
372
+ def __init__(self, in_shape: SpatialShape, out_size: int, **kwargs):
373
+ super().__init__(
374
+ in_shape=in_shape,
375
+ out_size=out_size,
376
+ depths=[3, 3, 27, 3],
377
+ dims=[128, 256, 512, 1024],
378
+ **kwargs,
379
+ )
380
+
381
+ def __repr__(self) -> str:
382
+ return (
383
+ f"ConvNeXtV2_Base({self.dim}D, in_shape={self.in_shape}, "
384
+ f"out_size={self.out_size})"
385
+ )
386
+
387
+
388
+ # =============================================================================
389
+ # PRETRAINED VARIANT (2D ONLY)
390
+ # =============================================================================
391
+
392
+
393
+ @register_model("convnext_v2_tiny_pretrained")
394
+ class ConvNeXtV2TinyPretrained(BaseModel):
395
+ """
396
+ ConvNeXt V2 Tiny with ImageNet pretrained weights (2D only).
397
+
398
+ Uses torchvision's ConvNeXt V2 implementation with:
399
+ - Adapted input layer for single-channel input
400
+ - Replaced classifier for regression
401
+
402
+ Args:
403
+ in_shape: (H, W) input shape (2D only)
404
+ out_size: Number of regression targets
405
+ pretrained: Whether to load pretrained weights
406
+ freeze_backbone: Whether to freeze backbone for fine-tuning
407
+ """
408
+
409
+ def __init__(
410
+ self,
411
+ in_shape: tuple[int, int],
412
+ out_size: int,
413
+ pretrained: bool = True,
414
+ freeze_backbone: bool = False,
415
+ dropout_rate: float = 0.3,
416
+ **kwargs,
417
+ ):
418
+ super().__init__(in_shape, out_size)
419
+
420
+ if len(in_shape) != 2:
421
+ raise ValueError(
422
+ f"ConvNeXtV2TinyPretrained requires 2D input (H, W), "
423
+ f"got {len(in_shape)}D"
424
+ )
425
+
426
+ self.pretrained = pretrained
427
+ self.freeze_backbone = freeze_backbone
428
+
429
+ # Try to load from torchvision (if available)
430
+ try:
431
+ from torchvision.models import (
432
+ ConvNeXt_Tiny_Weights,
433
+ convnext_tiny,
434
+ )
435
+
436
+ weights = ConvNeXt_Tiny_Weights.IMAGENET1K_V1 if pretrained else None
437
+ self.backbone = convnext_tiny(weights=weights)
438
+
439
+ # Note: torchvision's ConvNeXt is V1, not V2
440
+ # For true V2, we'd need custom implementation or timm
441
+ # This is a fallback using V1 architecture
442
+
443
+ except ImportError:
444
+ raise ImportError(
445
+ "torchvision is required for pretrained ConvNeXt. "
446
+ "Install with: pip install torchvision"
447
+ )
448
+
449
+ # Adapt input layer (3 channels -> 1 channel)
450
+ self._adapt_input_channels()
451
+
452
+ # Replace classifier with regression head
453
+ # Keep the LayerNorm2d (idx 0) and Flatten (idx 1), only replace Linear (idx 2)
454
+ in_features = self.backbone.classifier[2].in_features
455
+ new_head = build_regression_head(in_features, out_size, dropout_rate)
456
+
457
+ # Build new classifier keeping LayerNorm2d and Flatten
458
+ self.backbone.classifier = nn.Sequential(
459
+ self.backbone.classifier[0], # LayerNorm2d
460
+ self.backbone.classifier[1], # Flatten
461
+ new_head, # Our regression head
462
+ )
463
+
464
+ if freeze_backbone:
465
+ self._freeze_backbone()
466
+
467
+ def _adapt_input_channels(self):
468
+ """Adapt first conv layer for single-channel input."""
469
+ from wavedl.models._pretrained_utils import adapt_first_conv_for_single_channel
470
+
471
+ adapt_first_conv_for_single_channel(
472
+ self.backbone, "features.0.0", pretrained=self.pretrained
473
+ )
474
+
475
+ def _freeze_backbone(self):
476
+ """Freeze all backbone parameters except classifier."""
477
+ for name, param in self.backbone.named_parameters():
478
+ if "classifier" not in name:
479
+ param.requires_grad = False
480
+
481
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
482
+ return self.backbone(x)
483
+
484
+ def __repr__(self) -> str:
485
+ return (
486
+ f"ConvNeXtV2_Tiny_Pretrained(in_shape={self.in_shape}, "
487
+ f"out_size={self.out_size}, pretrained={self.pretrained})"
488
+ )
wavedl/models/densenet.py CHANGED
@@ -11,8 +11,8 @@ Features: feature reuse, efficient gradient flow, compact model.
11
11
  - 3D: Volumetric data, CT/MRI (N, 1, D, H, W) → Conv3d
12
12
 
13
13
  **Variants**:
14
- - densenet121: Standard (121 layers, ~8M params for 2D)
15
- - densenet169: Deeper (169 layers, ~14M params for 2D)
14
+ - densenet121: Standard (121 layers, ~7.0M backbone params for 2D)
15
+ - densenet169: Deeper (169 layers, ~12.5M backbone params for 2D)
16
16
 
17
17
  References:
18
18
  Huang, G., et al. (2017). Densely Connected Convolutional Networks.
@@ -26,14 +26,10 @@ from typing import Any
26
26
  import torch
27
27
  import torch.nn as nn
28
28
 
29
- from wavedl.models.base import BaseModel
29
+ from wavedl.models.base import BaseModel, SpatialShape
30
30
  from wavedl.models.registry import register_model
31
31
 
32
32
 
33
- # Type alias for spatial shapes
34
- SpatialShape = tuple[int] | tuple[int, int] | tuple[int, int, int]
35
-
36
-
37
33
  def _get_layers(dim: int):
38
34
  """Get dimension-appropriate layer classes."""
39
35
  if dim == 1:
@@ -262,7 +258,7 @@ class DenseNet121(DenseNetBase):
262
258
  """
263
259
  DenseNet-121: Standard variant with 121 layers.
264
260
 
265
- ~8M parameters (2D). Good for: Balanced accuracy, efficient training.
261
+ ~7.0M backbone parameters (2D). Good for: Balanced accuracy, efficient training.
266
262
 
267
263
  Args:
268
264
  in_shape: (L,), (H, W), or (D, H, W)
@@ -285,7 +281,7 @@ class DenseNet169(DenseNetBase):
285
281
  """
286
282
  DenseNet-169: Deeper variant with 169 layers.
287
283
 
288
- ~14M parameters (2D). Good for: Higher capacity, more complex patterns.
284
+ ~12.5M backbone parameters (2D). Good for: Higher capacity, more complex patterns.
289
285
 
290
286
  Args:
291
287
  in_shape: (L,), (H, W), or (D, H, W)
@@ -320,7 +316,7 @@ class DenseNet121Pretrained(BaseModel):
320
316
  """
321
317
  DenseNet-121 with ImageNet pretrained weights (2D only).
322
318
 
323
- ~8M parameters. Good for: Transfer learning with efficient feature reuse.
319
+ ~7.0M backbone parameters. Good for: Transfer learning with efficient feature reuse.
324
320
 
325
321
  Args:
326
322
  in_shape: (H, W) image dimensions
@@ -374,20 +370,11 @@ class DenseNet121Pretrained(BaseModel):
374
370
  )
375
371
 
376
372
  # Modify first conv for single-channel input
377
- old_conv = self.backbone.features.conv0
378
- self.backbone.features.conv0 = nn.Conv2d(
379
- 1,
380
- old_conv.out_channels,
381
- kernel_size=old_conv.kernel_size,
382
- stride=old_conv.stride,
383
- padding=old_conv.padding,
384
- bias=False,
373
+ from wavedl.models._pretrained_utils import adapt_first_conv_for_single_channel
374
+
375
+ adapt_first_conv_for_single_channel(
376
+ self.backbone, "features.conv0", pretrained=pretrained
385
377
  )
386
- if pretrained:
387
- with torch.no_grad():
388
- self.backbone.features.conv0.weight = nn.Parameter(
389
- old_conv.weight.mean(dim=1, keepdim=True)
390
- )
391
378
 
392
379
  if freeze_backbone:
393
380
  self._freeze_backbone()
@@ -6,9 +6,9 @@ Wrapper around torchvision's EfficientNet with a regression head.
6
6
  Provides optional ImageNet pretrained weights for transfer learning.
7
7
 
8
8
  **Variants**:
9
- - efficientnet_b0: Smallest, fastest (~4.7M params)
10
- - efficientnet_b1: Light (~7.2M params)
11
- - efficientnet_b2: Balanced (~8.4M params)
9
+ - efficientnet_b0: Smallest, fastest (~4.0M backbone params)
10
+ - efficientnet_b1: Light (~6.5M backbone params)
11
+ - efficientnet_b2: Balanced (~7.7M backbone params)
12
12
 
13
13
  **Note**: EfficientNet is 2D-only. For 1D data, use TCN. For 3D data, use ResNet3D.
14
14
 
@@ -169,7 +169,7 @@ class EfficientNetB0(EfficientNetBase):
169
169
  """
170
170
  EfficientNet-B0: Smallest, most efficient variant.
171
171
 
172
- ~5.3M parameters. Good for: Quick training, limited compute, baseline.
172
+ ~4.0M backbone parameters. Good for: Quick training, limited compute, baseline.
173
173
 
174
174
  Args:
175
175
  in_shape: (H, W) image dimensions
@@ -200,7 +200,7 @@ class EfficientNetB1(EfficientNetBase):
200
200
  """
201
201
  EfficientNet-B1: Slightly larger variant.
202
202
 
203
- ~7.8M parameters. Good for: Better accuracy with moderate compute.
203
+ ~6.5M backbone parameters. Good for: Better accuracy with moderate compute.
204
204
 
205
205
  Args:
206
206
  in_shape: (H, W) image dimensions
@@ -231,7 +231,7 @@ class EfficientNetB2(EfficientNetBase):
231
231
  """
232
232
  EfficientNet-B2: Best balance of size and performance.
233
233
 
234
- ~9.1M parameters. Good for: High accuracy without excessive compute.
234
+ ~7.7M backbone parameters. Good for: High accuracy without excessive compute.
235
235
 
236
236
  Args:
237
237
  in_shape: (H, W) image dimensions