wavedl 1.6.2__tar.gz → 1.7.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {wavedl-1.6.2/src/wavedl.egg-info → wavedl-1.7.0}/PKG-INFO +37 -18
- {wavedl-1.6.2 → wavedl-1.7.0}/README.md +36 -17
- {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/__init__.py +1 -1
- {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/hpo.py +115 -9
- {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/models/__init__.py +22 -0
- {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/models/_pretrained_utils.py +72 -0
- {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/models/_template.py +7 -6
- {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/models/cnn.py +20 -0
- {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/models/convnext.py +3 -70
- {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/models/convnext_v2.py +1 -18
- {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/models/mamba.py +126 -38
- {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/models/resnet3d.py +23 -5
- {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/models/unireplknet.py +1 -18
- {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/models/vit.py +18 -8
- {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/test.py +13 -23
- {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/train.py +494 -28
- {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/utils/__init__.py +49 -9
- {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/utils/config.py +6 -8
- {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/utils/cross_validation.py +17 -4
- {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/utils/data.py +176 -180
- {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/utils/metrics.py +26 -5
- {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/utils/schedulers.py +2 -2
- {wavedl-1.6.2 → wavedl-1.7.0/src/wavedl.egg-info}/PKG-INFO +37 -18
- {wavedl-1.6.2 → wavedl-1.7.0}/LICENSE +0 -0
- {wavedl-1.6.2 → wavedl-1.7.0}/pyproject.toml +0 -0
- {wavedl-1.6.2 → wavedl-1.7.0}/setup.cfg +0 -0
- {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/launcher.py +0 -0
- {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/models/base.py +0 -0
- {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/models/caformer.py +0 -0
- {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/models/densenet.py +0 -0
- {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/models/efficientnet.py +0 -0
- {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/models/efficientnetv2.py +0 -0
- {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/models/efficientvit.py +0 -0
- {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/models/fastvit.py +0 -0
- {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/models/maxvit.py +0 -0
- {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/models/mobilenetv3.py +0 -0
- {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/models/registry.py +0 -0
- {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/models/regnet.py +0 -0
- {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/models/resnet.py +0 -0
- {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/models/swin.py +0 -0
- {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/models/tcn.py +0 -0
- {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/models/unet.py +0 -0
- {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/utils/constraints.py +0 -0
- {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/utils/distributed.py +0 -0
- {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/utils/losses.py +0 -0
- {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/utils/optimizers.py +0 -0
- {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl.egg-info/SOURCES.txt +0 -0
- {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl.egg-info/dependency_links.txt +0 -0
- {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl.egg-info/entry_points.txt +0 -0
- {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl.egg-info/requires.txt +0 -0
- {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: wavedl
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.7.0
|
|
4
4
|
Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
|
|
5
5
|
Author: Ductho Le
|
|
6
6
|
License: MIT
|
|
@@ -214,11 +214,11 @@ This installs everything you need: training, inference, HPO, ONNX export.
|
|
|
214
214
|
```bash
|
|
215
215
|
git clone https://github.com/ductho-le/WaveDL.git
|
|
216
216
|
cd WaveDL
|
|
217
|
-
pip install -e .
|
|
217
|
+
pip install -e ".[dev]"
|
|
218
218
|
```
|
|
219
219
|
|
|
220
220
|
> [!NOTE]
|
|
221
|
-
> Python 3.11+ required. For
|
|
221
|
+
> Python 3.11+ required. For contributor setup (pre-commit hooks), see [CONTRIBUTING.md](.github/CONTRIBUTING.md).
|
|
222
222
|
|
|
223
223
|
### Quick Start
|
|
224
224
|
|
|
@@ -273,8 +273,6 @@ accelerate launch --num_machines 2 --main_process_ip <ip> -m wavedl.train --mode
|
|
|
273
273
|
|
|
274
274
|
### Testing & Inference
|
|
275
275
|
|
|
276
|
-
After training, use `wavedl-test` to evaluate your model on test data:
|
|
277
|
-
|
|
278
276
|
```bash
|
|
279
277
|
# Basic inference
|
|
280
278
|
wavedl-test --checkpoint <checkpoint_folder> --data_path <test_data>
|
|
@@ -909,18 +907,24 @@ Automatically find the best training configuration using [Optuna](https://optuna
|
|
|
909
907
|
**Run HPO:**
|
|
910
908
|
|
|
911
909
|
```bash
|
|
912
|
-
# Basic HPO (auto-detects GPUs
|
|
913
|
-
wavedl-hpo --data_path train.npz --
|
|
910
|
+
# Basic HPO (50 trials, auto-detects GPUs)
|
|
911
|
+
wavedl-hpo --data_path train.npz --n_trials 50
|
|
912
|
+
|
|
913
|
+
# Quick search (minimal search space, fastest)
|
|
914
|
+
wavedl-hpo --data_path train.npz --n_trials 30 --quick
|
|
914
915
|
|
|
915
|
-
#
|
|
916
|
-
wavedl-hpo --data_path train.npz --
|
|
916
|
+
# Medium search (balanced between quick and full)
|
|
917
|
+
wavedl-hpo --data_path train.npz --n_trials 50 --medium
|
|
917
918
|
|
|
918
|
-
#
|
|
919
|
-
wavedl-hpo --data_path train.npz --models cnn
|
|
919
|
+
# Full search with specific models
|
|
920
|
+
wavedl-hpo --data_path train.npz --n_trials 100 --models cnn resnet18 efficientnet_b0
|
|
921
|
+
|
|
922
|
+
# In-process mode (enables pruning, faster, single-GPU)
|
|
923
|
+
wavedl-hpo --data_path train.npz --n_trials 50 --inprocess
|
|
920
924
|
```
|
|
921
925
|
|
|
922
926
|
> [!TIP]
|
|
923
|
-
> **
|
|
927
|
+
> **GPU Detection**: HPO auto-detects GPUs and runs one trial per GPU in parallel. Use `--inprocess` for single-GPU with pruning support (early stopping of bad trials).
|
|
924
928
|
|
|
925
929
|
**Train with best parameters**
|
|
926
930
|
|
|
@@ -942,10 +946,23 @@ wavedl-train --data_path train.npz --model cnn --lr 3.2e-4 --batch_size 128 ...
|
|
|
942
946
|
| Learning rate | 1e-5 → 1e-2 | (always searched) |
|
|
943
947
|
| Batch size | 16, 32, 64, 128 | (always searched) |
|
|
944
948
|
|
|
945
|
-
**
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
|
|
949
|
+
**Search Presets:**
|
|
950
|
+
|
|
951
|
+
| Mode | Models | Optimizers | Schedulers | Use Case |
|
|
952
|
+
|------|--------|------------|------------|----------|
|
|
953
|
+
| Full (default) | cnn, resnet18, resnet34 | all 6 | all 8 | Production search |
|
|
954
|
+
| `--medium` | cnn, resnet18 | adamw, adam, sgd | plateau, cosine, onecycle | Balanced exploration |
|
|
955
|
+
| `--quick` | cnn | adamw | plateau | Fast validation |
|
|
956
|
+
|
|
957
|
+
**Execution Modes:**
|
|
958
|
+
|
|
959
|
+
| Mode | Flag | Pruning | GPU Memory | Best For |
|
|
960
|
+
|------|------|---------|------------|----------|
|
|
961
|
+
| Subprocess (default) | — | ❌ No | Isolated | Multi-GPU parallel trials |
|
|
962
|
+
| In-process | `--inprocess` | ✅ Yes | Shared | Single-GPU with early stopping |
|
|
963
|
+
|
|
964
|
+
> [!TIP]
|
|
965
|
+
> Use `--inprocess` when running single-GPU trials. It enables MedianPruner to stop unpromising trials early, reducing total search time.
|
|
949
966
|
|
|
950
967
|
---
|
|
951
968
|
|
|
@@ -956,7 +973,9 @@ wavedl-train --data_path train.npz --model cnn --lr 3.2e-4 --batch_size 128 ...
|
|
|
956
973
|
| `--data_path` | (required) | Training data file |
|
|
957
974
|
| `--models` | 3 defaults | Models to search (specify any number) |
|
|
958
975
|
| `--n_trials` | `50` | Number of trials to run |
|
|
959
|
-
| `--quick` | `False` |
|
|
976
|
+
| `--quick` | `False` | Quick mode: minimal search space |
|
|
977
|
+
| `--medium` | `False` | Medium mode: balanced search space |
|
|
978
|
+
| `--inprocess` | `False` | Run trials in-process (enables pruning) |
|
|
960
979
|
| `--optimizers` | all 6 | Optimizers to search |
|
|
961
980
|
| `--schedulers` | all 8 | Schedulers to search |
|
|
962
981
|
| `--losses` | all 6 | Losses to search |
|
|
@@ -965,7 +984,7 @@ wavedl-train --data_path train.npz --model cnn --lr 3.2e-4 --batch_size 128 ...
|
|
|
965
984
|
| `--output` | `hpo_results.json` | Output file |
|
|
966
985
|
|
|
967
986
|
|
|
968
|
-
> See [Available Models](#available-models) for all
|
|
987
|
+
> See [Available Models](#available-models) for all 69 architectures you can search.
|
|
969
988
|
|
|
970
989
|
</details>
|
|
971
990
|
|
|
@@ -166,11 +166,11 @@ This installs everything you need: training, inference, HPO, ONNX export.
|
|
|
166
166
|
```bash
|
|
167
167
|
git clone https://github.com/ductho-le/WaveDL.git
|
|
168
168
|
cd WaveDL
|
|
169
|
-
pip install -e .
|
|
169
|
+
pip install -e ".[dev]"
|
|
170
170
|
```
|
|
171
171
|
|
|
172
172
|
> [!NOTE]
|
|
173
|
-
> Python 3.11+ required. For
|
|
173
|
+
> Python 3.11+ required. For contributor setup (pre-commit hooks), see [CONTRIBUTING.md](.github/CONTRIBUTING.md).
|
|
174
174
|
|
|
175
175
|
### Quick Start
|
|
176
176
|
|
|
@@ -225,8 +225,6 @@ accelerate launch --num_machines 2 --main_process_ip <ip> -m wavedl.train --mode
|
|
|
225
225
|
|
|
226
226
|
### Testing & Inference
|
|
227
227
|
|
|
228
|
-
After training, use `wavedl-test` to evaluate your model on test data:
|
|
229
|
-
|
|
230
228
|
```bash
|
|
231
229
|
# Basic inference
|
|
232
230
|
wavedl-test --checkpoint <checkpoint_folder> --data_path <test_data>
|
|
@@ -861,18 +859,24 @@ Automatically find the best training configuration using [Optuna](https://optuna
|
|
|
861
859
|
**Run HPO:**
|
|
862
860
|
|
|
863
861
|
```bash
|
|
864
|
-
# Basic HPO (auto-detects GPUs
|
|
865
|
-
wavedl-hpo --data_path train.npz --
|
|
862
|
+
# Basic HPO (50 trials, auto-detects GPUs)
|
|
863
|
+
wavedl-hpo --data_path train.npz --n_trials 50
|
|
864
|
+
|
|
865
|
+
# Quick search (minimal search space, fastest)
|
|
866
|
+
wavedl-hpo --data_path train.npz --n_trials 30 --quick
|
|
866
867
|
|
|
867
|
-
#
|
|
868
|
-
wavedl-hpo --data_path train.npz --
|
|
868
|
+
# Medium search (balanced between quick and full)
|
|
869
|
+
wavedl-hpo --data_path train.npz --n_trials 50 --medium
|
|
869
870
|
|
|
870
|
-
#
|
|
871
|
-
wavedl-hpo --data_path train.npz --models cnn
|
|
871
|
+
# Full search with specific models
|
|
872
|
+
wavedl-hpo --data_path train.npz --n_trials 100 --models cnn resnet18 efficientnet_b0
|
|
873
|
+
|
|
874
|
+
# In-process mode (enables pruning, faster, single-GPU)
|
|
875
|
+
wavedl-hpo --data_path train.npz --n_trials 50 --inprocess
|
|
872
876
|
```
|
|
873
877
|
|
|
874
878
|
> [!TIP]
|
|
875
|
-
> **
|
|
879
|
+
> **GPU Detection**: HPO auto-detects GPUs and runs one trial per GPU in parallel. Use `--inprocess` for single-GPU with pruning support (early stopping of bad trials).
|
|
876
880
|
|
|
877
881
|
**Train with best parameters**
|
|
878
882
|
|
|
@@ -894,10 +898,23 @@ wavedl-train --data_path train.npz --model cnn --lr 3.2e-4 --batch_size 128 ...
|
|
|
894
898
|
| Learning rate | 1e-5 → 1e-2 | (always searched) |
|
|
895
899
|
| Batch size | 16, 32, 64, 128 | (always searched) |
|
|
896
900
|
|
|
897
|
-
**
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
|
|
901
|
+
**Search Presets:**
|
|
902
|
+
|
|
903
|
+
| Mode | Models | Optimizers | Schedulers | Use Case |
|
|
904
|
+
|------|--------|------------|------------|----------|
|
|
905
|
+
| Full (default) | cnn, resnet18, resnet34 | all 6 | all 8 | Production search |
|
|
906
|
+
| `--medium` | cnn, resnet18 | adamw, adam, sgd | plateau, cosine, onecycle | Balanced exploration |
|
|
907
|
+
| `--quick` | cnn | adamw | plateau | Fast validation |
|
|
908
|
+
|
|
909
|
+
**Execution Modes:**
|
|
910
|
+
|
|
911
|
+
| Mode | Flag | Pruning | GPU Memory | Best For |
|
|
912
|
+
|------|------|---------|------------|----------|
|
|
913
|
+
| Subprocess (default) | — | ❌ No | Isolated | Multi-GPU parallel trials |
|
|
914
|
+
| In-process | `--inprocess` | ✅ Yes | Shared | Single-GPU with early stopping |
|
|
915
|
+
|
|
916
|
+
> [!TIP]
|
|
917
|
+
> Use `--inprocess` when running single-GPU trials. It enables MedianPruner to stop unpromising trials early, reducing total search time.
|
|
901
918
|
|
|
902
919
|
---
|
|
903
920
|
|
|
@@ -908,7 +925,9 @@ wavedl-train --data_path train.npz --model cnn --lr 3.2e-4 --batch_size 128 ...
|
|
|
908
925
|
| `--data_path` | (required) | Training data file |
|
|
909
926
|
| `--models` | 3 defaults | Models to search (specify any number) |
|
|
910
927
|
| `--n_trials` | `50` | Number of trials to run |
|
|
911
|
-
| `--quick` | `False` |
|
|
928
|
+
| `--quick` | `False` | Quick mode: minimal search space |
|
|
929
|
+
| `--medium` | `False` | Medium mode: balanced search space |
|
|
930
|
+
| `--inprocess` | `False` | Run trials in-process (enables pruning) |
|
|
912
931
|
| `--optimizers` | all 6 | Optimizers to search |
|
|
913
932
|
| `--schedulers` | all 8 | Schedulers to search |
|
|
914
933
|
| `--losses` | all 6 | Losses to search |
|
|
@@ -917,7 +936,7 @@ wavedl-train --data_path train.npz --model cnn --lr 3.2e-4 --batch_size 128 ...
|
|
|
917
936
|
| `--output` | `hpo_results.json` | Output file |
|
|
918
937
|
|
|
919
938
|
|
|
920
|
-
> See [Available Models](#available-models) for all
|
|
939
|
+
> See [Available Models](#available-models) for all 69 architectures you can search.
|
|
921
940
|
|
|
922
941
|
</details>
|
|
923
942
|
|
|
@@ -10,12 +10,28 @@ Usage:
|
|
|
10
10
|
# Quick search (fewer parameters)
|
|
11
11
|
wavedl-hpo --data_path train.npz --n_trials 30 --quick
|
|
12
12
|
|
|
13
|
+
# Medium search (balanced)
|
|
14
|
+
wavedl-hpo --data_path train.npz --n_trials 50 --medium
|
|
15
|
+
|
|
13
16
|
# Full search with specific models
|
|
14
17
|
wavedl-hpo --data_path train.npz --n_trials 100 --models cnn resnet18 efficientnet_b0
|
|
15
18
|
|
|
16
19
|
# Parallel trials on multiple GPUs
|
|
17
20
|
wavedl-hpo --data_path train.npz --n_trials 100 --n_jobs 4
|
|
18
21
|
|
|
22
|
+
# In-process mode (enables pruning, faster, single-GPU)
|
|
23
|
+
wavedl-hpo --data_path train.npz --n_trials 50 --inprocess
|
|
24
|
+
|
|
25
|
+
Execution Modes:
|
|
26
|
+
--inprocess: Runs trials in the same Python process. Enables pruning
|
|
27
|
+
(MedianPruner) for early stopping of unpromising trials.
|
|
28
|
+
Faster due to no subprocess overhead, but trials share
|
|
29
|
+
GPU memory (no isolation between trials).
|
|
30
|
+
|
|
31
|
+
Default (subprocess): Launches each trial as a separate process.
|
|
32
|
+
Provides GPU memory isolation but prevents pruning
|
|
33
|
+
(subprocess can't report intermediate results).
|
|
34
|
+
|
|
19
35
|
Author: Ductho Le (ductho.le@outlook.com)
|
|
20
36
|
"""
|
|
21
37
|
|
|
@@ -41,10 +57,12 @@ except ImportError:
|
|
|
41
57
|
|
|
42
58
|
DEFAULT_MODELS = ["cnn", "resnet18", "resnet34"]
|
|
43
59
|
QUICK_MODELS = ["cnn"]
|
|
60
|
+
MEDIUM_MODELS = ["cnn", "resnet18"]
|
|
44
61
|
|
|
45
62
|
# All 6 optimizers
|
|
46
63
|
DEFAULT_OPTIMIZERS = ["adamw", "adam", "sgd", "nadam", "radam", "rmsprop"]
|
|
47
64
|
QUICK_OPTIMIZERS = ["adamw"]
|
|
65
|
+
MEDIUM_OPTIMIZERS = ["adamw", "adam", "sgd"]
|
|
48
66
|
|
|
49
67
|
# All 8 schedulers
|
|
50
68
|
DEFAULT_SCHEDULERS = [
|
|
@@ -58,10 +76,12 @@ DEFAULT_SCHEDULERS = [
|
|
|
58
76
|
"linear_warmup",
|
|
59
77
|
]
|
|
60
78
|
QUICK_SCHEDULERS = ["plateau"]
|
|
79
|
+
MEDIUM_SCHEDULERS = ["plateau", "cosine", "onecycle"]
|
|
61
80
|
|
|
62
81
|
# All 6 losses
|
|
63
82
|
DEFAULT_LOSSES = ["mse", "mae", "huber", "smooth_l1", "log_cosh", "weighted_mse"]
|
|
64
83
|
QUICK_LOSSES = ["mse"]
|
|
84
|
+
MEDIUM_LOSSES = ["mse", "mae", "huber"]
|
|
65
85
|
|
|
66
86
|
|
|
67
87
|
# =============================================================================
|
|
@@ -70,16 +90,28 @@ QUICK_LOSSES = ["mse"]
|
|
|
70
90
|
|
|
71
91
|
|
|
72
92
|
def create_objective(args):
|
|
73
|
-
"""Create Optuna objective function with configurable search space.
|
|
93
|
+
"""Create Optuna objective function with configurable search space.
|
|
94
|
+
|
|
95
|
+
Supports two execution modes:
|
|
96
|
+
- Subprocess (default): Launches wavedl.train via subprocess. Provides GPU
|
|
97
|
+
memory isolation but prevents pruning (MedianPruner has no effect).
|
|
98
|
+
- In-process (--inprocess): Calls train_single_trial() directly. Enables
|
|
99
|
+
pruning and reduces overhead, but trials share GPU memory.
|
|
100
|
+
"""
|
|
74
101
|
|
|
75
102
|
def objective(trial):
|
|
76
|
-
# Select search space based on mode
|
|
103
|
+
# Select search space based on mode (quick < medium < full)
|
|
77
104
|
# CLI arguments always take precedence over defaults
|
|
78
105
|
if args.quick:
|
|
79
106
|
models = args.models or QUICK_MODELS
|
|
80
107
|
optimizers = args.optimizers or QUICK_OPTIMIZERS
|
|
81
108
|
schedulers = args.schedulers or QUICK_SCHEDULERS
|
|
82
109
|
losses = args.losses or QUICK_LOSSES
|
|
110
|
+
elif args.medium:
|
|
111
|
+
models = args.models or MEDIUM_MODELS
|
|
112
|
+
optimizers = args.optimizers or MEDIUM_OPTIMIZERS
|
|
113
|
+
schedulers = args.schedulers or MEDIUM_SCHEDULERS
|
|
114
|
+
losses = args.losses or MEDIUM_LOSSES
|
|
83
115
|
else:
|
|
84
116
|
models = args.models or DEFAULT_MODELS
|
|
85
117
|
optimizers = args.optimizers or DEFAULT_OPTIMIZERS
|
|
@@ -101,13 +133,59 @@ def create_objective(args):
|
|
|
101
133
|
if loss == "huber":
|
|
102
134
|
huber_delta = trial.suggest_float("huber_delta", 0.1, 2.0)
|
|
103
135
|
else:
|
|
104
|
-
huber_delta =
|
|
136
|
+
huber_delta = 1.0 # default
|
|
105
137
|
|
|
106
138
|
if optimizer == "sgd":
|
|
107
139
|
momentum = trial.suggest_float("momentum", 0.8, 0.99)
|
|
108
140
|
else:
|
|
109
|
-
momentum =
|
|
141
|
+
momentum = 0.9 # default
|
|
142
|
+
|
|
143
|
+
# ==================================================================
|
|
144
|
+
# IN-PROCESS MODE: Direct function call with pruning support
|
|
145
|
+
# ==================================================================
|
|
146
|
+
if args.inprocess:
|
|
147
|
+
from wavedl.train import train_single_trial
|
|
148
|
+
|
|
149
|
+
try:
|
|
150
|
+
result = train_single_trial(
|
|
151
|
+
data_path=args.data_path,
|
|
152
|
+
model_name=model,
|
|
153
|
+
lr=lr,
|
|
154
|
+
batch_size=batch_size,
|
|
155
|
+
epochs=args.max_epochs,
|
|
156
|
+
patience=patience,
|
|
157
|
+
optimizer_name=optimizer,
|
|
158
|
+
scheduler_name=scheduler,
|
|
159
|
+
loss_name=loss,
|
|
160
|
+
weight_decay=weight_decay,
|
|
161
|
+
seed=args.seed,
|
|
162
|
+
huber_delta=huber_delta,
|
|
163
|
+
momentum=momentum,
|
|
164
|
+
trial=trial, # Enable pruning via trial.report/should_prune
|
|
165
|
+
verbose=False,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
if result["pruned"]:
|
|
169
|
+
print(
|
|
170
|
+
f"Trial {trial.number}: Pruned at epoch {result['epochs_trained']}"
|
|
171
|
+
)
|
|
172
|
+
raise optuna.TrialPruned()
|
|
173
|
+
|
|
174
|
+
val_loss = result["best_val_loss"]
|
|
175
|
+
print(
|
|
176
|
+
f"Trial {trial.number}: val_loss={val_loss:.6f} ({result['epochs_trained']} epochs)"
|
|
177
|
+
)
|
|
178
|
+
return val_loss
|
|
110
179
|
|
|
180
|
+
except optuna.TrialPruned:
|
|
181
|
+
raise # Re-raise for Optuna to handle
|
|
182
|
+
except Exception as e:
|
|
183
|
+
print(f"Trial {trial.number}: Error - {e}")
|
|
184
|
+
return float("inf")
|
|
185
|
+
|
|
186
|
+
# ==================================================================
|
|
187
|
+
# SUBPROCESS MODE (default): GPU memory isolation, no pruning
|
|
188
|
+
# ==================================================================
|
|
111
189
|
# Build command
|
|
112
190
|
cmd = [
|
|
113
191
|
sys.executable,
|
|
@@ -138,9 +216,9 @@ def create_objective(args):
|
|
|
138
216
|
]
|
|
139
217
|
|
|
140
218
|
# Add conditional args
|
|
141
|
-
if
|
|
219
|
+
if loss == "huber":
|
|
142
220
|
cmd.extend(["--huber_delta", str(huber_delta)])
|
|
143
|
-
if
|
|
221
|
+
if optimizer == "sgd":
|
|
144
222
|
cmd.extend(["--momentum", str(momentum)])
|
|
145
223
|
|
|
146
224
|
# Use temporary directory for trial output
|
|
@@ -285,7 +363,17 @@ Examples:
|
|
|
285
363
|
parser.add_argument(
|
|
286
364
|
"--quick",
|
|
287
365
|
action="store_true",
|
|
288
|
-
help="Quick mode: search fewer parameters",
|
|
366
|
+
help="Quick mode: search fewer parameters (fastest, least thorough)",
|
|
367
|
+
)
|
|
368
|
+
parser.add_argument(
|
|
369
|
+
"--medium",
|
|
370
|
+
action="store_true",
|
|
371
|
+
help="Medium mode: balanced parameter search (between --quick and full)",
|
|
372
|
+
)
|
|
373
|
+
parser.add_argument(
|
|
374
|
+
"--inprocess",
|
|
375
|
+
action="store_true",
|
|
376
|
+
help="Run trials in-process (enables pruning, faster, but no GPU memory isolation)",
|
|
289
377
|
)
|
|
290
378
|
parser.add_argument(
|
|
291
379
|
"--timeout",
|
|
@@ -384,14 +472,32 @@ Examples:
|
|
|
384
472
|
print("=" * 60)
|
|
385
473
|
print(f"Data: {args.data_path}")
|
|
386
474
|
print(f"Trials: {args.n_trials}")
|
|
387
|
-
|
|
475
|
+
# Determine mode name for display
|
|
476
|
+
if args.quick:
|
|
477
|
+
mode_name = "Quick"
|
|
478
|
+
elif args.medium:
|
|
479
|
+
mode_name = "Medium"
|
|
480
|
+
else:
|
|
481
|
+
mode_name = "Full"
|
|
482
|
+
|
|
483
|
+
print(
|
|
484
|
+
f"Mode: {mode_name}"
|
|
485
|
+
+ (" (in-process, pruning enabled)" if args.inprocess else " (subprocess)")
|
|
486
|
+
)
|
|
388
487
|
print(f"Parallel jobs: {args.n_jobs}")
|
|
389
488
|
print("=" * 60)
|
|
390
489
|
|
|
490
|
+
# Use MedianPruner only for in-process mode (subprocess trials can't report)
|
|
491
|
+
if args.inprocess:
|
|
492
|
+
pruner = optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=10)
|
|
493
|
+
else:
|
|
494
|
+
# NopPruner for subprocess mode - pruning has no effect there
|
|
495
|
+
pruner = optuna.pruners.NopPruner()
|
|
496
|
+
|
|
391
497
|
study = optuna.create_study(
|
|
392
498
|
study_name=args.study_name,
|
|
393
499
|
direction="minimize",
|
|
394
|
-
pruner=
|
|
500
|
+
pruner=pruner,
|
|
395
501
|
)
|
|
396
502
|
|
|
397
503
|
# Run optimization
|
|
@@ -77,6 +77,15 @@ from .unet import UNetRegression
|
|
|
77
77
|
from .vit import ViTBase_, ViTSmall, ViTTiny
|
|
78
78
|
|
|
79
79
|
|
|
80
|
+
# Optional RATENet (unpublished, may be gitignored)
|
|
81
|
+
try:
|
|
82
|
+
from .ratenet import RATENet, RATENetLite, RATENetTiny, RATENetV2
|
|
83
|
+
|
|
84
|
+
_HAS_RATENET = True
|
|
85
|
+
except ImportError:
|
|
86
|
+
_HAS_RATENET = False
|
|
87
|
+
|
|
88
|
+
|
|
80
89
|
# Optional timm-based models (imported conditionally)
|
|
81
90
|
try:
|
|
82
91
|
from .caformer import CaFormerS18, CaFormerS36, PoolFormerS12
|
|
@@ -111,6 +120,7 @@ __all__ = [
|
|
|
111
120
|
"MC3_18",
|
|
112
121
|
"MODEL_REGISTRY",
|
|
113
122
|
"TCN",
|
|
123
|
+
# Classes (uppercase first, alphabetically)
|
|
114
124
|
"BaseModel",
|
|
115
125
|
"ConvNeXtBase_",
|
|
116
126
|
"ConvNeXtSmall",
|
|
@@ -152,6 +162,7 @@ __all__ = [
|
|
|
152
162
|
"VimBase",
|
|
153
163
|
"VimSmall",
|
|
154
164
|
"VimTiny",
|
|
165
|
+
# Functions (lowercase, alphabetically)
|
|
155
166
|
"build_model",
|
|
156
167
|
"get_model",
|
|
157
168
|
"list_models",
|
|
@@ -186,3 +197,14 @@ if _HAS_TIMM_MODELS:
|
|
|
186
197
|
"UniRepLKNetTiny",
|
|
187
198
|
]
|
|
188
199
|
)
|
|
200
|
+
|
|
201
|
+
# Add RATENet models to __all__ if available (unpublished)
|
|
202
|
+
if _HAS_RATENET:
|
|
203
|
+
__all__.extend(
|
|
204
|
+
[
|
|
205
|
+
"RATENet",
|
|
206
|
+
"RATENetLite",
|
|
207
|
+
"RATENetTiny",
|
|
208
|
+
"RATENetV2",
|
|
209
|
+
]
|
|
210
|
+
)
|
|
@@ -166,6 +166,78 @@ class LayerNormNd(nn.Module):
|
|
|
166
166
|
return x
|
|
167
167
|
|
|
168
168
|
|
|
169
|
+
# =============================================================================
|
|
170
|
+
# STOCHASTIC DEPTH (DropPath)
|
|
171
|
+
# =============================================================================
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
class DropPath(nn.Module):
|
|
175
|
+
"""
|
|
176
|
+
Stochastic Depth (drop path) regularization for residual networks.
|
|
177
|
+
|
|
178
|
+
Randomly drops entire residual branches during training. Used in modern
|
|
179
|
+
architectures like ConvNeXt, Swin Transformer, UniRepLKNet.
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
drop_prob: Probability of dropping the path (default: 0.0)
|
|
183
|
+
|
|
184
|
+
Reference:
|
|
185
|
+
Huang, G., et al. (2016). Deep Networks with Stochastic Depth.
|
|
186
|
+
https://arxiv.org/abs/1603.09382
|
|
187
|
+
"""
|
|
188
|
+
|
|
189
|
+
def __init__(self, drop_prob: float = 0.0):
|
|
190
|
+
super().__init__()
|
|
191
|
+
self.drop_prob = drop_prob
|
|
192
|
+
|
|
193
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
194
|
+
if self.drop_prob == 0.0 or not self.training:
|
|
195
|
+
return x
|
|
196
|
+
|
|
197
|
+
keep_prob = 1 - self.drop_prob
|
|
198
|
+
# Shape: (batch_size, 1, 1, ...) for broadcasting
|
|
199
|
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
|
200
|
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
|
201
|
+
random_tensor.floor_() # Binarize
|
|
202
|
+
return x.div(keep_prob) * random_tensor
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
# =============================================================================
|
|
206
|
+
# BACKBONE FREEZING UTILITIES
|
|
207
|
+
# =============================================================================
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def freeze_backbone(
|
|
211
|
+
model: nn.Module,
|
|
212
|
+
exclude_patterns: list[str] | None = None,
|
|
213
|
+
) -> int:
|
|
214
|
+
"""
|
|
215
|
+
Freeze backbone parameters, keeping specified layers trainable.
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
model: The model whose parameters to freeze
|
|
219
|
+
exclude_patterns: List of patterns to exclude from freezing.
|
|
220
|
+
Parameters with names containing any of these patterns stay trainable.
|
|
221
|
+
Default: ["classifier", "head", "fc"]
|
|
222
|
+
|
|
223
|
+
Returns:
|
|
224
|
+
Number of parameters frozen
|
|
225
|
+
|
|
226
|
+
Example:
|
|
227
|
+
>>> freeze_backbone(model.backbone, exclude_patterns=["fc", "classifier"])
|
|
228
|
+
"""
|
|
229
|
+
if exclude_patterns is None:
|
|
230
|
+
exclude_patterns = ["classifier", "head", "fc"]
|
|
231
|
+
|
|
232
|
+
frozen_count = 0
|
|
233
|
+
for name, param in model.named_parameters():
|
|
234
|
+
if not any(pattern in name for pattern in exclude_patterns):
|
|
235
|
+
param.requires_grad = False
|
|
236
|
+
frozen_count += param.numel()
|
|
237
|
+
|
|
238
|
+
return frozen_count
|
|
239
|
+
|
|
240
|
+
|
|
169
241
|
# =============================================================================
|
|
170
242
|
# REGRESSION HEAD BUILDERS
|
|
171
243
|
# =============================================================================
|
|
@@ -31,22 +31,23 @@ from wavedl.models.base import BaseModel
|
|
|
31
31
|
# @register_model("my_model")
|
|
32
32
|
class TemplateModel(BaseModel):
|
|
33
33
|
"""
|
|
34
|
-
Template Model Architecture.
|
|
34
|
+
Template Model Architecture (2D only).
|
|
35
35
|
|
|
36
36
|
Replace this docstring with your model description.
|
|
37
37
|
The first line will appear in --list_models output.
|
|
38
38
|
|
|
39
|
+
NOTE: This template is hardcoded for 2D inputs using Conv2d/MaxPool2d.
|
|
40
|
+
For 1D/3D support, use dimension-agnostic layer factories from
|
|
41
|
+
_pretrained_utils.py (get_conv_layer, get_pool_layer, get_norm_layer).
|
|
42
|
+
|
|
39
43
|
Args:
|
|
40
|
-
in_shape: Input spatial dimensions (
|
|
41
|
-
- 1D: (L,) for signals
|
|
42
|
-
- 2D: (H, W) for images
|
|
43
|
-
- 3D: (D, H, W) for volumes
|
|
44
|
+
in_shape: Input spatial dimensions as (H, W) for 2D images
|
|
44
45
|
out_size: Number of regression targets (auto-detected from data)
|
|
45
46
|
hidden_dim: Size of hidden layers (default: 256)
|
|
46
47
|
dropout: Dropout rate (default: 0.1)
|
|
47
48
|
|
|
48
49
|
Input Shape:
|
|
49
|
-
(B, 1,
|
|
50
|
+
(B, 1, H, W) - 2D grayscale images
|
|
50
51
|
|
|
51
52
|
Output Shape:
|
|
52
53
|
(B, out_size) - Regression predictions
|
|
@@ -159,6 +159,26 @@ class CNN(BaseModel):
|
|
|
159
159
|
nn.Linear(64, out_size),
|
|
160
160
|
)
|
|
161
161
|
|
|
162
|
+
# Initialize weights
|
|
163
|
+
self._init_weights()
|
|
164
|
+
|
|
165
|
+
def _init_weights(self):
|
|
166
|
+
"""Initialize weights with Kaiming for conv, Xavier for linear."""
|
|
167
|
+
for m in self.modules():
|
|
168
|
+
if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
|
169
|
+
nn.init.kaiming_normal_(
|
|
170
|
+
m.weight, mode="fan_out", nonlinearity="leaky_relu"
|
|
171
|
+
)
|
|
172
|
+
if m.bias is not None:
|
|
173
|
+
nn.init.zeros_(m.bias)
|
|
174
|
+
elif isinstance(m, nn.Linear):
|
|
175
|
+
nn.init.xavier_uniform_(m.weight)
|
|
176
|
+
if m.bias is not None:
|
|
177
|
+
nn.init.zeros_(m.bias)
|
|
178
|
+
elif isinstance(m, (nn.GroupNorm, nn.LayerNorm)):
|
|
179
|
+
nn.init.ones_(m.weight)
|
|
180
|
+
nn.init.zeros_(m.bias)
|
|
181
|
+
|
|
162
182
|
def _make_conv_block(
|
|
163
183
|
self, in_channels: int, out_channels: int, dropout: float = 0.0
|
|
164
184
|
) -> nn.Sequential:
|