wavedl 1.3.1__py3-none-any.whl → 1.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/hpc.py +48 -28
- wavedl/models/__init__.py +33 -7
- wavedl/models/_template.py +28 -41
- wavedl/models/base.py +49 -2
- wavedl/models/cnn.py +0 -1
- wavedl/models/convnext.py +4 -1
- wavedl/models/densenet.py +4 -1
- wavedl/models/efficientnet.py +9 -5
- wavedl/models/efficientnetv2.py +292 -0
- wavedl/models/mobilenetv3.py +272 -0
- wavedl/models/registry.py +0 -1
- wavedl/models/regnet.py +383 -0
- wavedl/models/resnet.py +7 -4
- wavedl/models/resnet3d.py +258 -0
- wavedl/models/swin.py +390 -0
- wavedl/models/tcn.py +389 -0
- wavedl/models/unet.py +44 -110
- wavedl/models/vit.py +8 -4
- wavedl/train.py +1144 -1116
- wavedl/utils/config.py +88 -2
- {wavedl-1.3.1.dist-info → wavedl-1.4.1.dist-info}/METADATA +136 -98
- wavedl-1.4.1.dist-info/RECORD +37 -0
- wavedl-1.3.1.dist-info/RECORD +0 -31
- {wavedl-1.3.1.dist-info → wavedl-1.4.1.dist-info}/LICENSE +0 -0
- {wavedl-1.3.1.dist-info → wavedl-1.4.1.dist-info}/WHEEL +0 -0
- {wavedl-1.3.1.dist-info → wavedl-1.4.1.dist-info}/entry_points.txt +0 -0
- {wavedl-1.3.1.dist-info → wavedl-1.4.1.dist-info}/top_level.txt +0 -0
wavedl/utils/config.py
CHANGED
|
@@ -116,7 +116,13 @@ def merge_config_with_args(
|
|
|
116
116
|
"""
|
|
117
117
|
# Get parser defaults to detect which args were explicitly set by user
|
|
118
118
|
if parser is not None:
|
|
119
|
-
|
|
119
|
+
# Safe extraction: iterate actions instead of parse_args([])
|
|
120
|
+
# This avoids failures if required arguments are added later
|
|
121
|
+
defaults = {
|
|
122
|
+
action.dest: action.default
|
|
123
|
+
for action in parser._actions
|
|
124
|
+
if action.dest != "help"
|
|
125
|
+
}
|
|
120
126
|
else:
|
|
121
127
|
# Fallback: reconstruct defaults from known patterns
|
|
122
128
|
# This works because argparse stores actual values, and we compare
|
|
@@ -141,6 +147,9 @@ def merge_config_with_args(
|
|
|
141
147
|
setattr(args, key, value)
|
|
142
148
|
elif not ignore_unknown:
|
|
143
149
|
logging.warning(f"Unknown config key: {key}")
|
|
150
|
+
else:
|
|
151
|
+
# Even in ignore_unknown mode, log for discoverability
|
|
152
|
+
logging.debug(f"Config key '{key}' ignored: not a valid argument")
|
|
144
153
|
|
|
145
154
|
return args
|
|
146
155
|
|
|
@@ -188,12 +197,15 @@ def save_config(
|
|
|
188
197
|
return str(output_path)
|
|
189
198
|
|
|
190
199
|
|
|
191
|
-
def validate_config(
|
|
200
|
+
def validate_config(
|
|
201
|
+
config: dict[str, Any], known_keys: list[str] | None = None
|
|
202
|
+
) -> list[str]:
|
|
192
203
|
"""
|
|
193
204
|
Validate configuration values against known options.
|
|
194
205
|
|
|
195
206
|
Args:
|
|
196
207
|
config: Configuration dictionary
|
|
208
|
+
known_keys: Optional list of valid keys (if None, uses defaults from parser args)
|
|
197
209
|
|
|
198
210
|
Returns:
|
|
199
211
|
List of warning messages (empty if valid)
|
|
@@ -229,9 +241,83 @@ def validate_config(config: dict[str, Any]) -> list[str]:
|
|
|
229
241
|
for key, (min_val, max_val, msg) in numeric_checks.items():
|
|
230
242
|
if key in config:
|
|
231
243
|
val = config[key]
|
|
244
|
+
# Type check: ensure value is numeric before comparison
|
|
245
|
+
if not isinstance(val, (int, float)):
|
|
246
|
+
warnings.append(
|
|
247
|
+
f"Invalid type for '{key}': expected number, got {type(val).__name__} ({val!r})"
|
|
248
|
+
)
|
|
249
|
+
continue
|
|
232
250
|
if not (min_val <= val <= max_val):
|
|
233
251
|
warnings.append(f"{msg}: got {val}")
|
|
234
252
|
|
|
253
|
+
# Check for unknown/unrecognized keys (helps catch typos)
|
|
254
|
+
# Default known keys based on common training arguments
|
|
255
|
+
default_known_keys = {
|
|
256
|
+
# Model
|
|
257
|
+
"model",
|
|
258
|
+
"import_modules",
|
|
259
|
+
# Hyperparameters
|
|
260
|
+
"batch_size",
|
|
261
|
+
"lr",
|
|
262
|
+
"epochs",
|
|
263
|
+
"patience",
|
|
264
|
+
"weight_decay",
|
|
265
|
+
"grad_clip",
|
|
266
|
+
# Loss
|
|
267
|
+
"loss",
|
|
268
|
+
"huber_delta",
|
|
269
|
+
"loss_weights",
|
|
270
|
+
# Optimizer
|
|
271
|
+
"optimizer",
|
|
272
|
+
"momentum",
|
|
273
|
+
"nesterov",
|
|
274
|
+
"betas",
|
|
275
|
+
# Scheduler
|
|
276
|
+
"scheduler",
|
|
277
|
+
"scheduler_patience",
|
|
278
|
+
"min_lr",
|
|
279
|
+
"scheduler_factor",
|
|
280
|
+
"warmup_epochs",
|
|
281
|
+
"step_size",
|
|
282
|
+
"milestones",
|
|
283
|
+
# Data
|
|
284
|
+
"data_path",
|
|
285
|
+
"workers",
|
|
286
|
+
"seed",
|
|
287
|
+
"single_channel",
|
|
288
|
+
# Cross-validation
|
|
289
|
+
"cv",
|
|
290
|
+
"cv_stratify",
|
|
291
|
+
"cv_bins",
|
|
292
|
+
# Checkpointing
|
|
293
|
+
"resume",
|
|
294
|
+
"save_every",
|
|
295
|
+
"output_dir",
|
|
296
|
+
"fresh",
|
|
297
|
+
# Performance
|
|
298
|
+
"compile",
|
|
299
|
+
"precision",
|
|
300
|
+
"mixed_precision",
|
|
301
|
+
# Logging
|
|
302
|
+
"wandb",
|
|
303
|
+
"wandb_watch",
|
|
304
|
+
"project_name",
|
|
305
|
+
"run_name",
|
|
306
|
+
# Config
|
|
307
|
+
"config",
|
|
308
|
+
"list_models",
|
|
309
|
+
# Metadata (internal)
|
|
310
|
+
"_metadata",
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
check_keys = set(known_keys) if known_keys else default_known_keys
|
|
314
|
+
|
|
315
|
+
for key in config:
|
|
316
|
+
if key not in check_keys:
|
|
317
|
+
warnings.append(
|
|
318
|
+
f"Unknown config key: '{key}' - check for typos or see wavedl-train --help"
|
|
319
|
+
)
|
|
320
|
+
|
|
235
321
|
return warnings
|
|
236
322
|
|
|
237
323
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: wavedl
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.4.1
|
|
4
4
|
Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
|
|
5
5
|
Author: Ductho Le
|
|
6
6
|
License: MIT
|
|
@@ -30,31 +30,18 @@ Requires-Dist: scikit-learn>=1.2.0
|
|
|
30
30
|
Requires-Dist: pandas>=2.0.0
|
|
31
31
|
Requires-Dist: matplotlib>=3.7.0
|
|
32
32
|
Requires-Dist: tqdm>=4.65.0
|
|
33
|
-
Requires-Dist: wandb>=0.15.0
|
|
34
33
|
Requires-Dist: pyyaml>=6.0.0
|
|
35
34
|
Requires-Dist: h5py>=3.8.0
|
|
36
35
|
Requires-Dist: safetensors>=0.3.0
|
|
37
|
-
|
|
38
|
-
Requires-Dist:
|
|
39
|
-
Requires-Dist:
|
|
40
|
-
Requires-Dist:
|
|
41
|
-
Requires-Dist:
|
|
42
|
-
|
|
43
|
-
Requires-Dist:
|
|
44
|
-
Requires-Dist:
|
|
45
|
-
|
|
46
|
-
Requires-Dist: triton; sys_platform == "linux" and extra == "compile"
|
|
47
|
-
Provides-Extra: hpo
|
|
48
|
-
Requires-Dist: optuna>=3.0.0; extra == "hpo"
|
|
49
|
-
Provides-Extra: all
|
|
50
|
-
Requires-Dist: pytest>=7.0.0; extra == "all"
|
|
51
|
-
Requires-Dist: pytest-xdist>=3.5.0; extra == "all"
|
|
52
|
-
Requires-Dist: ruff>=0.8.0; extra == "all"
|
|
53
|
-
Requires-Dist: pre-commit>=3.5.0; extra == "all"
|
|
54
|
-
Requires-Dist: onnx>=1.14.0; extra == "all"
|
|
55
|
-
Requires-Dist: onnxruntime>=1.15.0; extra == "all"
|
|
56
|
-
Requires-Dist: triton; sys_platform == "linux" and extra == "all"
|
|
57
|
-
Requires-Dist: optuna>=3.0.0; extra == "all"
|
|
36
|
+
Requires-Dist: wandb>=0.15.0
|
|
37
|
+
Requires-Dist: optuna>=3.0.0
|
|
38
|
+
Requires-Dist: onnx>=1.14.0
|
|
39
|
+
Requires-Dist: onnxruntime>=1.15.0
|
|
40
|
+
Requires-Dist: pytest>=7.0.0
|
|
41
|
+
Requires-Dist: pytest-xdist>=3.5.0
|
|
42
|
+
Requires-Dist: ruff>=0.8.0
|
|
43
|
+
Requires-Dist: pre-commit>=3.5.0
|
|
44
|
+
Requires-Dist: triton>=2.0.0; sys_platform == "linux"
|
|
58
45
|
|
|
59
46
|
<div align="center">
|
|
60
47
|
|
|
@@ -210,20 +197,20 @@ Deploy models anywhere:
|
|
|
210
197
|
|
|
211
198
|
### Installation
|
|
212
199
|
|
|
200
|
+
#### From PyPI (recommended for all users)
|
|
201
|
+
|
|
213
202
|
```bash
|
|
214
|
-
# Install from PyPI (recommended)
|
|
215
203
|
pip install wavedl
|
|
216
|
-
|
|
217
|
-
# Or install with all extras (ONNX export, HPO, dev tools)
|
|
218
|
-
pip install wavedl[all]
|
|
219
204
|
```
|
|
220
205
|
|
|
206
|
+
This installs everything you need: training, inference, HPO, ONNX export, and dev tools.
|
|
207
|
+
|
|
221
208
|
#### From Source (for development)
|
|
222
209
|
|
|
223
210
|
```bash
|
|
224
211
|
git clone https://github.com/ductho-le/WaveDL.git
|
|
225
212
|
cd WaveDL
|
|
226
|
-
pip install -e
|
|
213
|
+
pip install -e .
|
|
227
214
|
```
|
|
228
215
|
|
|
229
216
|
> [!NOTE]
|
|
@@ -359,41 +346,47 @@ WaveDL handles everything else: training loop, logging, checkpoints, multi-GPU,
|
|
|
359
346
|
```
|
|
360
347
|
WaveDL/
|
|
361
348
|
├── src/
|
|
362
|
-
│ └── wavedl/
|
|
363
|
-
│ ├── __init__.py
|
|
364
|
-
│ ├── train.py
|
|
365
|
-
│ ├── test.py
|
|
366
|
-
│ ├── hpo.py
|
|
367
|
-
│ ├── hpc.py
|
|
349
|
+
│ └── wavedl/ # Main package (namespaced)
|
|
350
|
+
│ ├── __init__.py # Package init with __version__
|
|
351
|
+
│ ├── train.py # Training entry point
|
|
352
|
+
│ ├── test.py # Testing & inference script
|
|
353
|
+
│ ├── hpo.py # Hyperparameter optimization
|
|
354
|
+
│ ├── hpc.py # HPC distributed training launcher
|
|
368
355
|
│ │
|
|
369
|
-
│ ├── models/
|
|
370
|
-
│ │ ├── registry.py
|
|
371
|
-
│ │ ├── base.py
|
|
372
|
-
│ │ ├── cnn.py
|
|
373
|
-
│ │ ├── resnet.py
|
|
374
|
-
│ │ ├──
|
|
375
|
-
│ │ ├──
|
|
376
|
-
│ │ ├──
|
|
377
|
-
│ │ ├──
|
|
378
|
-
│ │
|
|
356
|
+
│ ├── models/ # Model architectures (38 variants)
|
|
357
|
+
│ │ ├── registry.py # Model factory (@register_model)
|
|
358
|
+
│ │ ├── base.py # Abstract base class
|
|
359
|
+
│ │ ├── cnn.py # Baseline CNN (1D/2D/3D)
|
|
360
|
+
│ │ ├── resnet.py # ResNet-18/34/50 (1D/2D/3D)
|
|
361
|
+
│ │ ├── resnet3d.py # ResNet3D-18, MC3-18 (3D only)
|
|
362
|
+
│ │ ├── tcn.py # TCN (1D only)
|
|
363
|
+
│ │ ├── efficientnet.py # EfficientNet-B0/B1/B2 (2D)
|
|
364
|
+
│ │ ├── efficientnetv2.py # EfficientNetV2-S/M/L (2D)
|
|
365
|
+
│ │ ├── mobilenetv3.py # MobileNetV3-Small/Large (2D)
|
|
366
|
+
│ │ ├── regnet.py # RegNetY variants (2D)
|
|
367
|
+
│ │ ├── swin.py # Swin Transformer (2D)
|
|
368
|
+
│ │ ├── vit.py # Vision Transformer (1D/2D)
|
|
369
|
+
│ │ ├── convnext.py # ConvNeXt (1D/2D/3D)
|
|
370
|
+
│ │ ├── densenet.py # DenseNet-121/169 (1D/2D/3D)
|
|
371
|
+
│ │ └── unet.py # U-Net Regression
|
|
379
372
|
│ │
|
|
380
|
-
│ └── utils/
|
|
381
|
-
│ ├── data.py
|
|
382
|
-
│ ├── metrics.py
|
|
383
|
-
│ ├── distributed.py
|
|
384
|
-
│ ├── losses.py
|
|
385
|
-
│ ├── optimizers.py
|
|
386
|
-
│ ├── schedulers.py
|
|
387
|
-
│ └── config.py
|
|
373
|
+
│ └── utils/ # Utilities
|
|
374
|
+
│ ├── data.py # Memory-mapped data pipeline
|
|
375
|
+
│ ├── metrics.py # R², Pearson, visualization
|
|
376
|
+
│ ├── distributed.py # DDP synchronization
|
|
377
|
+
│ ├── losses.py # Loss function factory
|
|
378
|
+
│ ├── optimizers.py # Optimizer factory
|
|
379
|
+
│ ├── schedulers.py # LR scheduler factory
|
|
380
|
+
│ └── config.py # YAML configuration support
|
|
388
381
|
│
|
|
389
|
-
├── configs/
|
|
390
|
-
├── examples/
|
|
391
|
-
├── notebooks/
|
|
392
|
-
├── unit_tests/
|
|
382
|
+
├── configs/ # YAML config templates
|
|
383
|
+
├── examples/ # Ready-to-run examples
|
|
384
|
+
├── notebooks/ # Jupyter notebooks
|
|
385
|
+
├── unit_tests/ # Pytest test suite (704 tests)
|
|
393
386
|
│
|
|
394
|
-
├── pyproject.toml
|
|
395
|
-
├── CHANGELOG.md
|
|
396
|
-
└── CITATION.cff
|
|
387
|
+
├── pyproject.toml # Package config, dependencies
|
|
388
|
+
├── CHANGELOG.md # Version history
|
|
389
|
+
└── CITATION.cff # Citation metadata
|
|
397
390
|
```
|
|
398
391
|
---
|
|
399
392
|
|
|
@@ -412,33 +405,63 @@ WaveDL/
|
|
|
412
405
|
> ```
|
|
413
406
|
|
|
414
407
|
<details>
|
|
415
|
-
<summary><b>Available Models</b> —
|
|
416
|
-
|
|
417
|
-
| Model |
|
|
418
|
-
|
|
419
|
-
|
|
|
420
|
-
| `
|
|
421
|
-
|
|
|
422
|
-
| `
|
|
423
|
-
| `
|
|
424
|
-
| `
|
|
425
|
-
| `
|
|
426
|
-
| `
|
|
427
|
-
|
|
|
428
|
-
| `
|
|
429
|
-
| `
|
|
430
|
-
|
|
|
431
|
-
| `
|
|
432
|
-
| `
|
|
433
|
-
| `
|
|
434
|
-
|
|
|
435
|
-
| `
|
|
436
|
-
| `
|
|
437
|
-
| `
|
|
438
|
-
|
|
|
439
|
-
| `
|
|
440
|
-
|
|
441
|
-
|
|
408
|
+
<summary><b>Available Models</b> — 38 architectures</summary>
|
|
409
|
+
|
|
410
|
+
| Model | Params | Dim |
|
|
411
|
+
|-------|--------|-----|
|
|
412
|
+
| **CNN** — Convolutional Neural Network |||
|
|
413
|
+
| `cnn` | 1.7M | 1D/2D/3D |
|
|
414
|
+
| **ResNet** — Residual Network |||
|
|
415
|
+
| `resnet18` | 11.4M | 1D/2D/3D |
|
|
416
|
+
| `resnet34` | 21.5M | 1D/2D/3D |
|
|
417
|
+
| `resnet50` | 24.6M | 1D/2D/3D |
|
|
418
|
+
| `resnet18_pretrained` ⭐ | 11.4M | 2D |
|
|
419
|
+
| `resnet50_pretrained` ⭐ | 24.6M | 2D |
|
|
420
|
+
| **ResNet3D** — 3D Residual Network |||
|
|
421
|
+
| `resnet3d_18` | 33.6M | 3D |
|
|
422
|
+
| `mc3_18` — Mixed Convolution 3D | 11.9M | 3D |
|
|
423
|
+
| **TCN** — Temporal Convolutional Network |||
|
|
424
|
+
| `tcn_small` | 1.0M | 1D |
|
|
425
|
+
| `tcn` | 7.0M | 1D |
|
|
426
|
+
| `tcn_large` | 10.2M | 1D |
|
|
427
|
+
| **EfficientNet** — Efficient Neural Network |||
|
|
428
|
+
| `efficientnet_b0` ⭐ | 4.7M | 2D |
|
|
429
|
+
| `efficientnet_b1` ⭐ | 7.2M | 2D |
|
|
430
|
+
| `efficientnet_b2` ⭐ | 8.4M | 2D |
|
|
431
|
+
| **EfficientNetV2** — Efficient Neural Network V2 |||
|
|
432
|
+
| `efficientnet_v2_s` ⭐ | 21.0M | 2D |
|
|
433
|
+
| `efficientnet_v2_m` ⭐ | 53.6M | 2D |
|
|
434
|
+
| `efficientnet_v2_l` ⭐ | 118.0M | 2D |
|
|
435
|
+
| **MobileNetV3** — Mobile Neural Network V3 |||
|
|
436
|
+
| `mobilenet_v3_small` ⭐ | 1.1M | 2D |
|
|
437
|
+
| `mobilenet_v3_large` ⭐ | 3.2M | 2D |
|
|
438
|
+
| **RegNet** — Regularized Network |||
|
|
439
|
+
| `regnet_y_400mf` ⭐ | 4.0M | 2D |
|
|
440
|
+
| `regnet_y_800mf` ⭐ | 5.8M | 2D |
|
|
441
|
+
| `regnet_y_1_6gf` ⭐ | 10.5M | 2D |
|
|
442
|
+
| `regnet_y_3_2gf` ⭐ | 18.3M | 2D |
|
|
443
|
+
| `regnet_y_8gf` ⭐ | 37.9M | 2D |
|
|
444
|
+
| **Swin** — Shifted Window Transformer |||
|
|
445
|
+
| `swin_t` ⭐ | 28.0M | 2D |
|
|
446
|
+
| `swin_s` ⭐ | 49.4M | 2D |
|
|
447
|
+
| `swin_b` ⭐ | 87.4M | 2D |
|
|
448
|
+
| **ConvNeXt** — Convolutional Next |||
|
|
449
|
+
| `convnext_tiny` | 28.2M | 1D/2D/3D |
|
|
450
|
+
| `convnext_small` | 49.8M | 1D/2D/3D |
|
|
451
|
+
| `convnext_base` | 88.1M | 1D/2D/3D |
|
|
452
|
+
| `convnext_tiny_pretrained` ⭐ | 28.2M | 2D |
|
|
453
|
+
| **DenseNet** — Densely Connected Network |||
|
|
454
|
+
| `densenet121` | 7.5M | 1D/2D/3D |
|
|
455
|
+
| `densenet169` | 13.3M | 1D/2D/3D |
|
|
456
|
+
| `densenet121_pretrained` ⭐ | 7.5M | 2D |
|
|
457
|
+
| **ViT** — Vision Transformer |||
|
|
458
|
+
| `vit_tiny` | 5.5M | 1D/2D |
|
|
459
|
+
| `vit_small` | 21.6M | 1D/2D |
|
|
460
|
+
| `vit_base` | 85.6M | 1D/2D |
|
|
461
|
+
| **U-Net** — U-shaped Network |||
|
|
462
|
+
| `unet_regression` | 31.1M | 1D/2D/3D |
|
|
463
|
+
|
|
464
|
+
> ⭐ = Pretrained on ImageNet. Recommended for smaller datasets.
|
|
442
465
|
|
|
443
466
|
</details>
|
|
444
467
|
|
|
@@ -479,12 +502,32 @@ WaveDL/
|
|
|
479
502
|
|
|
480
503
|
| Argument | Default | Description |
|
|
481
504
|
|----------|---------|-------------|
|
|
482
|
-
| `--compile` | `False` | Enable `torch.compile` |
|
|
505
|
+
| `--compile` | `False` | Enable `torch.compile` (recommended for long runs) |
|
|
483
506
|
| `--precision` | `bf16` | Mixed precision mode (`bf16`, `fp16`, `no`) |
|
|
507
|
+
| `--workers` | `-1` | DataLoader workers per GPU (-1=auto, up to 16) |
|
|
484
508
|
| `--wandb` | `False` | Enable W&B logging |
|
|
509
|
+
| `--wandb_watch` | `False` | Enable W&B gradient watching (adds overhead) |
|
|
485
510
|
| `--project_name` | `DL-Training` | W&B project name |
|
|
486
511
|
| `--run_name` | `None` | W&B run name (auto-generated if not set) |
|
|
487
512
|
|
|
513
|
+
**Automatic GPU Optimizations:**
|
|
514
|
+
|
|
515
|
+
WaveDL automatically enables performance optimizations for modern GPUs:
|
|
516
|
+
|
|
517
|
+
| Optimization | Effect | GPU Support |
|
|
518
|
+
|--------------|--------|-------------|
|
|
519
|
+
| **TF32 precision** | ~2x speedup for float32 matmul | A100, H100 (Ampere+) |
|
|
520
|
+
| **cuDNN benchmark** | Auto-tuned convolutions | All NVIDIA GPUs |
|
|
521
|
+
| **Worker scaling** | Up to 16 workers per GPU | All systems |
|
|
522
|
+
|
|
523
|
+
> [!NOTE]
|
|
524
|
+
> These optimizations are **backward compatible** — they have no effect on older GPUs (V100, T4, GTX) or CPU-only systems. No configuration needed.
|
|
525
|
+
|
|
526
|
+
**HPC Best Practices:**
|
|
527
|
+
- Stage data to `$SLURM_TMPDIR` (local NVMe) for maximum I/O throughput
|
|
528
|
+
- Use `--compile` for training runs > 50 epochs
|
|
529
|
+
- Increase `--workers` manually if auto-detection is suboptimal
|
|
530
|
+
|
|
488
531
|
</details>
|
|
489
532
|
|
|
490
533
|
<details>
|
|
@@ -653,12 +696,7 @@ seed: 2025
|
|
|
653
696
|
|
|
654
697
|
Automatically find the best training configuration using [Optuna](https://optuna.org/).
|
|
655
698
|
|
|
656
|
-
**
|
|
657
|
-
```bash
|
|
658
|
-
pip install -e ".[hpo]"
|
|
659
|
-
```
|
|
660
|
-
|
|
661
|
-
**Step 2: Run HPO**
|
|
699
|
+
**Run HPO:**
|
|
662
700
|
|
|
663
701
|
You specify which models to search and how many trials to run:
|
|
664
702
|
```bash
|
|
@@ -672,7 +710,7 @@ python -m wavedl.hpo --data_path train.npz --models cnn --n_trials 50
|
|
|
672
710
|
python -m wavedl.hpo --data_path train.npz --models cnn resnet18 resnet50 vit_small densenet121 --n_trials 200
|
|
673
711
|
```
|
|
674
712
|
|
|
675
|
-
**
|
|
713
|
+
**Train with best parameters**
|
|
676
714
|
|
|
677
715
|
After HPO completes, it prints the optimal command:
|
|
678
716
|
```bash
|
|
@@ -715,7 +753,7 @@ accelerate launch -m wavedl.train --data_path train.npz --model cnn --lr 3.2e-4
|
|
|
715
753
|
| `--output` | `hpo_results.json` | Output file |
|
|
716
754
|
|
|
717
755
|
> [!TIP]
|
|
718
|
-
> See [Available Models](#available-models) for all
|
|
756
|
+
> See [Available Models](#available-models) for all 38 architectures you can search.
|
|
719
757
|
|
|
720
758
|
</details>
|
|
721
759
|
|
|
@@ -819,7 +857,7 @@ import numpy as np
|
|
|
819
857
|
X = np.random.randn(1000, 256, 256).astype(np.float32)
|
|
820
858
|
y = np.random.randn(1000, 5).astype(np.float32)
|
|
821
859
|
|
|
822
|
-
np.savez('test_data.npz',
|
|
860
|
+
np.savez('test_data.npz', input_test=X, output_test=y)
|
|
823
861
|
```
|
|
824
862
|
|
|
825
863
|
</details>
|
|
@@ -831,7 +869,7 @@ np.savez('test_data.npz', input_train=X, output_train=y)
|
|
|
831
869
|
import numpy as np
|
|
832
870
|
|
|
833
871
|
data = np.load('train_data.npz')
|
|
834
|
-
assert data['input_train'].ndim
|
|
872
|
+
assert data['input_train'].ndim >= 2, "Input must be at least 2D: (N, ...) "
|
|
835
873
|
assert data['output_train'].ndim == 2, "Output must be 2D: (N, T)"
|
|
836
874
|
assert len(data['input_train']) == len(data['output_train']), "Sample mismatch"
|
|
837
875
|
|
|
@@ -986,7 +1024,7 @@ Beyond the material characterization example above, the WaveDL pipeline can be a
|
|
|
986
1024
|
| Resource | Description |
|
|
987
1025
|
|----------|-------------|
|
|
988
1026
|
| Technical Paper | In-depth framework description *(coming soon)* |
|
|
989
|
-
| [`_template.py`](models/_template.py) | Template for
|
|
1027
|
+
| [`_template.py`](src/wavedl/models/_template.py) | Template for custom architectures |
|
|
990
1028
|
|
|
991
1029
|
---
|
|
992
1030
|
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
wavedl/__init__.py,sha256=2LU5rtHKoYgXBAZ4zGNtFcHjrTtmmYskXnaURHEwkNc,1177
|
|
2
|
+
wavedl/hpc.py,sha256=de_GKERX8GS10sXRX9yXiGzMnk1jjq8JPzRw7QDs6d4,7967
|
|
3
|
+
wavedl/hpo.py,sha256=aZoa_Oto_anZpIhz-YM6kN8KxQXTolUvDEyg3NXwBrY,11542
|
|
4
|
+
wavedl/test.py,sha256=jZmRJaivYYTMMTaccCi0yQjHOfp0a9YWR1wAPeKFH-k,36246
|
|
5
|
+
wavedl/train.py,sha256=e0tX7_j2gkuYpPjZJqGoDV8arAe4bc4YVRMyrg-RcRY,46402
|
|
6
|
+
wavedl/models/__init__.py,sha256=lfSohEnAUztO14nuwayMJhPjpgySzRN3jGiyAUuBmAU,3206
|
|
7
|
+
wavedl/models/_template.py,sha256=J_D8taSPmV8lBaucN_vU-WiG98iFr7CJrZVNNX_Tdts,4600
|
|
8
|
+
wavedl/models/base.py,sha256=T9iDF9IQM2MYucG_ggQd31rieUkB2fob-nkHyNIl2ak,7337
|
|
9
|
+
wavedl/models/cnn.py,sha256=rn2Xmup0w_ll6wuAnYclSeIVazoSUrUGPY-9XnhA1gE,8341
|
|
10
|
+
wavedl/models/convnext.py,sha256=5zELY0ztMB6FxJB9uBurloT7JBdxLXezmrNRzLQjrI0,12846
|
|
11
|
+
wavedl/models/densenet.py,sha256=LzNbQOvtcJJ4SVf-XvIlXGNUgVS2SXl-MMPbr8lcYrA,12995
|
|
12
|
+
wavedl/models/efficientnet.py,sha256=0DHBgEGaOucevtmO1KPUTb5bCdJRg-Gzfpu9EuaylGQ,7456
|
|
13
|
+
wavedl/models/efficientnetv2.py,sha256=rP8y1ZAWyNyi0PXGPXg-4HjgzoELZ-CjMFgr8WnSXeg,10244
|
|
14
|
+
wavedl/models/mobilenetv3.py,sha256=h3f6TiNSyHRH9Qidce7dCGTbdEWYfYF5kbU-TFoTg0U,9490
|
|
15
|
+
wavedl/models/registry.py,sha256=InYAXX2xbRvsFDFnYUPCptJh0F9lHlFPN77A9kqHRT0,2980
|
|
16
|
+
wavedl/models/regnet.py,sha256=Yf9gAoDLv0j4uEuoKC822gizHNh59LCbvFCMP11Q1C0,13116
|
|
17
|
+
wavedl/models/resnet.py,sha256=8DNGIrH5pK8pjEE9BSyBqIc_pkFS_qaYggx-stjTF5k,16961
|
|
18
|
+
wavedl/models/resnet3d.py,sha256=C7CL4XeSnRlIBuwf5Ei-z183uzIBObrXfkM9Iwuc5e0,8746
|
|
19
|
+
wavedl/models/swin.py,sha256=p-okfq3Qm4_neJTxCcMzoHoVzC0BHW3BMnbpr_Ri2U0,13224
|
|
20
|
+
wavedl/models/tcn.py,sha256=RtY13QpFHqz72b4ultv2lStCIDxfvjySVe5JaTx_GaM,12601
|
|
21
|
+
wavedl/models/unet.py,sha256=LqIXhasdBygwP7SZNNmiW1bHMPaJTVBpaeHtPgEHkdU,7790
|
|
22
|
+
wavedl/models/vit.py,sha256=0C3GZk11VsYFTl14d86Wtl1Zk1T5rYJjvkaEfEN4N3k,11100
|
|
23
|
+
wavedl/utils/__init__.py,sha256=YMgzuwndjr64kt9k0_6_9PMJYTVdiaH5veSMff_ZycA,3051
|
|
24
|
+
wavedl/utils/config.py,sha256=fMoucikIQHn85mVhGMa7TnXTuFDcEEPjfXk2EjbkJR0,10591
|
|
25
|
+
wavedl/utils/cross_validation.py,sha256=117ac9KDzaIaqhtP8ZRs15Xpqmq5fLpX2-vqkNvtMaU,17487
|
|
26
|
+
wavedl/utils/data.py,sha256=9LrB9MC6jRZzbRSc9xiGzJWoh8FahwP_68REqBAT3Os,44131
|
|
27
|
+
wavedl/utils/distributed.py,sha256=7wQ3mRjkp_xjPSxDWMnBf5dSkAGUaTzntxbz0BhC5v0,4145
|
|
28
|
+
wavedl/utils/losses.py,sha256=5762M-TBC_hz6uyj1NPbU1vZeFOJQq7fR3-j7OygJRo,7254
|
|
29
|
+
wavedl/utils/metrics.py,sha256=mkCpqZwl_XUpNvA5Ekjf7y-HqApafR7eR6EuA8cBdM8,37287
|
|
30
|
+
wavedl/utils/optimizers.py,sha256=PyIkJ_hRhFi_Fio81Gy5YQNhcME0JUUEl8OTSyu-0RA,6323
|
|
31
|
+
wavedl/utils/schedulers.py,sha256=e6Sf0yj8VOqkdwkUHLMyUfGfHKTX4NMr-zfgxWqCTYI,7659
|
|
32
|
+
wavedl-1.4.1.dist-info/LICENSE,sha256=cEUCvcvH-9BT9Y-CNGY__PwWONCKu9zsoIqWA-NeHJ4,1066
|
|
33
|
+
wavedl-1.4.1.dist-info/METADATA,sha256=FEafy9hY2su6bB8iS8VNZceLpTs9E7nhVaejsOEHTUM,40245
|
|
34
|
+
wavedl-1.4.1.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
|
|
35
|
+
wavedl-1.4.1.dist-info/entry_points.txt,sha256=f1RNDkXFZwBzrBzTMFocJ6xhfTvTmaEDTi5YyDEUaF8,140
|
|
36
|
+
wavedl-1.4.1.dist-info/top_level.txt,sha256=ccneUt3D5Qzbh3bsBSSrq9bqrhGiogcWKY24ZC4Q6Xw,7
|
|
37
|
+
wavedl-1.4.1.dist-info/RECORD,,
|
wavedl-1.3.1.dist-info/RECORD
DELETED
|
@@ -1,31 +0,0 @@
|
|
|
1
|
-
wavedl/__init__.py,sha256=5EO4WDuyQksw2UQnnojmuA6asc7_Ew9qtLCF-dxo_qo,1177
|
|
2
|
-
wavedl/hpc.py,sha256=OaiGo0Q_ylu6tCEZSnMZ9ohk3nWcqbnwNMXrbZgikF0,7325
|
|
3
|
-
wavedl/hpo.py,sha256=aZoa_Oto_anZpIhz-YM6kN8KxQXTolUvDEyg3NXwBrY,11542
|
|
4
|
-
wavedl/test.py,sha256=jZmRJaivYYTMMTaccCi0yQjHOfp0a9YWR1wAPeKFH-k,36246
|
|
5
|
-
wavedl/train.py,sha256=dO64C2ktW6on1wbYVdZSPk6w-ZzZbzOGym-2xi-gk_g,43868
|
|
6
|
-
wavedl/models/__init__.py,sha256=AbsFkRNlsiWv4sJ-kLPdwjA2FS_cSp_TB3CV8884uUE,2219
|
|
7
|
-
wavedl/models/_template.py,sha256=O7SfL3Ef7eDXGmcOXPD0c82o_t3K4ybgJwpSEDsZNEg,4837
|
|
8
|
-
wavedl/models/base.py,sha256=cql0wv8i1sMaVttXOSdBBTPfa2s2sLH5LyAsfKJdXX8,5304
|
|
9
|
-
wavedl/models/cnn.py,sha256=2FFQetQaCJqeeku6glXbOQ3KJw5VvSTu9-u9cpygVk8,8356
|
|
10
|
-
wavedl/models/convnext.py,sha256=zh-x5NFcZrcRv3bi55p-VKWHLYe-v1nvPcMp9xPizLk,12747
|
|
11
|
-
wavedl/models/densenet.py,sha256=q9qrgnacMQ1GDGGPks0jx-C3DRjacnTV8BQ-iw6BTFY,12864
|
|
12
|
-
wavedl/models/efficientnet.py,sha256=irxab-yt3z89tMTf1x6odR2IqgpMrMM44Wiu3n6-IEs,7285
|
|
13
|
-
wavedl/models/registry.py,sha256=p5Eof3T6cwHggcEM-xzeBoKMbpuNyRmOJIvqMhzHvJA,2995
|
|
14
|
-
wavedl/models/resnet.py,sha256=sT4S_Rx56dqLN5zEPbBKeJet1dvr49IWhnBSjiVfcQs,16777
|
|
15
|
-
wavedl/models/unet.py,sha256=i3DFpeJmvdzNiBSqi4ecjLbC9RZXXbNJ_ZNMr2c3I6I,10019
|
|
16
|
-
wavedl/models/vit.py,sha256=iyJ8FQ1DOAgBhaVIUGGQEP2L37wZ28JeHKkJ1tmgj9w,10898
|
|
17
|
-
wavedl/utils/__init__.py,sha256=YMgzuwndjr64kt9k0_6_9PMJYTVdiaH5veSMff_ZycA,3051
|
|
18
|
-
wavedl/utils/config.py,sha256=E0_m5aQ1OdwEwzZysSwc5v905P4g3SDprObFAeVIj9g,8107
|
|
19
|
-
wavedl/utils/cross_validation.py,sha256=117ac9KDzaIaqhtP8ZRs15Xpqmq5fLpX2-vqkNvtMaU,17487
|
|
20
|
-
wavedl/utils/data.py,sha256=9LrB9MC6jRZzbRSc9xiGzJWoh8FahwP_68REqBAT3Os,44131
|
|
21
|
-
wavedl/utils/distributed.py,sha256=7wQ3mRjkp_xjPSxDWMnBf5dSkAGUaTzntxbz0BhC5v0,4145
|
|
22
|
-
wavedl/utils/losses.py,sha256=5762M-TBC_hz6uyj1NPbU1vZeFOJQq7fR3-j7OygJRo,7254
|
|
23
|
-
wavedl/utils/metrics.py,sha256=mkCpqZwl_XUpNvA5Ekjf7y-HqApafR7eR6EuA8cBdM8,37287
|
|
24
|
-
wavedl/utils/optimizers.py,sha256=PyIkJ_hRhFi_Fio81Gy5YQNhcME0JUUEl8OTSyu-0RA,6323
|
|
25
|
-
wavedl/utils/schedulers.py,sha256=e6Sf0yj8VOqkdwkUHLMyUfGfHKTX4NMr-zfgxWqCTYI,7659
|
|
26
|
-
wavedl-1.3.1.dist-info/LICENSE,sha256=cEUCvcvH-9BT9Y-CNGY__PwWONCKu9zsoIqWA-NeHJ4,1066
|
|
27
|
-
wavedl-1.3.1.dist-info/METADATA,sha256=JrPtQBD_sXt_8lUlqIYzSqe2KBYgQLCeGAaUXy_hmhA,38922
|
|
28
|
-
wavedl-1.3.1.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
|
|
29
|
-
wavedl-1.3.1.dist-info/entry_points.txt,sha256=f1RNDkXFZwBzrBzTMFocJ6xhfTvTmaEDTi5YyDEUaF8,140
|
|
30
|
-
wavedl-1.3.1.dist-info/top_level.txt,sha256=ccneUt3D5Qzbh3bsBSSrq9bqrhGiogcWKY24ZC4Q6Xw,7
|
|
31
|
-
wavedl-1.3.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|