wavedl 1.5.7__py3-none-any.whl → 1.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. wavedl/__init__.py +1 -1
  2. wavedl/hpo.py +451 -451
  3. wavedl/models/__init__.py +80 -4
  4. wavedl/models/_pretrained_utils.py +366 -0
  5. wavedl/models/base.py +48 -0
  6. wavedl/models/caformer.py +270 -0
  7. wavedl/models/cnn.py +2 -27
  8. wavedl/models/convnext.py +113 -51
  9. wavedl/models/convnext_v2.py +488 -0
  10. wavedl/models/densenet.py +10 -23
  11. wavedl/models/efficientnet.py +6 -6
  12. wavedl/models/efficientnetv2.py +315 -315
  13. wavedl/models/efficientvit.py +398 -0
  14. wavedl/models/fastvit.py +252 -0
  15. wavedl/models/mamba.py +555 -0
  16. wavedl/models/maxvit.py +254 -0
  17. wavedl/models/mobilenetv3.py +295 -295
  18. wavedl/models/regnet.py +406 -406
  19. wavedl/models/resnet.py +19 -61
  20. wavedl/models/resnet3d.py +258 -258
  21. wavedl/models/swin.py +443 -443
  22. wavedl/models/tcn.py +393 -409
  23. wavedl/models/unet.py +2 -6
  24. wavedl/models/unireplknet.py +491 -0
  25. wavedl/models/vit.py +9 -9
  26. wavedl/train.py +1430 -1425
  27. wavedl/utils/config.py +367 -367
  28. wavedl/utils/cross_validation.py +530 -530
  29. wavedl/utils/data.py +39 -6
  30. wavedl/utils/losses.py +216 -216
  31. wavedl/utils/optimizers.py +216 -216
  32. wavedl/utils/schedulers.py +251 -251
  33. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/METADATA +150 -82
  34. wavedl-1.6.1.dist-info/RECORD +46 -0
  35. wavedl-1.5.7.dist-info/RECORD +0 -38
  36. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/LICENSE +0 -0
  37. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/WHEEL +0 -0
  38. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/entry_points.txt +0 -0
  39. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/top_level.txt +0 -0
wavedl/models/__init__.py CHANGED
@@ -6,10 +6,11 @@ This module provides a centralized registry for neural network architectures,
6
6
  enabling dynamic model selection via command-line arguments.
7
7
 
8
8
  **Dimensionality Coverage**:
9
- - 1D (waveforms): TCN, CNN, ResNet, ConvNeXt, DenseNet, ViT
10
- - 2D (images): CNN, ResNet, ConvNeXt, DenseNet, ViT, UNet,
11
- EfficientNet, MobileNetV3, RegNet, Swin
12
- - 3D (volumes): ResNet3D, CNN, ResNet, ConvNeXt, DenseNet
9
+ - 1D (waveforms): TCN, CNN, ResNet, ConvNeXt, ConvNeXt V2, DenseNet, ViT, Mamba
10
+ - 2D (images): CNN, ResNet, ConvNeXt, ConvNeXt V2, DenseNet, ViT, UNet,
11
+ EfficientNet, MobileNetV3, RegNet, Swin, MaxViT, FastViT,
12
+ CAFormer, PoolFormer, Vision Mamba
13
+ - 3D (volumes): ResNet3D, CNN, ResNet, ConvNeXt, ConvNeXt V2, DenseNet
13
14
 
14
15
  Usage:
15
16
  from wavedl.models import get_model, list_models, MODEL_REGISTRY
@@ -46,9 +47,19 @@ from .base import BaseModel
46
47
  # Import model implementations (triggers registration via decorators)
47
48
  from .cnn import CNN
48
49
  from .convnext import ConvNeXtBase_, ConvNeXtSmall, ConvNeXtTiny
50
+
51
+ # New models (v1.6+)
52
+ from .convnext_v2 import (
53
+ ConvNeXtV2Base,
54
+ ConvNeXtV2BaseLarge,
55
+ ConvNeXtV2Small,
56
+ ConvNeXtV2Tiny,
57
+ ConvNeXtV2TinyPretrained,
58
+ )
49
59
  from .densenet import DenseNet121, DenseNet169
50
60
  from .efficientnet import EfficientNetB0, EfficientNetB1, EfficientNetB2
51
61
  from .efficientnetv2 import EfficientNetV2L, EfficientNetV2M, EfficientNetV2S
62
+ from .mamba import Mamba1D, VimBase, VimSmall, VimTiny
52
63
  from .mobilenetv3 import MobileNetV3Large, MobileNetV3Small
53
64
  from .registry import (
54
65
  MODEL_REGISTRY,
@@ -66,6 +77,33 @@ from .unet import UNetRegression
66
77
  from .vit import ViTBase_, ViTSmall, ViTTiny
67
78
 
68
79
 
80
+ # Optional timm-based models (imported conditionally)
81
+ try:
82
+ from .caformer import CaFormerS18, CaFormerS36, PoolFormerS12
83
+ from .efficientvit import (
84
+ EfficientViTB0,
85
+ EfficientViTB1,
86
+ EfficientViTB2,
87
+ EfficientViTB3,
88
+ EfficientViTL1,
89
+ EfficientViTL2,
90
+ EfficientViTM0,
91
+ EfficientViTM1,
92
+ EfficientViTM2,
93
+ )
94
+ from .fastvit import FastViTS12, FastViTSA12, FastViTT8, FastViTT12
95
+ from .maxvit import MaxViTBaseLarge, MaxViTSmall, MaxViTTiny
96
+ from .unireplknet import (
97
+ UniRepLKNetBaseLarge,
98
+ UniRepLKNetSmall,
99
+ UniRepLKNetTiny,
100
+ )
101
+
102
+ _HAS_TIMM_MODELS = True
103
+ except ImportError:
104
+ _HAS_TIMM_MODELS = False
105
+
106
+
69
107
  # Export public API (sorted alphabetically per RUF022)
70
108
  # See module docstring for dimensionality support details
71
109
  __all__ = [
@@ -77,6 +115,11 @@ __all__ = [
77
115
  "ConvNeXtBase_",
78
116
  "ConvNeXtSmall",
79
117
  "ConvNeXtTiny",
118
+ "ConvNeXtV2Base",
119
+ "ConvNeXtV2BaseLarge",
120
+ "ConvNeXtV2Small",
121
+ "ConvNeXtV2Tiny",
122
+ "ConvNeXtV2TinyPretrained",
80
123
  "DenseNet121",
81
124
  "DenseNet169",
82
125
  "EfficientNetB0",
@@ -85,6 +128,7 @@ __all__ = [
85
128
  "EfficientNetV2L",
86
129
  "EfficientNetV2M",
87
130
  "EfficientNetV2S",
131
+ "Mamba1D",
88
132
  "MobileNetV3Large",
89
133
  "MobileNetV3Small",
90
134
  "RegNetY1_6GF",
@@ -105,8 +149,40 @@ __all__ = [
105
149
  "ViTBase_",
106
150
  "ViTSmall",
107
151
  "ViTTiny",
152
+ "VimBase",
153
+ "VimSmall",
154
+ "VimTiny",
108
155
  "build_model",
109
156
  "get_model",
110
157
  "list_models",
111
158
  "register_model",
112
159
  ]
160
+
161
+ # Add timm-based models to __all__ if available
162
+ if _HAS_TIMM_MODELS:
163
+ __all__.extend(
164
+ [
165
+ "CaFormerS18",
166
+ "CaFormerS36",
167
+ "EfficientViTB0",
168
+ "EfficientViTB1",
169
+ "EfficientViTB2",
170
+ "EfficientViTB3",
171
+ "EfficientViTL1",
172
+ "EfficientViTL2",
173
+ "EfficientViTM0",
174
+ "EfficientViTM1",
175
+ "EfficientViTM2",
176
+ "FastViTS12",
177
+ "FastViTSA12",
178
+ "FastViTT8",
179
+ "FastViTT12",
180
+ "MaxViTBaseLarge",
181
+ "MaxViTSmall",
182
+ "MaxViTTiny",
183
+ "PoolFormerS12",
184
+ "UniRepLKNetBaseLarge",
185
+ "UniRepLKNetSmall",
186
+ "UniRepLKNetTiny",
187
+ ]
188
+ )
@@ -0,0 +1,366 @@
1
+ """
2
+ Shared Utilities for Model Architectures
3
+ =========================================
4
+
5
+ Common components used across multiple models:
6
+ - GRN (Global Response Normalization) for ConvNeXt V2
7
+ - Dimension-agnostic layer factories
8
+ - Regression head builders
9
+ - Input channel adaptation for pretrained models
10
+
11
+ Author: Ductho Le (ductho.le@outlook.com)
12
+ """
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+
18
+
19
+ # =============================================================================
20
+ # DIMENSION-AGNOSTIC LAYER FACTORIES
21
+ # =============================================================================
22
+
23
+
24
+ def get_conv_layer(dim: int) -> type[nn.Module]:
25
+ """Get dimension-appropriate Conv class."""
26
+ layers = {1: nn.Conv1d, 2: nn.Conv2d, 3: nn.Conv3d}
27
+ if dim not in layers:
28
+ raise ValueError(f"Unsupported dimension: {dim}")
29
+ return layers[dim]
30
+
31
+
32
+ def get_norm_layer(dim: int) -> type[nn.Module]:
33
+ """Get dimension-appropriate BatchNorm class."""
34
+ layers = {1: nn.BatchNorm1d, 2: nn.BatchNorm2d, 3: nn.BatchNorm3d}
35
+ if dim not in layers:
36
+ raise ValueError(f"Unsupported dimension: {dim}")
37
+ return layers[dim]
38
+
39
+
40
+ def get_pool_layer(dim: int) -> type[nn.Module]:
41
+ """Get dimension-appropriate AdaptiveAvgPool class."""
42
+ layers = {1: nn.AdaptiveAvgPool1d, 2: nn.AdaptiveAvgPool2d, 3: nn.AdaptiveAvgPool3d}
43
+ if dim not in layers:
44
+ raise ValueError(f"Unsupported dimension: {dim}")
45
+ return layers[dim]
46
+
47
+
48
+ # =============================================================================
49
+ # GLOBAL RESPONSE NORMALIZATION (GRN) - ConvNeXt V2
50
+ # =============================================================================
51
+
52
+
53
+ class GRN1d(nn.Module):
54
+ """
55
+ Global Response Normalization for 1D inputs.
56
+
57
+ GRN enhances inter-channel feature competition and promotes diversity.
58
+ Replaces LayerScale in ConvNeXt V2.
59
+
60
+ Reference: ConvNeXt V2 (CVPR 2023)
61
+ """
62
+
63
+ def __init__(self, dim: int, eps: float = 1e-6):
64
+ super().__init__()
65
+ self.gamma = nn.Parameter(torch.zeros(1, dim, 1))
66
+ self.beta = nn.Parameter(torch.zeros(1, dim, 1))
67
+ self.eps = eps
68
+
69
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
70
+ # x: (B, C, L)
71
+ Gx = torch.norm(x, p=2, dim=2, keepdim=True) # (B, C, 1)
72
+ Nx = Gx / (Gx.mean(dim=1, keepdim=True) + self.eps) # (B, C, 1)
73
+ return self.gamma * (x * Nx) + self.beta + x
74
+
75
+
76
+ class GRN2d(nn.Module):
77
+ """
78
+ Global Response Normalization for 2D inputs.
79
+
80
+ GRN enhances inter-channel feature competition and promotes diversity.
81
+ Replaces LayerScale in ConvNeXt V2.
82
+
83
+ Reference: ConvNeXt V2 (CVPR 2023)
84
+ """
85
+
86
+ def __init__(self, dim: int, eps: float = 1e-6):
87
+ super().__init__()
88
+ self.gamma = nn.Parameter(torch.zeros(1, dim, 1, 1))
89
+ self.beta = nn.Parameter(torch.zeros(1, dim, 1, 1))
90
+ self.eps = eps
91
+
92
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
93
+ # x: (B, C, H, W)
94
+ Gx = torch.norm(x, p=2, dim=(2, 3), keepdim=True) # (B, C, 1, 1)
95
+ Nx = Gx / (Gx.mean(dim=1, keepdim=True) + self.eps) # (B, C, 1, 1)
96
+ return self.gamma * (x * Nx) + self.beta + x
97
+
98
+
99
+ class GRN3d(nn.Module):
100
+ """
101
+ Global Response Normalization for 3D inputs.
102
+
103
+ GRN enhances inter-channel feature competition and promotes diversity.
104
+ Replaces LayerScale in ConvNeXt V2.
105
+
106
+ Reference: ConvNeXt V2 (CVPR 2023)
107
+ """
108
+
109
+ def __init__(self, dim: int, eps: float = 1e-6):
110
+ super().__init__()
111
+ self.gamma = nn.Parameter(torch.zeros(1, dim, 1, 1, 1))
112
+ self.beta = nn.Parameter(torch.zeros(1, dim, 1, 1, 1))
113
+ self.eps = eps
114
+
115
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
116
+ # x: (B, C, D, H, W)
117
+ Gx = torch.norm(x, p=2, dim=(2, 3, 4), keepdim=True) # (B, C, 1, 1, 1)
118
+ Nx = Gx / (Gx.mean(dim=1, keepdim=True) + self.eps) # (B, C, 1, 1, 1)
119
+ return self.gamma * (x * Nx) + self.beta + x
120
+
121
+
122
+ def get_grn_layer(dim: int) -> type[nn.Module]:
123
+ """Get dimension-appropriate GRN class."""
124
+ layers = {1: GRN1d, 2: GRN2d, 3: GRN3d}
125
+ if dim not in layers:
126
+ raise ValueError(f"Unsupported dimension: {dim}")
127
+ return layers[dim]
128
+
129
+
130
+ # =============================================================================
131
+ # LAYER NORMALIZATION (Channels Last for CNNs)
132
+ # =============================================================================
133
+
134
+
135
+ class LayerNormNd(nn.Module):
136
+ """
137
+ LayerNorm that works with channels-first tensors of any dimension.
138
+ Applies normalization over the channel dimension.
139
+ """
140
+
141
+ def __init__(self, normalized_shape: int, dim: int, eps: float = 1e-6):
142
+ super().__init__()
143
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
144
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
145
+ self.eps = eps
146
+ self.dim = dim
147
+ self.normalized_shape = (normalized_shape,)
148
+
149
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
150
+ # Move channels to last, apply LN, move back
151
+ if self.dim == 1:
152
+ # (B, C, L) -> (B, L, C) -> LN -> (B, C, L)
153
+ x = x.permute(0, 2, 1)
154
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
155
+ x = x.permute(0, 2, 1)
156
+ elif self.dim == 2:
157
+ # (B, C, H, W) -> (B, H, W, C) -> LN -> (B, C, H, W)
158
+ x = x.permute(0, 2, 3, 1)
159
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
160
+ x = x.permute(0, 3, 1, 2)
161
+ elif self.dim == 3:
162
+ # (B, C, D, H, W) -> (B, D, H, W, C) -> LN -> (B, C, D, H, W)
163
+ x = x.permute(0, 2, 3, 4, 1)
164
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
165
+ x = x.permute(0, 4, 1, 2, 3)
166
+ return x
167
+
168
+
169
+ # =============================================================================
170
+ # REGRESSION HEAD BUILDERS
171
+ # =============================================================================
172
+
173
+
174
+ def build_regression_head(
175
+ in_features: int,
176
+ out_size: int,
177
+ dropout_rate: float = 0.3,
178
+ hidden_dim: int = 512,
179
+ ) -> nn.Sequential:
180
+ """
181
+ Build a standard regression head for pretrained models.
182
+
183
+ Args:
184
+ in_features: Input feature dimension
185
+ out_size: Number of regression targets
186
+ dropout_rate: Dropout rate
187
+ hidden_dim: Hidden layer dimension
188
+
189
+ Returns:
190
+ nn.Sequential regression head
191
+ """
192
+ return nn.Sequential(
193
+ nn.Dropout(dropout_rate),
194
+ nn.Linear(in_features, hidden_dim),
195
+ nn.SiLU(inplace=True),
196
+ nn.Dropout(dropout_rate * 0.5),
197
+ nn.Linear(hidden_dim, hidden_dim // 2),
198
+ nn.SiLU(inplace=True),
199
+ nn.Linear(hidden_dim // 2, out_size),
200
+ )
201
+
202
+
203
+ def adapt_input_channels(
204
+ conv_layer: nn.Module,
205
+ new_in_channels: int = 1,
206
+ pretrained: bool = True,
207
+ ) -> nn.Module:
208
+ """
209
+ Adapt a convolutional layer for different input channels.
210
+
211
+ For pretrained models, averages RGB weights to grayscale.
212
+
213
+ Args:
214
+ conv_layer: Original conv layer (expects 3 input channels)
215
+ new_in_channels: New number of input channels (default: 1)
216
+ pretrained: Whether to adapt pretrained weights
217
+
218
+ Returns:
219
+ New conv layer with adapted input channels
220
+ """
221
+ if isinstance(conv_layer, nn.Conv2d):
222
+ new_conv = nn.Conv2d(
223
+ new_in_channels,
224
+ conv_layer.out_channels,
225
+ kernel_size=conv_layer.kernel_size,
226
+ stride=conv_layer.stride,
227
+ padding=conv_layer.padding,
228
+ bias=conv_layer.bias is not None,
229
+ )
230
+ if pretrained and conv_layer.in_channels == 3:
231
+ with torch.no_grad():
232
+ # Average RGB weights
233
+ new_conv.weight.copy_(conv_layer.weight.mean(dim=1, keepdim=True))
234
+ if conv_layer.bias is not None:
235
+ new_conv.bias.copy_(conv_layer.bias)
236
+ return new_conv
237
+ else:
238
+ raise NotImplementedError(f"Unsupported layer type: {type(conv_layer)}")
239
+
240
+
241
+ def adapt_first_conv_for_single_channel(
242
+ module: nn.Module,
243
+ conv_path: str,
244
+ pretrained: bool = True,
245
+ ) -> None:
246
+ """
247
+ Adapt the first convolutional layer of a pretrained model for single-channel input.
248
+
249
+ This is a convenience function for torchvision-style models where the path
250
+ to the first conv layer is known. It modifies the model in-place.
251
+
252
+ For pretrained models, the RGB weights are averaged to create grayscale weights,
253
+ which provides a reasonable initialization for single-channel inputs.
254
+
255
+ Args:
256
+ module: The model or submodule containing the conv layer
257
+ conv_path: Dot-separated path to the conv layer (e.g., "conv1", "features.0.0")
258
+ pretrained: Whether to adapt pretrained weights by averaging RGB channels
259
+
260
+ Example:
261
+ >>> # For torchvision ResNet
262
+ >>> adapt_first_conv_for_single_channel(
263
+ ... model.backbone, "conv1", pretrained=True
264
+ ... )
265
+ >>> # For torchvision ConvNeXt
266
+ >>> adapt_first_conv_for_single_channel(
267
+ ... model.backbone, "features.0.0", pretrained=True
268
+ ... )
269
+ >>> # For torchvision DenseNet
270
+ >>> adapt_first_conv_for_single_channel(
271
+ ... model.backbone, "features.conv0", pretrained=True
272
+ ... )
273
+ """
274
+ # Navigate to parent and get the conv layer
275
+ parts = conv_path.split(".")
276
+ parent = module
277
+ for part in parts[:-1]:
278
+ if part.isdigit():
279
+ parent = parent[int(part)]
280
+ else:
281
+ parent = getattr(parent, part)
282
+
283
+ # Get the final attribute name and the old conv
284
+ final_attr = parts[-1]
285
+ if final_attr.isdigit():
286
+ old_conv = parent[int(final_attr)]
287
+ else:
288
+ old_conv = getattr(parent, final_attr)
289
+
290
+ # Create and set the new conv
291
+ new_conv = adapt_input_channels(old_conv, new_in_channels=1, pretrained=pretrained)
292
+
293
+ if final_attr.isdigit():
294
+ parent[int(final_attr)] = new_conv
295
+ else:
296
+ setattr(parent, final_attr, new_conv)
297
+
298
+
299
+ def find_and_adapt_input_convs(
300
+ backbone: nn.Module,
301
+ pretrained: bool = True,
302
+ adapt_all: bool = False,
303
+ ) -> int:
304
+ """
305
+ Find and adapt Conv2d layers with 3 input channels for single-channel input.
306
+
307
+ This is useful for timm-style models where the exact path to the first
308
+ conv layer may vary or where multiple layers need adaptation.
309
+
310
+ Args:
311
+ backbone: The backbone model to adapt
312
+ pretrained: Whether to adapt pretrained weights by averaging RGB channels
313
+ adapt_all: If True, adapt all Conv2d layers with 3 input channels.
314
+ If False (default), only adapt the first one found.
315
+
316
+ Returns:
317
+ Number of layers adapted
318
+
319
+ Example:
320
+ >>> # For timm models (adapt first conv only)
321
+ >>> count = find_and_adapt_input_convs(model.backbone, pretrained=True)
322
+ >>> # For models with multiple input convs (e.g., FastViT)
323
+ >>> count = find_and_adapt_input_convs(
324
+ ... model.backbone, pretrained=True, adapt_all=True
325
+ ... )
326
+ """
327
+ adapted_count = 0
328
+
329
+ for name, module in backbone.named_modules():
330
+ if not hasattr(module, "in_channels") or module.in_channels != 3:
331
+ continue
332
+
333
+ # Check if this is a wrapper with inner .conv attribute
334
+ if hasattr(module, "conv") and isinstance(module.conv, nn.Conv2d):
335
+ old_conv = module.conv
336
+ module.conv = adapt_input_channels(
337
+ old_conv, new_in_channels=1, pretrained=pretrained
338
+ )
339
+ adapted_count += 1
340
+
341
+ elif isinstance(module, nn.Conv2d):
342
+ # Direct Conv2d - need to replace it in parent
343
+ parts = name.split(".")
344
+ parent = backbone
345
+ for part in parts[:-1]:
346
+ if part.isdigit():
347
+ parent = parent[int(part)]
348
+ else:
349
+ parent = getattr(parent, part)
350
+
351
+ child_name = parts[-1]
352
+ new_conv = adapt_input_channels(
353
+ module, new_in_channels=1, pretrained=pretrained
354
+ )
355
+
356
+ if child_name.isdigit():
357
+ parent[int(child_name)] = new_conv
358
+ else:
359
+ setattr(parent, child_name, new_conv)
360
+
361
+ adapted_count += 1
362
+
363
+ if not adapt_all and adapted_count > 0:
364
+ break
365
+
366
+ return adapted_count
wavedl/models/base.py CHANGED
@@ -15,6 +15,54 @@ import torch
15
15
  import torch.nn as nn
16
16
 
17
17
 
18
+ # =============================================================================
19
+ # TYPE ALIASES
20
+ # =============================================================================
21
+
22
+ # Spatial shape type aliases for model input dimensions
23
+ SpatialShape1D = tuple[int]
24
+ SpatialShape2D = tuple[int, int]
25
+ SpatialShape3D = tuple[int, int, int]
26
+ SpatialShape = SpatialShape1D | SpatialShape2D | SpatialShape3D
27
+
28
+
29
+ # =============================================================================
30
+ # UTILITY FUNCTIONS
31
+ # =============================================================================
32
+
33
+
34
+ def compute_num_groups(num_channels: int, preferred_groups: int = 32) -> int:
35
+ """
36
+ Compute valid num_groups for GroupNorm that divides num_channels evenly.
37
+
38
+ GroupNorm requires num_channels to be divisible by num_groups. This function
39
+ finds the largest valid divisor up to preferred_groups.
40
+
41
+ Args:
42
+ num_channels: Number of channels to normalize (must be positive)
43
+ preferred_groups: Preferred number of groups (default: 32)
44
+
45
+ Returns:
46
+ Valid num_groups that satisfies num_channels % num_groups == 0
47
+
48
+ Example:
49
+ >>> compute_num_groups(64) # Returns 32
50
+ >>> compute_num_groups(48) # Returns 16 (48 % 32 != 0)
51
+ >>> compute_num_groups(7) # Returns 1 (prime number)
52
+ """
53
+ # Try preferred groups first, then common divisors
54
+ for groups in [preferred_groups, 16, 8, 4, 2, 1]:
55
+ if groups <= num_channels and num_channels % groups == 0:
56
+ return groups
57
+
58
+ # Fallback: find any valid divisor (always returns at least 1)
59
+ for groups in range(min(32, num_channels), 0, -1):
60
+ if num_channels % groups == 0:
61
+ return groups
62
+
63
+ return 1 # Always valid
64
+
65
+
18
66
  class BaseModel(nn.Module, ABC):
19
67
  """
20
68
  Abstract base class for all regression models.