wavedl 1.6.0__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. wavedl/__init__.py +1 -1
  2. wavedl/hpo.py +451 -451
  3. wavedl/{hpc.py → launcher.py} +135 -61
  4. wavedl/models/__init__.py +28 -0
  5. wavedl/models/{_timm_utils.py → _pretrained_utils.py} +128 -0
  6. wavedl/models/base.py +48 -0
  7. wavedl/models/caformer.py +1 -1
  8. wavedl/models/cnn.py +2 -27
  9. wavedl/models/convnext.py +5 -18
  10. wavedl/models/convnext_v2.py +6 -22
  11. wavedl/models/densenet.py +5 -18
  12. wavedl/models/efficientnetv2.py +315 -315
  13. wavedl/models/efficientvit.py +398 -0
  14. wavedl/models/fastvit.py +6 -39
  15. wavedl/models/mamba.py +44 -24
  16. wavedl/models/maxvit.py +51 -48
  17. wavedl/models/mobilenetv3.py +295 -295
  18. wavedl/models/regnet.py +406 -406
  19. wavedl/models/resnet.py +14 -56
  20. wavedl/models/resnet3d.py +258 -258
  21. wavedl/models/swin.py +443 -443
  22. wavedl/models/tcn.py +393 -409
  23. wavedl/models/unet.py +1 -5
  24. wavedl/models/unireplknet.py +491 -0
  25. wavedl/models/vit.py +3 -3
  26. wavedl/train.py +1427 -1430
  27. wavedl/utils/config.py +367 -367
  28. wavedl/utils/cross_validation.py +530 -530
  29. wavedl/utils/losses.py +216 -216
  30. wavedl/utils/optimizers.py +216 -216
  31. wavedl/utils/schedulers.py +251 -251
  32. {wavedl-1.6.0.dist-info → wavedl-1.6.2.dist-info}/METADATA +150 -113
  33. wavedl-1.6.2.dist-info/RECORD +46 -0
  34. {wavedl-1.6.0.dist-info → wavedl-1.6.2.dist-info}/entry_points.txt +2 -2
  35. wavedl-1.6.0.dist-info/RECORD +0 -44
  36. {wavedl-1.6.0.dist-info → wavedl-1.6.2.dist-info}/LICENSE +0 -0
  37. {wavedl-1.6.0.dist-info → wavedl-1.6.2.dist-info}/WHEEL +0 -0
  38. {wavedl-1.6.0.dist-info → wavedl-1.6.2.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: wavedl
3
- Version: 1.6.0
3
+ Version: 1.6.2
4
4
  Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
5
5
  Author: Ductho Le
6
6
  License: MIT
@@ -38,6 +38,7 @@ Requires-Dist: wandb>=0.15.0
38
38
  Requires-Dist: optuna>=3.0.0
39
39
  Requires-Dist: onnx>=1.14.0
40
40
  Requires-Dist: onnxruntime>=1.15.0
41
+ Requires-Dist: onnxscript>=0.1.0
41
42
  Requires-Dist: triton>=2.0.0; sys_platform == "linux"
42
43
  Provides-Extra: dev
43
44
  Requires-Dist: pytest>=7.0.0; extra == "dev"
@@ -118,7 +119,7 @@ Train on datasets larger than RAM:
118
119
 
119
120
  **🧠 Models? We've Got Options**
120
121
 
121
- 57 architectures, ready to go:
122
+ 69 architectures, ready to go:
122
123
  - CNNs, ResNets, ViTs, EfficientNets...
123
124
  - All adapted for regression
124
125
  - [Add your own](#adding-custom-models) in one line
@@ -224,66 +225,74 @@ pip install -e .
224
225
  > [!TIP]
225
226
  > In all examples below, replace `<...>` placeholders with your values. See [Configuration](#️-configuration) for defaults and options.
226
227
 
227
- #### Option 1: Using wavedl-hpc (Recommended for HPC)
228
-
229
- The `wavedl-hpc` command automatically configures the environment for HPC systems:
228
+ ### Training
230
229
 
231
230
  ```bash
232
- # Basic training (auto-detects available GPUs)
233
- wavedl-hpc --model <model_name> --data_path <train_data> --batch_size <number> --output_dir <output_folder>
231
+ # Basic training (auto-detects GPUs and environment)
232
+ wavedl-train --model <model_name> --data_path <train_data> --output_dir <output_folder>
234
233
 
235
234
  # Detailed configuration
236
- wavedl-hpc --model <model_name> --data_path <train_data> --batch_size <number> \
235
+ wavedl-train --model <model_name> --data_path <train_data> --batch_size <number> \
237
236
  --lr <number> --epochs <number> --patience <number> --compile --output_dir <output_folder>
238
237
 
239
- # Specify GPU count explicitly
240
- wavedl-hpc --num_gpus 4 --model cnn --data_path train.npz --output_dir results
241
- ```
242
-
243
- #### Option 2: Direct Accelerate Launch
244
-
245
- ```bash
246
- # Local - auto-detects GPUs
247
- accelerate launch -m wavedl.train --model <model_name> --data_path <train_data> --batch_size <number> --output_dir <output_folder>
238
+ # Multi-GPU is automatic (uses all available GPUs)
239
+ # Override with --num_gpus if needed
240
+ wavedl-train --model cnn --data_path train.npz --num_gpus 4 --output_dir results
248
241
 
249
242
  # Resume training (automatic - just re-run with same output_dir)
250
- # Manual resume from specific checkpoint:
251
- accelerate launch -m wavedl.train --model <model_name> --data_path <train_data> --resume <checkpoint_folder> --output_dir <output_folder>
243
+ wavedl-train --model <model_name> --data_path <train_data> --output_dir <output_folder>
252
244
 
253
245
  # Force fresh start (ignores existing checkpoints)
254
- accelerate launch -m wavedl.train --model <model_name> --data_path <train_data> --output_dir <output_folder> --fresh
246
+ wavedl-train --model <model_name> --data_path <train_data> --output_dir <output_folder> --fresh
255
247
 
256
248
  # List available models
257
249
  wavedl-train --list_models
258
250
  ```
259
251
 
260
- > [!TIP]
261
- > **Auto-Resume**: If training crashes or is interrupted, simply re-run with the same `--output_dir`. The framework automatically detects incomplete training and resumes from the last checkpoint. Use `--fresh` to force a fresh start.
252
+ > [!NOTE]
253
+ > `wavedl-train` automatically detects your environment:
254
+ > - **HPC clusters** (SLURM, PBS, etc.): Uses local caching, offline WandB
255
+ > - **Local machines**: Uses standard cache locations (~/.cache)
262
256
  >
263
- > **GPU Auto-Detection**: `wavedl-hpc` automatically detects available GPUs using `nvidia-smi`. Use `--num_gpus` to override.
257
+ > **Auto-Resume**: If training crashes or is interrupted, simply re-run with the same `--output_dir`. The framework automatically detects incomplete training and resumes from the last checkpoint.
258
+
259
+ <details>
260
+ <summary><b>Advanced: Direct Accelerate Launch</b></summary>
261
+
262
+ For fine-grained control over distributed training, you can use `accelerate launch` directly:
263
+
264
+ ```bash
265
+ # Custom accelerate configuration
266
+ accelerate launch -m wavedl.train --model <model_name> --data_path <train_data> --output_dir <output_folder>
267
+
268
+ # Multi-node training
269
+ accelerate launch --num_machines 2 --main_process_ip <ip> -m wavedl.train --model cnn --data_path train.npz
270
+ ```
271
+
272
+ </details>
264
273
 
265
274
  ### Testing & Inference
266
275
 
267
- After training, use `wavedl.test` to evaluate your model on test data:
276
+ After training, use `wavedl-test` to evaluate your model on test data:
268
277
 
269
278
  ```bash
270
279
  # Basic inference
271
- python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data>
280
+ wavedl-test --checkpoint <checkpoint_folder> --data_path <test_data>
272
281
 
273
282
  # With visualization, CSV export, and multiple file formats
274
- python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
283
+ wavedl-test --checkpoint <checkpoint_folder> --data_path <test_data> \
275
284
  --plot --plot_format png pdf --save_predictions --output_dir <output_folder>
276
285
 
277
286
  # With custom parameter names
278
- python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
287
+ wavedl-test --checkpoint <checkpoint_folder> --data_path <test_data> \
279
288
  --param_names '$p_1$' '$p_2$' '$p_3$' --plot
280
289
 
281
290
  # Export model to ONNX for deployment (LabVIEW, MATLAB, C++, etc.)
282
- python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
291
+ wavedl-test --checkpoint <checkpoint_folder> --data_path <test_data> \
283
292
  --export onnx --export_path <output_file.onnx>
284
293
 
285
294
  # For 3D volumes with small depth (e.g., 8×128×128), override auto-detection
286
- python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
295
+ wavedl-test --checkpoint <checkpoint_folder> --data_path <test_data> \
287
296
  --input_channels 1
288
297
  ```
289
298
 
@@ -294,7 +303,7 @@ python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
294
303
  - **Format** (with `--plot_format`): Supported formats: `png` (default), `pdf` (vector), `svg` (vector), `eps` (LaTeX), `tiff`, `jpg`, `ps`
295
304
 
296
305
  > [!NOTE]
297
- > `wavedl.test` auto-detects the model architecture from checkpoint metadata. If unavailable, it falls back to folder name parsing. Use `--model` to override if needed.
306
+ > `wavedl-test` auto-detects the model architecture from checkpoint metadata. If unavailable, it falls back to folder name parsing. Use `--model` to override if needed.
298
307
 
299
308
  ### Adding Custom Models
300
309
 
@@ -338,7 +347,7 @@ class MyModel(BaseModel):
338
347
  **Step 2: Train**
339
348
 
340
349
  ```bash
341
- wavedl-hpc --import my_model.py --model my_model --data_path train.npz
350
+ wavedl-train --import my_model.py --model my_model --data_path train.npz
342
351
  ```
343
352
 
344
353
  WaveDL handles everything else: training loop, logging, checkpoints, multi-GPU, early stopping, etc.
@@ -354,12 +363,12 @@ WaveDL/
354
363
  ├── src/
355
364
  │ └── wavedl/ # Main package (namespaced)
356
365
  │ ├── __init__.py # Package init with __version__
357
- │ ├── train.py # Training entry point
366
+ │ ├── train.py # Training script
358
367
  │ ├── test.py # Testing & inference script
359
368
  │ ├── hpo.py # Hyperparameter optimization
360
- │ ├── hpc.py # HPC distributed training launcher
369
+ │ ├── launcher.py # Training launcher (wavedl-train)
361
370
  │ │
362
- │ ├── models/ # Model Zoo (57 architectures)
371
+ │ ├── models/ # Model Zoo (69 architectures)
363
372
  │ │ ├── registry.py # Model factory (@register_model)
364
373
  │ │ ├── base.py # Abstract base class
365
374
  │ │ └── ... # See "Available Models" section
@@ -388,22 +397,14 @@ WaveDL/
388
397
  ## ⚙️ Configuration
389
398
 
390
399
  > [!NOTE]
391
- > All configuration options below work with **both** `wavedl-hpc` and direct `accelerate launch`. The wrapper script passes all arguments directly to `train.py`.
392
- >
393
- > **Examples:**
394
- > ```bash
395
- > # Using wavedl-hpc
396
- > wavedl-hpc --model cnn --batch_size 256 --lr 5e-4 --compile
397
- >
398
- > # Using accelerate launch directly
399
- > accelerate launch -m wavedl.train --model cnn --batch_size 256 --lr 5e-4 --compile
400
- > ```
400
+ > All configuration options below work with `wavedl-train`. The wrapper script passes all arguments directly to `train.py`.
401
401
 
402
402
  <details>
403
- <summary><b>Available Models</b> — 57 architectures</summary>
403
+ <summary><b>Available Models</b> — 69 architectures</summary>
404
404
 
405
405
  | Model | Backbone Params | Dim |
406
406
  |-------|-----------------|-----|
407
+ | **── Classic CNNs ──** |||
407
408
  | **CNN** — Convolutional Neural Network |||
408
409
  | `cnn` | 1.6M | 1D/2D/3D |
409
410
  | **ResNet** — Residual Network |||
@@ -412,13 +413,14 @@ WaveDL/
412
413
  | `resnet50` | 23.5M | 1D/2D/3D |
413
414
  | `resnet18_pretrained` ⭐ | 11.2M | 2D |
414
415
  | `resnet50_pretrained` ⭐ | 23.5M | 2D |
415
- | **ResNet3D** — 3D Residual Network |||
416
- | `resnet3d_18` | 33.2M | 3D |
417
- | `mc3_18` — Mixed Convolution 3D | 11.5M | 3D |
418
- | **TCN** Temporal Convolutional Network |||
419
- | `tcn_small` | 0.9M | 1D |
420
- | `tcn` | 6.9M | 1D |
421
- | `tcn_large` | 10.0M | 1D |
416
+ | **DenseNet** — Densely Connected Network |||
417
+ | `densenet121` | 7.0M | 1D/2D/3D |
418
+ | `densenet169` | 12.5M | 1D/2D/3D |
419
+ | `densenet121_pretrained` | 7.0M | 2D |
420
+ | **── Efficient/Mobile CNNs ──** |||
421
+ | **MobileNetV3** Mobile Neural Network V3 |||
422
+ | `mobilenet_v3_small` | 0.9M | 2D |
423
+ | `mobilenet_v3_large` ⭐ | 3.0M | 2D |
422
424
  | **EfficientNet** — Efficient Neural Network |||
423
425
  | `efficientnet_b0` ⭐ | 4.0M | 2D |
424
426
  | `efficientnet_b1` ⭐ | 6.5M | 2D |
@@ -427,47 +429,41 @@ WaveDL/
427
429
  | `efficientnet_v2_s` ⭐ | 20.2M | 2D |
428
430
  | `efficientnet_v2_m` ⭐ | 52.9M | 2D |
429
431
  | `efficientnet_v2_l` ⭐ | 117.2M | 2D |
430
- | **MobileNetV3** — Mobile Neural Network V3 |||
431
- | `mobilenet_v3_small` ⭐ | 0.9M | 2D |
432
- | `mobilenet_v3_large` ⭐ | 3.0M | 2D |
433
432
  | **RegNet** — Regularized Network |||
434
433
  | `regnet_y_400mf` ⭐ | 3.9M | 2D |
435
434
  | `regnet_y_800mf` ⭐ | 5.7M | 2D |
436
435
  | `regnet_y_1_6gf` ⭐ | 10.3M | 2D |
437
436
  | `regnet_y_3_2gf` ⭐ | 17.9M | 2D |
438
437
  | `regnet_y_8gf` ⭐ | 37.4M | 2D |
439
- | **Swin** Shifted Window Transformer |||
440
- | `swin_t` ⭐ | 27.5M | 2D |
441
- | `swin_s` ⭐ | 48.8M | 2D |
442
- | `swin_b` ⭐ | 86.7M | 2D |
438
+ | **── Modern CNNs ──** |||
443
439
  | **ConvNeXt** — Convolutional Next |||
444
440
  | `convnext_tiny` | 27.8M | 1D/2D/3D |
445
441
  | `convnext_small` | 49.5M | 1D/2D/3D |
446
442
  | `convnext_base` | 87.6M | 1D/2D/3D |
447
443
  | `convnext_tiny_pretrained` ⭐ | 27.8M | 2D |
448
- | **DenseNet** — Densely Connected Network |||
449
- | `densenet121` | 7.0M | 1D/2D/3D |
450
- | `densenet169` | 12.5M | 1D/2D/3D |
451
- | `densenet121_pretrained` ⭐ | 7.0M | 2D |
452
- | **ViT** — Vision Transformer |||
453
- | `vit_tiny` | 5.4M | 1D/2D |
454
- | `vit_small` | 21.4M | 1D/2D |
455
- | `vit_base` | 85.3M | 1D/2D |
456
444
  | **ConvNeXt V2** — ConvNeXt with GRN |||
457
445
  | `convnext_v2_tiny` | 27.9M | 1D/2D/3D |
458
446
  | `convnext_v2_small` | 49.6M | 1D/2D/3D |
459
447
  | `convnext_v2_base` | 87.7M | 1D/2D/3D |
460
448
  | `convnext_v2_tiny_pretrained` ⭐ | 27.9M | 2D |
461
- | **Mamba** — State Space Model |||
462
- | `mamba_1d` | 3.4M | 1D |
463
- | **Vision Mamba (ViM)** 2D Mamba |||
464
- | `vim_tiny` | 6.6M | 2D |
465
- | `vim_small` | 51.1M | 2D |
466
- | `vim_base` | 201.4M | 2D |
449
+ | **UniRepLKNet** — Large-Kernel ConvNet |||
450
+ | `unireplknet_tiny` | 30.8M | 1D/2D/3D |
451
+ | `unireplknet_small` | 56.0M | 1D/2D/3D |
452
+ | `unireplknet_base` | 97.6M | 1D/2D/3D |
453
+ | **── Vision Transformers ──** |||
454
+ | **ViT** Vision Transformer |||
455
+ | `vit_tiny` | 5.4M | 1D/2D |
456
+ | `vit_small` | 21.4M | 1D/2D |
457
+ | `vit_base` | 85.3M | 1D/2D |
458
+ | **Swin** — Shifted Window Transformer |||
459
+ | `swin_t` ⭐ | 27.5M | 2D |
460
+ | `swin_s` ⭐ | 48.8M | 2D |
461
+ | `swin_b` ⭐ | 86.7M | 2D |
467
462
  | **MaxViT** — Multi-Axis ViT |||
468
463
  | `maxvit_tiny` ⭐ | 30.1M | 2D |
469
464
  | `maxvit_small` ⭐ | 67.6M | 2D |
470
465
  | `maxvit_base` ⭐ | 119.1M | 2D |
466
+ | **── Hybrid CNN-Transformer ──** |||
471
467
  | **FastViT** — Fast Hybrid CNN-ViT |||
472
468
  | `fastvit_t8` ⭐ | 4.0M | 2D |
473
469
  | `fastvit_t12` ⭐ | 6.8M | 2D |
@@ -478,6 +474,31 @@ WaveDL/
478
474
  | `caformer_s36` ⭐ | 39.2M | 2D |
479
475
  | `caformer_m36` ⭐ | 56.9M | 2D |
480
476
  | `poolformer_s12` ⭐ | 11.9M | 2D |
477
+ | **EfficientViT** — Memory-Efficient ViT |||
478
+ | `efficientvit_m0` ⭐ | 2.2M | 2D |
479
+ | `efficientvit_m1` ⭐ | 2.6M | 2D |
480
+ | `efficientvit_m2` ⭐ | 3.8M | 2D |
481
+ | `efficientvit_b0` ⭐ | 2.1M | 2D |
482
+ | `efficientvit_b1` ⭐ | 7.5M | 2D |
483
+ | `efficientvit_b2` ⭐ | 21.8M | 2D |
484
+ | `efficientvit_b3` ⭐ | 46.1M | 2D |
485
+ | `efficientvit_l1` ⭐ | 49.5M | 2D |
486
+ | `efficientvit_l2` ⭐ | 60.5M | 2D |
487
+ | **── State Space Models ──** |||
488
+ | **Mamba** — State Space Model |||
489
+ | `mamba_1d` | 3.4M | 1D |
490
+ | **Vision Mamba (ViM)** — 2D Mamba |||
491
+ | `vim_tiny` | 6.6M | 2D |
492
+ | `vim_small` | 51.1M | 2D |
493
+ | `vim_base` | 201.4M | 2D |
494
+ | **── Specialized Architectures ──** |||
495
+ | **TCN** — Temporal Convolutional Network |||
496
+ | `tcn_small` | 0.9M | 1D |
497
+ | `tcn` | 6.9M | 1D |
498
+ | `tcn_large` | 10.0M | 1D |
499
+ | **ResNet3D** — 3D Residual Network |||
500
+ | `resnet3d_18` | 33.2M | 3D |
501
+ | `mc3_18` — Mixed Convolution 3D | 11.5M | 3D |
481
502
  | **U-Net** — U-shaped Network |||
482
503
  | `unet_regression` | 31.0M | 1D/2D/3D |
483
504
 
@@ -497,34 +518,52 @@ os.environ['TORCH_HOME'] = '.torch_cache' # Match WaveDL's HPC cache location
497
518
  from torchvision import models as m
498
519
  from torchvision.models import video as v
499
520
 
500
- # === TorchVision Models ===
501
- weights = {
502
- 'resnet18': m.ResNet18_Weights, 'resnet50': m.ResNet50_Weights,
503
- 'efficientnet_b0': m.EfficientNet_B0_Weights, 'efficientnet_b1': m.EfficientNet_B1_Weights,
504
- 'efficientnet_b2': m.EfficientNet_B2_Weights, 'efficientnet_v2_s': m.EfficientNet_V2_S_Weights,
505
- 'efficientnet_v2_m': m.EfficientNet_V2_M_Weights, 'efficientnet_v2_l': m.EfficientNet_V2_L_Weights,
506
- 'mobilenet_v3_small': m.MobileNet_V3_Small_Weights, 'mobilenet_v3_large': m.MobileNet_V3_Large_Weights,
507
- 'regnet_y_400mf': m.RegNet_Y_400MF_Weights, 'regnet_y_800mf': m.RegNet_Y_800MF_Weights,
508
- 'regnet_y_1_6gf': m.RegNet_Y_1_6GF_Weights, 'regnet_y_3_2gf': m.RegNet_Y_3_2GF_Weights,
509
- 'regnet_y_8gf': m.RegNet_Y_8GF_Weights, 'swin_t': m.Swin_T_Weights, 'swin_s': m.Swin_S_Weights,
510
- 'swin_b': m.Swin_B_Weights, 'convnext_tiny': m.ConvNeXt_Tiny_Weights, 'densenet121': m.DenseNet121_Weights,
511
- }
512
- for name, w in weights.items():
513
- getattr(m, name)(weights=w.DEFAULT); print(f'✓ {name}')
521
+ # === TorchVision Models (use IMAGENET1K_V1 to match WaveDL) ===
522
+ models = [
523
+ ('resnet18', m.ResNet18_Weights.IMAGENET1K_V1),
524
+ ('resnet50', m.ResNet50_Weights.IMAGENET1K_V1),
525
+ ('efficientnet_b0', m.EfficientNet_B0_Weights.IMAGENET1K_V1),
526
+ ('efficientnet_b1', m.EfficientNet_B1_Weights.IMAGENET1K_V1),
527
+ ('efficientnet_b2', m.EfficientNet_B2_Weights.IMAGENET1K_V1),
528
+ ('efficientnet_v2_s', m.EfficientNet_V2_S_Weights.IMAGENET1K_V1),
529
+ ('efficientnet_v2_m', m.EfficientNet_V2_M_Weights.IMAGENET1K_V1),
530
+ ('efficientnet_v2_l', m.EfficientNet_V2_L_Weights.IMAGENET1K_V1),
531
+ ('mobilenet_v3_small', m.MobileNet_V3_Small_Weights.IMAGENET1K_V1),
532
+ ('mobilenet_v3_large', m.MobileNet_V3_Large_Weights.IMAGENET1K_V1),
533
+ ('regnet_y_400mf', m.RegNet_Y_400MF_Weights.IMAGENET1K_V1),
534
+ ('regnet_y_800mf', m.RegNet_Y_800MF_Weights.IMAGENET1K_V1),
535
+ ('regnet_y_1_6gf', m.RegNet_Y_1_6GF_Weights.IMAGENET1K_V1),
536
+ ('regnet_y_3_2gf', m.RegNet_Y_3_2GF_Weights.IMAGENET1K_V1),
537
+ ('regnet_y_8gf', m.RegNet_Y_8GF_Weights.IMAGENET1K_V1),
538
+ ('swin_t', m.Swin_T_Weights.IMAGENET1K_V1),
539
+ ('swin_s', m.Swin_S_Weights.IMAGENET1K_V1),
540
+ ('swin_b', m.Swin_B_Weights.IMAGENET1K_V1),
541
+ ('convnext_tiny', m.ConvNeXt_Tiny_Weights.IMAGENET1K_V1),
542
+ ('densenet121', m.DenseNet121_Weights.IMAGENET1K_V1),
543
+ ]
544
+ for name, w in models:
545
+ getattr(m, name)(weights=w); print(f'✓ {name}')
514
546
 
515
547
  # 3D video models
516
- v.r3d_18(weights=v.R3D_18_Weights.DEFAULT); print('✓ r3d_18')
517
- v.mc3_18(weights=v.MC3_18_Weights.DEFAULT); print('✓ mc3_18')
548
+ v.r3d_18(weights=v.R3D_18_Weights.KINETICS400_V1); print('✓ r3d_18')
549
+ v.mc3_18(weights=v.MC3_18_Weights.KINETICS400_V1); print('✓ mc3_18')
518
550
 
519
551
  # === Timm Models (MaxViT, FastViT, CAFormer, ConvNeXt V2) ===
520
552
  import timm
521
553
 
522
554
  timm_models = [
523
- 'maxvit_tiny_tf_224.in1k', 'maxvit_small_tf_224.in1k', 'maxvit_base_tf_224.in1k',
524
- 'fastvit_t8.apple_in1k', 'fastvit_t12.apple_in1k', 'fastvit_s12.apple_in1k', 'fastvit_sa12.apple_in1k',
525
- 'caformer_s18.sail_in1k', 'caformer_s36.sail_in22k_ft_in1k', 'caformer_m36.sail_in22k_ft_in1k',
526
- 'poolformer_s12.sail_in1k',
527
- 'convnextv2_tiny.fcmae_ft_in1k',
555
+ # MaxViT (no suffix - timm resolves to default)
556
+ 'maxvit_tiny_tf_224', 'maxvit_small_tf_224', 'maxvit_base_tf_224',
557
+ # FastViT (no suffix)
558
+ 'fastvit_t8', 'fastvit_t12', 'fastvit_s12', 'fastvit_sa12',
559
+ # CAFormer/PoolFormer (no suffix)
560
+ 'caformer_s18', 'caformer_s36', 'caformer_m36', 'poolformer_s12',
561
+ # ConvNeXt V2 (no suffix)
562
+ 'convnextv2_tiny',
563
+ # EfficientViT (no suffix)
564
+ 'efficientvit_m0', 'efficientvit_m1', 'efficientvit_m2',
565
+ 'efficientvit_b0', 'efficientvit_b1', 'efficientvit_b2', 'efficientvit_b3',
566
+ 'efficientvit_l1', 'efficientvit_l2',
528
567
  ]
529
568
  for name in timm_models:
530
569
  timm.create_model(name, pretrained=True); print(f'✓ {name}')
@@ -602,7 +641,7 @@ WaveDL automatically enables performance optimizations for modern GPUs:
602
641
  </details>
603
642
 
604
643
  <details>
605
- <summary><b>HPC CLI Arguments (wavedl-hpc)</b></summary>
644
+ <summary><b>Distributed Training Arguments</b></summary>
606
645
 
607
646
  | Argument | Default | Description |
608
647
  |----------|---------|-------------|
@@ -634,10 +673,10 @@ WaveDL automatically enables performance optimizations for modern GPUs:
634
673
  **Example:**
635
674
  ```bash
636
675
  # Use Huber loss for noisy NDE data
637
- accelerate launch -m wavedl.train --model cnn --loss huber --huber_delta 0.5
676
+ wavedl-train --model cnn --loss huber --huber_delta 0.5
638
677
 
639
678
  # Weighted MSE: prioritize thickness (first target)
640
- accelerate launch -m wavedl.train --model cnn --loss weighted_mse --loss_weights "2.0,1.0,1.0"
679
+ wavedl-train --model cnn --loss weighted_mse --loss_weights "2.0,1.0,1.0"
641
680
  ```
642
681
 
643
682
  </details>
@@ -657,10 +696,10 @@ accelerate launch -m wavedl.train --model cnn --loss weighted_mse --loss_weights
657
696
  **Example:**
658
697
  ```bash
659
698
  # SGD with Nesterov momentum (often better generalization)
660
- accelerate launch -m wavedl.train --model cnn --optimizer sgd --lr 0.01 --momentum 0.9 --nesterov
699
+ wavedl-train --model cnn --optimizer sgd --lr 0.01 --momentum 0.9 --nesterov
661
700
 
662
701
  # RAdam for more stable training
663
- accelerate launch -m wavedl.train --model cnn --optimizer radam --lr 1e-3
702
+ wavedl-train --model cnn --optimizer radam --lr 1e-3
664
703
  ```
665
704
 
666
705
  </details>
@@ -682,13 +721,13 @@ accelerate launch -m wavedl.train --model cnn --optimizer radam --lr 1e-3
682
721
  **Example:**
683
722
  ```bash
684
723
  # Cosine annealing for 1000 epochs
685
- accelerate launch -m wavedl.train --model cnn --scheduler cosine --epochs 1000 --min_lr 1e-7
724
+ wavedl-train --model cnn --scheduler cosine --epochs 1000 --min_lr 1e-7
686
725
 
687
726
  # OneCycleLR for super-convergence
688
- accelerate launch -m wavedl.train --model cnn --scheduler onecycle --lr 1e-2 --epochs 50
727
+ wavedl-train --model cnn --scheduler onecycle --lr 1e-2 --epochs 50
689
728
 
690
729
  # MultiStep with custom milestones
691
- accelerate launch -m wavedl.train --model cnn --scheduler multistep --milestones "100,200,300"
730
+ wavedl-train --model cnn --scheduler multistep --milestones "100,200,300"
692
731
  ```
693
732
 
694
733
  </details>
@@ -699,16 +738,14 @@ accelerate launch -m wavedl.train --model cnn --scheduler multistep --milestones
699
738
  For robust model evaluation, simply add the `--cv` flag:
700
739
 
701
740
  ```bash
702
- # 5-fold cross-validation (works with both methods!)
703
- wavedl-hpc --model cnn --cv 5 --data_path train_data.npz
704
- # OR
705
- accelerate launch -m wavedl.train --model cnn --cv 5 --data_path train_data.npz
741
+ # 5-fold cross-validation
742
+ wavedl-train --model cnn --cv 5 --data_path train_data.npz
706
743
 
707
744
  # Stratified CV (recommended for unbalanced data)
708
- wavedl-hpc --model cnn --cv 5 --cv_stratify --loss huber --epochs 100
745
+ wavedl-train --model cnn --cv 5 --cv_stratify --loss huber --epochs 100
709
746
 
710
747
  # Full configuration
711
- wavedl-hpc --model cnn --cv 5 --cv_stratify \
748
+ wavedl-train --model cnn --cv 5 --cv_stratify \
712
749
  --loss huber --optimizer adamw --scheduler cosine \
713
750
  --output_dir ./cv_results
714
751
  ```
@@ -733,10 +770,10 @@ Use YAML files for reproducible experiments. CLI arguments can override any conf
733
770
 
734
771
  ```bash
735
772
  # Use a config file
736
- accelerate launch -m wavedl.train --config configs/config.yaml --data_path train.npz
773
+ wavedl-train --config configs/config.yaml --data_path train.npz
737
774
 
738
775
  # Override specific values from config
739
- accelerate launch -m wavedl.train --config configs/config.yaml --lr 5e-4 --epochs 500
776
+ wavedl-train --config configs/config.yaml --lr 5e-4 --epochs 500
740
777
  ```
741
778
 
742
779
  **Example config (`configs/config.yaml`):**
@@ -889,7 +926,7 @@ wavedl-hpo --data_path train.npz --models cnn --n_trials 50 --quick
889
926
 
890
927
  After HPO completes, it prints the optimal command:
891
928
  ```bash
892
- accelerate launch -m wavedl.train --data_path train.npz --model cnn --lr 3.2e-4 --batch_size 128 ...
929
+ wavedl-train --data_path train.npz --model cnn --lr 3.2e-4 --batch_size 128 ...
893
930
  ```
894
931
 
895
932
  ---
@@ -1082,12 +1119,12 @@ The `examples/` folder contains a **complete, ready-to-run example** for **mater
1082
1119
 
1083
1120
  ```bash
1084
1121
  # Run inference on the example data
1085
- python -m wavedl.test --checkpoint ./examples/elasticity_prediction/best_checkpoint \
1122
+ wavedl-test --checkpoint ./examples/elasticity_prediction/best_checkpoint \
1086
1123
  --data_path ./examples/elasticity_prediction/Test_data_100.mat \
1087
1124
  --plot --save_predictions --output_dir ./examples/elasticity_prediction/test_results
1088
1125
 
1089
1126
  # Export to ONNX (already included as model.onnx)
1090
- python -m wavedl.test --checkpoint ./examples/elasticity_prediction/best_checkpoint \
1127
+ wavedl-test --checkpoint ./examples/elasticity_prediction/best_checkpoint \
1091
1128
  --data_path ./examples/elasticity_prediction/Test_data_100.mat \
1092
1129
  --export onnx --export_path ./examples/elasticity_prediction/model.onnx
1093
1130
  ```
@@ -0,0 +1,46 @@
1
+ wavedl/__init__.py,sha256=hFGU_j86Beexkcrn_V3fotGQ4ncwLGvz2lCOejEJ-f0,1177
2
+ wavedl/hpo.py,sha256=nEiy-2O_5EhxF5hU8X5TviSAiXfVrTQx0-VE6baW7JQ,14633
3
+ wavedl/launcher.py,sha256=_CFlgpKgHrtZebl1yQbJZJEcob06Y9-fqnRYzwW7UJQ,11776
4
+ wavedl/test.py,sha256=1UUy9phCqrr3h_lN6mGJ7Sj73skDg4KyLk2Yuq9DiKU,38797
5
+ wavedl/train.py,sha256=vBufy6gHShawgj8O6dvVER9TPhORa1s7L6pQtTe-N5M,57824
6
+ wavedl/models/__init__.py,sha256=8OiT2seq1qBiUzKaSkmh_VOLJlLTT9Cn-mjhMHKGFpI,5203
7
+ wavedl/models/_pretrained_utils.py,sha256=VPdU1DwJB93ZBf_GFIgb8-6BbAt18Phs4yorwlhLw70,12404
8
+ wavedl/models/_template.py,sha256=J_D8taSPmV8lBaucN_vU-WiG98iFr7CJrZVNNX_Tdts,4600
9
+ wavedl/models/base.py,sha256=bDoHYFli-aR8amcFYXbF98QYaKSCEwZWpvOhN21ODro,9075
10
+ wavedl/models/caformer.py,sha256=ufPM-HzQ-qUZcXgnOulurY6jBUlMUzokC01whtPeVMg,7922
11
+ wavedl/models/cnn.py,sha256=1-sNBDZHc5DySbduf5tkV1Ha25R6irksjVqfOiFbI3M,7465
12
+ wavedl/models/convnext.py,sha256=fdXieXUuHyULjicw9Nno2SK2Tm5bDabUtdiGuEpuAF4,15711
13
+ wavedl/models/convnext_v2.py,sha256=1ELKBPWIlUm3uybLX1KN5cgwjBPEUzZDoXL8qUzF9YY,14920
14
+ wavedl/models/densenet.py,sha256=V_caGd0wsG_Q3Q38I4MEgYmU0v4j8mDyvv7Rn3Bk7Ac,12667
15
+ wavedl/models/efficientnet.py,sha256=HWfhqSX57lC5Xug5TrQ3r-uFqkksoIKjmQ5Zr5njkEA,8264
16
+ wavedl/models/efficientnetv2.py,sha256=hVSnVId8T1rjqaKlckLqWFwvo2J-qASX7o9lMbXbP-s,10947
17
+ wavedl/models/efficientvit.py,sha256=KqFoZq9YHBMnTue6aMdPKgBOMczeBPryY_F6ip0hoEI,11630
18
+ wavedl/models/fastvit.py,sha256=S0SF0iC-9ZJrP-9YUTLPhMJMV-W9r2--V3hVAmSSVKI,7083
19
+ wavedl/models/mamba.py,sha256=ENmOQjtoX8btS1tDvOYEG_M3GFn1P2vWsDWcsQPSPJ0,17189
20
+ wavedl/models/maxvit.py,sha256=I6TFGrLRcyMU-nU7u5VhOaXZWWdwmNJwHsMqbJh_g_o,7548
21
+ wavedl/models/mobilenetv3.py,sha256=LZxCg599kGP6-XI_l3PpT8jzh4oTAdWH3Y7GH097o28,10242
22
+ wavedl/models/registry.py,sha256=InYAXX2xbRvsFDFnYUPCptJh0F9lHlFPN77A9kqHRT0,2980
23
+ wavedl/models/regnet.py,sha256=6Yjo2wZzdjK8VpOMagbCrHqmsfRmGkuiURmc-MesYvA,13777
24
+ wavedl/models/resnet.py,sha256=3i4zfE15qF4cd0qbTKX-Wdy2Kd0f4mLcdd316FAcVCo,16720
25
+ wavedl/models/resnet3d.py,sha256=edxLW4P4OBpZ5z9kMnWYV6qJ1GTkiqpwnW3-IqrPyqE,8510
26
+ wavedl/models/swin.py,sha256=39Gwn5hNEw3-tndc8qFFzV-VZ7pJMMKey2oZONAZ8MU,14980
27
+ wavedl/models/tcn.py,sha256=XzojpuMFG4lu_0oQHbQnkLAb7AnW-D7_6KoBlQDPLnQ,12367
28
+ wavedl/models/unet.py,sha256=oi7eBONSe0ALpJKsYda3jRGwu-LuSiFgNdURebnGGt0,7712
29
+ wavedl/models/unireplknet.py,sha256=jCy22m6mkApkLf3EzimMIqXy4xFs5WPUkaoz_KVWpqc,15205
30
+ wavedl/models/vit.py,sha256=5DXshtBdN2jYlH8MxWGTlIxP5lgbmfsdLSNchOvTaYk,14911
31
+ wavedl/utils/__init__.py,sha256=s5R9bRmJ8GNcJrD3OSAOXzwZJIXZbdYrAkZnus11sVQ,3300
32
+ wavedl/utils/config.py,sha256=MXkaVc1_zo8sDro8mjtK1MV65t2z8b1Z6fviwSorNiY,10534
33
+ wavedl/utils/constraints.py,sha256=V9Gyi8-uIMbLUWb2cOaHZD0SliWLxVrHZHFyo4HWK7g,18031
34
+ wavedl/utils/cross_validation.py,sha256=HfInyZ8gUROc_AyihYKzzUE0vnoPt_mFvAI2OPK4P54,17945
35
+ wavedl/utils/data.py,sha256=5ph2Pi8PKvuaSoJaXbFIL9WsX8pTN0A6P8FdmxvXdv4,63469
36
+ wavedl/utils/distributed.py,sha256=7wQ3mRjkp_xjPSxDWMnBf5dSkAGUaTzntxbz0BhC5v0,4145
37
+ wavedl/utils/losses.py,sha256=KWpU5S5noFzp3bLbcH9RNpkFPajy6fyTIh5cNjI-BYA,7038
38
+ wavedl/utils/metrics.py,sha256=YoqiXWOsUB9Y4_alj8CmHcTgnV4MFcH5PH4XlIC13HY,40304
39
+ wavedl/utils/optimizers.py,sha256=ZoETDSOK1fWUT2dx69PyYebeM8Vcqf9zOIKUERWk5HY,6107
40
+ wavedl/utils/schedulers.py,sha256=K6YCiyiMM9rb0cCRXTp89noXeXcAyUEiePr27O5Cozs,7408
41
+ wavedl-1.6.2.dist-info/LICENSE,sha256=cEUCvcvH-9BT9Y-CNGY__PwWONCKu9zsoIqWA-NeHJ4,1066
42
+ wavedl-1.6.2.dist-info/METADATA,sha256=2mTyuip32AneUURV3K8oAjZQ2rA_13AB16R-VyRN5s8,47659
43
+ wavedl-1.6.2.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
44
+ wavedl-1.6.2.dist-info/entry_points.txt,sha256=NuAvdiG93EYYpqv-_1wf6PN0WqBfABanDKalNKe2GOs,148
45
+ wavedl-1.6.2.dist-info/top_level.txt,sha256=ccneUt3D5Qzbh3bsBSSrq9bqrhGiogcWKY24ZC4Q6Xw,7
46
+ wavedl-1.6.2.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  [console_scripts]
2
- wavedl-hpc = wavedl.hpc:main
2
+ wavedl-hpc = wavedl.launcher:main
3
3
  wavedl-hpo = wavedl.hpo:main
4
4
  wavedl-test = wavedl.test:main
5
- wavedl-train = wavedl.train:main
5
+ wavedl-train = wavedl.launcher:main
@@ -1,44 +0,0 @@
1
- wavedl/__init__.py,sha256=aVEVBCcciyAZkpkVZjaY2BrgP5Pbx96x38_RRvv4H2Q,1177
2
- wavedl/hpc.py,sha256=6rV38nozzMt0-jKZbVJNwvQZXK0wUsIZmr9lgWN_XUw,9212
3
- wavedl/hpo.py,sha256=CZF0MZwTGMOrPGDveUXZFbGHwLHj1FcJTCBKVVEtLWg,15105
4
- wavedl/test.py,sha256=1UUy9phCqrr3h_lN6mGJ7Sj73skDg4KyLk2Yuq9DiKU,38797
5
- wavedl/train.py,sha256=xfA5fuug0bk-20o2MHpAXoWpGFmciSpWsE9C5RERpf8,59433
6
- wavedl/models/__init__.py,sha256=uBoH7JRZIYF2TxiZbdTw8x_I9fz_ZRaSnQPRG7HDyug,4462
7
- wavedl/models/_template.py,sha256=J_D8taSPmV8lBaucN_vU-WiG98iFr7CJrZVNNX_Tdts,4600
8
- wavedl/models/_timm_utils.py,sha256=yb_6ZiklFmNG3ETw3kw8BzGfo6DCdgizb_B7duLQEFs,8051
9
- wavedl/models/base.py,sha256=T9iDF9IQM2MYucG_ggQd31rieUkB2fob-nkHyNIl2ak,7337
10
- wavedl/models/caformer.py,sha256=H8T_UbO1gq0PZFMgWYaWq5qg_5sFf42coQ829ab7n3o,7916
11
- wavedl/models/cnn.py,sha256=rn2Xmup0w_ll6wuAnYclSeIVazoSUrUGPY-9XnhA1gE,8341
12
- wavedl/models/convnext.py,sha256=R72w6Vep-SIvzIYlAdQz38Gk8Zmg4wU1WyQ_ZFNdOk0,16116
13
- wavedl/models/convnext_v2.py,sha256=qj8SewFxOJ-JZiUJjzBDGmSw1wxEX7XnMBwf_yckhvI,15434
14
- wavedl/models/densenet.py,sha256=oVNKJPzoET43KJxJBhDnLkbJOjFBDWe_f_TqpgBetlY,13050
15
- wavedl/models/efficientnet.py,sha256=HWfhqSX57lC5Xug5TrQ3r-uFqkksoIKjmQ5Zr5njkEA,8264
16
- wavedl/models/efficientnetv2.py,sha256=mSJaHJwtQbtfsOFEuOCoQwUY2vh4CXgISqnobbABD_U,11262
17
- wavedl/models/fastvit.py,sha256=PrrNEN_q5uFHRcbY4LrzM2MwU3Y_C1cOqdv_oErRlm8,8539
18
- wavedl/models/mamba.py,sha256=ZavdpOLYZOIuCgyy2tFPCk0jiAtW7_mRKu8O9kqH3nY,15819
19
- wavedl/models/maxvit.py,sha256=yHPbFyEppEweSg4TwMbcrZQmJYHrpKtciTslfa_KhwY,7459
20
- wavedl/models/mobilenetv3.py,sha256=nj-OYXSfxLp_HkoMF2qzvaa8wwhmpNslWlpyknN-VKk,10537
21
- wavedl/models/registry.py,sha256=InYAXX2xbRvsFDFnYUPCptJh0F9lHlFPN77A9kqHRT0,2980
22
- wavedl/models/regnet.py,sha256=kZz9IVxPW_q0ZIFsMbD7H2DuW60h4pfdZmOTypvAkbg,14183
23
- wavedl/models/resnet.py,sha256=W27hx_g8_Jt6kmzRILZ4uYuhL4_c0Jro_yOLJ2ijm6g,18082
24
- wavedl/models/resnet3d.py,sha256=I2_4k2kEXfgSxpkocD2J0cLN2RRoPezrDzDyd_o5bDs,8768
25
- wavedl/models/swin.py,sha256=G_C7xQM2RIuEzrOrD2m_4VINUhmJNsntcu1WnKwHK68,15423
26
- wavedl/models/tcn.py,sha256=VZOzTnGbDyXZeULPU9VnGcN-4WcRbgAff7fKbGUVqrA,13214
27
- wavedl/models/unet.py,sha256=L5qPmSKRrybwSldXIuUCPdpY1KSkokbWsQIl1ZHABhg,7799
28
- wavedl/models/vit.py,sha256=nE2IWtSeMVxyKJreI7jyfS-ZqNG5g2AB7KBHKjLHKyc,14878
29
- wavedl/utils/__init__.py,sha256=s5R9bRmJ8GNcJrD3OSAOXzwZJIXZbdYrAkZnus11sVQ,3300
30
- wavedl/utils/config.py,sha256=AsGwb3XtxmbTLb59BLl5AA4wzMNgVTpl7urOJ6IGqfM,10901
31
- wavedl/utils/constraints.py,sha256=V9Gyi8-uIMbLUWb2cOaHZD0SliWLxVrHZHFyo4HWK7g,18031
32
- wavedl/utils/cross_validation.py,sha256=gwXSFTx5oxWndPjWLJAJzB6nnq2f1t9f86SbjbF-jNI,18475
33
- wavedl/utils/data.py,sha256=5ph2Pi8PKvuaSoJaXbFIL9WsX8pTN0A6P8FdmxvXdv4,63469
34
- wavedl/utils/distributed.py,sha256=7wQ3mRjkp_xjPSxDWMnBf5dSkAGUaTzntxbz0BhC5v0,4145
35
- wavedl/utils/losses.py,sha256=5762M-TBC_hz6uyj1NPbU1vZeFOJQq7fR3-j7OygJRo,7254
36
- wavedl/utils/metrics.py,sha256=YoqiXWOsUB9Y4_alj8CmHcTgnV4MFcH5PH4XlIC13HY,40304
37
- wavedl/utils/optimizers.py,sha256=PyIkJ_hRhFi_Fio81Gy5YQNhcME0JUUEl8OTSyu-0RA,6323
38
- wavedl/utils/schedulers.py,sha256=e6Sf0yj8VOqkdwkUHLMyUfGfHKTX4NMr-zfgxWqCTYI,7659
39
- wavedl-1.6.0.dist-info/LICENSE,sha256=cEUCvcvH-9BT9Y-CNGY__PwWONCKu9zsoIqWA-NeHJ4,1066
40
- wavedl-1.6.0.dist-info/METADATA,sha256=rYu2eVqVaFndhEngDzM0yr-U1MAlcH2zBjELaMY9xmU,46707
41
- wavedl-1.6.0.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
42
- wavedl-1.6.0.dist-info/entry_points.txt,sha256=f1RNDkXFZwBzrBzTMFocJ6xhfTvTmaEDTi5YyDEUaF8,140
43
- wavedl-1.6.0.dist-info/top_level.txt,sha256=ccneUt3D5Qzbh3bsBSSrq9bqrhGiogcWKY24ZC4Q6Xw,7
44
- wavedl-1.6.0.dist-info/RECORD,,
File without changes