wavedl 1.5.7__py3-none-any.whl → 1.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. wavedl/__init__.py +1 -1
  2. wavedl/hpo.py +451 -451
  3. wavedl/models/__init__.py +80 -4
  4. wavedl/models/_pretrained_utils.py +366 -0
  5. wavedl/models/base.py +48 -0
  6. wavedl/models/caformer.py +270 -0
  7. wavedl/models/cnn.py +2 -27
  8. wavedl/models/convnext.py +113 -51
  9. wavedl/models/convnext_v2.py +488 -0
  10. wavedl/models/densenet.py +10 -23
  11. wavedl/models/efficientnet.py +6 -6
  12. wavedl/models/efficientnetv2.py +315 -315
  13. wavedl/models/efficientvit.py +398 -0
  14. wavedl/models/fastvit.py +252 -0
  15. wavedl/models/mamba.py +555 -0
  16. wavedl/models/maxvit.py +254 -0
  17. wavedl/models/mobilenetv3.py +295 -295
  18. wavedl/models/regnet.py +406 -406
  19. wavedl/models/resnet.py +19 -61
  20. wavedl/models/resnet3d.py +258 -258
  21. wavedl/models/swin.py +443 -443
  22. wavedl/models/tcn.py +393 -409
  23. wavedl/models/unet.py +2 -6
  24. wavedl/models/unireplknet.py +491 -0
  25. wavedl/models/vit.py +9 -9
  26. wavedl/train.py +1430 -1425
  27. wavedl/utils/config.py +367 -367
  28. wavedl/utils/cross_validation.py +530 -530
  29. wavedl/utils/data.py +39 -6
  30. wavedl/utils/losses.py +216 -216
  31. wavedl/utils/optimizers.py +216 -216
  32. wavedl/utils/schedulers.py +251 -251
  33. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/METADATA +150 -82
  34. wavedl-1.6.1.dist-info/RECORD +46 -0
  35. wavedl-1.5.7.dist-info/RECORD +0 -38
  36. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/LICENSE +0 -0
  37. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/WHEEL +0 -0
  38. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/entry_points.txt +0 -0
  39. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/top_level.txt +0 -0
wavedl/models/resnet.py CHANGED
@@ -11,9 +11,9 @@ Provides multiple depth variants (18, 34, 50) with optional pretrained weights f
11
11
  - 3D: Volumetric data, CT/MRI (N, 1, D, H, W) → Conv3d
12
12
 
13
13
  **Variants**:
14
- - resnet18: Lightweight, fast training (~11M params)
15
- - resnet34: Balanced capacity (~21M params)
16
- - resnet50: Higher capacity with bottleneck blocks (~25M params)
14
+ - resnet18: Lightweight, fast training (~11.2M backbone params)
15
+ - resnet34: Balanced capacity (~21.3M backbone params)
16
+ - resnet50: Higher capacity with bottleneck blocks (~23.5M backbone params)
17
17
 
18
18
  References:
19
19
  He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep Residual Learning
@@ -27,14 +27,10 @@ from typing import Any
27
27
  import torch
28
28
  import torch.nn as nn
29
29
 
30
- from wavedl.models.base import BaseModel
30
+ from wavedl.models.base import BaseModel, SpatialShape, compute_num_groups
31
31
  from wavedl.models.registry import register_model
32
32
 
33
33
 
34
- # Type alias for spatial shapes
35
- SpatialShape = tuple[int] | tuple[int, int] | tuple[int, int, int]
36
-
37
-
38
34
  def _get_conv_layers(
39
35
  dim: int,
40
36
  ) -> tuple[type[nn.Module], type[nn.Module], type[nn.Module]]:
