wavedl 1.3.1__py3-none-any.whl → 1.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/hpc.py +48 -28
- wavedl/models/__init__.py +33 -7
- wavedl/models/_template.py +28 -41
- wavedl/models/base.py +49 -2
- wavedl/models/cnn.py +0 -1
- wavedl/models/convnext.py +4 -1
- wavedl/models/densenet.py +4 -1
- wavedl/models/efficientnet.py +9 -5
- wavedl/models/efficientnetv2.py +292 -0
- wavedl/models/mobilenetv3.py +272 -0
- wavedl/models/registry.py +0 -1
- wavedl/models/regnet.py +383 -0
- wavedl/models/resnet.py +7 -4
- wavedl/models/resnet3d.py +258 -0
- wavedl/models/swin.py +390 -0
- wavedl/models/tcn.py +389 -0
- wavedl/models/unet.py +44 -110
- wavedl/models/vit.py +8 -4
- wavedl/train.py +1144 -1116
- wavedl/utils/config.py +88 -2
- {wavedl-1.3.1.dist-info → wavedl-1.4.1.dist-info}/METADATA +136 -98
- wavedl-1.4.1.dist-info/RECORD +37 -0
- wavedl-1.3.1.dist-info/RECORD +0 -31
- {wavedl-1.3.1.dist-info → wavedl-1.4.1.dist-info}/LICENSE +0 -0
- {wavedl-1.3.1.dist-info → wavedl-1.4.1.dist-info}/WHEEL +0 -0
- {wavedl-1.3.1.dist-info → wavedl-1.4.1.dist-info}/entry_points.txt +0 -0
- {wavedl-1.3.1.dist-info → wavedl-1.4.1.dist-info}/top_level.txt +0 -0
wavedl/__init__.py
CHANGED
wavedl/hpc.py
CHANGED
|
@@ -33,7 +33,7 @@ from pathlib import Path
|
|
|
33
33
|
def detect_gpus() -> int:
|
|
34
34
|
"""Auto-detect available GPUs using nvidia-smi."""
|
|
35
35
|
if shutil.which("nvidia-smi") is None:
|
|
36
|
-
print("Warning: nvidia-smi not found, defaulting to 1
|
|
36
|
+
print("Warning: nvidia-smi not found, defaulting to NUM_GPUS=1")
|
|
37
37
|
return 1
|
|
38
38
|
|
|
39
39
|
try:
|
|
@@ -50,7 +50,7 @@ def detect_gpus() -> int:
|
|
|
50
50
|
except (subprocess.CalledProcessError, FileNotFoundError):
|
|
51
51
|
pass
|
|
52
52
|
|
|
53
|
-
print("Warning:
|
|
53
|
+
print("Warning: No GPUs detected, defaulting to NUM_GPUS=1")
|
|
54
54
|
return 1
|
|
55
55
|
|
|
56
56
|
|
|
@@ -61,10 +61,15 @@ def setup_hpc_environment() -> None:
|
|
|
61
61
|
offline logging configurations.
|
|
62
62
|
"""
|
|
63
63
|
# Use SLURM_TMPDIR if available, otherwise system temp
|
|
64
|
-
tmpdir = os.environ.get(
|
|
64
|
+
tmpdir = os.environ.get(
|
|
65
|
+
"SLURM_TMPDIR", os.environ.get("TMPDIR", tempfile.gettempdir())
|
|
66
|
+
)
|
|
65
67
|
|
|
66
68
|
# Configure directories for systems with restricted home directories
|
|
67
69
|
os.environ.setdefault("MPLCONFIGDIR", f"{tmpdir}/matplotlib")
|
|
70
|
+
os.environ.setdefault(
|
|
71
|
+
"FONTCONFIG_PATH", os.environ.get("FONTCONFIG_PATH", "/etc/fonts")
|
|
72
|
+
)
|
|
68
73
|
os.environ.setdefault("XDG_CACHE_HOME", f"{tmpdir}/.cache")
|
|
69
74
|
|
|
70
75
|
# Ensure matplotlib config dir exists
|
|
@@ -125,6 +130,18 @@ Environment Variables:
|
|
|
125
130
|
default=0,
|
|
126
131
|
help="Rank of this machine in multi-node setup (default: 0)",
|
|
127
132
|
)
|
|
133
|
+
parser.add_argument(
|
|
134
|
+
"--main_process_ip",
|
|
135
|
+
type=str,
|
|
136
|
+
default=None,
|
|
137
|
+
help="IP address of the main process for multi-node training",
|
|
138
|
+
)
|
|
139
|
+
parser.add_argument(
|
|
140
|
+
"--main_process_port",
|
|
141
|
+
type=int,
|
|
142
|
+
default=None,
|
|
143
|
+
help="Port for multi-node communication (default: accelerate auto-selects)",
|
|
144
|
+
)
|
|
128
145
|
parser.add_argument(
|
|
129
146
|
"--mixed_precision",
|
|
130
147
|
type=str,
|
|
@@ -147,11 +164,11 @@ Environment Variables:
|
|
|
147
164
|
def print_summary(exit_code: int, wandb_mode: str, wandb_dir: str) -> None:
|
|
148
165
|
"""Print post-training summary and instructions."""
|
|
149
166
|
print()
|
|
150
|
-
print("=" *
|
|
167
|
+
print("=" * 40)
|
|
151
168
|
|
|
152
169
|
if exit_code == 0:
|
|
153
170
|
print("✅ Training completed successfully!")
|
|
154
|
-
print("=" *
|
|
171
|
+
print("=" * 40)
|
|
155
172
|
|
|
156
173
|
if wandb_mode == "offline":
|
|
157
174
|
print()
|
|
@@ -162,15 +179,15 @@ def print_summary(exit_code: int, wandb_mode: str, wandb_dir: str) -> None:
|
|
|
162
179
|
print(" This will upload your training logs to wandb.ai")
|
|
163
180
|
else:
|
|
164
181
|
print(f"❌ Training failed with exit code: {exit_code}")
|
|
165
|
-
print("=" *
|
|
182
|
+
print("=" * 40)
|
|
166
183
|
print()
|
|
167
184
|
print("Common issues:")
|
|
168
185
|
print(" - Missing data file (check --data_path)")
|
|
169
186
|
print(" - Insufficient GPU memory (reduce --batch_size)")
|
|
170
|
-
print(" - Invalid model name (run:
|
|
187
|
+
print(" - Invalid model name (run: python train.py --list_models)")
|
|
171
188
|
print()
|
|
172
189
|
|
|
173
|
-
print("=" *
|
|
190
|
+
print("=" * 40)
|
|
174
191
|
print()
|
|
175
192
|
|
|
176
193
|
|
|
@@ -182,22 +199,38 @@ def main() -> int:
|
|
|
182
199
|
# Setup HPC environment
|
|
183
200
|
setup_hpc_environment()
|
|
184
201
|
|
|
202
|
+
# Check if wavedl package is importable
|
|
203
|
+
try:
|
|
204
|
+
import wavedl # noqa: F401
|
|
205
|
+
except ImportError:
|
|
206
|
+
print("Error: wavedl package not found. Run: pip install -e .", file=sys.stderr)
|
|
207
|
+
return 1
|
|
208
|
+
|
|
185
209
|
# Auto-detect GPUs if not specified
|
|
186
|
-
|
|
210
|
+
if args.num_gpus is not None:
|
|
211
|
+
num_gpus = args.num_gpus
|
|
212
|
+
print(f"Using NUM_GPUS={num_gpus} (set via command line)")
|
|
213
|
+
else:
|
|
214
|
+
num_gpus = detect_gpus()
|
|
187
215
|
|
|
188
216
|
# Build accelerate launch command
|
|
189
217
|
cmd = [
|
|
190
|
-
|
|
191
|
-
"
|
|
192
|
-
"accelerate.commands.launch",
|
|
218
|
+
"accelerate",
|
|
219
|
+
"launch",
|
|
193
220
|
f"--num_processes={num_gpus}",
|
|
194
221
|
f"--num_machines={args.num_machines}",
|
|
195
222
|
f"--machine_rank={args.machine_rank}",
|
|
196
223
|
f"--mixed_precision={args.mixed_precision}",
|
|
197
224
|
f"--dynamo_backend={args.dynamo_backend}",
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
225
|
+
]
|
|
226
|
+
|
|
227
|
+
# Add multi-node networking args if specified (required for some clusters)
|
|
228
|
+
if args.main_process_ip:
|
|
229
|
+
cmd.append(f"--main_process_ip={args.main_process_ip}")
|
|
230
|
+
if args.main_process_port:
|
|
231
|
+
cmd.append(f"--main_process_port={args.main_process_port}")
|
|
232
|
+
|
|
233
|
+
cmd += ["-m", "wavedl.train"] + train_args
|
|
201
234
|
|
|
202
235
|
# Create output directory if specified
|
|
203
236
|
for i, arg in enumerate(train_args):
|
|
@@ -208,19 +241,6 @@ def main() -> int:
|
|
|
208
241
|
Path(arg.split("=", 1)[1]).mkdir(parents=True, exist_ok=True)
|
|
209
242
|
break
|
|
210
243
|
|
|
211
|
-
# Print launch configuration
|
|
212
|
-
print()
|
|
213
|
-
print("=" * 50)
|
|
214
|
-
print("🚀 WaveDL HPC Training Launcher")
|
|
215
|
-
print("=" * 50)
|
|
216
|
-
print(f" GPUs: {num_gpus}")
|
|
217
|
-
print(f" Machines: {args.num_machines}")
|
|
218
|
-
print(f" Mixed Precision: {args.mixed_precision}")
|
|
219
|
-
print(f" Dynamo Backend: {args.dynamo_backend}")
|
|
220
|
-
print(f" WandB Mode: {os.environ.get('WANDB_MODE', 'offline')}")
|
|
221
|
-
print("=" * 50)
|
|
222
|
-
print()
|
|
223
|
-
|
|
224
244
|
# Launch training
|
|
225
245
|
try:
|
|
226
246
|
result = subprocess.run(cmd, check=False)
|
wavedl/models/__init__.py
CHANGED
|
@@ -5,6 +5,12 @@ Model Registry and Factory Pattern for Deep Learning Architectures
|
|
|
5
5
|
This module provides a centralized registry for neural network architectures,
|
|
6
6
|
enabling dynamic model selection via command-line arguments.
|
|
7
7
|
|
|
8
|
+
**Dimensionality Coverage**:
|
|
9
|
+
- 1D (waveforms): TCN, CNN, ResNet, ConvNeXt, DenseNet, ViT
|
|
10
|
+
- 2D (images): CNN, ResNet, ConvNeXt, DenseNet, ViT, UNet,
|
|
11
|
+
EfficientNet, MobileNetV3, RegNet, Swin
|
|
12
|
+
- 3D (volumes): ResNet3D, CNN, ResNet, ConvNeXt, DenseNet
|
|
13
|
+
|
|
8
14
|
Usage:
|
|
9
15
|
from wavedl.models import get_model, list_models, MODEL_REGISTRY
|
|
10
16
|
|
|
@@ -31,7 +37,6 @@ Adding New Models:
|
|
|
31
37
|
...
|
|
32
38
|
|
|
33
39
|
Author: Ductho Le (ductho.le@outlook.com)
|
|
34
|
-
Version: 1.0.0
|
|
35
40
|
"""
|
|
36
41
|
|
|
37
42
|
# Import registry first (no dependencies)
|
|
@@ -43,6 +48,8 @@ from .cnn import CNN
|
|
|
43
48
|
from .convnext import ConvNeXtBase_, ConvNeXtSmall, ConvNeXtTiny
|
|
44
49
|
from .densenet import DenseNet121, DenseNet169
|
|
45
50
|
from .efficientnet import EfficientNetB0, EfficientNetB1, EfficientNetB2
|
|
51
|
+
from .efficientnetv2 import EfficientNetV2L, EfficientNetV2M, EfficientNetV2S
|
|
52
|
+
from .mobilenetv3 import MobileNetV3Large, MobileNetV3Small
|
|
46
53
|
from .registry import (
|
|
47
54
|
MODEL_REGISTRY,
|
|
48
55
|
build_model,
|
|
@@ -50,18 +57,22 @@ from .registry import (
|
|
|
50
57
|
list_models,
|
|
51
58
|
register_model,
|
|
52
59
|
)
|
|
60
|
+
from .regnet import RegNetY1_6GF, RegNetY3_2GF, RegNetY8GF, RegNetY400MF, RegNetY800MF
|
|
53
61
|
from .resnet import ResNet18, ResNet34, ResNet50
|
|
54
|
-
from .
|
|
62
|
+
from .resnet3d import MC3_18, ResNet3D18
|
|
63
|
+
from .swin import SwinBase, SwinSmall, SwinTiny
|
|
64
|
+
from .tcn import TCN, TCNLarge, TCNSmall
|
|
65
|
+
from .unet import UNetRegression
|
|
55
66
|
from .vit import ViTBase_, ViTSmall, ViTTiny
|
|
56
67
|
|
|
57
68
|
|
|
58
|
-
# Export public API
|
|
69
|
+
# Export public API (sorted alphabetically per RUF022)
|
|
70
|
+
# See module docstring for dimensionality support details
|
|
59
71
|
__all__ = [
|
|
60
|
-
# Models
|
|
61
72
|
"CNN",
|
|
62
|
-
|
|
73
|
+
"MC3_18",
|
|
63
74
|
"MODEL_REGISTRY",
|
|
64
|
-
|
|
75
|
+
"TCN",
|
|
65
76
|
"BaseModel",
|
|
66
77
|
"ConvNeXtBase_",
|
|
67
78
|
"ConvNeXtSmall",
|
|
@@ -71,10 +82,25 @@ __all__ = [
|
|
|
71
82
|
"EfficientNetB0",
|
|
72
83
|
"EfficientNetB1",
|
|
73
84
|
"EfficientNetB2",
|
|
85
|
+
"EfficientNetV2L",
|
|
86
|
+
"EfficientNetV2M",
|
|
87
|
+
"EfficientNetV2S",
|
|
88
|
+
"MobileNetV3Large",
|
|
89
|
+
"MobileNetV3Small",
|
|
90
|
+
"RegNetY1_6GF",
|
|
91
|
+
"RegNetY3_2GF",
|
|
92
|
+
"RegNetY8GF",
|
|
93
|
+
"RegNetY400MF",
|
|
94
|
+
"RegNetY800MF",
|
|
95
|
+
"ResNet3D18",
|
|
74
96
|
"ResNet18",
|
|
75
97
|
"ResNet34",
|
|
76
98
|
"ResNet50",
|
|
77
|
-
"
|
|
99
|
+
"SwinBase",
|
|
100
|
+
"SwinSmall",
|
|
101
|
+
"SwinTiny",
|
|
102
|
+
"TCNLarge",
|
|
103
|
+
"TCNSmall",
|
|
78
104
|
"UNetRegression",
|
|
79
105
|
"ViTBase_",
|
|
80
106
|
"ViTSmall",
|
wavedl/models/_template.py
CHANGED
|
@@ -1,24 +1,26 @@
|
|
|
1
1
|
"""
|
|
2
|
-
Model Template for
|
|
3
|
-
|
|
2
|
+
Model Template for Custom Architectures
|
|
3
|
+
========================================
|
|
4
4
|
|
|
5
|
-
Copy this file and modify to add
|
|
5
|
+
Copy this file and modify to add custom model architectures to WaveDL.
|
|
6
6
|
The model will be automatically registered and available via --model flag.
|
|
7
7
|
|
|
8
|
-
|
|
9
|
-
1. Copy this file to
|
|
10
|
-
2. Rename the class and update @register_model("
|
|
11
|
-
3. Implement
|
|
12
|
-
4.
|
|
13
|
-
|
|
14
|
-
|
|
8
|
+
Quick Start:
|
|
9
|
+
1. Copy this file to your project: cp _template.py my_model.py
|
|
10
|
+
2. Rename the class and update @register_model("my_model")
|
|
11
|
+
3. Implement your architecture in __init__ and forward
|
|
12
|
+
4. Train: wavedl-train --import my_model --model my_model --data_path data.npz
|
|
13
|
+
|
|
14
|
+
Requirements (your model MUST):
|
|
15
|
+
1. Inherit from BaseModel
|
|
16
|
+
2. Accept (in_shape, out_size, **kwargs) in __init__
|
|
17
|
+
3. Return tensor of shape (batch, out_size) from forward()
|
|
18
|
+
|
|
19
|
+
See README.md "Adding Custom Models" section for more details.
|
|
15
20
|
|
|
16
21
|
Author: Ductho Le (ductho.le@outlook.com)
|
|
17
|
-
Version: 1.0.0
|
|
18
22
|
"""
|
|
19
23
|
|
|
20
|
-
from typing import Any
|
|
21
|
-
|
|
22
24
|
import torch
|
|
23
25
|
import torch.nn as nn
|
|
24
26
|
|
|
@@ -26,7 +28,7 @@ from wavedl.models.base import BaseModel
|
|
|
26
28
|
|
|
27
29
|
|
|
28
30
|
# Uncomment the decorator to register this model
|
|
29
|
-
# @register_model("
|
|
31
|
+
# @register_model("my_model")
|
|
30
32
|
class TemplateModel(BaseModel):
|
|
31
33
|
"""
|
|
32
34
|
Template Model Architecture.
|
|
@@ -35,14 +37,16 @@ class TemplateModel(BaseModel):
|
|
|
35
37
|
The first line will appear in --list_models output.
|
|
36
38
|
|
|
37
39
|
Args:
|
|
38
|
-
in_shape: Input spatial dimensions (
|
|
39
|
-
|
|
40
|
+
in_shape: Input spatial dimensions (auto-detected from data)
|
|
41
|
+
- 1D: (L,) for signals
|
|
42
|
+
- 2D: (H, W) for images
|
|
43
|
+
- 3D: (D, H, W) for volumes
|
|
44
|
+
out_size: Number of regression targets (auto-detected from data)
|
|
40
45
|
hidden_dim: Size of hidden layers (default: 256)
|
|
41
|
-
num_layers: Number of convolutional layers (default: 4)
|
|
42
46
|
dropout: Dropout rate (default: 0.1)
|
|
43
47
|
|
|
44
48
|
Input Shape:
|
|
45
|
-
(B, 1,
|
|
49
|
+
(B, 1, *in_shape) - e.g., (B, 1, 64, 64) for 2D
|
|
46
50
|
|
|
47
51
|
Output Shape:
|
|
48
52
|
(B, out_size) - Regression predictions
|
|
@@ -50,10 +54,9 @@ class TemplateModel(BaseModel):
|
|
|
50
54
|
|
|
51
55
|
def __init__(
|
|
52
56
|
self,
|
|
53
|
-
in_shape: tuple
|
|
57
|
+
in_shape: tuple,
|
|
54
58
|
out_size: int,
|
|
55
59
|
hidden_dim: int = 256,
|
|
56
|
-
num_layers: int = 4,
|
|
57
60
|
dropout: float = 0.1,
|
|
58
61
|
**kwargs, # Accept extra kwargs for flexibility
|
|
59
62
|
):
|
|
@@ -62,14 +65,13 @@ class TemplateModel(BaseModel):
|
|
|
62
65
|
|
|
63
66
|
# Store hyperparameters as attributes (optional but recommended)
|
|
64
67
|
self.hidden_dim = hidden_dim
|
|
65
|
-
self.num_layers = num_layers
|
|
66
68
|
self.dropout_rate = dropout
|
|
67
69
|
|
|
68
70
|
# =================================================================
|
|
69
71
|
# BUILD YOUR ARCHITECTURE HERE
|
|
70
72
|
# =================================================================
|
|
71
73
|
|
|
72
|
-
# Example: Simple CNN encoder
|
|
74
|
+
# Example: Simple CNN encoder (assumes 2D input with 1 channel)
|
|
73
75
|
self.encoder = nn.Sequential(
|
|
74
76
|
# Layer 1
|
|
75
77
|
nn.Conv2d(1, 32, kernel_size=3, padding=1),
|
|
@@ -107,10 +109,10 @@ class TemplateModel(BaseModel):
|
|
|
107
109
|
"""
|
|
108
110
|
Forward pass of the model.
|
|
109
111
|
|
|
110
|
-
REQUIRED: Must accept (B, C,
|
|
112
|
+
REQUIRED: Must accept (B, C, *spatial) and return (B, out_size)
|
|
111
113
|
|
|
112
114
|
Args:
|
|
113
|
-
x: Input tensor of shape (B, 1,
|
|
115
|
+
x: Input tensor of shape (B, 1, *in_shape)
|
|
114
116
|
|
|
115
117
|
Returns:
|
|
116
118
|
Output tensor of shape (B, out_size)
|
|
@@ -123,35 +125,20 @@ class TemplateModel(BaseModel):
|
|
|
123
125
|
|
|
124
126
|
return output
|
|
125
127
|
|
|
126
|
-
@classmethod
|
|
127
|
-
def get_default_config(cls) -> dict[str, Any]:
|
|
128
|
-
"""
|
|
129
|
-
Return default hyperparameters for this model.
|
|
130
|
-
|
|
131
|
-
OPTIONAL: Override to provide model-specific defaults.
|
|
132
|
-
These can be used for documentation or config files.
|
|
133
|
-
"""
|
|
134
|
-
return {
|
|
135
|
-
"hidden_dim": 256,
|
|
136
|
-
"num_layers": 4,
|
|
137
|
-
"dropout": 0.1,
|
|
138
|
-
}
|
|
139
|
-
|
|
140
128
|
|
|
141
129
|
# =============================================================================
|
|
142
130
|
# USAGE EXAMPLE
|
|
143
131
|
# =============================================================================
|
|
144
132
|
if __name__ == "__main__":
|
|
145
133
|
# Quick test of the model
|
|
146
|
-
model = TemplateModel(in_shape=(
|
|
134
|
+
model = TemplateModel(in_shape=(64, 64), out_size=5)
|
|
147
135
|
|
|
148
136
|
# Print model summary
|
|
149
137
|
print(f"Model: {model.__class__.__name__}")
|
|
150
138
|
print(f"Parameters: {model.count_parameters():,}")
|
|
151
|
-
print(f"Default config: {model.get_default_config()}")
|
|
152
139
|
|
|
153
140
|
# Test forward pass
|
|
154
|
-
dummy_input = torch.randn(2, 1,
|
|
141
|
+
dummy_input = torch.randn(2, 1, 64, 64)
|
|
155
142
|
output = model(dummy_input)
|
|
156
143
|
print(f"Input shape: {dummy_input.shape}")
|
|
157
144
|
print(f"Output shape: {output.shape}")
|
wavedl/models/base.py
CHANGED
|
@@ -6,7 +6,6 @@ Defines the interface contract that all models must implement for compatibility
|
|
|
6
6
|
with the training pipeline. Provides common utilities and enforces consistency.
|
|
7
7
|
|
|
8
8
|
Author: Ductho Le (ductho.le@outlook.com)
|
|
9
|
-
Version: 1.0.0
|
|
10
9
|
"""
|
|
11
10
|
|
|
12
11
|
from abc import ABC, abstractmethod
|
|
@@ -76,13 +75,61 @@ class BaseModel(nn.Module, ABC):
|
|
|
76
75
|
Forward pass of the model.
|
|
77
76
|
|
|
78
77
|
Args:
|
|
79
|
-
x: Input tensor of shape (B, C,
|
|
78
|
+
x: Input tensor of shape (B, C, *spatial_dims)
|
|
79
|
+
- 1D: (B, C, L)
|
|
80
|
+
- 2D: (B, C, H, W)
|
|
81
|
+
- 3D: (B, C, D, H, W)
|
|
80
82
|
|
|
81
83
|
Returns:
|
|
82
84
|
Output tensor of shape (B, out_size)
|
|
83
85
|
"""
|
|
84
86
|
pass
|
|
85
87
|
|
|
88
|
+
def validate_input_shape(self, x: torch.Tensor) -> None:
|
|
89
|
+
"""
|
|
90
|
+
Validate input tensor shape against model's expected shape.
|
|
91
|
+
|
|
92
|
+
Call this at the start of forward() for explicit shape contract enforcement.
|
|
93
|
+
Provides clear, actionable error messages instead of cryptic Conv layer errors.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
x: Input tensor to validate
|
|
97
|
+
|
|
98
|
+
Raises:
|
|
99
|
+
ValueError: If shape doesn't match expected dimensions
|
|
100
|
+
|
|
101
|
+
Example:
|
|
102
|
+
def forward(self, x):
|
|
103
|
+
self.validate_input_shape(x) # Optional but recommended
|
|
104
|
+
return self.model(x)
|
|
105
|
+
"""
|
|
106
|
+
expected_ndim = len(self.in_shape) + 2 # +2 for (batch, channel)
|
|
107
|
+
|
|
108
|
+
if x.ndim != expected_ndim:
|
|
109
|
+
dim_names = {
|
|
110
|
+
3: "1D (B, C, L)",
|
|
111
|
+
4: "2D (B, C, H, W)",
|
|
112
|
+
5: "3D (B, C, D, H, W)",
|
|
113
|
+
}
|
|
114
|
+
expected_name = dim_names.get(expected_ndim, f"{expected_ndim}D")
|
|
115
|
+
actual_name = dim_names.get(x.ndim, f"{x.ndim}D")
|
|
116
|
+
raise ValueError(
|
|
117
|
+
f"Input shape mismatch: model expects {expected_name} input, "
|
|
118
|
+
f"got {actual_name} with shape {tuple(x.shape)}.\n"
|
|
119
|
+
f"Expected in_shape: {self.in_shape} -> input should be (B, C, {', '.join(map(str, self.in_shape))})\n"
|
|
120
|
+
f"Hint: Check your data preprocessing - you may need to add/remove dimensions."
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
# Validate spatial dimensions match
|
|
124
|
+
spatial_dims = tuple(x.shape[2:]) # Skip batch and channel
|
|
125
|
+
if spatial_dims != tuple(self.in_shape):
|
|
126
|
+
raise ValueError(
|
|
127
|
+
f"Spatial dimension mismatch: model expects {self.in_shape}, "
|
|
128
|
+
f"got {spatial_dims}.\n"
|
|
129
|
+
f"Full input shape: {tuple(x.shape)} (B={x.shape[0]}, C={x.shape[1]})\n"
|
|
130
|
+
f"Hint: Ensure your data dimensions match the model's in_shape."
|
|
131
|
+
)
|
|
132
|
+
|
|
86
133
|
def count_parameters(self, trainable_only: bool = True) -> int:
|
|
87
134
|
"""
|
|
88
135
|
Count the number of parameters in the model.
|
wavedl/models/cnn.py
CHANGED
wavedl/models/convnext.py
CHANGED
|
@@ -15,8 +15,11 @@ Features: inverted bottleneck, LayerNorm, GELU activation, depthwise convolution
|
|
|
15
15
|
- convnext_small: Medium (~50M params for 2D)
|
|
16
16
|
- convnext_base: Standard (~89M params for 2D)
|
|
17
17
|
|
|
18
|
+
References:
|
|
19
|
+
Liu, Z., et al. (2022). A ConvNet for the 2020s.
|
|
20
|
+
CVPR 2022. https://arxiv.org/abs/2201.03545
|
|
21
|
+
|
|
18
22
|
Author: Ductho Le (ductho.le@outlook.com)
|
|
19
|
-
Version: 1.0.0
|
|
20
23
|
"""
|
|
21
24
|
|
|
22
25
|
from typing import Any
|
wavedl/models/densenet.py
CHANGED
|
@@ -14,8 +14,11 @@ Features: feature reuse, efficient gradient flow, compact model.
|
|
|
14
14
|
- densenet121: Standard (121 layers, ~8M params for 2D)
|
|
15
15
|
- densenet169: Deeper (169 layers, ~14M params for 2D)
|
|
16
16
|
|
|
17
|
+
References:
|
|
18
|
+
Huang, G., et al. (2017). Densely Connected Convolutional Networks.
|
|
19
|
+
CVPR 2017 (Best Paper). https://arxiv.org/abs/1608.06993
|
|
20
|
+
|
|
17
21
|
Author: Ductho Le (ductho.le@outlook.com)
|
|
18
|
-
Version: 1.0.0
|
|
19
22
|
"""
|
|
20
23
|
|
|
21
24
|
from typing import Any
|
wavedl/models/efficientnet.py
CHANGED
|
@@ -6,14 +6,18 @@ Wrapper around torchvision's EfficientNet with a regression head.
|
|
|
6
6
|
Provides optional ImageNet pretrained weights for transfer learning.
|
|
7
7
|
|
|
8
8
|
**Variants**:
|
|
9
|
-
- efficientnet_b0: Smallest, fastest (
|
|
10
|
-
- efficientnet_b1: Light (7.
|
|
11
|
-
- efficientnet_b2: Balanced (
|
|
9
|
+
- efficientnet_b0: Smallest, fastest (~4.7M params)
|
|
10
|
+
- efficientnet_b1: Light (~7.2M params)
|
|
11
|
+
- efficientnet_b2: Balanced (~8.4M params)
|
|
12
12
|
|
|
13
|
-
**Note**: EfficientNet is 2D-only. For 1D
|
|
13
|
+
**Note**: EfficientNet is 2D-only. For 1D data, use TCN. For 3D data, use ResNet3D.
|
|
14
|
+
|
|
15
|
+
References:
|
|
16
|
+
Tan, M., & Le, Q. (2019). EfficientNet: Rethinking Model Scaling
|
|
17
|
+
for Convolutional Neural Networks. ICML 2019.
|
|
18
|
+
https://arxiv.org/abs/1905.11946
|
|
14
19
|
|
|
15
20
|
Author: Ductho Le (ductho.le@outlook.com)
|
|
16
|
-
Version: 1.0.0
|
|
17
21
|
"""
|
|
18
22
|
|
|
19
23
|
from typing import Any
|