wavedl 1.6.0__py3-none-any.whl → 1.6.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/hpo.py +451 -451
- wavedl/{hpc.py → launcher.py} +135 -61
- wavedl/models/__init__.py +28 -0
- wavedl/models/{_timm_utils.py → _pretrained_utils.py} +128 -0
- wavedl/models/base.py +48 -0
- wavedl/models/caformer.py +1 -1
- wavedl/models/cnn.py +2 -27
- wavedl/models/convnext.py +5 -18
- wavedl/models/convnext_v2.py +6 -22
- wavedl/models/densenet.py +5 -18
- wavedl/models/efficientnetv2.py +315 -315
- wavedl/models/efficientvit.py +398 -0
- wavedl/models/fastvit.py +6 -39
- wavedl/models/mamba.py +44 -24
- wavedl/models/maxvit.py +51 -48
- wavedl/models/mobilenetv3.py +295 -295
- wavedl/models/regnet.py +406 -406
- wavedl/models/resnet.py +14 -56
- wavedl/models/resnet3d.py +258 -258
- wavedl/models/swin.py +443 -443
- wavedl/models/tcn.py +393 -409
- wavedl/models/unet.py +1 -5
- wavedl/models/unireplknet.py +491 -0
- wavedl/models/vit.py +3 -3
- wavedl/train.py +1427 -1430
- wavedl/utils/config.py +367 -367
- wavedl/utils/cross_validation.py +530 -530
- wavedl/utils/losses.py +216 -216
- wavedl/utils/optimizers.py +216 -216
- wavedl/utils/schedulers.py +251 -251
- {wavedl-1.6.0.dist-info → wavedl-1.6.2.dist-info}/METADATA +150 -113
- wavedl-1.6.2.dist-info/RECORD +46 -0
- {wavedl-1.6.0.dist-info → wavedl-1.6.2.dist-info}/entry_points.txt +2 -2
- wavedl-1.6.0.dist-info/RECORD +0 -44
- {wavedl-1.6.0.dist-info → wavedl-1.6.2.dist-info}/LICENSE +0 -0
- {wavedl-1.6.0.dist-info → wavedl-1.6.2.dist-info}/WHEEL +0 -0
- {wavedl-1.6.0.dist-info → wavedl-1.6.2.dist-info}/top_level.txt +0 -0
wavedl/{hpc.py → launcher.py}
RENAMED
|
@@ -1,12 +1,21 @@
|
|
|
1
1
|
#!/usr/bin/env python
|
|
2
2
|
"""
|
|
3
|
-
WaveDL
|
|
3
|
+
WaveDL Training Launcher.
|
|
4
4
|
|
|
5
|
-
This module provides a
|
|
6
|
-
for distributed training
|
|
5
|
+
This module provides a universal training launcher that wraps accelerate
|
|
6
|
+
for distributed training. It works seamlessly on both:
|
|
7
|
+
- Local machines (uses standard cache locations)
|
|
8
|
+
- HPC clusters (uses local caching, offline WandB)
|
|
9
|
+
|
|
10
|
+
The environment is auto-detected based on scheduler variables (SLURM, PBS, etc.)
|
|
11
|
+
and home directory writability.
|
|
7
12
|
|
|
8
13
|
Usage:
|
|
9
|
-
|
|
14
|
+
# Local machine or HPC - same command!
|
|
15
|
+
wavedl-train --model cnn --data_path train.npz --output_dir results
|
|
16
|
+
|
|
17
|
+
# Multi-GPU is automatic (uses all available GPUs)
|
|
18
|
+
wavedl-train --model resnet18 --data_path train.npz --num_gpus 4
|
|
10
19
|
|
|
11
20
|
Example SLURM script:
|
|
12
21
|
#!/bin/bash
|
|
@@ -14,7 +23,7 @@ Example SLURM script:
|
|
|
14
23
|
#SBATCH --gpus-per-node=4
|
|
15
24
|
#SBATCH --time=12:00:00
|
|
16
25
|
|
|
17
|
-
wavedl-
|
|
26
|
+
wavedl-train --model cnn --data_path /scratch/data.npz --compile
|
|
18
27
|
|
|
19
28
|
Author: Ductho Le (ductho.le@outlook.com)
|
|
20
29
|
"""
|
|
@@ -53,78 +62,138 @@ def detect_gpus() -> int:
|
|
|
53
62
|
return 1
|
|
54
63
|
|
|
55
64
|
|
|
56
|
-
def
|
|
57
|
-
"""
|
|
65
|
+
def is_hpc_environment() -> bool:
|
|
66
|
+
"""Detect if running on an HPC cluster.
|
|
67
|
+
|
|
68
|
+
Checks for:
|
|
69
|
+
1. Common HPC scheduler environment variables (SLURM, PBS, LSF, SGE, Cobalt)
|
|
70
|
+
2. Non-writable home directory (common on HPC systems)
|
|
58
71
|
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
since compute nodes typically lack internet access.
|
|
72
|
+
Returns:
|
|
73
|
+
True if HPC environment detected, False otherwise.
|
|
62
74
|
"""
|
|
63
|
-
#
|
|
64
|
-
|
|
75
|
+
# Check for common HPC scheduler environment variables
|
|
76
|
+
hpc_indicators = [
|
|
77
|
+
"SLURM_JOB_ID", # SLURM
|
|
78
|
+
"PBS_JOBID", # PBS/Torque
|
|
79
|
+
"LSB_JOBID", # LSF
|
|
80
|
+
"SGE_TASK_ID", # Sun Grid Engine
|
|
81
|
+
"COBALT_JOBID", # Cobalt
|
|
82
|
+
]
|
|
83
|
+
if any(var in os.environ for var in hpc_indicators):
|
|
84
|
+
return True
|
|
65
85
|
|
|
66
|
-
#
|
|
67
|
-
os.
|
|
68
|
-
|
|
86
|
+
# Check if home directory is not writable (common on HPC)
|
|
87
|
+
home = os.path.expanduser("~")
|
|
88
|
+
return not os.access(home, os.W_OK)
|
|
69
89
|
|
|
70
|
-
# Triton/Inductor caches - prevents permission errors with --compile
|
|
71
|
-
# These MUST be set before any torch.compile calls
|
|
72
|
-
os.environ.setdefault("TRITON_CACHE_DIR", f"{cache_base}/.triton_cache")
|
|
73
|
-
os.environ.setdefault("TORCHINDUCTOR_CACHE_DIR", f"{cache_base}/.inductor_cache")
|
|
74
|
-
Path(os.environ["TRITON_CACHE_DIR"]).mkdir(parents=True, exist_ok=True)
|
|
75
|
-
Path(os.environ["TORCHINDUCTOR_CACHE_DIR"]).mkdir(parents=True, exist_ok=True)
|
|
76
90
|
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
#
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
]
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
91
|
+
def setup_environment() -> None:
|
|
92
|
+
"""Configure environment for HPC or local machine.
|
|
93
|
+
|
|
94
|
+
Automatically detects the environment and configures accordingly:
|
|
95
|
+
- HPC: Uses CWD-based caching, offline WandB (compute nodes lack internet)
|
|
96
|
+
- Local: Uses standard cache locations (~/.cache), doesn't override WandB
|
|
97
|
+
"""
|
|
98
|
+
is_hpc = is_hpc_environment()
|
|
99
|
+
|
|
100
|
+
if is_hpc:
|
|
101
|
+
# HPC: use CWD-based caching (compute nodes lack internet)
|
|
102
|
+
cache_base = os.getcwd()
|
|
103
|
+
|
|
104
|
+
# TORCH_HOME set to CWD - compute nodes need pre-cached weights
|
|
105
|
+
os.environ.setdefault("TORCH_HOME", f"{cache_base}/.torch_cache")
|
|
106
|
+
Path(os.environ["TORCH_HOME"]).mkdir(parents=True, exist_ok=True)
|
|
107
|
+
|
|
108
|
+
# Triton/Inductor caches - prevents permission errors with --compile
|
|
109
|
+
os.environ.setdefault("TRITON_CACHE_DIR", f"{cache_base}/.triton_cache")
|
|
110
|
+
os.environ.setdefault(
|
|
111
|
+
"TORCHINDUCTOR_CACHE_DIR", f"{cache_base}/.inductor_cache"
|
|
112
|
+
)
|
|
113
|
+
Path(os.environ["TRITON_CACHE_DIR"]).mkdir(parents=True, exist_ok=True)
|
|
114
|
+
Path(os.environ["TORCHINDUCTOR_CACHE_DIR"]).mkdir(parents=True, exist_ok=True)
|
|
115
|
+
|
|
116
|
+
# Check if home is writable for other caches
|
|
117
|
+
home = os.path.expanduser("~")
|
|
118
|
+
home_writable = os.access(home, os.W_OK)
|
|
119
|
+
|
|
120
|
+
# Other caches only if home is not writable
|
|
121
|
+
if not home_writable:
|
|
122
|
+
os.environ.setdefault("MPLCONFIGDIR", f"{cache_base}/.matplotlib")
|
|
123
|
+
os.environ.setdefault("FONTCONFIG_CACHE", f"{cache_base}/.fontconfig")
|
|
124
|
+
os.environ.setdefault("XDG_CACHE_HOME", f"{cache_base}/.cache")
|
|
125
|
+
|
|
126
|
+
for env_var in [
|
|
127
|
+
"MPLCONFIGDIR",
|
|
128
|
+
"FONTCONFIG_CACHE",
|
|
129
|
+
"XDG_CACHE_HOME",
|
|
130
|
+
]:
|
|
131
|
+
Path(os.environ[env_var]).mkdir(parents=True, exist_ok=True)
|
|
132
|
+
|
|
133
|
+
# WandB configuration (offline by default for HPC)
|
|
134
|
+
os.environ.setdefault("WANDB_MODE", "offline")
|
|
135
|
+
os.environ.setdefault("WANDB_DIR", f"{cache_base}/.wandb")
|
|
136
|
+
os.environ.setdefault("WANDB_CACHE_DIR", f"{cache_base}/.wandb_cache")
|
|
137
|
+
os.environ.setdefault("WANDB_CONFIG_DIR", f"{cache_base}/.wandb_config")
|
|
138
|
+
|
|
139
|
+
print("🖥️ HPC environment detected - using local caching")
|
|
140
|
+
else:
|
|
141
|
+
# Local machine: use standard locations, don't override user settings
|
|
142
|
+
# TORCH_HOME defaults to ~/.cache/torch (PyTorch default)
|
|
143
|
+
# WANDB_MODE defaults to online (WandB default)
|
|
144
|
+
print("💻 Local environment detected - using standard cache locations")
|
|
145
|
+
|
|
146
|
+
# Suppress non-critical warnings (both environments)
|
|
102
147
|
os.environ.setdefault(
|
|
103
148
|
"PYTHONWARNINGS",
|
|
104
149
|
"ignore::UserWarning,ignore::FutureWarning,ignore::DeprecationWarning",
|
|
105
150
|
)
|
|
106
151
|
|
|
107
152
|
|
|
153
|
+
def handle_fast_path_args() -> int | None:
|
|
154
|
+
"""Handle utility flags that don't need accelerate launch.
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
Exit code if handled (0 for success), None if should continue to full launch.
|
|
158
|
+
"""
|
|
159
|
+
# --list_models: print models and exit immediately
|
|
160
|
+
if "--list_models" in sys.argv:
|
|
161
|
+
from wavedl.models import list_models
|
|
162
|
+
|
|
163
|
+
print("Available models:")
|
|
164
|
+
for name in list_models():
|
|
165
|
+
print(f" {name}")
|
|
166
|
+
return 0
|
|
167
|
+
|
|
168
|
+
return None # Continue to full launch
|
|
169
|
+
|
|
170
|
+
|
|
108
171
|
def parse_args() -> tuple[argparse.Namespace, list[str]]:
|
|
109
|
-
"""Parse
|
|
172
|
+
"""Parse launcher-specific arguments, pass remaining to wavedl.train."""
|
|
110
173
|
parser = argparse.ArgumentParser(
|
|
111
|
-
description="WaveDL
|
|
174
|
+
description="WaveDL Training Launcher (works on local machines and HPC clusters)",
|
|
112
175
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
113
176
|
epilog="""
|
|
114
177
|
Examples:
|
|
115
|
-
# Basic training
|
|
116
|
-
wavedl-
|
|
178
|
+
# Basic training (auto-detects GPUs and environment)
|
|
179
|
+
wavedl-train --model cnn --data_path train.npz --output_dir results
|
|
117
180
|
|
|
118
|
-
# Specify GPU count
|
|
119
|
-
wavedl-
|
|
181
|
+
# Specify GPU count explicitly
|
|
182
|
+
wavedl-train --model cnn --data_path train.npz --num_gpus 4
|
|
120
183
|
|
|
121
184
|
# Full configuration
|
|
122
|
-
wavedl-
|
|
123
|
-
|
|
185
|
+
wavedl-train --model resnet18 --data_path train.npz --batch_size 256 \\
|
|
186
|
+
--lr 1e-3 --epochs 100 --compile --output_dir ./results
|
|
187
|
+
|
|
188
|
+
# List available models
|
|
189
|
+
wavedl-train --list_models
|
|
124
190
|
|
|
125
|
-
Environment
|
|
126
|
-
|
|
127
|
-
|
|
191
|
+
Environment Detection:
|
|
192
|
+
The launcher automatically detects your environment:
|
|
193
|
+
- HPC (SLURM, PBS, etc.): Uses local caching, offline WandB
|
|
194
|
+
- Local machine: Uses standard cache locations (~/.cache)
|
|
195
|
+
|
|
196
|
+
For full training options, see: python -m wavedl.train --help
|
|
128
197
|
""",
|
|
129
198
|
)
|
|
130
199
|
|
|
@@ -204,7 +273,7 @@ def print_summary(
|
|
|
204
273
|
print("Common issues:")
|
|
205
274
|
print(" - Missing data file (check --data_path)")
|
|
206
275
|
print(" - Insufficient GPU memory (reduce --batch_size)")
|
|
207
|
-
print(" - Invalid model name (run:
|
|
276
|
+
print(" - Invalid model name (run: wavedl-train --list_models)")
|
|
208
277
|
print()
|
|
209
278
|
|
|
210
279
|
print("=" * 40)
|
|
@@ -212,12 +281,17 @@ def print_summary(
|
|
|
212
281
|
|
|
213
282
|
|
|
214
283
|
def main() -> int:
|
|
215
|
-
"""Main entry point for wavedl-
|
|
284
|
+
"""Main entry point for wavedl-train command."""
|
|
285
|
+
# Fast path for utility flags (avoid accelerate launch overhead)
|
|
286
|
+
exit_code = handle_fast_path_args()
|
|
287
|
+
if exit_code is not None:
|
|
288
|
+
return exit_code
|
|
289
|
+
|
|
216
290
|
# Parse arguments
|
|
217
291
|
args, train_args = parse_args()
|
|
218
292
|
|
|
219
|
-
# Setup
|
|
220
|
-
|
|
293
|
+
# Setup environment (smart detection)
|
|
294
|
+
setup_environment()
|
|
221
295
|
|
|
222
296
|
# Check if wavedl package is importable
|
|
223
297
|
try:
|
wavedl/models/__init__.py
CHANGED
|
@@ -80,8 +80,24 @@ from .vit import ViTBase_, ViTSmall, ViTTiny
|
|
|
80
80
|
# Optional timm-based models (imported conditionally)
|
|
81
81
|
try:
|
|
82
82
|
from .caformer import CaFormerS18, CaFormerS36, PoolFormerS12
|
|
83
|
+
from .efficientvit import (
|
|
84
|
+
EfficientViTB0,
|
|
85
|
+
EfficientViTB1,
|
|
86
|
+
EfficientViTB2,
|
|
87
|
+
EfficientViTB3,
|
|
88
|
+
EfficientViTL1,
|
|
89
|
+
EfficientViTL2,
|
|
90
|
+
EfficientViTM0,
|
|
91
|
+
EfficientViTM1,
|
|
92
|
+
EfficientViTM2,
|
|
93
|
+
)
|
|
83
94
|
from .fastvit import FastViTS12, FastViTSA12, FastViTT8, FastViTT12
|
|
84
95
|
from .maxvit import MaxViTBaseLarge, MaxViTSmall, MaxViTTiny
|
|
96
|
+
from .unireplknet import (
|
|
97
|
+
UniRepLKNetBaseLarge,
|
|
98
|
+
UniRepLKNetSmall,
|
|
99
|
+
UniRepLKNetTiny,
|
|
100
|
+
)
|
|
85
101
|
|
|
86
102
|
_HAS_TIMM_MODELS = True
|
|
87
103
|
except ImportError:
|
|
@@ -148,6 +164,15 @@ if _HAS_TIMM_MODELS:
|
|
|
148
164
|
[
|
|
149
165
|
"CaFormerS18",
|
|
150
166
|
"CaFormerS36",
|
|
167
|
+
"EfficientViTB0",
|
|
168
|
+
"EfficientViTB1",
|
|
169
|
+
"EfficientViTB2",
|
|
170
|
+
"EfficientViTB3",
|
|
171
|
+
"EfficientViTL1",
|
|
172
|
+
"EfficientViTL2",
|
|
173
|
+
"EfficientViTM0",
|
|
174
|
+
"EfficientViTM1",
|
|
175
|
+
"EfficientViTM2",
|
|
151
176
|
"FastViTS12",
|
|
152
177
|
"FastViTSA12",
|
|
153
178
|
"FastViTT8",
|
|
@@ -156,5 +181,8 @@ if _HAS_TIMM_MODELS:
|
|
|
156
181
|
"MaxViTSmall",
|
|
157
182
|
"MaxViTTiny",
|
|
158
183
|
"PoolFormerS12",
|
|
184
|
+
"UniRepLKNetBaseLarge",
|
|
185
|
+
"UniRepLKNetSmall",
|
|
186
|
+
"UniRepLKNetTiny",
|
|
159
187
|
]
|
|
160
188
|
)
|
|
@@ -236,3 +236,131 @@ def adapt_input_channels(
|
|
|
236
236
|
return new_conv
|
|
237
237
|
else:
|
|
238
238
|
raise NotImplementedError(f"Unsupported layer type: {type(conv_layer)}")
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def adapt_first_conv_for_single_channel(
|
|
242
|
+
module: nn.Module,
|
|
243
|
+
conv_path: str,
|
|
244
|
+
pretrained: bool = True,
|
|
245
|
+
) -> None:
|
|
246
|
+
"""
|
|
247
|
+
Adapt the first convolutional layer of a pretrained model for single-channel input.
|
|
248
|
+
|
|
249
|
+
This is a convenience function for torchvision-style models where the path
|
|
250
|
+
to the first conv layer is known. It modifies the model in-place.
|
|
251
|
+
|
|
252
|
+
For pretrained models, the RGB weights are averaged to create grayscale weights,
|
|
253
|
+
which provides a reasonable initialization for single-channel inputs.
|
|
254
|
+
|
|
255
|
+
Args:
|
|
256
|
+
module: The model or submodule containing the conv layer
|
|
257
|
+
conv_path: Dot-separated path to the conv layer (e.g., "conv1", "features.0.0")
|
|
258
|
+
pretrained: Whether to adapt pretrained weights by averaging RGB channels
|
|
259
|
+
|
|
260
|
+
Example:
|
|
261
|
+
>>> # For torchvision ResNet
|
|
262
|
+
>>> adapt_first_conv_for_single_channel(
|
|
263
|
+
... model.backbone, "conv1", pretrained=True
|
|
264
|
+
... )
|
|
265
|
+
>>> # For torchvision ConvNeXt
|
|
266
|
+
>>> adapt_first_conv_for_single_channel(
|
|
267
|
+
... model.backbone, "features.0.0", pretrained=True
|
|
268
|
+
... )
|
|
269
|
+
>>> # For torchvision DenseNet
|
|
270
|
+
>>> adapt_first_conv_for_single_channel(
|
|
271
|
+
... model.backbone, "features.conv0", pretrained=True
|
|
272
|
+
... )
|
|
273
|
+
"""
|
|
274
|
+
# Navigate to parent and get the conv layer
|
|
275
|
+
parts = conv_path.split(".")
|
|
276
|
+
parent = module
|
|
277
|
+
for part in parts[:-1]:
|
|
278
|
+
if part.isdigit():
|
|
279
|
+
parent = parent[int(part)]
|
|
280
|
+
else:
|
|
281
|
+
parent = getattr(parent, part)
|
|
282
|
+
|
|
283
|
+
# Get the final attribute name and the old conv
|
|
284
|
+
final_attr = parts[-1]
|
|
285
|
+
if final_attr.isdigit():
|
|
286
|
+
old_conv = parent[int(final_attr)]
|
|
287
|
+
else:
|
|
288
|
+
old_conv = getattr(parent, final_attr)
|
|
289
|
+
|
|
290
|
+
# Create and set the new conv
|
|
291
|
+
new_conv = adapt_input_channels(old_conv, new_in_channels=1, pretrained=pretrained)
|
|
292
|
+
|
|
293
|
+
if final_attr.isdigit():
|
|
294
|
+
parent[int(final_attr)] = new_conv
|
|
295
|
+
else:
|
|
296
|
+
setattr(parent, final_attr, new_conv)
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
def find_and_adapt_input_convs(
|
|
300
|
+
backbone: nn.Module,
|
|
301
|
+
pretrained: bool = True,
|
|
302
|
+
adapt_all: bool = False,
|
|
303
|
+
) -> int:
|
|
304
|
+
"""
|
|
305
|
+
Find and adapt Conv2d layers with 3 input channels for single-channel input.
|
|
306
|
+
|
|
307
|
+
This is useful for timm-style models where the exact path to the first
|
|
308
|
+
conv layer may vary or where multiple layers need adaptation.
|
|
309
|
+
|
|
310
|
+
Args:
|
|
311
|
+
backbone: The backbone model to adapt
|
|
312
|
+
pretrained: Whether to adapt pretrained weights by averaging RGB channels
|
|
313
|
+
adapt_all: If True, adapt all Conv2d layers with 3 input channels.
|
|
314
|
+
If False (default), only adapt the first one found.
|
|
315
|
+
|
|
316
|
+
Returns:
|
|
317
|
+
Number of layers adapted
|
|
318
|
+
|
|
319
|
+
Example:
|
|
320
|
+
>>> # For timm models (adapt first conv only)
|
|
321
|
+
>>> count = find_and_adapt_input_convs(model.backbone, pretrained=True)
|
|
322
|
+
>>> # For models with multiple input convs (e.g., FastViT)
|
|
323
|
+
>>> count = find_and_adapt_input_convs(
|
|
324
|
+
... model.backbone, pretrained=True, adapt_all=True
|
|
325
|
+
... )
|
|
326
|
+
"""
|
|
327
|
+
adapted_count = 0
|
|
328
|
+
|
|
329
|
+
for name, module in backbone.named_modules():
|
|
330
|
+
if not hasattr(module, "in_channels") or module.in_channels != 3:
|
|
331
|
+
continue
|
|
332
|
+
|
|
333
|
+
# Check if this is a wrapper with inner .conv attribute
|
|
334
|
+
if hasattr(module, "conv") and isinstance(module.conv, nn.Conv2d):
|
|
335
|
+
old_conv = module.conv
|
|
336
|
+
module.conv = adapt_input_channels(
|
|
337
|
+
old_conv, new_in_channels=1, pretrained=pretrained
|
|
338
|
+
)
|
|
339
|
+
adapted_count += 1
|
|
340
|
+
|
|
341
|
+
elif isinstance(module, nn.Conv2d):
|
|
342
|
+
# Direct Conv2d - need to replace it in parent
|
|
343
|
+
parts = name.split(".")
|
|
344
|
+
parent = backbone
|
|
345
|
+
for part in parts[:-1]:
|
|
346
|
+
if part.isdigit():
|
|
347
|
+
parent = parent[int(part)]
|
|
348
|
+
else:
|
|
349
|
+
parent = getattr(parent, part)
|
|
350
|
+
|
|
351
|
+
child_name = parts[-1]
|
|
352
|
+
new_conv = adapt_input_channels(
|
|
353
|
+
module, new_in_channels=1, pretrained=pretrained
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
if child_name.isdigit():
|
|
357
|
+
parent[int(child_name)] = new_conv
|
|
358
|
+
else:
|
|
359
|
+
setattr(parent, child_name, new_conv)
|
|
360
|
+
|
|
361
|
+
adapted_count += 1
|
|
362
|
+
|
|
363
|
+
if not adapt_all and adapted_count > 0:
|
|
364
|
+
break
|
|
365
|
+
|
|
366
|
+
return adapted_count
|
wavedl/models/base.py
CHANGED
|
@@ -15,6 +15,54 @@ import torch
|
|
|
15
15
|
import torch.nn as nn
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
# =============================================================================
|
|
19
|
+
# TYPE ALIASES
|
|
20
|
+
# =============================================================================
|
|
21
|
+
|
|
22
|
+
# Spatial shape type aliases for model input dimensions
|
|
23
|
+
SpatialShape1D = tuple[int]
|
|
24
|
+
SpatialShape2D = tuple[int, int]
|
|
25
|
+
SpatialShape3D = tuple[int, int, int]
|
|
26
|
+
SpatialShape = SpatialShape1D | SpatialShape2D | SpatialShape3D
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
# =============================================================================
|
|
30
|
+
# UTILITY FUNCTIONS
|
|
31
|
+
# =============================================================================
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def compute_num_groups(num_channels: int, preferred_groups: int = 32) -> int:
|
|
35
|
+
"""
|
|
36
|
+
Compute valid num_groups for GroupNorm that divides num_channels evenly.
|
|
37
|
+
|
|
38
|
+
GroupNorm requires num_channels to be divisible by num_groups. This function
|
|
39
|
+
finds the largest valid divisor up to preferred_groups.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
num_channels: Number of channels to normalize (must be positive)
|
|
43
|
+
preferred_groups: Preferred number of groups (default: 32)
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
Valid num_groups that satisfies num_channels % num_groups == 0
|
|
47
|
+
|
|
48
|
+
Example:
|
|
49
|
+
>>> compute_num_groups(64) # Returns 32
|
|
50
|
+
>>> compute_num_groups(48) # Returns 16 (48 % 32 != 0)
|
|
51
|
+
>>> compute_num_groups(7) # Returns 1 (prime number)
|
|
52
|
+
"""
|
|
53
|
+
# Try preferred groups first, then common divisors
|
|
54
|
+
for groups in [preferred_groups, 16, 8, 4, 2, 1]:
|
|
55
|
+
if groups <= num_channels and num_channels % groups == 0:
|
|
56
|
+
return groups
|
|
57
|
+
|
|
58
|
+
# Fallback: find any valid divisor (always returns at least 1)
|
|
59
|
+
for groups in range(min(32, num_channels), 0, -1):
|
|
60
|
+
if num_channels % groups == 0:
|
|
61
|
+
return groups
|
|
62
|
+
|
|
63
|
+
return 1 # Always valid
|
|
64
|
+
|
|
65
|
+
|
|
18
66
|
class BaseModel(nn.Module, ABC):
|
|
19
67
|
"""
|
|
20
68
|
Abstract base class for all regression models.
|
wavedl/models/caformer.py
CHANGED
|
@@ -33,7 +33,7 @@ Author: Ductho Le (ductho.le@outlook.com)
|
|
|
33
33
|
import torch
|
|
34
34
|
import torch.nn as nn
|
|
35
35
|
|
|
36
|
-
from wavedl.models.
|
|
36
|
+
from wavedl.models._pretrained_utils import build_regression_head
|
|
37
37
|
from wavedl.models.base import BaseModel
|
|
38
38
|
from wavedl.models.registry import register_model
|
|
39
39
|
|
wavedl/models/cnn.py
CHANGED
|
@@ -24,14 +24,10 @@ from typing import Any
|
|
|
24
24
|
import torch
|
|
25
25
|
import torch.nn as nn
|
|
26
26
|
|
|
27
|
-
from wavedl.models.base import BaseModel
|
|
27
|
+
from wavedl.models.base import BaseModel, SpatialShape, compute_num_groups
|
|
28
28
|
from wavedl.models.registry import register_model
|
|
29
29
|
|
|
30
30
|
|
|
31
|
-
# Type alias for spatial shapes
|
|
32
|
-
SpatialShape = tuple[int] | tuple[int, int] | tuple[int, int, int]
|
|
33
|
-
|
|
34
|
-
|
|
35
31
|
def _get_conv_layers(
|
|
36
32
|
dim: int,
|
|
37
33
|
) -> tuple[type[nn.Module], type[nn.Module], type[nn.Module]]:
|
|
@@ -163,27 +159,6 @@ class CNN(BaseModel):
|
|
|
163
159
|
nn.Linear(64, out_size),
|
|
164
160
|
)
|
|
165
161
|
|
|
166
|
-
@staticmethod
|
|
167
|
-
def _compute_num_groups(num_channels: int, target_groups: int = 4) -> int:
|
|
168
|
-
"""
|
|
169
|
-
Compute valid num_groups for GroupNorm that divides num_channels.
|
|
170
|
-
|
|
171
|
-
Finds the largest divisor of num_channels that is <= target_groups,
|
|
172
|
-
or falls back to 1 if no suitable divisor exists.
|
|
173
|
-
|
|
174
|
-
Args:
|
|
175
|
-
num_channels: Number of channels (must be positive)
|
|
176
|
-
target_groups: Desired number of groups (default: 4)
|
|
177
|
-
|
|
178
|
-
Returns:
|
|
179
|
-
Valid num_groups that satisfies num_channels % num_groups == 0
|
|
180
|
-
"""
|
|
181
|
-
# Try target_groups down to 1, return first valid divisor
|
|
182
|
-
for g in range(min(target_groups, num_channels), 0, -1):
|
|
183
|
-
if num_channels % g == 0:
|
|
184
|
-
return g
|
|
185
|
-
return 1 # Fallback (always valid)
|
|
186
|
-
|
|
187
162
|
def _make_conv_block(
|
|
188
163
|
self, in_channels: int, out_channels: int, dropout: float = 0.0
|
|
189
164
|
) -> nn.Sequential:
|
|
@@ -198,7 +173,7 @@ class CNN(BaseModel):
|
|
|
198
173
|
Returns:
|
|
199
174
|
Sequential block: Conv → GroupNorm → LeakyReLU → MaxPool [→ Dropout]
|
|
200
175
|
"""
|
|
201
|
-
num_groups =
|
|
176
|
+
num_groups = compute_num_groups(out_channels, preferred_groups=4)
|
|
202
177
|
|
|
203
178
|
layers = [
|
|
204
179
|
self._Conv(in_channels, out_channels, kernel_size=3, padding=1),
|
wavedl/models/convnext.py
CHANGED
|
@@ -28,14 +28,10 @@ import torch
|
|
|
28
28
|
import torch.nn as nn
|
|
29
29
|
import torch.nn.functional as F
|
|
30
30
|
|
|
31
|
-
from wavedl.models.base import BaseModel
|
|
31
|
+
from wavedl.models.base import BaseModel, SpatialShape
|
|
32
32
|
from wavedl.models.registry import register_model
|
|
33
33
|
|
|
34
34
|
|
|
35
|
-
# Type alias for spatial shapes
|
|
36
|
-
SpatialShape = tuple[int] | tuple[int, int] | tuple[int, int, int]
|
|
37
|
-
|
|
38
|
-
|
|
39
35
|
def _get_conv_layer(dim: int) -> type[nn.Module]:
|
|
40
36
|
"""Get dimension-appropriate Conv class."""
|
|
41
37
|
if dim == 1:
|
|
@@ -468,20 +464,11 @@ class ConvNeXtTinyPretrained(BaseModel):
|
|
|
468
464
|
)
|
|
469
465
|
|
|
470
466
|
# Modify first conv for single-channel input
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
kernel_size=old_conv.kernel_size,
|
|
476
|
-
stride=old_conv.stride,
|
|
477
|
-
padding=old_conv.padding,
|
|
478
|
-
bias=old_conv.bias is not None,
|
|
467
|
+
from wavedl.models._pretrained_utils import adapt_first_conv_for_single_channel
|
|
468
|
+
|
|
469
|
+
adapt_first_conv_for_single_channel(
|
|
470
|
+
self.backbone, "features.0.0", pretrained=pretrained
|
|
479
471
|
)
|
|
480
|
-
if pretrained:
|
|
481
|
-
with torch.no_grad():
|
|
482
|
-
self.backbone.features[0][0].weight = nn.Parameter(
|
|
483
|
-
old_conv.weight.mean(dim=1, keepdim=True)
|
|
484
|
-
)
|
|
485
472
|
|
|
486
473
|
if freeze_backbone:
|
|
487
474
|
self._freeze_backbone()
|
wavedl/models/convnext_v2.py
CHANGED
|
@@ -31,20 +31,17 @@ from typing import Any
|
|
|
31
31
|
import torch
|
|
32
32
|
import torch.nn as nn
|
|
33
33
|
|
|
34
|
-
from wavedl.models.
|
|
34
|
+
from wavedl.models._pretrained_utils import (
|
|
35
35
|
LayerNormNd,
|
|
36
36
|
build_regression_head,
|
|
37
37
|
get_conv_layer,
|
|
38
38
|
get_grn_layer,
|
|
39
39
|
get_pool_layer,
|
|
40
40
|
)
|
|
41
|
-
from wavedl.models.base import BaseModel
|
|
41
|
+
from wavedl.models.base import BaseModel, SpatialShape
|
|
42
42
|
from wavedl.models.registry import register_model
|
|
43
43
|
|
|
44
44
|
|
|
45
|
-
# Type alias for spatial shapes
|
|
46
|
-
SpatialShape = tuple[int] | tuple[int, int] | tuple[int, int, int]
|
|
47
|
-
|
|
48
45
|
__all__ = [
|
|
49
46
|
"ConvNeXtV2Base",
|
|
50
47
|
"ConvNeXtV2BaseLarge",
|
|
@@ -469,24 +466,11 @@ class ConvNeXtV2TinyPretrained(BaseModel):
|
|
|
469
466
|
|
|
470
467
|
def _adapt_input_channels(self):
|
|
471
468
|
"""Adapt first conv layer for single-channel input."""
|
|
472
|
-
|
|
473
|
-
new_conv = nn.Conv2d(
|
|
474
|
-
1,
|
|
475
|
-
old_conv.out_channels,
|
|
476
|
-
kernel_size=old_conv.kernel_size,
|
|
477
|
-
stride=old_conv.stride,
|
|
478
|
-
padding=old_conv.padding,
|
|
479
|
-
bias=old_conv.bias is not None,
|
|
480
|
-
)
|
|
469
|
+
from wavedl.models._pretrained_utils import adapt_first_conv_for_single_channel
|
|
481
470
|
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
|
|
486
|
-
if old_conv.bias is not None:
|
|
487
|
-
new_conv.bias.copy_(old_conv.bias)
|
|
488
|
-
|
|
489
|
-
self.backbone.features[0][0] = new_conv
|
|
471
|
+
adapt_first_conv_for_single_channel(
|
|
472
|
+
self.backbone, "features.0.0", pretrained=self.pretrained
|
|
473
|
+
)
|
|
490
474
|
|
|
491
475
|
def _freeze_backbone(self):
|
|
492
476
|
"""Freeze all backbone parameters except classifier."""
|