@@ -49,36 +45,6 @@ def _get_conv_layers(
49
45
  raise ValueError(f"Unsupported dimensionality: {dim}D. Supported: 1D, 2D, 3D.")
50
46
 
51
47
 
52
- def _get_num_groups(num_channels: int, preferred_groups: int = 32) -> int:
53
- """
54
- Get valid num_groups for GroupNorm that divides num_channels evenly.
55
-
56
- Args:
57
- num_channels: Number of channels to normalize
58
- preferred_groups: Preferred number of groups (default: 32)
59
-
60
- Returns:
61
- Valid num_groups that divides num_channels
62
-
63
- Raises:
64
- ValueError: If no valid divisor found (shouldn't happen with power-of-2 channels)
65
- """
66
- # Try preferred groups first, then decrease
67
- for groups in [preferred_groups, 16, 8, 4, 2, 1]:
68
- if groups <= num_channels and num_channels % groups == 0:
69
- return groups
70
-
71
- # Fallback: find any valid divisor
72
- for groups in range(min(32, num_channels), 0, -1):
73
- if num_channels % groups == 0:
74
- return groups
75
-
76
- raise ValueError(
77
- f"Cannot find valid num_groups for {num_channels} channels. "
78
- f"Consider using base_width that is a power of 2 (e.g., 32, 64, 128)."
79
- )
80
-
81
-
82
48
  class BasicBlock(nn.Module):
83
49
  """
84
50
  Basic residual block for ResNet-18/34.
@@ -107,12 +73,12 @@ class BasicBlock(nn.Module):
107
73
  padding=1,
108
74
  bias=False,
109
75
  )
110
- self.gn1 = nn.GroupNorm(_get_num_groups(out_channels), out_channels)
76
+ self.gn1 = nn.GroupNorm(compute_num_groups(out_channels), out_channels)
111
77
  self.relu = nn.ReLU(inplace=True)
112
78
  self.conv2 = Conv(
113
79
  out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False
114
80
  )
115
- self.gn2 = nn.GroupNorm(_get_num_groups(out_channels), out_channels)
81
+ self.gn2 = nn.GroupNorm(compute_num_groups(out_channels), out_channels)
116
82
  self.downsample = downsample
117
83
 
118
84
  def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -155,7 +121,7 @@ class Bottleneck(nn.Module):
155
121
 
156
122
  # 1x1 reduce
157
123
  self.conv1 = Conv(in_channels, out_channels, kernel_size=1, bias=False)
158
- self.gn1 = nn.GroupNorm(_get_num_groups(out_channels), out_channels)
124
+ self.gn1 = nn.GroupNorm(compute_num_groups(out_channels), out_channels)
159
125
 
160
126
  # 3x3 conv
161
127
  self.conv2 = Conv(
@@ -166,14 +132,16 @@ class Bottleneck(nn.Module):
166
132
  padding=1,
167
133
  bias=False,
168
134
  )
169
- self.gn2 = nn.GroupNorm(_get_num_groups(out_channels), out_channels)
135
+ self.gn2 = nn.GroupNorm(compute_num_groups(out_channels), out_channels)
170
136
 
171
137
  # 1x1 expand
172
138
  self.conv3 = Conv(
173
139
  out_channels, out_channels * self.expansion, kernel_size=1, bias=False
174
140
  )
175
141
  expanded_channels = out_channels * self.expansion
176
- self.gn3 = nn.GroupNorm(_get_num_groups(expanded_channels), expanded_channels)
142
+ self.gn3 = nn.GroupNorm(
143
+ compute_num_groups(expanded_channels), expanded_channels
144
+ )
177
145
 
178
146
  self.relu = nn.ReLU(inplace=True)
179
147
  self.downsample = downsample
@@ -229,7 +197,7 @@ class ResNetBase(BaseModel):
229
197
 
230
198
  # Stem: 7x7 conv (or equivalent for 1D/3D)
231
199
  self.conv1 = Conv(1, base_width, kernel_size=7, stride=2, padding=3, bias=False)
232
- self.gn1 = nn.GroupNorm(_get_num_groups(base_width), base_width)
200
+ self.gn1 = nn.GroupNorm(compute_num_groups(base_width), base_width)
233
201
  self.relu = nn.ReLU(inplace=True)
234
202
  self.maxpool = MaxPool(kernel_size=3, stride=2, padding=1)
235
203
 
@@ -275,7 +243,7 @@ class ResNetBase(BaseModel):
275
243
  bias=False,
276
244
  ),
277
245
  nn.GroupNorm(
278
- _get_num_groups(out_channels * block.expansion),
246
+ compute_num_groups(out_channels * block.expansion),
279
247
  out_channels * block.expansion,
280
248
  ),
281
249
  )
@@ -495,21 +463,11 @@ class PretrainedResNetBase(BaseModel):
495
463
 
496
464
  # Modify first conv for single-channel input
497
465
  # Original: Conv2d(3, 64, ...) → New: Conv2d(1, 64, ...)
498
- old_conv = self.backbone.conv1
499
- self.backbone.conv1 = nn.Conv2d(
500
- 1,
501
- old_conv.out_channels,
502
- kernel_size=old_conv.kernel_size,
503
- stride=old_conv.stride,
504
- padding=old_conv.padding,
505
- bias=False,
466
+ from wavedl.models._pretrained_utils import adapt_first_conv_for_single_channel
467
+
468
+ adapt_first_conv_for_single_channel(
469
+ self.backbone, "conv1", pretrained=pretrained
506
470
  )
507
- # Initialize new conv with mean of pretrained weights
508
- if pretrained:
509
- with torch.no_grad():
510
- self.backbone.conv1.weight = nn.Parameter(
511
- old_conv.weight.mean(dim=1, keepdim=True)
512
- )
513
471
 
514
472
  # Optionally freeze backbone
515
473
  if freeze_backbone:
@@ -534,7 +492,7 @@ class ResNet18Pretrained(PretrainedResNetBase):
534
492
  """
535
493
  ResNet-18 with ImageNet pretrained weights (2D only).
536
494
 
537
- ~11M parameters. Good for: Transfer learning, faster convergence.
495
+ ~11.2M backbone parameters. Good for: Transfer learning, faster convergence.
538
496
 
539
497
  Args:
540
498
  in_shape: (H, W) image dimensions
@@ -563,7 +521,7 @@ class ResNet50Pretrained(PretrainedResNetBase):
563
521
  """
564
522
  ResNet-50 with ImageNet pretrained weights (2D only).
565
523
 
566
- ~25M parameters. Good for: High accuracy with transfer learning.
524
+ ~23.5M backbone parameters. Good for: High accuracy with transfer learning.
567
525
 
568
526
  Args:
569
527
  in_shape: (H, W) image dimensions