wavedl 1.3.1__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/hpc.py +28 -26
- wavedl/models/__init__.py +33 -7
- wavedl/models/_template.py +0 -1
- wavedl/models/base.py +0 -1
- wavedl/models/cnn.py +0 -1
- wavedl/models/convnext.py +4 -1
- wavedl/models/densenet.py +4 -1
- wavedl/models/efficientnet.py +9 -5
- wavedl/models/efficientnetv2.py +292 -0
- wavedl/models/mobilenetv3.py +272 -0
- wavedl/models/registry.py +0 -1
- wavedl/models/regnet.py +383 -0
- wavedl/models/resnet.py +7 -4
- wavedl/models/resnet3d.py +258 -0
- wavedl/models/swin.py +390 -0
- wavedl/models/tcn.py +389 -0
- wavedl/models/unet.py +44 -110
- wavedl/models/vit.py +8 -4
- wavedl/train.py +1113 -1116
- {wavedl-1.3.1.dist-info → wavedl-1.4.0.dist-info}/METADATA +111 -93
- wavedl-1.4.0.dist-info/RECORD +37 -0
- wavedl-1.3.1.dist-info/RECORD +0 -31
- {wavedl-1.3.1.dist-info → wavedl-1.4.0.dist-info}/LICENSE +0 -0
- {wavedl-1.3.1.dist-info → wavedl-1.4.0.dist-info}/WHEEL +0 -0
- {wavedl-1.3.1.dist-info → wavedl-1.4.0.dist-info}/entry_points.txt +0 -0
- {wavedl-1.3.1.dist-info → wavedl-1.4.0.dist-info}/top_level.txt +0 -0
wavedl/__init__.py
CHANGED
wavedl/hpc.py
CHANGED
|
@@ -33,7 +33,7 @@ from pathlib import Path
|
|
|
33
33
|
def detect_gpus() -> int:
|
|
34
34
|
"""Auto-detect available GPUs using nvidia-smi."""
|
|
35
35
|
if shutil.which("nvidia-smi") is None:
|
|
36
|
-
print("Warning: nvidia-smi not found, defaulting to 1
|
|
36
|
+
print("Warning: nvidia-smi not found, defaulting to NUM_GPUS=1")
|
|
37
37
|
return 1
|
|
38
38
|
|
|
39
39
|
try:
|
|
@@ -50,7 +50,7 @@ def detect_gpus() -> int:
|
|
|
50
50
|
except (subprocess.CalledProcessError, FileNotFoundError):
|
|
51
51
|
pass
|
|
52
52
|
|
|
53
|
-
print("Warning:
|
|
53
|
+
print("Warning: No GPUs detected, defaulting to NUM_GPUS=1")
|
|
54
54
|
return 1
|
|
55
55
|
|
|
56
56
|
|
|
@@ -61,10 +61,15 @@ def setup_hpc_environment() -> None:
|
|
|
61
61
|
offline logging configurations.
|
|
62
62
|
"""
|
|
63
63
|
# Use SLURM_TMPDIR if available, otherwise system temp
|
|
64
|
-
tmpdir = os.environ.get(
|
|
64
|
+
tmpdir = os.environ.get(
|
|
65
|
+
"SLURM_TMPDIR", os.environ.get("TMPDIR", tempfile.gettempdir())
|
|
66
|
+
)
|
|
65
67
|
|
|
66
68
|
# Configure directories for systems with restricted home directories
|
|
67
69
|
os.environ.setdefault("MPLCONFIGDIR", f"{tmpdir}/matplotlib")
|
|
70
|
+
os.environ.setdefault(
|
|
71
|
+
"FONTCONFIG_PATH", os.environ.get("FONTCONFIG_PATH", "/etc/fonts")
|
|
72
|
+
)
|
|
68
73
|
os.environ.setdefault("XDG_CACHE_HOME", f"{tmpdir}/.cache")
|
|
69
74
|
|
|
70
75
|
# Ensure matplotlib config dir exists
|
|
@@ -147,11 +152,11 @@ Environment Variables:
|
|
|
147
152
|
def print_summary(exit_code: int, wandb_mode: str, wandb_dir: str) -> None:
|
|
148
153
|
"""Print post-training summary and instructions."""
|
|
149
154
|
print()
|
|
150
|
-
print("=" *
|
|
155
|
+
print("=" * 40)
|
|
151
156
|
|
|
152
157
|
if exit_code == 0:
|
|
153
158
|
print("✅ Training completed successfully!")
|
|
154
|
-
print("=" *
|
|
159
|
+
print("=" * 40)
|
|
155
160
|
|
|
156
161
|
if wandb_mode == "offline":
|
|
157
162
|
print()
|
|
@@ -162,15 +167,15 @@ def print_summary(exit_code: int, wandb_mode: str, wandb_dir: str) -> None:
|
|
|
162
167
|
print(" This will upload your training logs to wandb.ai")
|
|
163
168
|
else:
|
|
164
169
|
print(f"❌ Training failed with exit code: {exit_code}")
|
|
165
|
-
print("=" *
|
|
170
|
+
print("=" * 40)
|
|
166
171
|
print()
|
|
167
172
|
print("Common issues:")
|
|
168
173
|
print(" - Missing data file (check --data_path)")
|
|
169
174
|
print(" - Insufficient GPU memory (reduce --batch_size)")
|
|
170
|
-
print(" - Invalid model name (run:
|
|
175
|
+
print(" - Invalid model name (run: python train.py --list_models)")
|
|
171
176
|
print()
|
|
172
177
|
|
|
173
|
-
print("=" *
|
|
178
|
+
print("=" * 40)
|
|
174
179
|
print()
|
|
175
180
|
|
|
176
181
|
|
|
@@ -182,17 +187,27 @@ def main() -> int:
|
|
|
182
187
|
# Setup HPC environment
|
|
183
188
|
setup_hpc_environment()
|
|
184
189
|
|
|
190
|
+
# Check if wavedl package is importable
|
|
191
|
+
try:
|
|
192
|
+
import wavedl # noqa: F401
|
|
193
|
+
except ImportError:
|
|
194
|
+
print("Error: wavedl package not found. Run: pip install -e .", file=sys.stderr)
|
|
195
|
+
return 1
|
|
196
|
+
|
|
185
197
|
# Auto-detect GPUs if not specified
|
|
186
|
-
|
|
198
|
+
if args.num_gpus is not None:
|
|
199
|
+
num_gpus = args.num_gpus
|
|
200
|
+
print(f"Using NUM_GPUS={num_gpus} (set via command line)")
|
|
201
|
+
else:
|
|
202
|
+
num_gpus = detect_gpus()
|
|
187
203
|
|
|
188
204
|
# Build accelerate launch command
|
|
189
205
|
cmd = [
|
|
190
|
-
|
|
191
|
-
"
|
|
192
|
-
"accelerate.commands.launch",
|
|
206
|
+
"accelerate",
|
|
207
|
+
"launch",
|
|
193
208
|
f"--num_processes={num_gpus}",
|
|
194
209
|
f"--num_machines={args.num_machines}",
|
|
195
|
-
|
|
210
|
+
"--machine_rank=0",
|
|
196
211
|
f"--mixed_precision={args.mixed_precision}",
|
|
197
212
|
f"--dynamo_backend={args.dynamo_backend}",
|
|
198
213
|
"-m",
|
|
@@ -208,19 +223,6 @@ def main() -> int:
|
|
|
208
223
|
Path(arg.split("=", 1)[1]).mkdir(parents=True, exist_ok=True)
|
|
209
224
|
break
|
|
210
225
|
|
|
211
|
-
# Print launch configuration
|
|
212
|
-
print()
|
|
213
|
-
print("=" * 50)
|
|
214
|
-
print("🚀 WaveDL HPC Training Launcher")
|
|
215
|
-
print("=" * 50)
|
|
216
|
-
print(f" GPUs: {num_gpus}")
|
|
217
|
-
print(f" Machines: {args.num_machines}")
|
|
218
|
-
print(f" Mixed Precision: {args.mixed_precision}")
|
|
219
|
-
print(f" Dynamo Backend: {args.dynamo_backend}")
|
|
220
|
-
print(f" WandB Mode: {os.environ.get('WANDB_MODE', 'offline')}")
|
|
221
|
-
print("=" * 50)
|
|
222
|
-
print()
|
|
223
|
-
|
|
224
226
|
# Launch training
|
|
225
227
|
try:
|
|
226
228
|
result = subprocess.run(cmd, check=False)
|
wavedl/models/__init__.py
CHANGED
|
@@ -5,6 +5,12 @@ Model Registry and Factory Pattern for Deep Learning Architectures
|
|
|
5
5
|
This module provides a centralized registry for neural network architectures,
|
|
6
6
|
enabling dynamic model selection via command-line arguments.
|
|
7
7
|
|
|
8
|
+
**Dimensionality Coverage**:
|
|
9
|
+
- 1D (waveforms): TCN, CNN, ResNet, ConvNeXt, DenseNet, ViT
|
|
10
|
+
- 2D (images): CNN, ResNet, ConvNeXt, DenseNet, ViT, UNet,
|
|
11
|
+
EfficientNet, MobileNetV3, RegNet, Swin
|
|
12
|
+
- 3D (volumes): ResNet3D, CNN, ResNet, ConvNeXt, DenseNet
|
|
13
|
+
|
|
8
14
|
Usage:
|
|
9
15
|
from wavedl.models import get_model, list_models, MODEL_REGISTRY
|
|
10
16
|
|
|
@@ -31,7 +37,6 @@ Adding New Models:
|
|
|
31
37
|
...
|
|
32
38
|
|
|
33
39
|
Author: Ductho Le (ductho.le@outlook.com)
|
|
34
|
-
Version: 1.0.0
|
|
35
40
|
"""
|
|
36
41
|
|
|
37
42
|
# Import registry first (no dependencies)
|
|
@@ -43,6 +48,8 @@ from .cnn import CNN
|
|
|
43
48
|
from .convnext import ConvNeXtBase_, ConvNeXtSmall, ConvNeXtTiny
|
|
44
49
|
from .densenet import DenseNet121, DenseNet169
|
|
45
50
|
from .efficientnet import EfficientNetB0, EfficientNetB1, EfficientNetB2
|
|
51
|
+
from .efficientnetv2 import EfficientNetV2L, EfficientNetV2M, EfficientNetV2S
|
|
52
|
+
from .mobilenetv3 import MobileNetV3Large, MobileNetV3Small
|
|
46
53
|
from .registry import (
|
|
47
54
|
MODEL_REGISTRY,
|
|
48
55
|
build_model,
|
|
@@ -50,18 +57,22 @@ from .registry import (
|
|
|
50
57
|
list_models,
|
|
51
58
|
register_model,
|
|
52
59
|
)
|
|
60
|
+
from .regnet import RegNetY1_6GF, RegNetY3_2GF, RegNetY8GF, RegNetY400MF, RegNetY800MF
|
|
53
61
|
from .resnet import ResNet18, ResNet34, ResNet50
|
|
54
|
-
from .
|
|
62
|
+
from .resnet3d import MC3_18, ResNet3D18
|
|
63
|
+
from .swin import SwinBase, SwinSmall, SwinTiny
|
|
64
|
+
from .tcn import TCN, TCNLarge, TCNSmall
|
|
65
|
+
from .unet import UNetRegression
|
|
55
66
|
from .vit import ViTBase_, ViTSmall, ViTTiny
|
|
56
67
|
|
|
57
68
|
|
|
58
|
-
# Export public API
|
|
69
|
+
# Export public API (sorted alphabetically per RUF022)
|
|
70
|
+
# See module docstring for dimensionality support details
|
|
59
71
|
__all__ = [
|
|
60
|
-
# Models
|
|
61
72
|
"CNN",
|
|
62
|
-
|
|
73
|
+
"MC3_18",
|
|
63
74
|
"MODEL_REGISTRY",
|
|
64
|
-
|
|
75
|
+
"TCN",
|
|
65
76
|
"BaseModel",
|
|
66
77
|
"ConvNeXtBase_",
|
|
67
78
|
"ConvNeXtSmall",
|
|
@@ -71,10 +82,25 @@ __all__ = [
|
|
|
71
82
|
"EfficientNetB0",
|
|
72
83
|
"EfficientNetB1",
|
|
73
84
|
"EfficientNetB2",
|
|
85
|
+
"EfficientNetV2L",
|
|
86
|
+
"EfficientNetV2M",
|
|
87
|
+
"EfficientNetV2S",
|
|
88
|
+
"MobileNetV3Large",
|
|
89
|
+
"MobileNetV3Small",
|
|
90
|
+
"RegNetY1_6GF",
|
|
91
|
+
"RegNetY3_2GF",
|
|
92
|
+
"RegNetY8GF",
|
|
93
|
+
"RegNetY400MF",
|
|
94
|
+
"RegNetY800MF",
|
|
95
|
+
"ResNet3D18",
|
|
74
96
|
"ResNet18",
|
|
75
97
|
"ResNet34",
|
|
76
98
|
"ResNet50",
|
|
77
|
-
"
|
|
99
|
+
"SwinBase",
|
|
100
|
+
"SwinSmall",
|
|
101
|
+
"SwinTiny",
|
|
102
|
+
"TCNLarge",
|
|
103
|
+
"TCNSmall",
|
|
78
104
|
"UNetRegression",
|
|
79
105
|
"ViTBase_",
|
|
80
106
|
"ViTSmall",
|
wavedl/models/_template.py
CHANGED
wavedl/models/base.py
CHANGED
|
@@ -6,7 +6,6 @@ Defines the interface contract that all models must implement for compatibility
|
|
|
6
6
|
with the training pipeline. Provides common utilities and enforces consistency.
|
|
7
7
|
|
|
8
8
|
Author: Ductho Le (ductho.le@outlook.com)
|
|
9
|
-
Version: 1.0.0
|
|
10
9
|
"""
|
|
11
10
|
|
|
12
11
|
from abc import ABC, abstractmethod
|
wavedl/models/cnn.py
CHANGED
wavedl/models/convnext.py
CHANGED
|
@@ -15,8 +15,11 @@ Features: inverted bottleneck, LayerNorm, GELU activation, depthwise convolution
|
|
|
15
15
|
- convnext_small: Medium (~50M params for 2D)
|
|
16
16
|
- convnext_base: Standard (~89M params for 2D)
|
|
17
17
|
|
|
18
|
+
References:
|
|
19
|
+
Liu, Z., et al. (2022). A ConvNet for the 2020s.
|
|
20
|
+
CVPR 2022. https://arxiv.org/abs/2201.03545
|
|
21
|
+
|
|
18
22
|
Author: Ductho Le (ductho.le@outlook.com)
|
|
19
|
-
Version: 1.0.0
|
|
20
23
|
"""
|
|
21
24
|
|
|
22
25
|
from typing import Any
|
wavedl/models/densenet.py
CHANGED
|
@@ -14,8 +14,11 @@ Features: feature reuse, efficient gradient flow, compact model.
|
|
|
14
14
|
- densenet121: Standard (121 layers, ~8M params for 2D)
|
|
15
15
|
- densenet169: Deeper (169 layers, ~14M params for 2D)
|
|
16
16
|
|
|
17
|
+
References:
|
|
18
|
+
Huang, G., et al. (2017). Densely Connected Convolutional Networks.
|
|
19
|
+
CVPR 2017 (Best Paper). https://arxiv.org/abs/1608.06993
|
|
20
|
+
|
|
17
21
|
Author: Ductho Le (ductho.le@outlook.com)
|
|
18
|
-
Version: 1.0.0
|
|
19
22
|
"""
|
|
20
23
|
|
|
21
24
|
from typing import Any
|
wavedl/models/efficientnet.py
CHANGED
|
@@ -6,14 +6,18 @@ Wrapper around torchvision's EfficientNet with a regression head.
|
|
|
6
6
|
Provides optional ImageNet pretrained weights for transfer learning.
|
|
7
7
|
|
|
8
8
|
**Variants**:
|
|
9
|
-
- efficientnet_b0: Smallest, fastest (
|
|
10
|
-
- efficientnet_b1: Light (7.
|
|
11
|
-
- efficientnet_b2: Balanced (
|
|
9
|
+
- efficientnet_b0: Smallest, fastest (~4.7M params)
|
|
10
|
+
- efficientnet_b1: Light (~7.2M params)
|
|
11
|
+
- efficientnet_b2: Balanced (~8.4M params)
|
|
12
12
|
|
|
13
|
-
**Note**: EfficientNet is 2D-only. For 1D
|
|
13
|
+
**Note**: EfficientNet is 2D-only. For 1D data, use TCN. For 3D data, use ResNet3D.
|
|
14
|
+
|
|
15
|
+
References:
|
|
16
|
+
Tan, M., & Le, Q. (2019). EfficientNet: Rethinking Model Scaling
|
|
17
|
+
for Convolutional Neural Networks. ICML 2019.
|
|
18
|
+
https://arxiv.org/abs/1905.11946
|
|
14
19
|
|
|
15
20
|
Author: Ductho Le (ductho.le@outlook.com)
|
|
16
|
-
Version: 1.0.0
|
|
17
21
|
"""
|
|
18
22
|
|
|
19
23
|
from typing import Any
|
|
@@ -0,0 +1,292 @@
|
|
|
1
|
+
"""
|
|
2
|
+
EfficientNetV2: Faster Training and Better Accuracy
|
|
3
|
+
====================================================
|
|
4
|
+
|
|
5
|
+
Next-generation EfficientNet with improved training efficiency and performance.
|
|
6
|
+
EfficientNetV2 replaces early depthwise convolutions with fused MBConv blocks,
|
|
7
|
+
enabling 2-4× faster training while achieving better accuracy.
|
|
8
|
+
|
|
9
|
+
**Key Improvements over EfficientNet**:
|
|
10
|
+
- Fused-MBConv in early stages (faster on accelerators)
|
|
11
|
+
- Progressive learning support (start small, grow)
|
|
12
|
+
- Better NAS-optimized architecture
|
|
13
|
+
|
|
14
|
+
**Variants**:
|
|
15
|
+
- efficientnet_v2_s: Small (21.5M params) - Recommended default
|
|
16
|
+
- efficientnet_v2_m: Medium (54.1M params) - Higher accuracy
|
|
17
|
+
- efficientnet_v2_l: Large (118.5M params) - Maximum accuracy
|
|
18
|
+
|
|
19
|
+
**Note**: EfficientNetV2 is 2D-only. For 1D data, use TCN. For 3D data, use ResNet3D.
|
|
20
|
+
|
|
21
|
+
References:
|
|
22
|
+
Tan, M., & Le, Q. (2021). EfficientNetV2: Smaller Models and Faster Training.
|
|
23
|
+
ICML 2021. https://arxiv.org/abs/2104.00298
|
|
24
|
+
|
|
25
|
+
Author: Ductho Le (ductho.le@outlook.com)
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
from typing import Any
|
|
29
|
+
|
|
30
|
+
import torch
|
|
31
|
+
import torch.nn as nn
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
try:
|
|
35
|
+
from torchvision.models import (
|
|
36
|
+
EfficientNet_V2_L_Weights,
|
|
37
|
+
EfficientNet_V2_M_Weights,
|
|
38
|
+
EfficientNet_V2_S_Weights,
|
|
39
|
+
efficientnet_v2_l,
|
|
40
|
+
efficientnet_v2_m,
|
|
41
|
+
efficientnet_v2_s,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
EFFICIENTNETV2_AVAILABLE = True
|
|
45
|
+
except ImportError:
|
|
46
|
+
EFFICIENTNETV2_AVAILABLE = False
|
|
47
|
+
|
|
48
|
+
from wavedl.models.base import BaseModel
|
|
49
|
+
from wavedl.models.registry import register_model
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class EfficientNetV2Base(BaseModel):
|
|
53
|
+
"""
|
|
54
|
+
Base EfficientNetV2 class for regression tasks.
|
|
55
|
+
|
|
56
|
+
Wraps torchvision EfficientNetV2 with:
|
|
57
|
+
- Optional pretrained weights (ImageNet-1K)
|
|
58
|
+
- Automatic input channel adaptation (grayscale → 3ch)
|
|
59
|
+
- Custom multi-layer regression head
|
|
60
|
+
|
|
61
|
+
Compared to EfficientNet (V1):
|
|
62
|
+
- 2-4× faster training on GPU/TPU
|
|
63
|
+
- Better accuracy at similar parameter counts
|
|
64
|
+
- More efficient at higher resolutions
|
|
65
|
+
|
|
66
|
+
Note: This is 2D-only. Input shape must be (H, W).
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
def __init__(
|
|
70
|
+
self,
|
|
71
|
+
in_shape: tuple[int, int],
|
|
72
|
+
out_size: int,
|
|
73
|
+
model_fn,
|
|
74
|
+
weights_class,
|
|
75
|
+
pretrained: bool = True,
|
|
76
|
+
dropout_rate: float = 0.3,
|
|
77
|
+
freeze_backbone: bool = False,
|
|
78
|
+
regression_hidden: int = 512,
|
|
79
|
+
**kwargs,
|
|
80
|
+
):
|
|
81
|
+
"""
|
|
82
|
+
Initialize EfficientNetV2 for regression.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
in_shape: (H, W) input image dimensions
|
|
86
|
+
out_size: Number of regression output targets
|
|
87
|
+
model_fn: torchvision model constructor
|
|
88
|
+
weights_class: Pretrained weights enum class
|
|
89
|
+
pretrained: Use ImageNet pretrained weights (default: True)
|
|
90
|
+
dropout_rate: Dropout rate in regression head (default: 0.3)
|
|
91
|
+
freeze_backbone: Freeze backbone for fine-tuning (default: False)
|
|
92
|
+
regression_hidden: Hidden units in regression head (default: 512)
|
|
93
|
+
"""
|
|
94
|
+
super().__init__(in_shape, out_size)
|
|
95
|
+
|
|
96
|
+
if not EFFICIENTNETV2_AVAILABLE:
|
|
97
|
+
raise ImportError(
|
|
98
|
+
"torchvision >= 0.13 is required for EfficientNetV2. "
|
|
99
|
+
"Install with: pip install torchvision>=0.13"
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
if len(in_shape) != 2:
|
|
103
|
+
raise ValueError(
|
|
104
|
+
f"EfficientNetV2 requires 2D input (H, W), got {len(in_shape)}D. "
|
|
105
|
+
"For 1D data, use TCN. For 3D data, use ResNet3D."
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
self.pretrained = pretrained
|
|
109
|
+
self.dropout_rate = dropout_rate
|
|
110
|
+
self.freeze_backbone = freeze_backbone
|
|
111
|
+
self.regression_hidden = regression_hidden
|
|
112
|
+
|
|
113
|
+
# Load pretrained backbone
|
|
114
|
+
weights = weights_class.IMAGENET1K_V1 if pretrained else None
|
|
115
|
+
self.backbone = model_fn(weights=weights)
|
|
116
|
+
|
|
117
|
+
# Get classifier input features (before the final classification layer)
|
|
118
|
+
in_features = self.backbone.classifier[1].in_features
|
|
119
|
+
|
|
120
|
+
# Replace classifier with regression head
|
|
121
|
+
# EfficientNetV2 benefits from a deeper regression head
|
|
122
|
+
self.backbone.classifier = nn.Sequential(
|
|
123
|
+
nn.Dropout(dropout_rate),
|
|
124
|
+
nn.Linear(in_features, regression_hidden),
|
|
125
|
+
nn.SiLU(inplace=True), # SiLU (Swish) matches EfficientNet's activation
|
|
126
|
+
nn.Dropout(dropout_rate * 0.5),
|
|
127
|
+
nn.Linear(regression_hidden, regression_hidden // 2),
|
|
128
|
+
nn.SiLU(inplace=True),
|
|
129
|
+
nn.Linear(regression_hidden // 2, out_size),
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
# Optionally freeze backbone for fine-tuning
|
|
133
|
+
if freeze_backbone:
|
|
134
|
+
self._freeze_backbone()
|
|
135
|
+
|
|
136
|
+
def _freeze_backbone(self):
|
|
137
|
+
"""Freeze all backbone parameters except the classifier."""
|
|
138
|
+
for name, param in self.backbone.named_parameters():
|
|
139
|
+
if "classifier" not in name:
|
|
140
|
+
param.requires_grad = False
|
|
141
|
+
|
|
142
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
143
|
+
"""
|
|
144
|
+
Forward pass.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
x: Input tensor of shape (B, C, H, W) where C is 1 or 3
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
Output tensor of shape (B, out_size)
|
|
151
|
+
"""
|
|
152
|
+
# Expand single channel to 3 channels for pretrained weights compatibility
|
|
153
|
+
if x.size(1) == 1:
|
|
154
|
+
x = x.expand(-1, 3, -1, -1)
|
|
155
|
+
|
|
156
|
+
return self.backbone(x)
|
|
157
|
+
|
|
158
|
+
@classmethod
|
|
159
|
+
def get_default_config(cls) -> dict[str, Any]:
|
|
160
|
+
"""Return default configuration for EfficientNetV2."""
|
|
161
|
+
return {
|
|
162
|
+
"pretrained": True,
|
|
163
|
+
"dropout_rate": 0.3,
|
|
164
|
+
"freeze_backbone": False,
|
|
165
|
+
"regression_hidden": 512,
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
# =============================================================================
|
|
170
|
+
# REGISTERED MODEL VARIANTS
|
|
171
|
+
# =============================================================================
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
@register_model("efficientnet_v2_s")
|
|
175
|
+
class EfficientNetV2S(EfficientNetV2Base):
|
|
176
|
+
"""
|
|
177
|
+
EfficientNetV2-S: Small variant, recommended default.
|
|
178
|
+
|
|
179
|
+
~21.5M parameters. Best balance of speed and accuracy for most tasks.
|
|
180
|
+
2× faster training than EfficientNet-B4 with better accuracy.
|
|
181
|
+
|
|
182
|
+
Recommended for:
|
|
183
|
+
- Default choice for 2D wave data
|
|
184
|
+
- Moderate compute budgets
|
|
185
|
+
- When training speed matters
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
in_shape: (H, W) image dimensions
|
|
189
|
+
out_size: Number of regression targets
|
|
190
|
+
pretrained: Use ImageNet pretrained weights (default: True)
|
|
191
|
+
dropout_rate: Dropout rate in head (default: 0.3)
|
|
192
|
+
freeze_backbone: Freeze backbone for fine-tuning (default: False)
|
|
193
|
+
regression_hidden: Hidden units in regression head (default: 512)
|
|
194
|
+
|
|
195
|
+
Example:
|
|
196
|
+
>>> model = EfficientNetV2S(in_shape=(500, 500), out_size=3)
|
|
197
|
+
>>> x = torch.randn(4, 1, 500, 500)
|
|
198
|
+
>>> out = model(x) # (4, 3)
|
|
199
|
+
"""
|
|
200
|
+
|
|
201
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
202
|
+
super().__init__(
|
|
203
|
+
in_shape=in_shape,
|
|
204
|
+
out_size=out_size,
|
|
205
|
+
model_fn=efficientnet_v2_s,
|
|
206
|
+
weights_class=EfficientNet_V2_S_Weights,
|
|
207
|
+
**kwargs,
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
def __repr__(self) -> str:
|
|
211
|
+
pt = "pretrained" if self.pretrained else "scratch"
|
|
212
|
+
return f"EfficientNetV2_S({pt}, in={self.in_shape}, out={self.out_size})"
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
@register_model("efficientnet_v2_m")
|
|
216
|
+
class EfficientNetV2M(EfficientNetV2Base):
|
|
217
|
+
"""
|
|
218
|
+
EfficientNetV2-M: Medium variant for higher accuracy.
|
|
219
|
+
|
|
220
|
+
~54.1M parameters. Use when accuracy is more important than speed.
|
|
221
|
+
|
|
222
|
+
Recommended for:
|
|
223
|
+
- Large datasets (>50k samples)
|
|
224
|
+
- Complex wave patterns
|
|
225
|
+
- When compute is not a bottleneck
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
in_shape: (H, W) image dimensions
|
|
229
|
+
out_size: Number of regression targets
|
|
230
|
+
pretrained: Use ImageNet pretrained weights (default: True)
|
|
231
|
+
dropout_rate: Dropout rate in head (default: 0.3)
|
|
232
|
+
freeze_backbone: Freeze backbone for fine-tuning (default: False)
|
|
233
|
+
regression_hidden: Hidden units in regression head (default: 512)
|
|
234
|
+
|
|
235
|
+
Example:
|
|
236
|
+
>>> model = EfficientNetV2M(in_shape=(500, 500), out_size=3)
|
|
237
|
+
>>> x = torch.randn(4, 1, 500, 500)
|
|
238
|
+
>>> out = model(x) # (4, 3)
|
|
239
|
+
"""
|
|
240
|
+
|
|
241
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
242
|
+
super().__init__(
|
|
243
|
+
in_shape=in_shape,
|
|
244
|
+
out_size=out_size,
|
|
245
|
+
model_fn=efficientnet_v2_m,
|
|
246
|
+
weights_class=EfficientNet_V2_M_Weights,
|
|
247
|
+
**kwargs,
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
def __repr__(self) -> str:
|
|
251
|
+
pt = "pretrained" if self.pretrained else "scratch"
|
|
252
|
+
return f"EfficientNetV2_M({pt}, in={self.in_shape}, out={self.out_size})"
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
@register_model("efficientnet_v2_l")
|
|
256
|
+
class EfficientNetV2L(EfficientNetV2Base):
|
|
257
|
+
"""
|
|
258
|
+
EfficientNetV2-L: Large variant for maximum accuracy.
|
|
259
|
+
|
|
260
|
+
~118.5M parameters. Use only with large datasets and sufficient compute.
|
|
261
|
+
|
|
262
|
+
Recommended for:
|
|
263
|
+
- Very large datasets (>100k samples)
|
|
264
|
+
- When maximum accuracy is critical
|
|
265
|
+
- HPC environments with ample GPU memory
|
|
266
|
+
|
|
267
|
+
Args:
|
|
268
|
+
in_shape: (H, W) image dimensions
|
|
269
|
+
out_size: Number of regression targets
|
|
270
|
+
pretrained: Use ImageNet pretrained weights (default: True)
|
|
271
|
+
dropout_rate: Dropout rate in head (default: 0.3)
|
|
272
|
+
freeze_backbone: Freeze backbone for fine-tuning (default: False)
|
|
273
|
+
regression_hidden: Hidden units in regression head (default: 512)
|
|
274
|
+
|
|
275
|
+
Example:
|
|
276
|
+
>>> model = EfficientNetV2L(in_shape=(500, 500), out_size=3)
|
|
277
|
+
>>> x = torch.randn(4, 1, 500, 500)
|
|
278
|
+
>>> out = model(x) # (4, 3)
|
|
279
|
+
"""
|
|
280
|
+
|
|
281
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
282
|
+
super().__init__(
|
|
283
|
+
in_shape=in_shape,
|
|
284
|
+
out_size=out_size,
|
|
285
|
+
model_fn=efficientnet_v2_l,
|
|
286
|
+
weights_class=EfficientNet_V2_L_Weights,
|
|
287
|
+
**kwargs,
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
def __repr__(self) -> str:
|
|
291
|
+
pt = "pretrained" if self.pretrained else "scratch"
|
|
292
|
+
return f"EfficientNetV2_L({pt}, in={self.in_shape}, out={self.out_size})"
|