wavedl 1.6.2__tar.gz → 1.7.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. {wavedl-1.6.2/src/wavedl.egg-info → wavedl-1.7.0}/PKG-INFO +37 -18
  2. {wavedl-1.6.2 → wavedl-1.7.0}/README.md +36 -17
  3. {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/__init__.py +1 -1
  4. {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/hpo.py +115 -9
  5. {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/models/__init__.py +22 -0
  6. {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/models/_pretrained_utils.py +72 -0
  7. {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/models/_template.py +7 -6
  8. {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/models/cnn.py +20 -0
  9. {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/models/convnext.py +3 -70
  10. {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/models/convnext_v2.py +1 -18
  11. {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/models/mamba.py +126 -38
  12. {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/models/resnet3d.py +23 -5
  13. {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/models/unireplknet.py +1 -18
  14. {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/models/vit.py +18 -8
  15. {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/test.py +13 -23
  16. {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/train.py +494 -28
  17. {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/utils/__init__.py +49 -9
  18. {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/utils/config.py +6 -8
  19. {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/utils/cross_validation.py +17 -4
  20. {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/utils/data.py +176 -180
  21. {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/utils/metrics.py +26 -5
  22. {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/utils/schedulers.py +2 -2
  23. {wavedl-1.6.2 → wavedl-1.7.0/src/wavedl.egg-info}/PKG-INFO +37 -18
  24. {wavedl-1.6.2 → wavedl-1.7.0}/LICENSE +0 -0
  25. {wavedl-1.6.2 → wavedl-1.7.0}/pyproject.toml +0 -0
  26. {wavedl-1.6.2 → wavedl-1.7.0}/setup.cfg +0 -0
  27. {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/launcher.py +0 -0
  28. {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/models/base.py +0 -0
  29. {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/models/caformer.py +0 -0
  30. {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/models/densenet.py +0 -0
  31. {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/models/efficientnet.py +0 -0
  32. {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/models/efficientnetv2.py +0 -0
  33. {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/models/efficientvit.py +0 -0
  34. {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/models/fastvit.py +0 -0
  35. {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/models/maxvit.py +0 -0
  36. {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/models/mobilenetv3.py +0 -0
  37. {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/models/registry.py +0 -0
  38. {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/models/regnet.py +0 -0
  39. {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/models/resnet.py +0 -0
  40. {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/models/swin.py +0 -0
  41. {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/models/tcn.py +0 -0
  42. {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/models/unet.py +0 -0
  43. {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/utils/constraints.py +0 -0
  44. {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/utils/distributed.py +0 -0
  45. {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/utils/losses.py +0 -0
  46. {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl/utils/optimizers.py +0 -0
  47. {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl.egg-info/SOURCES.txt +0 -0
  48. {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl.egg-info/dependency_links.txt +0 -0
  49. {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl.egg-info/entry_points.txt +0 -0
  50. {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl.egg-info/requires.txt +0 -0
  51. {wavedl-1.6.2 → wavedl-1.7.0}/src/wavedl.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: wavedl
3
- Version: 1.6.2
3
+ Version: 1.7.0
4
4
  Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
5
5
  Author: Ductho Le
6
6
  License: MIT
@@ -214,11 +214,11 @@ This installs everything you need: training, inference, HPO, ONNX export.
214
214
  ```bash
215
215
  git clone https://github.com/ductho-le/WaveDL.git
216
216
  cd WaveDL
217
- pip install -e .
217
+ pip install -e ".[dev]"
218
218
  ```
219
219
 
220
220
  > [!NOTE]
221
- > Python 3.11+ required. For development setup, see [CONTRIBUTING.md](.github/CONTRIBUTING.md).
221
+ > Python 3.11+ required. For contributor setup (pre-commit hooks), see [CONTRIBUTING.md](.github/CONTRIBUTING.md).
222
222
 
223
223
  ### Quick Start
224
224
 
@@ -273,8 +273,6 @@ accelerate launch --num_machines 2 --main_process_ip <ip> -m wavedl.train --mode
273
273
 
274
274
  ### Testing & Inference
275
275
 
276
- After training, use `wavedl-test` to evaluate your model on test data:
277
-
278
276
  ```bash
279
277
  # Basic inference
280
278
  wavedl-test --checkpoint <checkpoint_folder> --data_path <test_data>
@@ -909,18 +907,24 @@ Automatically find the best training configuration using [Optuna](https://optuna
909
907
  **Run HPO:**
910
908
 
911
909
  ```bash
912
- # Basic HPO (auto-detects GPUs for parallel trials)
913
- wavedl-hpo --data_path train.npz --models cnn --n_trials 100
910
+ # Basic HPO (50 trials, auto-detects GPUs)
911
+ wavedl-hpo --data_path train.npz --n_trials 50
912
+
913
+ # Quick search (minimal search space, fastest)
914
+ wavedl-hpo --data_path train.npz --n_trials 30 --quick
914
915
 
915
- # Search multiple models
916
- wavedl-hpo --data_path train.npz --models cnn resnet18 efficientnet_b0 --n_trials 200
916
+ # Medium search (balanced between quick and full)
917
+ wavedl-hpo --data_path train.npz --n_trials 50 --medium
917
918
 
918
- # Quick mode (fewer parameters, faster)
919
- wavedl-hpo --data_path train.npz --models cnn --n_trials 50 --quick
919
+ # Full search with specific models
920
+ wavedl-hpo --data_path train.npz --n_trials 100 --models cnn resnet18 efficientnet_b0
921
+
922
+ # In-process mode (enables pruning, faster, single-GPU)
923
+ wavedl-hpo --data_path train.npz --n_trials 50 --inprocess
920
924
  ```
921
925
 
922
926
  > [!TIP]
923
- > **Auto GPU Detection**: HPO automatically detects available GPUs and runs one trial per GPU in parallel. On a 4-GPU system, 4 trials run simultaneously. Use `--n_jobs 1` to force serial execution.
927
+ > **GPU Detection**: HPO auto-detects GPUs and runs one trial per GPU in parallel. Use `--inprocess` for single-GPU with pruning support (early stopping of bad trials).
924
928
 
925
929
  **Train with best parameters**
926
930
 
@@ -942,10 +946,23 @@ wavedl-train --data_path train.npz --model cnn --lr 3.2e-4 --batch_size 128 ...
942
946
  | Learning rate | 1e-5 → 1e-2 | (always searched) |
943
947
  | Batch size | 16, 32, 64, 128 | (always searched) |
944
948
 
945
- **Quick Mode** (`--quick`):
946
- - Uses minimal defaults: cnn + adamw + plateau + mse
947
- - Faster for testing your setup before running full search
948
- - You can still override any option with the flags above
949
+ **Search Presets:**
950
+
951
+ | Mode | Models | Optimizers | Schedulers | Use Case |
952
+ |------|--------|------------|------------|----------|
953
+ | Full (default) | cnn, resnet18, resnet34 | all 6 | all 8 | Production search |
954
+ | `--medium` | cnn, resnet18 | adamw, adam, sgd | plateau, cosine, onecycle | Balanced exploration |
955
+ | `--quick` | cnn | adamw | plateau | Fast validation |
956
+
957
+ **Execution Modes:**
958
+
959
+ | Mode | Flag | Pruning | GPU Memory | Best For |
960
+ |------|------|---------|------------|----------|
961
+ | Subprocess (default) | — | ❌ No | Isolated | Multi-GPU parallel trials |
962
+ | In-process | `--inprocess` | ✅ Yes | Shared | Single-GPU with early stopping |
963
+
964
+ > [!TIP]
965
+ > Use `--inprocess` when running single-GPU trials. It enables MedianPruner to stop unpromising trials early, reducing total search time.
949
966
 
950
967
  ---
951
968
 
@@ -956,7 +973,9 @@ wavedl-train --data_path train.npz --model cnn --lr 3.2e-4 --batch_size 128 ...
956
973
  | `--data_path` | (required) | Training data file |
957
974
  | `--models` | 3 defaults | Models to search (specify any number) |
958
975
  | `--n_trials` | `50` | Number of trials to run |
959
- | `--quick` | `False` | Use minimal defaults (faster) |
976
+ | `--quick` | `False` | Quick mode: minimal search space |
977
+ | `--medium` | `False` | Medium mode: balanced search space |
978
+ | `--inprocess` | `False` | Run trials in-process (enables pruning) |
960
979
  | `--optimizers` | all 6 | Optimizers to search |
961
980
  | `--schedulers` | all 8 | Schedulers to search |
962
981
  | `--losses` | all 6 | Losses to search |
@@ -965,7 +984,7 @@ wavedl-train --data_path train.npz --model cnn --lr 3.2e-4 --batch_size 128 ...
965
984
  | `--output` | `hpo_results.json` | Output file |
966
985
 
967
986
 
968
- > See [Available Models](#available-models) for all 38 architectures you can search.
987
+ > See [Available Models](#available-models) for all 69 architectures you can search.
969
988
 
970
989
  </details>
971
990
 
@@ -166,11 +166,11 @@ This installs everything you need: training, inference, HPO, ONNX export.
166
166
  ```bash
167
167
  git clone https://github.com/ductho-le/WaveDL.git
168
168
  cd WaveDL
169
- pip install -e .
169
+ pip install -e ".[dev]"
170
170
  ```
171
171
 
172
172
  > [!NOTE]
173
- > Python 3.11+ required. For development setup, see [CONTRIBUTING.md](.github/CONTRIBUTING.md).
173
+ > Python 3.11+ required. For contributor setup (pre-commit hooks), see [CONTRIBUTING.md](.github/CONTRIBUTING.md).
174
174
 
175
175
  ### Quick Start
176
176
 
@@ -225,8 +225,6 @@ accelerate launch --num_machines 2 --main_process_ip <ip> -m wavedl.train --mode
225
225
 
226
226
  ### Testing & Inference
227
227
 
228
- After training, use `wavedl-test` to evaluate your model on test data:
229
-
230
228
  ```bash
231
229
  # Basic inference
232
230
  wavedl-test --checkpoint <checkpoint_folder> --data_path <test_data>
@@ -861,18 +859,24 @@ Automatically find the best training configuration using [Optuna](https://optuna
861
859
  **Run HPO:**
862
860
 
863
861
  ```bash
864
- # Basic HPO (auto-detects GPUs for parallel trials)
865
- wavedl-hpo --data_path train.npz --models cnn --n_trials 100
862
+ # Basic HPO (50 trials, auto-detects GPUs)
863
+ wavedl-hpo --data_path train.npz --n_trials 50
864
+
865
+ # Quick search (minimal search space, fastest)
866
+ wavedl-hpo --data_path train.npz --n_trials 30 --quick
866
867
 
867
- # Search multiple models
868
- wavedl-hpo --data_path train.npz --models cnn resnet18 efficientnet_b0 --n_trials 200
868
+ # Medium search (balanced between quick and full)
869
+ wavedl-hpo --data_path train.npz --n_trials 50 --medium
869
870
 
870
- # Quick mode (fewer parameters, faster)
871
- wavedl-hpo --data_path train.npz --models cnn --n_trials 50 --quick
871
+ # Full search with specific models
872
+ wavedl-hpo --data_path train.npz --n_trials 100 --models cnn resnet18 efficientnet_b0
873
+
874
+ # In-process mode (enables pruning, faster, single-GPU)
875
+ wavedl-hpo --data_path train.npz --n_trials 50 --inprocess
872
876
  ```
873
877
 
874
878
  > [!TIP]
875
- > **Auto GPU Detection**: HPO automatically detects available GPUs and runs one trial per GPU in parallel. On a 4-GPU system, 4 trials run simultaneously. Use `--n_jobs 1` to force serial execution.
879
+ > **GPU Detection**: HPO auto-detects GPUs and runs one trial per GPU in parallel. Use `--inprocess` for single-GPU with pruning support (early stopping of bad trials).
876
880
 
877
881
  **Train with best parameters**
878
882
 
@@ -894,10 +898,23 @@ wavedl-train --data_path train.npz --model cnn --lr 3.2e-4 --batch_size 128 ...
894
898
  | Learning rate | 1e-5 → 1e-2 | (always searched) |
895
899
  | Batch size | 16, 32, 64, 128 | (always searched) |
896
900
 
897
- **Quick Mode** (`--quick`):
898
- - Uses minimal defaults: cnn + adamw + plateau + mse
899
- - Faster for testing your setup before running full search
900
- - You can still override any option with the flags above
901
+ **Search Presets:**
902
+
903
+ | Mode | Models | Optimizers | Schedulers | Use Case |
904
+ |------|--------|------------|------------|----------|
905
+ | Full (default) | cnn, resnet18, resnet34 | all 6 | all 8 | Production search |
906
+ | `--medium` | cnn, resnet18 | adamw, adam, sgd | plateau, cosine, onecycle | Balanced exploration |
907
+ | `--quick` | cnn | adamw | plateau | Fast validation |
908
+
909
+ **Execution Modes:**
910
+
911
+ | Mode | Flag | Pruning | GPU Memory | Best For |
912
+ |------|------|---------|------------|----------|
913
+ | Subprocess (default) | — | ❌ No | Isolated | Multi-GPU parallel trials |
914
+ | In-process | `--inprocess` | ✅ Yes | Shared | Single-GPU with early stopping |
915
+
916
+ > [!TIP]
917
+ > Use `--inprocess` when running single-GPU trials. It enables MedianPruner to stop unpromising trials early, reducing total search time.
901
918
 
902
919
  ---
903
920
 
@@ -908,7 +925,9 @@ wavedl-train --data_path train.npz --model cnn --lr 3.2e-4 --batch_size 128 ...
908
925
  | `--data_path` | (required) | Training data file |
909
926
  | `--models` | 3 defaults | Models to search (specify any number) |
910
927
  | `--n_trials` | `50` | Number of trials to run |
911
- | `--quick` | `False` | Use minimal defaults (faster) |
928
+ | `--quick` | `False` | Quick mode: minimal search space |
929
+ | `--medium` | `False` | Medium mode: balanced search space |
930
+ | `--inprocess` | `False` | Run trials in-process (enables pruning) |
912
931
  | `--optimizers` | all 6 | Optimizers to search |
913
932
  | `--schedulers` | all 8 | Schedulers to search |
914
933
  | `--losses` | all 6 | Losses to search |
@@ -917,7 +936,7 @@ wavedl-train --data_path train.npz --model cnn --lr 3.2e-4 --batch_size 128 ...
917
936
  | `--output` | `hpo_results.json` | Output file |
918
937
 
919
938
 
920
- > See [Available Models](#available-models) for all 38 architectures you can search.
939
+ > See [Available Models](#available-models) for all 69 architectures you can search.
921
940
 
922
941
  </details>
923
942
 
@@ -18,7 +18,7 @@ For inference:
18
18
  # or: python -m wavedl.test --checkpoint best_checkpoint --data_path test.npz
19
19
  """
20
20
 
21
- __version__ = "1.6.2"
21
+ __version__ = "1.7.0"
22
22
  __author__ = "Ductho Le"
23
23
  __email__ = "ductho.le@outlook.com"
24
24
 
@@ -10,12 +10,28 @@ Usage:
10
10
  # Quick search (fewer parameters)
11
11
  wavedl-hpo --data_path train.npz --n_trials 30 --quick
12
12
 
13
+ # Medium search (balanced)
14
+ wavedl-hpo --data_path train.npz --n_trials 50 --medium
15
+
13
16
  # Full search with specific models
14
17
  wavedl-hpo --data_path train.npz --n_trials 100 --models cnn resnet18 efficientnet_b0
15
18
 
16
19
  # Parallel trials on multiple GPUs
17
20
  wavedl-hpo --data_path train.npz --n_trials 100 --n_jobs 4
18
21
 
22
+ # In-process mode (enables pruning, faster, single-GPU)
23
+ wavedl-hpo --data_path train.npz --n_trials 50 --inprocess
24
+
25
+ Execution Modes:
26
+ --inprocess: Runs trials in the same Python process. Enables pruning
27
+ (MedianPruner) for early stopping of unpromising trials.
28
+ Faster due to no subprocess overhead, but trials share
29
+ GPU memory (no isolation between trials).
30
+
31
+ Default (subprocess): Launches each trial as a separate process.
32
+ Provides GPU memory isolation but prevents pruning
33
+ (subprocess can't report intermediate results).
34
+
19
35
  Author: Ductho Le (ductho.le@outlook.com)
20
36
  """
21
37
 
@@ -41,10 +57,12 @@ except ImportError:
41
57
 
42
58
  DEFAULT_MODELS = ["cnn", "resnet18", "resnet34"]
43
59
  QUICK_MODELS = ["cnn"]
60
+ MEDIUM_MODELS = ["cnn", "resnet18"]
44
61
 
45
62
  # All 6 optimizers
46
63
  DEFAULT_OPTIMIZERS = ["adamw", "adam", "sgd", "nadam", "radam", "rmsprop"]
47
64
  QUICK_OPTIMIZERS = ["adamw"]
65
+ MEDIUM_OPTIMIZERS = ["adamw", "adam", "sgd"]
48
66
 
49
67
  # All 8 schedulers
50
68
  DEFAULT_SCHEDULERS = [
@@ -58,10 +76,12 @@ DEFAULT_SCHEDULERS = [
58
76
  "linear_warmup",
59
77
  ]
60
78
  QUICK_SCHEDULERS = ["plateau"]
79
+ MEDIUM_SCHEDULERS = ["plateau", "cosine", "onecycle"]
61
80
 
62
81
  # All 6 losses
63
82
  DEFAULT_LOSSES = ["mse", "mae", "huber", "smooth_l1", "log_cosh", "weighted_mse"]
64
83
  QUICK_LOSSES = ["mse"]
84
+ MEDIUM_LOSSES = ["mse", "mae", "huber"]
65
85
 
66
86
 
67
87
  # =============================================================================
@@ -70,16 +90,28 @@ QUICK_LOSSES = ["mse"]
70
90
 
71
91
 
72
92
  def create_objective(args):
73
- """Create Optuna objective function with configurable search space."""
93
+ """Create Optuna objective function with configurable search space.
94
+
95
+ Supports two execution modes:
96
+ - Subprocess (default): Launches wavedl.train via subprocess. Provides GPU
97
+ memory isolation but prevents pruning (MedianPruner has no effect).
98
+ - In-process (--inprocess): Calls train_single_trial() directly. Enables
99
+ pruning and reduces overhead, but trials share GPU memory.
100
+ """
74
101
 
75
102
  def objective(trial):
76
- # Select search space based on mode
103
+ # Select search space based on mode (quick < medium < full)
77
104
  # CLI arguments always take precedence over defaults
78
105
  if args.quick:
79
106
  models = args.models or QUICK_MODELS
80
107
  optimizers = args.optimizers or QUICK_OPTIMIZERS
81
108
  schedulers = args.schedulers or QUICK_SCHEDULERS
82
109
  losses = args.losses or QUICK_LOSSES
110
+ elif args.medium:
111
+ models = args.models or MEDIUM_MODELS
112
+ optimizers = args.optimizers or MEDIUM_OPTIMIZERS
113
+ schedulers = args.schedulers or MEDIUM_SCHEDULERS
114
+ losses = args.losses or MEDIUM_LOSSES
83
115
  else:
84
116
  models = args.models or DEFAULT_MODELS
85
117
  optimizers = args.optimizers or DEFAULT_OPTIMIZERS
@@ -101,13 +133,59 @@ def create_objective(args):
101
133
  if loss == "huber":
102
134
  huber_delta = trial.suggest_float("huber_delta", 0.1, 2.0)
103
135
  else:
104
- huber_delta = None
136
+ huber_delta = 1.0 # default
105
137
 
106
138
  if optimizer == "sgd":
107
139
  momentum = trial.suggest_float("momentum", 0.8, 0.99)
108
140
  else:
109
- momentum = None
141
+ momentum = 0.9 # default
142
+
143
+ # ==================================================================
144
+ # IN-PROCESS MODE: Direct function call with pruning support
145
+ # ==================================================================
146
+ if args.inprocess:
147
+ from wavedl.train import train_single_trial
148
+
149
+ try:
150
+ result = train_single_trial(
151
+ data_path=args.data_path,
152
+ model_name=model,
153
+ lr=lr,
154
+ batch_size=batch_size,
155
+ epochs=args.max_epochs,
156
+ patience=patience,
157
+ optimizer_name=optimizer,
158
+ scheduler_name=scheduler,
159
+ loss_name=loss,
160
+ weight_decay=weight_decay,
161
+ seed=args.seed,
162
+ huber_delta=huber_delta,
163
+ momentum=momentum,
164
+ trial=trial, # Enable pruning via trial.report/should_prune
165
+ verbose=False,
166
+ )
167
+
168
+ if result["pruned"]:
169
+ print(
170
+ f"Trial {trial.number}: Pruned at epoch {result['epochs_trained']}"
171
+ )
172
+ raise optuna.TrialPruned()
173
+
174
+ val_loss = result["best_val_loss"]
175
+ print(
176
+ f"Trial {trial.number}: val_loss={val_loss:.6f} ({result['epochs_trained']} epochs)"
177
+ )
178
+ return val_loss
110
179
 
180
+ except optuna.TrialPruned:
181
+ raise # Re-raise for Optuna to handle
182
+ except Exception as e:
183
+ print(f"Trial {trial.number}: Error - {e}")
184
+ return float("inf")
185
+
186
+ # ==================================================================
187
+ # SUBPROCESS MODE (default): GPU memory isolation, no pruning
188
+ # ==================================================================
111
189
  # Build command
112
190
  cmd = [
113
191
  sys.executable,
@@ -138,9 +216,9 @@ def create_objective(args):
138
216
  ]
139
217
 
140
218
  # Add conditional args
141
- if huber_delta:
219
+ if loss == "huber":
142
220
  cmd.extend(["--huber_delta", str(huber_delta)])
143
- if momentum:
221
+ if optimizer == "sgd":
144
222
  cmd.extend(["--momentum", str(momentum)])
145
223
 
146
224
  # Use temporary directory for trial output
@@ -285,7 +363,17 @@ Examples:
285
363
  parser.add_argument(
286
364
  "--quick",
287
365
  action="store_true",
288
- help="Quick mode: search fewer parameters",
366
+ help="Quick mode: search fewer parameters (fastest, least thorough)",
367
+ )
368
+ parser.add_argument(
369
+ "--medium",
370
+ action="store_true",
371
+ help="Medium mode: balanced parameter search (between --quick and full)",
372
+ )
373
+ parser.add_argument(
374
+ "--inprocess",
375
+ action="store_true",
376
+ help="Run trials in-process (enables pruning, faster, but no GPU memory isolation)",
289
377
  )
290
378
  parser.add_argument(
291
379
  "--timeout",
@@ -384,14 +472,32 @@ Examples:
384
472
  print("=" * 60)
385
473
  print(f"Data: {args.data_path}")
386
474
  print(f"Trials: {args.n_trials}")
387
- print(f"Mode: {'Quick' if args.quick else 'Full'}")
475
+ # Determine mode name for display
476
+ if args.quick:
477
+ mode_name = "Quick"
478
+ elif args.medium:
479
+ mode_name = "Medium"
480
+ else:
481
+ mode_name = "Full"
482
+
483
+ print(
484
+ f"Mode: {mode_name}"
485
+ + (" (in-process, pruning enabled)" if args.inprocess else " (subprocess)")
486
+ )
388
487
  print(f"Parallel jobs: {args.n_jobs}")
389
488
  print("=" * 60)
390
489
 
490
+ # Use MedianPruner only for in-process mode (subprocess trials can't report)
491
+ if args.inprocess:
492
+ pruner = optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=10)
493
+ else:
494
+ # NopPruner for subprocess mode - pruning has no effect there
495
+ pruner = optuna.pruners.NopPruner()
496
+
391
497
  study = optuna.create_study(
392
498
  study_name=args.study_name,
393
499
  direction="minimize",
394
- pruner=optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=10),
500
+ pruner=pruner,
395
501
  )
396
502
 
397
503
  # Run optimization
@@ -77,6 +77,15 @@ from .unet import UNetRegression
77
77
  from .vit import ViTBase_, ViTSmall, ViTTiny
78
78
 
79
79
 
80
+ # Optional RATENet (unpublished, may be gitignored)
81
+ try:
82
+ from .ratenet import RATENet, RATENetLite, RATENetTiny, RATENetV2
83
+
84
+ _HAS_RATENET = True
85
+ except ImportError:
86
+ _HAS_RATENET = False
87
+
88
+
80
89
  # Optional timm-based models (imported conditionally)
81
90
  try:
82
91
  from .caformer import CaFormerS18, CaFormerS36, PoolFormerS12
@@ -111,6 +120,7 @@ __all__ = [
111
120
  "MC3_18",
112
121
  "MODEL_REGISTRY",
113
122
  "TCN",
123
+ # Classes (uppercase first, alphabetically)
114
124
  "BaseModel",
115
125
  "ConvNeXtBase_",
116
126
  "ConvNeXtSmall",
@@ -152,6 +162,7 @@ __all__ = [
152
162
  "VimBase",
153
163
  "VimSmall",
154
164
  "VimTiny",
165
+ # Functions (lowercase, alphabetically)
155
166
  "build_model",
156
167
  "get_model",
157
168
  "list_models",
@@ -186,3 +197,14 @@ if _HAS_TIMM_MODELS:
186
197
  "UniRepLKNetTiny",
187
198
  ]
188
199
  )
200
+
201
+ # Add RATENet models to __all__ if available (unpublished)
202
+ if _HAS_RATENET:
203
+ __all__.extend(
204
+ [
205
+ "RATENet",
206
+ "RATENetLite",
207
+ "RATENetTiny",
208
+ "RATENetV2",
209
+ ]
210
+ )
@@ -166,6 +166,78 @@ class LayerNormNd(nn.Module):
166
166
  return x
167
167
 
168
168
 
169
+ # =============================================================================
170
+ # STOCHASTIC DEPTH (DropPath)
171
+ # =============================================================================
172
+
173
+
174
+ class DropPath(nn.Module):
175
+ """
176
+ Stochastic Depth (drop path) regularization for residual networks.
177
+
178
+ Randomly drops entire residual branches during training. Used in modern
179
+ architectures like ConvNeXt, Swin Transformer, UniRepLKNet.
180
+
181
+ Args:
182
+ drop_prob: Probability of dropping the path (default: 0.0)
183
+
184
+ Reference:
185
+ Huang, G., et al. (2016). Deep Networks with Stochastic Depth.
186
+ https://arxiv.org/abs/1603.09382
187
+ """
188
+
189
+ def __init__(self, drop_prob: float = 0.0):
190
+ super().__init__()
191
+ self.drop_prob = drop_prob
192
+
193
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
194
+ if self.drop_prob == 0.0 or not self.training:
195
+ return x
196
+
197
+ keep_prob = 1 - self.drop_prob
198
+ # Shape: (batch_size, 1, 1, ...) for broadcasting
199
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
200
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
201
+ random_tensor.floor_() # Binarize
202
+ return x.div(keep_prob) * random_tensor
203
+
204
+
205
+ # =============================================================================
206
+ # BACKBONE FREEZING UTILITIES
207
+ # =============================================================================
208
+
209
+
210
+ def freeze_backbone(
211
+ model: nn.Module,
212
+ exclude_patterns: list[str] | None = None,
213
+ ) -> int:
214
+ """
215
+ Freeze backbone parameters, keeping specified layers trainable.
216
+
217
+ Args:
218
+ model: The model whose parameters to freeze
219
+ exclude_patterns: List of patterns to exclude from freezing.
220
+ Parameters with names containing any of these patterns stay trainable.
221
+ Default: ["classifier", "head", "fc"]
222
+
223
+ Returns:
224
+ Number of parameters frozen
225
+
226
+ Example:
227
+ >>> freeze_backbone(model.backbone, exclude_patterns=["fc", "classifier"])
228
+ """
229
+ if exclude_patterns is None:
230
+ exclude_patterns = ["classifier", "head", "fc"]
231
+
232
+ frozen_count = 0
233
+ for name, param in model.named_parameters():
234
+ if not any(pattern in name for pattern in exclude_patterns):
235
+ param.requires_grad = False
236
+ frozen_count += param.numel()
237
+
238
+ return frozen_count
239
+
240
+
169
241
  # =============================================================================
170
242
  # REGRESSION HEAD BUILDERS
171
243
  # =============================================================================
@@ -31,22 +31,23 @@ from wavedl.models.base import BaseModel
31
31
  # @register_model("my_model")
32
32
  class TemplateModel(BaseModel):
33
33
  """
34
- Template Model Architecture.
34
+ Template Model Architecture (2D only).
35
35
 
36
36
  Replace this docstring with your model description.
37
37
  The first line will appear in --list_models output.
38
38
 
39
+ NOTE: This template is hardcoded for 2D inputs using Conv2d/MaxPool2d.
40
+ For 1D/3D support, use dimension-agnostic layer factories from
41
+ _pretrained_utils.py (get_conv_layer, get_pool_layer, get_norm_layer).
42
+
39
43
  Args:
40
- in_shape: Input spatial dimensions (auto-detected from data)
41
- - 1D: (L,) for signals
42
- - 2D: (H, W) for images
43
- - 3D: (D, H, W) for volumes
44
+ in_shape: Input spatial dimensions as (H, W) for 2D images
44
45
  out_size: Number of regression targets (auto-detected from data)
45
46
  hidden_dim: Size of hidden layers (default: 256)
46
47
  dropout: Dropout rate (default: 0.1)
47
48
 
48
49
  Input Shape:
49
- (B, 1, *in_shape) - e.g., (B, 1, 64, 64) for 2D
50
+ (B, 1, H, W) - 2D grayscale images
50
51
 
51
52
  Output Shape:
52
53
  (B, out_size) - Regression predictions
@@ -159,6 +159,26 @@ class CNN(BaseModel):
159
159
  nn.Linear(64, out_size),
160
160
  )
161
161
 
162
+ # Initialize weights
163
+ self._init_weights()
164
+
165
+ def _init_weights(self):
166
+ """Initialize weights with Kaiming for conv, Xavier for linear."""
167
+ for m in self.modules():
168
+ if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
169
+ nn.init.kaiming_normal_(
170
+ m.weight, mode="fan_out", nonlinearity="leaky_relu"
171
+ )
172
+ if m.bias is not None:
173
+ nn.init.zeros_(m.bias)
174
+ elif isinstance(m, nn.Linear):
175
+ nn.init.xavier_uniform_(m.weight)
176
+ if m.bias is not None:
177
+ nn.init.zeros_(m.bias)
178
+ elif isinstance(m, (nn.GroupNorm, nn.LayerNorm)):
179
+ nn.init.ones_(m.weight)
180
+ nn.init.zeros_(m.bias)
181
+
162
182
  def _make_conv_block(
163
183
  self, in_channels: int, out_channels: int, dropout: float = 0.0
164
184
  ) -> nn.Sequential: