wavedl 1.6.0__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. wavedl/__init__.py +1 -1
  2. wavedl/hpo.py +451 -451
  3. wavedl/{hpc.py → launcher.py} +135 -61
  4. wavedl/models/__init__.py +28 -0
  5. wavedl/models/{_timm_utils.py → _pretrained_utils.py} +128 -0
  6. wavedl/models/base.py +48 -0
  7. wavedl/models/caformer.py +1 -1
  8. wavedl/models/cnn.py +2 -27
  9. wavedl/models/convnext.py +5 -18
  10. wavedl/models/convnext_v2.py +6 -22
  11. wavedl/models/densenet.py +5 -18
  12. wavedl/models/efficientnetv2.py +315 -315
  13. wavedl/models/efficientvit.py +398 -0
  14. wavedl/models/fastvit.py +6 -39
  15. wavedl/models/mamba.py +44 -24
  16. wavedl/models/maxvit.py +51 -48
  17. wavedl/models/mobilenetv3.py +295 -295
  18. wavedl/models/regnet.py +406 -406
  19. wavedl/models/resnet.py +14 -56
  20. wavedl/models/resnet3d.py +258 -258
  21. wavedl/models/swin.py +443 -443
  22. wavedl/models/tcn.py +393 -409
  23. wavedl/models/unet.py +1 -5
  24. wavedl/models/unireplknet.py +491 -0
  25. wavedl/models/vit.py +3 -3
  26. wavedl/train.py +1427 -1430
  27. wavedl/utils/config.py +367 -367
  28. wavedl/utils/cross_validation.py +530 -530
  29. wavedl/utils/losses.py +216 -216
  30. wavedl/utils/optimizers.py +216 -216
  31. wavedl/utils/schedulers.py +251 -251
  32. {wavedl-1.6.0.dist-info → wavedl-1.6.2.dist-info}/METADATA +150 -113
  33. wavedl-1.6.2.dist-info/RECORD +46 -0
  34. {wavedl-1.6.0.dist-info → wavedl-1.6.2.dist-info}/entry_points.txt +2 -2
  35. wavedl-1.6.0.dist-info/RECORD +0 -44
  36. {wavedl-1.6.0.dist-info → wavedl-1.6.2.dist-info}/LICENSE +0 -0
  37. {wavedl-1.6.0.dist-info → wavedl-1.6.2.dist-info}/WHEEL +0 -0
  38. {wavedl-1.6.0.dist-info → wavedl-1.6.2.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,21 @@
1
1
  #!/usr/bin/env python
2
2
  """
3
- WaveDL HPC Training Launcher.
3
+ WaveDL Training Launcher.
4
4
 
5
- This module provides a Python-based HPC training launcher that wraps accelerate
6
- for distributed training on High-Performance Computing clusters.
5
+ This module provides a universal training launcher that wraps accelerate
6
+ for distributed training. It works seamlessly on both:
7
+ - Local machines (uses standard cache locations)
8
+ - HPC clusters (uses local caching, offline WandB)
9
+
10
+ The environment is auto-detected based on scheduler variables (SLURM, PBS, etc.)
11
+ and home directory writability.
7
12
 
8
13
  Usage:
9
- wavedl-hpc --model cnn --data_path train.npz --num_gpus 4
14
+ # Local machine or HPC - same command!
15
+ wavedl-train --model cnn --data_path train.npz --output_dir results
16
+
17
+ # Multi-GPU is automatic (uses all available GPUs)
18
+ wavedl-train --model resnet18 --data_path train.npz --num_gpus 4
10
19
 
11
20
  Example SLURM script:
12
21
  #!/bin/bash
@@ -14,7 +23,7 @@ Example SLURM script:
14
23
  #SBATCH --gpus-per-node=4
15
24
  #SBATCH --time=12:00:00
16
25
 
17
- wavedl-hpc --model cnn --data_path /scratch/data.npz --compile
26
+ wavedl-train --model cnn --data_path /scratch/data.npz --compile
18
27
 
19
28
  Author: Ductho Le (ductho.le@outlook.com)
20
29
  """
@@ -53,78 +62,138 @@ def detect_gpus() -> int:
53
62
  return 1
54
63
 
55
64
 
56
- def setup_hpc_environment() -> None:
57
- """Configure environment variables for HPC systems.
65
+ def is_hpc_environment() -> bool:
66
+ """Detect if running on an HPC cluster.
67
+
68
+ Checks for:
69
+ 1. Common HPC scheduler environment variables (SLURM, PBS, LSF, SGE, Cobalt)
70
+ 2. Non-writable home directory (common on HPC systems)
58
71
 
59
- Handles restricted home directories (e.g., Compute Canada) and
60
- offline logging configurations. Always uses CWD-based TORCH_HOME
61
- since compute nodes typically lack internet access.
72
+ Returns:
73
+ True if HPC environment detected, False otherwise.
62
74
  """
63
- # Use CWD for cache base since HPC compute nodes typically lack internet
64
- cache_base = os.getcwd()
75
+ # Check for common HPC scheduler environment variables
76
+ hpc_indicators = [
77
+ "SLURM_JOB_ID", # SLURM
78
+ "PBS_JOBID", # PBS/Torque
79
+ "LSB_JOBID", # LSF
80
+ "SGE_TASK_ID", # Sun Grid Engine
81
+ "COBALT_JOBID", # Cobalt
82
+ ]
83
+ if any(var in os.environ for var in hpc_indicators):
84
+ return True
65
85
 
66
- # TORCH_HOME always set to CWD - compute nodes need pre-cached weights
67
- os.environ.setdefault("TORCH_HOME", f"{cache_base}/.torch_cache")
68
- Path(os.environ["TORCH_HOME"]).mkdir(parents=True, exist_ok=True)
86
+ # Check if home directory is not writable (common on HPC)
87
+ home = os.path.expanduser("~")
88
+ return not os.access(home, os.W_OK)
69
89
 
70
- # Triton/Inductor caches - prevents permission errors with --compile
71
- # These MUST be set before any torch.compile calls
72
- os.environ.setdefault("TRITON_CACHE_DIR", f"{cache_base}/.triton_cache")
73
- os.environ.setdefault("TORCHINDUCTOR_CACHE_DIR", f"{cache_base}/.inductor_cache")
74
- Path(os.environ["TRITON_CACHE_DIR"]).mkdir(parents=True, exist_ok=True)
75
- Path(os.environ["TORCHINDUCTOR_CACHE_DIR"]).mkdir(parents=True, exist_ok=True)
76
90
 
77
- # Check if home is writable for other caches
78
- home = os.path.expanduser("~")
79
- home_writable = os.access(home, os.W_OK)
80
-
81
- # Other caches only if home is not writable
82
- if not home_writable:
83
- os.environ.setdefault("MPLCONFIGDIR", f"{cache_base}/.matplotlib")
84
- os.environ.setdefault("FONTCONFIG_CACHE", f"{cache_base}/.fontconfig")
85
- os.environ.setdefault("XDG_CACHE_HOME", f"{cache_base}/.cache")
86
-
87
- # Ensure directories exist
88
- for env_var in [
89
- "MPLCONFIGDIR",
90
- "FONTCONFIG_CACHE",
91
- "XDG_CACHE_HOME",
92
- ]:
93
- Path(os.environ[env_var]).mkdir(parents=True, exist_ok=True)
94
-
95
- # WandB configuration (offline by default for HPC)
96
- os.environ.setdefault("WANDB_MODE", "offline")
97
- os.environ.setdefault("WANDB_DIR", f"{cache_base}/.wandb")
98
- os.environ.setdefault("WANDB_CACHE_DIR", f"{cache_base}/.wandb_cache")
99
- os.environ.setdefault("WANDB_CONFIG_DIR", f"{cache_base}/.wandb_config")
100
-
101
- # Suppress non-critical warnings
91
+ def setup_environment() -> None:
92
+ """Configure environment for HPC or local machine.
93
+
94
+ Automatically detects the environment and configures accordingly:
95
+ - HPC: Uses CWD-based caching, offline WandB (compute nodes lack internet)
96
+ - Local: Uses standard cache locations (~/.cache), doesn't override WandB
97
+ """
98
+ is_hpc = is_hpc_environment()
99
+
100
+ if is_hpc:
101
+ # HPC: use CWD-based caching (compute nodes lack internet)
102
+ cache_base = os.getcwd()
103
+
104
+ # TORCH_HOME set to CWD - compute nodes need pre-cached weights
105
+ os.environ.setdefault("TORCH_HOME", f"{cache_base}/.torch_cache")
106
+ Path(os.environ["TORCH_HOME"]).mkdir(parents=True, exist_ok=True)
107
+
108
+ # Triton/Inductor caches - prevents permission errors with --compile
109
+ os.environ.setdefault("TRITON_CACHE_DIR", f"{cache_base}/.triton_cache")
110
+ os.environ.setdefault(
111
+ "TORCHINDUCTOR_CACHE_DIR", f"{cache_base}/.inductor_cache"
112
+ )
113
+ Path(os.environ["TRITON_CACHE_DIR"]).mkdir(parents=True, exist_ok=True)
114
+ Path(os.environ["TORCHINDUCTOR_CACHE_DIR"]).mkdir(parents=True, exist_ok=True)
115
+
116
+ # Check if home is writable for other caches
117
+ home = os.path.expanduser("~")
118
+ home_writable = os.access(home, os.W_OK)
119
+
120
+ # Other caches only if home is not writable
121
+ if not home_writable:
122
+ os.environ.setdefault("MPLCONFIGDIR", f"{cache_base}/.matplotlib")
123
+ os.environ.setdefault("FONTCONFIG_CACHE", f"{cache_base}/.fontconfig")
124
+ os.environ.setdefault("XDG_CACHE_HOME", f"{cache_base}/.cache")
125
+
126
+ for env_var in [
127
+ "MPLCONFIGDIR",
128
+ "FONTCONFIG_CACHE",
129
+ "XDG_CACHE_HOME",
130
+ ]:
131
+ Path(os.environ[env_var]).mkdir(parents=True, exist_ok=True)
132
+
133
+ # WandB configuration (offline by default for HPC)
134
+ os.environ.setdefault("WANDB_MODE", "offline")
135
+ os.environ.setdefault("WANDB_DIR", f"{cache_base}/.wandb")
136
+ os.environ.setdefault("WANDB_CACHE_DIR", f"{cache_base}/.wandb_cache")
137
+ os.environ.setdefault("WANDB_CONFIG_DIR", f"{cache_base}/.wandb_config")
138
+
139
+ print("🖥️ HPC environment detected - using local caching")
140
+ else:
141
+ # Local machine: use standard locations, don't override user settings
142
+ # TORCH_HOME defaults to ~/.cache/torch (PyTorch default)
143
+ # WANDB_MODE defaults to online (WandB default)
144
+ print("💻 Local environment detected - using standard cache locations")
145
+
146
+ # Suppress non-critical warnings (both environments)
102
147
  os.environ.setdefault(
103
148
  "PYTHONWARNINGS",
104
149
  "ignore::UserWarning,ignore::FutureWarning,ignore::DeprecationWarning",
105
150
  )
106
151
 
107
152
 
153
+ def handle_fast_path_args() -> int | None:
154
+ """Handle utility flags that don't need accelerate launch.
155
+
156
+ Returns:
157
+ Exit code if handled (0 for success), None if should continue to full launch.
158
+ """
159
+ # --list_models: print models and exit immediately
160
+ if "--list_models" in sys.argv:
161
+ from wavedl.models import list_models
162
+
163
+ print("Available models:")
164
+ for name in list_models():
165
+ print(f" {name}")
166
+ return 0
167
+
168
+ return None # Continue to full launch
169
+
170
+
108
171
  def parse_args() -> tuple[argparse.Namespace, list[str]]:
109
- """Parse HPC-specific arguments, pass remaining to wavedl.train."""
172
+ """Parse launcher-specific arguments, pass remaining to wavedl.train."""
110
173
  parser = argparse.ArgumentParser(
111
- description="WaveDL HPC Training Launcher",
174
+ description="WaveDL Training Launcher (works on local machines and HPC clusters)",
112
175
  formatter_class=argparse.RawDescriptionHelpFormatter,
113
176
  epilog="""
114
177
  Examples:
115
- # Basic training with auto-detected GPUs
116
- wavedl-hpc --model cnn --data_path train.npz --epochs 100
178
+ # Basic training (auto-detects GPUs and environment)
179
+ wavedl-train --model cnn --data_path train.npz --output_dir results
117
180
 
118
- # Specify GPU count and mixed precision
119
- wavedl-hpc --model cnn --data_path train.npz --num_gpus 4 --mixed_precision bf16
181
+ # Specify GPU count explicitly
182
+ wavedl-train --model cnn --data_path train.npz --num_gpus 4
120
183
 
121
184
  # Full configuration
122
- wavedl-hpc --model resnet18 --data_path train.npz --num_gpus 8 \\
123
- --batch_size 256 --lr 1e-3 --compile --output_dir ./results
185
+ wavedl-train --model resnet18 --data_path train.npz --batch_size 256 \\
186
+ --lr 1e-3 --epochs 100 --compile --output_dir ./results
187
+
188
+ # List available models
189
+ wavedl-train --list_models
124
190
 
125
- Environment Variables:
126
- WANDB_MODE WandB mode: offline|online (default: offline)
127
- SLURM_TMPDIR Temp directory for HPC systems
191
+ Environment Detection:
192
+ The launcher automatically detects your environment:
193
+ - HPC (SLURM, PBS, etc.): Uses local caching, offline WandB
194
+ - Local machine: Uses standard cache locations (~/.cache)
195
+
196
+ For full training options, see: python -m wavedl.train --help
128
197
  """,
129
198
  )
130
199
 
@@ -204,7 +273,7 @@ def print_summary(
204
273
  print("Common issues:")
205
274
  print(" - Missing data file (check --data_path)")
206
275
  print(" - Insufficient GPU memory (reduce --batch_size)")
207
- print(" - Invalid model name (run: python train.py --list_models)")
276
+ print(" - Invalid model name (run: wavedl-train --list_models)")
208
277
  print()
209
278
 
210
279
  print("=" * 40)
@@ -212,12 +281,17 @@ def print_summary(
212
281
 
213
282
 
214
283
  def main() -> int:
215
- """Main entry point for wavedl-hpc command."""
284
+ """Main entry point for wavedl-train command."""
285
+ # Fast path for utility flags (avoid accelerate launch overhead)
286
+ exit_code = handle_fast_path_args()
287
+ if exit_code is not None:
288
+ return exit_code
289
+
216
290
  # Parse arguments
217
291
  args, train_args = parse_args()
218
292
 
219
- # Setup HPC environment
220
- setup_hpc_environment()
293
+ # Setup environment (smart detection)
294
+ setup_environment()
221
295
 
222
296
  # Check if wavedl package is importable
223
297
  try:
wavedl/models/__init__.py CHANGED
@@ -80,8 +80,24 @@ from .vit import ViTBase_, ViTSmall, ViTTiny
80
80
  # Optional timm-based models (imported conditionally)
81
81
  try:
82
82
  from .caformer import CaFormerS18, CaFormerS36, PoolFormerS12
83
+ from .efficientvit import (
84
+ EfficientViTB0,
85
+ EfficientViTB1,
86
+ EfficientViTB2,
87
+ EfficientViTB3,
88
+ EfficientViTL1,
89
+ EfficientViTL2,
90
+ EfficientViTM0,
91
+ EfficientViTM1,
92
+ EfficientViTM2,
93
+ )
83
94
  from .fastvit import FastViTS12, FastViTSA12, FastViTT8, FastViTT12
84
95
  from .maxvit import MaxViTBaseLarge, MaxViTSmall, MaxViTTiny
96
+ from .unireplknet import (
97
+ UniRepLKNetBaseLarge,
98
+ UniRepLKNetSmall,
99
+ UniRepLKNetTiny,
100
+ )
85
101
 
86
102
  _HAS_TIMM_MODELS = True
87
103
  except ImportError:
@@ -148,6 +164,15 @@ if _HAS_TIMM_MODELS:
148
164
  [
149
165
  "CaFormerS18",
150
166
  "CaFormerS36",
167
+ "EfficientViTB0",
168
+ "EfficientViTB1",
169
+ "EfficientViTB2",
170
+ "EfficientViTB3",
171
+ "EfficientViTL1",
172
+ "EfficientViTL2",
173
+ "EfficientViTM0",
174
+ "EfficientViTM1",
175
+ "EfficientViTM2",
151
176
  "FastViTS12",
152
177
  "FastViTSA12",
153
178
  "FastViTT8",
@@ -156,5 +181,8 @@ if _HAS_TIMM_MODELS:
156
181
  "MaxViTSmall",
157
182
  "MaxViTTiny",
158
183
  "PoolFormerS12",
184
+ "UniRepLKNetBaseLarge",
185
+ "UniRepLKNetSmall",
186
+ "UniRepLKNetTiny",
159
187
  ]
160
188
  )
@@ -236,3 +236,131 @@ def adapt_input_channels(
236
236
  return new_conv
237
237
  else:
238
238
  raise NotImplementedError(f"Unsupported layer type: {type(conv_layer)}")
239
+
240
+
241
+ def adapt_first_conv_for_single_channel(
242
+ module: nn.Module,
243
+ conv_path: str,
244
+ pretrained: bool = True,
245
+ ) -> None:
246
+ """
247
+ Adapt the first convolutional layer of a pretrained model for single-channel input.
248
+
249
+ This is a convenience function for torchvision-style models where the path
250
+ to the first conv layer is known. It modifies the model in-place.
251
+
252
+ For pretrained models, the RGB weights are averaged to create grayscale weights,
253
+ which provides a reasonable initialization for single-channel inputs.
254
+
255
+ Args:
256
+ module: The model or submodule containing the conv layer
257
+ conv_path: Dot-separated path to the conv layer (e.g., "conv1", "features.0.0")
258
+ pretrained: Whether to adapt pretrained weights by averaging RGB channels
259
+
260
+ Example:
261
+ >>> # For torchvision ResNet
262
+ >>> adapt_first_conv_for_single_channel(
263
+ ... model.backbone, "conv1", pretrained=True
264
+ ... )
265
+ >>> # For torchvision ConvNeXt
266
+ >>> adapt_first_conv_for_single_channel(
267
+ ... model.backbone, "features.0.0", pretrained=True
268
+ ... )
269
+ >>> # For torchvision DenseNet
270
+ >>> adapt_first_conv_for_single_channel(
271
+ ... model.backbone, "features.conv0", pretrained=True
272
+ ... )
273
+ """
274
+ # Navigate to parent and get the conv layer
275
+ parts = conv_path.split(".")
276
+ parent = module
277
+ for part in parts[:-1]:
278
+ if part.isdigit():
279
+ parent = parent[int(part)]
280
+ else:
281
+ parent = getattr(parent, part)
282
+
283
+ # Get the final attribute name and the old conv
284
+ final_attr = parts[-1]
285
+ if final_attr.isdigit():
286
+ old_conv = parent[int(final_attr)]
287
+ else:
288
+ old_conv = getattr(parent, final_attr)
289
+
290
+ # Create and set the new conv
291
+ new_conv = adapt_input_channels(old_conv, new_in_channels=1, pretrained=pretrained)
292
+
293
+ if final_attr.isdigit():
294
+ parent[int(final_attr)] = new_conv
295
+ else:
296
+ setattr(parent, final_attr, new_conv)
297
+
298
+
299
+ def find_and_adapt_input_convs(
300
+ backbone: nn.Module,
301
+ pretrained: bool = True,
302
+ adapt_all: bool = False,
303
+ ) -> int:
304
+ """
305
+ Find and adapt Conv2d layers with 3 input channels for single-channel input.
306
+
307
+ This is useful for timm-style models where the exact path to the first
308
+ conv layer may vary or where multiple layers need adaptation.
309
+
310
+ Args:
311
+ backbone: The backbone model to adapt
312
+ pretrained: Whether to adapt pretrained weights by averaging RGB channels
313
+ adapt_all: If True, adapt all Conv2d layers with 3 input channels.
314
+ If False (default), only adapt the first one found.
315
+
316
+ Returns:
317
+ Number of layers adapted
318
+
319
+ Example:
320
+ >>> # For timm models (adapt first conv only)
321
+ >>> count = find_and_adapt_input_convs(model.backbone, pretrained=True)
322
+ >>> # For models with multiple input convs (e.g., FastViT)
323
+ >>> count = find_and_adapt_input_convs(
324
+ ... model.backbone, pretrained=True, adapt_all=True
325
+ ... )
326
+ """
327
+ adapted_count = 0
328
+
329
+ for name, module in backbone.named_modules():
330
+ if not hasattr(module, "in_channels") or module.in_channels != 3:
331
+ continue
332
+
333
+ # Check if this is a wrapper with inner .conv attribute
334
+ if hasattr(module, "conv") and isinstance(module.conv, nn.Conv2d):
335
+ old_conv = module.conv
336
+ module.conv = adapt_input_channels(
337
+ old_conv, new_in_channels=1, pretrained=pretrained
338
+ )
339
+ adapted_count += 1
340
+
341
+ elif isinstance(module, nn.Conv2d):
342
+ # Direct Conv2d - need to replace it in parent
343
+ parts = name.split(".")
344
+ parent = backbone
345
+ for part in parts[:-1]:
346
+ if part.isdigit():
347
+ parent = parent[int(part)]
348
+ else:
349
+ parent = getattr(parent, part)
350
+
351
+ child_name = parts[-1]
352
+ new_conv = adapt_input_channels(
353
+ module, new_in_channels=1, pretrained=pretrained
354
+ )
355
+
356
+ if child_name.isdigit():
357
+ parent[int(child_name)] = new_conv
358
+ else:
359
+ setattr(parent, child_name, new_conv)
360
+
361
+ adapted_count += 1
362
+
363
+ if not adapt_all and adapted_count > 0:
364
+ break
365
+
366
+ return adapted_count
wavedl/models/base.py CHANGED
@@ -15,6 +15,54 @@ import torch
15
15
  import torch.nn as nn
16
16
 
17
17
 
18
+ # =============================================================================
19
+ # TYPE ALIASES
20
+ # =============================================================================
21
+
22
+ # Spatial shape type aliases for model input dimensions
23
+ SpatialShape1D = tuple[int]
24
+ SpatialShape2D = tuple[int, int]
25
+ SpatialShape3D = tuple[int, int, int]
26
+ SpatialShape = SpatialShape1D | SpatialShape2D | SpatialShape3D
27
+
28
+
29
+ # =============================================================================
30
+ # UTILITY FUNCTIONS
31
+ # =============================================================================
32
+
33
+
34
+ def compute_num_groups(num_channels: int, preferred_groups: int = 32) -> int:
35
+ """
36
+ Compute valid num_groups for GroupNorm that divides num_channels evenly.
37
+
38
+ GroupNorm requires num_channels to be divisible by num_groups. This function
39
+ finds the largest valid divisor up to preferred_groups.
40
+
41
+ Args:
42
+ num_channels: Number of channels to normalize (must be positive)
43
+ preferred_groups: Preferred number of groups (default: 32)
44
+
45
+ Returns:
46
+ Valid num_groups that satisfies num_channels % num_groups == 0
47
+
48
+ Example:
49
+ >>> compute_num_groups(64) # Returns 32
50
+ >>> compute_num_groups(48) # Returns 16 (48 % 32 != 0)
51
+ >>> compute_num_groups(7) # Returns 1 (prime number)
52
+ """
53
+ # Try preferred groups first, then common divisors
54
+ for groups in [preferred_groups, 16, 8, 4, 2, 1]:
55
+ if groups <= num_channels and num_channels % groups == 0:
56
+ return groups
57
+
58
+ # Fallback: find any valid divisor (always returns at least 1)
59
+ for groups in range(min(32, num_channels), 0, -1):
60
+ if num_channels % groups == 0:
61
+ return groups
62
+
63
+ return 1 # Always valid
64
+
65
+
18
66
  class BaseModel(nn.Module, ABC):
19
67
  """
20
68
  Abstract base class for all regression models.
wavedl/models/caformer.py CHANGED
@@ -33,7 +33,7 @@ Author: Ductho Le (ductho.le@outlook.com)
33
33
  import torch
34
34
  import torch.nn as nn
35
35
 
36
- from wavedl.models._timm_utils import build_regression_head
36
+ from wavedl.models._pretrained_utils import build_regression_head
37
37
  from wavedl.models.base import BaseModel
38
38
  from wavedl.models.registry import register_model
39
39
 
wavedl/models/cnn.py CHANGED
@@ -24,14 +24,10 @@ from typing import Any
24
24
  import torch
25
25
  import torch.nn as nn
26
26
 
27
- from wavedl.models.base import BaseModel
27
+ from wavedl.models.base import BaseModel, SpatialShape, compute_num_groups
28
28
  from wavedl.models.registry import register_model
29
29
 
30
30
 
31
- # Type alias for spatial shapes
32
- SpatialShape = tuple[int] | tuple[int, int] | tuple[int, int, int]
33
-
34
-
35
31
  def _get_conv_layers(
36
32
  dim: int,
37
33
  ) -> tuple[type[nn.Module], type[nn.Module], type[nn.Module]]:
@@ -163,27 +159,6 @@ class CNN(BaseModel):
163
159
  nn.Linear(64, out_size),
164
160
  )
165
161
 
166
- @staticmethod
167
- def _compute_num_groups(num_channels: int, target_groups: int = 4) -> int:
168
- """
169
- Compute valid num_groups for GroupNorm that divides num_channels.
170
-
171
- Finds the largest divisor of num_channels that is <= target_groups,
172
- or falls back to 1 if no suitable divisor exists.
173
-
174
- Args:
175
- num_channels: Number of channels (must be positive)
176
- target_groups: Desired number of groups (default: 4)
177
-
178
- Returns:
179
- Valid num_groups that satisfies num_channels % num_groups == 0
180
- """
181
- # Try target_groups down to 1, return first valid divisor
182
- for g in range(min(target_groups, num_channels), 0, -1):
183
- if num_channels % g == 0:
184
- return g
185
- return 1 # Fallback (always valid)
186
-
187
162
  def _make_conv_block(
188
163
  self, in_channels: int, out_channels: int, dropout: float = 0.0
189
164
  ) -> nn.Sequential:
@@ -198,7 +173,7 @@ class CNN(BaseModel):
198
173
  Returns:
199
174
  Sequential block: Conv → GroupNorm → LeakyReLU → MaxPool [→ Dropout]
200
175
  """
201
- num_groups = self._compute_num_groups(out_channels, target_groups=4)
176
+ num_groups = compute_num_groups(out_channels, preferred_groups=4)
202
177
 
203
178
  layers = [
204
179
  self._Conv(in_channels, out_channels, kernel_size=3, padding=1),
wavedl/models/convnext.py CHANGED
@@ -28,14 +28,10 @@ import torch
28
28
  import torch.nn as nn
29
29
  import torch.nn.functional as F
30
30
 
31
- from wavedl.models.base import BaseModel
31
+ from wavedl.models.base import BaseModel, SpatialShape
32
32
  from wavedl.models.registry import register_model
33
33
 
34
34
 
35
- # Type alias for spatial shapes
36
- SpatialShape = tuple[int] | tuple[int, int] | tuple[int, int, int]
37
-
38
-
39
35
  def _get_conv_layer(dim: int) -> type[nn.Module]:
40
36
  """Get dimension-appropriate Conv class."""
41
37
  if dim == 1:
@@ -468,20 +464,11 @@ class ConvNeXtTinyPretrained(BaseModel):
468
464
  )
469
465
 
470
466
  # Modify first conv for single-channel input
471
- old_conv = self.backbone.features[0][0]
472
- self.backbone.features[0][0] = nn.Conv2d(
473
- 1,
474
- old_conv.out_channels,
475
- kernel_size=old_conv.kernel_size,
476
- stride=old_conv.stride,
477
- padding=old_conv.padding,
478
- bias=old_conv.bias is not None,
467
+ from wavedl.models._pretrained_utils import adapt_first_conv_for_single_channel
468
+
469
+ adapt_first_conv_for_single_channel(
470
+ self.backbone, "features.0.0", pretrained=pretrained
479
471
  )
480
- if pretrained:
481
- with torch.no_grad():
482
- self.backbone.features[0][0].weight = nn.Parameter(
483
- old_conv.weight.mean(dim=1, keepdim=True)
484
- )
485
472
 
486
473
  if freeze_backbone:
487
474
  self._freeze_backbone()
@@ -31,20 +31,17 @@ from typing import Any
31
31
  import torch
32
32
  import torch.nn as nn
33
33
 
34
- from wavedl.models._timm_utils import (
34
+ from wavedl.models._pretrained_utils import (
35
35
  LayerNormNd,
36
36
  build_regression_head,
37
37
  get_conv_layer,
38
38
  get_grn_layer,
39
39
  get_pool_layer,
40
40
  )
41
- from wavedl.models.base import BaseModel
41
+ from wavedl.models.base import BaseModel, SpatialShape
42
42
  from wavedl.models.registry import register_model
43
43
 
44
44
 
45
- # Type alias for spatial shapes
46
- SpatialShape = tuple[int] | tuple[int, int] | tuple[int, int, int]
47
-
48
45
  __all__ = [
49
46
  "ConvNeXtV2Base",
50
47
  "ConvNeXtV2BaseLarge",
@@ -469,24 +466,11 @@ class ConvNeXtV2TinyPretrained(BaseModel):
469
466
 
470
467
  def _adapt_input_channels(self):
471
468
  """Adapt first conv layer for single-channel input."""
472
- old_conv = self.backbone.features[0][0]
473
- new_conv = nn.Conv2d(
474
- 1,
475
- old_conv.out_channels,
476
- kernel_size=old_conv.kernel_size,
477
- stride=old_conv.stride,
478
- padding=old_conv.padding,
479
- bias=old_conv.bias is not None,
480
- )
469
+ from wavedl.models._pretrained_utils import adapt_first_conv_for_single_channel
481
470
 
482
- if self.pretrained:
483
- with torch.no_grad():
484
- # Average RGB weights for grayscale
485
- new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
486
- if old_conv.bias is not None:
487
- new_conv.bias.copy_(old_conv.bias)
488
-
489
- self.backbone.features[0][0] = new_conv
471
+ adapt_first_conv_for_single_channel(
472
+ self.backbone, "features.0.0", pretrained=self.pretrained
473
+ )
490
474
 
491
475
  def _freeze_backbone(self):
492
476
  """Freeze all backbone parameters except classifier."""