wavedl 1.6.0__py3-none-any.whl → 1.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/hpo.py +451 -451
- wavedl/models/__init__.py +28 -0
- wavedl/models/{_timm_utils.py → _pretrained_utils.py} +128 -0
- wavedl/models/base.py +48 -0
- wavedl/models/caformer.py +1 -1
- wavedl/models/cnn.py +2 -27
- wavedl/models/convnext.py +5 -18
- wavedl/models/convnext_v2.py +6 -22
- wavedl/models/densenet.py +5 -18
- wavedl/models/efficientnetv2.py +315 -315
- wavedl/models/efficientvit.py +398 -0
- wavedl/models/fastvit.py +6 -39
- wavedl/models/mamba.py +44 -24
- wavedl/models/maxvit.py +51 -48
- wavedl/models/mobilenetv3.py +295 -295
- wavedl/models/regnet.py +406 -406
- wavedl/models/resnet.py +14 -56
- wavedl/models/resnet3d.py +258 -258
- wavedl/models/swin.py +443 -443
- wavedl/models/tcn.py +393 -409
- wavedl/models/unet.py +1 -5
- wavedl/models/unireplknet.py +491 -0
- wavedl/models/vit.py +3 -3
- wavedl/train.py +1430 -1430
- wavedl/utils/config.py +367 -367
- wavedl/utils/cross_validation.py +530 -530
- wavedl/utils/losses.py +216 -216
- wavedl/utils/optimizers.py +216 -216
- wavedl/utils/schedulers.py +251 -251
- {wavedl-1.6.0.dist-info → wavedl-1.6.1.dist-info}/METADATA +93 -53
- wavedl-1.6.1.dist-info/RECORD +46 -0
- wavedl-1.6.0.dist-info/RECORD +0 -44
- {wavedl-1.6.0.dist-info → wavedl-1.6.1.dist-info}/LICENSE +0 -0
- {wavedl-1.6.0.dist-info → wavedl-1.6.1.dist-info}/WHEEL +0 -0
- {wavedl-1.6.0.dist-info → wavedl-1.6.1.dist-info}/entry_points.txt +0 -0
- {wavedl-1.6.0.dist-info → wavedl-1.6.1.dist-info}/top_level.txt +0 -0
wavedl/hpo.py
CHANGED
|
@@ -1,451 +1,451 @@
|
|
|
1
|
-
"""
|
|
2
|
-
WaveDL - Hyperparameter Optimization with Optuna
|
|
3
|
-
=================================================
|
|
4
|
-
Automated hyperparameter search for finding optimal training configurations.
|
|
5
|
-
|
|
6
|
-
Usage:
|
|
7
|
-
# Basic HPO (50 trials)
|
|
8
|
-
wavedl-hpo --data_path train.npz --n_trials 50
|
|
9
|
-
|
|
10
|
-
# Quick search (fewer parameters)
|
|
11
|
-
wavedl-hpo --data_path train.npz --n_trials 30 --quick
|
|
12
|
-
|
|
13
|
-
# Full search with specific models
|
|
14
|
-
wavedl-hpo --data_path train.npz --n_trials 100 --models cnn resnet18 efficientnet_b0
|
|
15
|
-
|
|
16
|
-
# Parallel trials on multiple GPUs
|
|
17
|
-
wavedl-hpo --data_path train.npz --n_trials 100 --n_jobs 4
|
|
18
|
-
|
|
19
|
-
Author: Ductho Le (ductho.le@outlook.com)
|
|
20
|
-
"""
|
|
21
|
-
|
|
22
|
-
import argparse
|
|
23
|
-
import json
|
|
24
|
-
import subprocess
|
|
25
|
-
import sys
|
|
26
|
-
import tempfile
|
|
27
|
-
from pathlib import Path
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
try:
|
|
31
|
-
import optuna
|
|
32
|
-
from optuna.trial import TrialState
|
|
33
|
-
except ImportError:
|
|
34
|
-
print("Error: Optuna not installed. Run: pip install wavedl")
|
|
35
|
-
sys.exit(1)
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
# =============================================================================
|
|
39
|
-
# DEFAULT SEARCH SPACES
|
|
40
|
-
# =============================================================================
|
|
41
|
-
|
|
42
|
-
DEFAULT_MODELS = ["cnn", "resnet18", "resnet34"]
|
|
43
|
-
QUICK_MODELS = ["cnn"]
|
|
44
|
-
|
|
45
|
-
# All 6 optimizers
|
|
46
|
-
DEFAULT_OPTIMIZERS = ["adamw", "adam", "sgd", "nadam", "radam", "rmsprop"]
|
|
47
|
-
QUICK_OPTIMIZERS = ["adamw"]
|
|
48
|
-
|
|
49
|
-
# All 8 schedulers
|
|
50
|
-
DEFAULT_SCHEDULERS = [
|
|
51
|
-
"plateau",
|
|
52
|
-
"cosine",
|
|
53
|
-
"cosine_restarts",
|
|
54
|
-
"onecycle",
|
|
55
|
-
"step",
|
|
56
|
-
"multistep",
|
|
57
|
-
"exponential",
|
|
58
|
-
"linear_warmup",
|
|
59
|
-
]
|
|
60
|
-
QUICK_SCHEDULERS = ["plateau"]
|
|
61
|
-
|
|
62
|
-
# All 6 losses
|
|
63
|
-
DEFAULT_LOSSES = ["mse", "mae", "huber", "smooth_l1", "log_cosh", "weighted_mse"]
|
|
64
|
-
QUICK_LOSSES = ["mse"]
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
# =============================================================================
|
|
68
|
-
# OBJECTIVE FUNCTION
|
|
69
|
-
# =============================================================================
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
def create_objective(args):
|
|
73
|
-
"""Create Optuna objective function with configurable search space."""
|
|
74
|
-
|
|
75
|
-
def objective(trial):
|
|
76
|
-
# Select search space based on mode
|
|
77
|
-
# CLI arguments always take precedence over defaults
|
|
78
|
-
if args.quick:
|
|
79
|
-
models = args.models or QUICK_MODELS
|
|
80
|
-
optimizers = args.optimizers or QUICK_OPTIMIZERS
|
|
81
|
-
schedulers = args.schedulers or QUICK_SCHEDULERS
|
|
82
|
-
losses = args.losses or QUICK_LOSSES
|
|
83
|
-
else:
|
|
84
|
-
models = args.models or DEFAULT_MODELS
|
|
85
|
-
optimizers = args.optimizers or DEFAULT_OPTIMIZERS
|
|
86
|
-
schedulers = args.schedulers or DEFAULT_SCHEDULERS
|
|
87
|
-
losses = args.losses or DEFAULT_LOSSES
|
|
88
|
-
|
|
89
|
-
# Suggest hyperparameters
|
|
90
|
-
model = trial.suggest_categorical("model", models)
|
|
91
|
-
lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True)
|
|
92
|
-
batch_sizes = args.batch_sizes or [16, 32, 64, 128]
|
|
93
|
-
batch_size = trial.suggest_categorical("batch_size", batch_sizes)
|
|
94
|
-
optimizer = trial.suggest_categorical("optimizer", optimizers)
|
|
95
|
-
scheduler = trial.suggest_categorical("scheduler", schedulers)
|
|
96
|
-
loss = trial.suggest_categorical("loss", losses)
|
|
97
|
-
weight_decay = trial.suggest_float("weight_decay", 1e-6, 1e-3, log=True)
|
|
98
|
-
patience = trial.suggest_int("patience", 10, 30, step=5)
|
|
99
|
-
|
|
100
|
-
# Conditional hyperparameters
|
|
101
|
-
if loss == "huber":
|
|
102
|
-
huber_delta = trial.suggest_float("huber_delta", 0.1, 2.0)
|
|
103
|
-
else:
|
|
104
|
-
huber_delta = None
|
|
105
|
-
|
|
106
|
-
if optimizer == "sgd":
|
|
107
|
-
momentum = trial.suggest_float("momentum", 0.8, 0.99)
|
|
108
|
-
else:
|
|
109
|
-
momentum = None
|
|
110
|
-
|
|
111
|
-
# Build command
|
|
112
|
-
cmd = [
|
|
113
|
-
sys.executable,
|
|
114
|
-
"-m",
|
|
115
|
-
"wavedl.train",
|
|
116
|
-
"--data_path",
|
|
117
|
-
str(args.data_path),
|
|
118
|
-
"--model",
|
|
119
|
-
model,
|
|
120
|
-
"--lr",
|
|
121
|
-
str(lr),
|
|
122
|
-
"--batch_size",
|
|
123
|
-
str(batch_size),
|
|
124
|
-
"--optimizer",
|
|
125
|
-
optimizer,
|
|
126
|
-
"--scheduler",
|
|
127
|
-
scheduler,
|
|
128
|
-
"--loss",
|
|
129
|
-
loss,
|
|
130
|
-
"--weight_decay",
|
|
131
|
-
str(weight_decay),
|
|
132
|
-
"--patience",
|
|
133
|
-
str(patience),
|
|
134
|
-
"--epochs",
|
|
135
|
-
str(args.max_epochs),
|
|
136
|
-
"--seed",
|
|
137
|
-
str(args.seed),
|
|
138
|
-
]
|
|
139
|
-
|
|
140
|
-
# Add conditional args
|
|
141
|
-
if huber_delta:
|
|
142
|
-
cmd.extend(["--huber_delta", str(huber_delta)])
|
|
143
|
-
if momentum:
|
|
144
|
-
cmd.extend(["--momentum", str(momentum)])
|
|
145
|
-
|
|
146
|
-
# Use temporary directory for trial output
|
|
147
|
-
with tempfile.TemporaryDirectory() as tmpdir:
|
|
148
|
-
cmd.extend(["--output_dir", tmpdir])
|
|
149
|
-
history_file = Path(tmpdir) / "training_history.csv"
|
|
150
|
-
|
|
151
|
-
# GPU isolation for parallel trials: assign each trial to a specific GPU
|
|
152
|
-
# This prevents multiple trials from competing for all GPUs
|
|
153
|
-
env = None
|
|
154
|
-
if args.n_jobs > 1:
|
|
155
|
-
import os
|
|
156
|
-
|
|
157
|
-
# Detect available GPUs
|
|
158
|
-
n_gpus = 1
|
|
159
|
-
try:
|
|
160
|
-
import subprocess as sp
|
|
161
|
-
|
|
162
|
-
result_gpu = sp.run(
|
|
163
|
-
["nvidia-smi", "--list-gpus"],
|
|
164
|
-
capture_output=True,
|
|
165
|
-
text=True,
|
|
166
|
-
)
|
|
167
|
-
if result_gpu.returncode == 0:
|
|
168
|
-
n_gpus = len(result_gpu.stdout.strip().split("\n"))
|
|
169
|
-
except Exception:
|
|
170
|
-
pass
|
|
171
|
-
|
|
172
|
-
# Assign trial to a specific GPU (round-robin)
|
|
173
|
-
gpu_id = trial.number % n_gpus
|
|
174
|
-
env = os.environ.copy()
|
|
175
|
-
env["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
|
|
176
|
-
|
|
177
|
-
# Run training
|
|
178
|
-
# Note: We inherit the user's cwd instead of setting cwd=Path(__file__).parent
|
|
179
|
-
# because site-packages may be read-only and train.py creates cache directories
|
|
180
|
-
try:
|
|
181
|
-
result = subprocess.run(
|
|
182
|
-
cmd,
|
|
183
|
-
capture_output=True,
|
|
184
|
-
text=True,
|
|
185
|
-
timeout=args.timeout,
|
|
186
|
-
env=env,
|
|
187
|
-
)
|
|
188
|
-
|
|
189
|
-
# Read best val_loss from training_history.csv (reliable machine-readable)
|
|
190
|
-
val_loss = None
|
|
191
|
-
if history_file.exists():
|
|
192
|
-
try:
|
|
193
|
-
import csv
|
|
194
|
-
|
|
195
|
-
with open(history_file) as f:
|
|
196
|
-
reader = csv.DictReader(f)
|
|
197
|
-
val_losses = []
|
|
198
|
-
for row in reader:
|
|
199
|
-
if "val_loss" in row:
|
|
200
|
-
try:
|
|
201
|
-
val_losses.append(float(row["val_loss"]))
|
|
202
|
-
except (ValueError, TypeError):
|
|
203
|
-
pass
|
|
204
|
-
if val_losses:
|
|
205
|
-
val_loss = min(val_losses) # Best (minimum) val_loss
|
|
206
|
-
except Exception as e:
|
|
207
|
-
print(f"Trial {trial.number}: Error reading history: {e}")
|
|
208
|
-
|
|
209
|
-
if val_loss is None:
|
|
210
|
-
# Fallback: parse stdout for training log format
|
|
211
|
-
# Pattern: "epoch | train_loss | val_loss | ..."
|
|
212
|
-
# Use regex to avoid false positives from unrelated lines
|
|
213
|
-
import re
|
|
214
|
-
|
|
215
|
-
# Match lines like: " 42 | 0.0123 | 0.0156 | ..."
|
|
216
|
-
log_pattern = re.compile(
|
|
217
|
-
r"^\s*\d+\s*\|\s*[\d.]+\s*\|\s*([\d.]+)\s*\|"
|
|
218
|
-
)
|
|
219
|
-
val_losses_stdout = []
|
|
220
|
-
for line in result.stdout.split("\n"):
|
|
221
|
-
match = log_pattern.match(line)
|
|
222
|
-
if match:
|
|
223
|
-
try:
|
|
224
|
-
val_losses_stdout.append(float(match.group(1)))
|
|
225
|
-
except ValueError:
|
|
226
|
-
continue
|
|
227
|
-
if val_losses_stdout:
|
|
228
|
-
val_loss = min(val_losses_stdout)
|
|
229
|
-
|
|
230
|
-
if val_loss is None:
|
|
231
|
-
# Training failed or no loss found
|
|
232
|
-
print(f"Trial {trial.number}: Training failed (no val_loss found)")
|
|
233
|
-
if result.returncode != 0:
|
|
234
|
-
# Show last few lines of stderr for debugging
|
|
235
|
-
stderr_lines = result.stderr.strip().split("\n")[-3:]
|
|
236
|
-
for line in stderr_lines:
|
|
237
|
-
print(f" stderr: {line}")
|
|
238
|
-
return float("inf")
|
|
239
|
-
|
|
240
|
-
print(f"Trial {trial.number}: val_loss={val_loss:.6f}")
|
|
241
|
-
return val_loss
|
|
242
|
-
|
|
243
|
-
except subprocess.TimeoutExpired:
|
|
244
|
-
print(f"Trial {trial.number}: Timeout after {args.timeout}s")
|
|
245
|
-
return float("inf")
|
|
246
|
-
except Exception as e:
|
|
247
|
-
print(f"Trial {trial.number}: Error - {e}")
|
|
248
|
-
return float("inf")
|
|
249
|
-
|
|
250
|
-
return objective
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
# =============================================================================
|
|
254
|
-
# MAIN FUNCTION
|
|
255
|
-
# =============================================================================
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
def main():
|
|
259
|
-
parser = argparse.ArgumentParser(
|
|
260
|
-
description="WaveDL Hyperparameter Optimization with Optuna",
|
|
261
|
-
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
262
|
-
epilog="""
|
|
263
|
-
Examples:
|
|
264
|
-
wavedl-hpo --data_path train.npz --n_trials 50
|
|
265
|
-
wavedl-hpo --data_path train.npz --n_trials 30 --quick
|
|
266
|
-
wavedl-hpo --data_path train.npz --n_trials 100 --models cnn resnet18
|
|
267
|
-
""",
|
|
268
|
-
)
|
|
269
|
-
|
|
270
|
-
# Required
|
|
271
|
-
parser.add_argument(
|
|
272
|
-
"--data_path", type=str, required=True, help="Path to training data"
|
|
273
|
-
)
|
|
274
|
-
|
|
275
|
-
# HPO settings
|
|
276
|
-
parser.add_argument(
|
|
277
|
-
"--n_trials", type=int, default=50, help="Number of HPO trials (default: 50)"
|
|
278
|
-
)
|
|
279
|
-
parser.add_argument(
|
|
280
|
-
"--n_jobs",
|
|
281
|
-
type=int,
|
|
282
|
-
default=-1,
|
|
283
|
-
help="Parallel trials (-1 = auto-detect GPUs, default: -1)",
|
|
284
|
-
)
|
|
285
|
-
parser.add_argument(
|
|
286
|
-
"--quick",
|
|
287
|
-
action="store_true",
|
|
288
|
-
help="Quick mode: search fewer parameters",
|
|
289
|
-
)
|
|
290
|
-
parser.add_argument(
|
|
291
|
-
"--timeout",
|
|
292
|
-
type=int,
|
|
293
|
-
default=3600,
|
|
294
|
-
help="Timeout per trial in seconds (default: 3600)",
|
|
295
|
-
)
|
|
296
|
-
|
|
297
|
-
# Search space customization
|
|
298
|
-
parser.add_argument(
|
|
299
|
-
"--models",
|
|
300
|
-
nargs="+",
|
|
301
|
-
default=None,
|
|
302
|
-
help=f"Models to search (default: {DEFAULT_MODELS})",
|
|
303
|
-
)
|
|
304
|
-
parser.add_argument(
|
|
305
|
-
"--optimizers",
|
|
306
|
-
nargs="+",
|
|
307
|
-
default=None,
|
|
308
|
-
help=f"Optimizers to search (default: {DEFAULT_OPTIMIZERS})",
|
|
309
|
-
)
|
|
310
|
-
parser.add_argument(
|
|
311
|
-
"--schedulers",
|
|
312
|
-
nargs="+",
|
|
313
|
-
default=None,
|
|
314
|
-
help=f"Schedulers to search (default: {DEFAULT_SCHEDULERS})",
|
|
315
|
-
)
|
|
316
|
-
parser.add_argument(
|
|
317
|
-
"--losses",
|
|
318
|
-
nargs="+",
|
|
319
|
-
default=None,
|
|
320
|
-
help=f"Losses to search (default: {DEFAULT_LOSSES})",
|
|
321
|
-
)
|
|
322
|
-
parser.add_argument(
|
|
323
|
-
"--batch_sizes",
|
|
324
|
-
type=int,
|
|
325
|
-
nargs="+",
|
|
326
|
-
default=None,
|
|
327
|
-
help="Batch sizes to search (default: 16 32 64 128)",
|
|
328
|
-
)
|
|
329
|
-
|
|
330
|
-
# Training settings for each trial
|
|
331
|
-
parser.add_argument(
|
|
332
|
-
"--max_epochs",
|
|
333
|
-
type=int,
|
|
334
|
-
default=50,
|
|
335
|
-
help="Max epochs per trial (default: 50, use early stopping)",
|
|
336
|
-
)
|
|
337
|
-
parser.add_argument(
|
|
338
|
-
"--seed", type=int, default=2025, help="Random seed (default: 2025)"
|
|
339
|
-
)
|
|
340
|
-
|
|
341
|
-
# Output
|
|
342
|
-
parser.add_argument(
|
|
343
|
-
"--output",
|
|
344
|
-
type=str,
|
|
345
|
-
default="hpo_results.json",
|
|
346
|
-
help="Output file for best params (default: hpo_results.json)",
|
|
347
|
-
)
|
|
348
|
-
parser.add_argument(
|
|
349
|
-
"--study_name",
|
|
350
|
-
type=str,
|
|
351
|
-
default="wavedl_hpo",
|
|
352
|
-
help="Optuna study name (default: wavedl_hpo)",
|
|
353
|
-
)
|
|
354
|
-
|
|
355
|
-
args = parser.parse_args()
|
|
356
|
-
|
|
357
|
-
# Convert to absolute path (child processes may run in different cwd)
|
|
358
|
-
args.data_path = str(Path(args.data_path).resolve())
|
|
359
|
-
|
|
360
|
-
# Validate data path
|
|
361
|
-
if not Path(args.data_path).exists():
|
|
362
|
-
print(f"Error: Data file not found: {args.data_path}")
|
|
363
|
-
sys.exit(1)
|
|
364
|
-
|
|
365
|
-
# Auto-detect GPUs for n_jobs if not specified
|
|
366
|
-
if args.n_jobs == -1:
|
|
367
|
-
try:
|
|
368
|
-
result_gpu = subprocess.run(
|
|
369
|
-
["nvidia-smi", "--list-gpus"],
|
|
370
|
-
capture_output=True,
|
|
371
|
-
text=True,
|
|
372
|
-
)
|
|
373
|
-
if result_gpu.returncode == 0:
|
|
374
|
-
args.n_jobs = max(1, len(result_gpu.stdout.strip().split("\n")))
|
|
375
|
-
else:
|
|
376
|
-
args.n_jobs = 1
|
|
377
|
-
except Exception:
|
|
378
|
-
args.n_jobs = 1
|
|
379
|
-
print(f"Auto-detected {args.n_jobs} GPU(s) for parallel trials")
|
|
380
|
-
|
|
381
|
-
# Create study
|
|
382
|
-
print("=" * 60)
|
|
383
|
-
print("WaveDL Hyperparameter Optimization")
|
|
384
|
-
print("=" * 60)
|
|
385
|
-
print(f"Data: {args.data_path}")
|
|
386
|
-
print(f"Trials: {args.n_trials}")
|
|
387
|
-
print(f"Mode: {'Quick' if args.quick else 'Full'}")
|
|
388
|
-
print(f"Parallel jobs: {args.n_jobs}")
|
|
389
|
-
print("=" * 60)
|
|
390
|
-
|
|
391
|
-
study = optuna.create_study(
|
|
392
|
-
study_name=args.study_name,
|
|
393
|
-
direction="minimize",
|
|
394
|
-
pruner=optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=10),
|
|
395
|
-
)
|
|
396
|
-
|
|
397
|
-
# Run optimization
|
|
398
|
-
objective = create_objective(args)
|
|
399
|
-
study.optimize(
|
|
400
|
-
objective,
|
|
401
|
-
n_trials=args.n_trials,
|
|
402
|
-
n_jobs=args.n_jobs,
|
|
403
|
-
show_progress_bar=True,
|
|
404
|
-
)
|
|
405
|
-
|
|
406
|
-
# Results
|
|
407
|
-
print("\n" + "=" * 60)
|
|
408
|
-
print("OPTIMIZATION COMPLETE")
|
|
409
|
-
print("=" * 60)
|
|
410
|
-
|
|
411
|
-
# Filter completed trials
|
|
412
|
-
completed_trials = [t for t in study.trials if t.state == TrialState.COMPLETE]
|
|
413
|
-
|
|
414
|
-
if not completed_trials:
|
|
415
|
-
print("No trials completed successfully.")
|
|
416
|
-
sys.exit(1)
|
|
417
|
-
|
|
418
|
-
print(f"\nCompleted trials: {len(completed_trials)}/{args.n_trials}")
|
|
419
|
-
print(f"Best trial: #{study.best_trial.number}")
|
|
420
|
-
print(f"Best val_loss: {study.best_value:.6f}")
|
|
421
|
-
|
|
422
|
-
print("\nBest hyperparameters:")
|
|
423
|
-
for key, value in study.best_params.items():
|
|
424
|
-
print(f" {key}: {value}")
|
|
425
|
-
|
|
426
|
-
# Save results
|
|
427
|
-
results = {
|
|
428
|
-
"best_value": study.best_value,
|
|
429
|
-
"best_params": study.best_params,
|
|
430
|
-
"n_trials": len(completed_trials),
|
|
431
|
-
"study_name": args.study_name,
|
|
432
|
-
}
|
|
433
|
-
|
|
434
|
-
with open(args.output, "w") as f:
|
|
435
|
-
json.dump(results, f, indent=2)
|
|
436
|
-
|
|
437
|
-
print(f"\nResults saved to: {args.output}")
|
|
438
|
-
|
|
439
|
-
# Print command to train with best params
|
|
440
|
-
print("\n" + "=" * 60)
|
|
441
|
-
print("TO TRAIN WITH BEST PARAMETERS:")
|
|
442
|
-
print("=" * 60)
|
|
443
|
-
cmd_parts = ["accelerate launch -m wavedl.train"]
|
|
444
|
-
cmd_parts.append(f"--data_path {args.data_path}")
|
|
445
|
-
for key, value in study.best_params.items():
|
|
446
|
-
cmd_parts.append(f"--{key} {value}")
|
|
447
|
-
print(" \\\n ".join(cmd_parts))
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
if __name__ == "__main__":
|
|
451
|
-
main()
|
|
1
|
+
"""
|
|
2
|
+
WaveDL - Hyperparameter Optimization with Optuna
|
|
3
|
+
=================================================
|
|
4
|
+
Automated hyperparameter search for finding optimal training configurations.
|
|
5
|
+
|
|
6
|
+
Usage:
|
|
7
|
+
# Basic HPO (50 trials)
|
|
8
|
+
wavedl-hpo --data_path train.npz --n_trials 50
|
|
9
|
+
|
|
10
|
+
# Quick search (fewer parameters)
|
|
11
|
+
wavedl-hpo --data_path train.npz --n_trials 30 --quick
|
|
12
|
+
|
|
13
|
+
# Full search with specific models
|
|
14
|
+
wavedl-hpo --data_path train.npz --n_trials 100 --models cnn resnet18 efficientnet_b0
|
|
15
|
+
|
|
16
|
+
# Parallel trials on multiple GPUs
|
|
17
|
+
wavedl-hpo --data_path train.npz --n_trials 100 --n_jobs 4
|
|
18
|
+
|
|
19
|
+
Author: Ductho Le (ductho.le@outlook.com)
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
import argparse
|
|
23
|
+
import json
|
|
24
|
+
import subprocess
|
|
25
|
+
import sys
|
|
26
|
+
import tempfile
|
|
27
|
+
from pathlib import Path
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
try:
|
|
31
|
+
import optuna
|
|
32
|
+
from optuna.trial import TrialState
|
|
33
|
+
except ImportError:
|
|
34
|
+
print("Error: Optuna not installed. Run: pip install wavedl")
|
|
35
|
+
sys.exit(1)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
# =============================================================================
|
|
39
|
+
# DEFAULT SEARCH SPACES
|
|
40
|
+
# =============================================================================
|
|
41
|
+
|
|
42
|
+
DEFAULT_MODELS = ["cnn", "resnet18", "resnet34"]
|
|
43
|
+
QUICK_MODELS = ["cnn"]
|
|
44
|
+
|
|
45
|
+
# All 6 optimizers
|
|
46
|
+
DEFAULT_OPTIMIZERS = ["adamw", "adam", "sgd", "nadam", "radam", "rmsprop"]
|
|
47
|
+
QUICK_OPTIMIZERS = ["adamw"]
|
|
48
|
+
|
|
49
|
+
# All 8 schedulers
|
|
50
|
+
DEFAULT_SCHEDULERS = [
|
|
51
|
+
"plateau",
|
|
52
|
+
"cosine",
|
|
53
|
+
"cosine_restarts",
|
|
54
|
+
"onecycle",
|
|
55
|
+
"step",
|
|
56
|
+
"multistep",
|
|
57
|
+
"exponential",
|
|
58
|
+
"linear_warmup",
|
|
59
|
+
]
|
|
60
|
+
QUICK_SCHEDULERS = ["plateau"]
|
|
61
|
+
|
|
62
|
+
# All 6 losses
|
|
63
|
+
DEFAULT_LOSSES = ["mse", "mae", "huber", "smooth_l1", "log_cosh", "weighted_mse"]
|
|
64
|
+
QUICK_LOSSES = ["mse"]
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
# =============================================================================
|
|
68
|
+
# OBJECTIVE FUNCTION
|
|
69
|
+
# =============================================================================
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def create_objective(args):
|
|
73
|
+
"""Create Optuna objective function with configurable search space."""
|
|
74
|
+
|
|
75
|
+
def objective(trial):
|
|
76
|
+
# Select search space based on mode
|
|
77
|
+
# CLI arguments always take precedence over defaults
|
|
78
|
+
if args.quick:
|
|
79
|
+
models = args.models or QUICK_MODELS
|
|
80
|
+
optimizers = args.optimizers or QUICK_OPTIMIZERS
|
|
81
|
+
schedulers = args.schedulers or QUICK_SCHEDULERS
|
|
82
|
+
losses = args.losses or QUICK_LOSSES
|
|
83
|
+
else:
|
|
84
|
+
models = args.models or DEFAULT_MODELS
|
|
85
|
+
optimizers = args.optimizers or DEFAULT_OPTIMIZERS
|
|
86
|
+
schedulers = args.schedulers or DEFAULT_SCHEDULERS
|
|
87
|
+
losses = args.losses or DEFAULT_LOSSES
|
|
88
|
+
|
|
89
|
+
# Suggest hyperparameters
|
|
90
|
+
model = trial.suggest_categorical("model", models)
|
|
91
|
+
lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True)
|
|
92
|
+
batch_sizes = args.batch_sizes or [16, 32, 64, 128]
|
|
93
|
+
batch_size = trial.suggest_categorical("batch_size", batch_sizes)
|
|
94
|
+
optimizer = trial.suggest_categorical("optimizer", optimizers)
|
|
95
|
+
scheduler = trial.suggest_categorical("scheduler", schedulers)
|
|
96
|
+
loss = trial.suggest_categorical("loss", losses)
|
|
97
|
+
weight_decay = trial.suggest_float("weight_decay", 1e-6, 1e-3, log=True)
|
|
98
|
+
patience = trial.suggest_int("patience", 10, 30, step=5)
|
|
99
|
+
|
|
100
|
+
# Conditional hyperparameters
|
|
101
|
+
if loss == "huber":
|
|
102
|
+
huber_delta = trial.suggest_float("huber_delta", 0.1, 2.0)
|
|
103
|
+
else:
|
|
104
|
+
huber_delta = None
|
|
105
|
+
|
|
106
|
+
if optimizer == "sgd":
|
|
107
|
+
momentum = trial.suggest_float("momentum", 0.8, 0.99)
|
|
108
|
+
else:
|
|
109
|
+
momentum = None
|
|
110
|
+
|
|
111
|
+
# Build command
|
|
112
|
+
cmd = [
|
|
113
|
+
sys.executable,
|
|
114
|
+
"-m",
|
|
115
|
+
"wavedl.train",
|
|
116
|
+
"--data_path",
|
|
117
|
+
str(args.data_path),
|
|
118
|
+
"--model",
|
|
119
|
+
model,
|
|
120
|
+
"--lr",
|
|
121
|
+
str(lr),
|
|
122
|
+
"--batch_size",
|
|
123
|
+
str(batch_size),
|
|
124
|
+
"--optimizer",
|
|
125
|
+
optimizer,
|
|
126
|
+
"--scheduler",
|
|
127
|
+
scheduler,
|
|
128
|
+
"--loss",
|
|
129
|
+
loss,
|
|
130
|
+
"--weight_decay",
|
|
131
|
+
str(weight_decay),
|
|
132
|
+
"--patience",
|
|
133
|
+
str(patience),
|
|
134
|
+
"--epochs",
|
|
135
|
+
str(args.max_epochs),
|
|
136
|
+
"--seed",
|
|
137
|
+
str(args.seed),
|
|
138
|
+
]
|
|
139
|
+
|
|
140
|
+
# Add conditional args
|
|
141
|
+
if huber_delta:
|
|
142
|
+
cmd.extend(["--huber_delta", str(huber_delta)])
|
|
143
|
+
if momentum:
|
|
144
|
+
cmd.extend(["--momentum", str(momentum)])
|
|
145
|
+
|
|
146
|
+
# Use temporary directory for trial output
|
|
147
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
148
|
+
cmd.extend(["--output_dir", tmpdir])
|
|
149
|
+
history_file = Path(tmpdir) / "training_history.csv"
|
|
150
|
+
|
|
151
|
+
# GPU isolation for parallel trials: assign each trial to a specific GPU
|
|
152
|
+
# This prevents multiple trials from competing for all GPUs
|
|
153
|
+
env = None
|
|
154
|
+
if args.n_jobs > 1:
|
|
155
|
+
import os
|
|
156
|
+
|
|
157
|
+
# Detect available GPUs
|
|
158
|
+
n_gpus = 1
|
|
159
|
+
try:
|
|
160
|
+
import subprocess as sp
|
|
161
|
+
|
|
162
|
+
result_gpu = sp.run(
|
|
163
|
+
["nvidia-smi", "--list-gpus"],
|
|
164
|
+
capture_output=True,
|
|
165
|
+
text=True,
|
|
166
|
+
)
|
|
167
|
+
if result_gpu.returncode == 0:
|
|
168
|
+
n_gpus = len(result_gpu.stdout.strip().split("\n"))
|
|
169
|
+
except Exception:
|
|
170
|
+
pass
|
|
171
|
+
|
|
172
|
+
# Assign trial to a specific GPU (round-robin)
|
|
173
|
+
gpu_id = trial.number % n_gpus
|
|
174
|
+
env = os.environ.copy()
|
|
175
|
+
env["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
|
|
176
|
+
|
|
177
|
+
# Run training
|
|
178
|
+
# Note: We inherit the user's cwd instead of setting cwd=Path(__file__).parent
|
|
179
|
+
# because site-packages may be read-only and train.py creates cache directories
|
|
180
|
+
try:
|
|
181
|
+
result = subprocess.run(
|
|
182
|
+
cmd,
|
|
183
|
+
capture_output=True,
|
|
184
|
+
text=True,
|
|
185
|
+
timeout=args.timeout,
|
|
186
|
+
env=env,
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
# Read best val_loss from training_history.csv (reliable machine-readable)
|
|
190
|
+
val_loss = None
|
|
191
|
+
if history_file.exists():
|
|
192
|
+
try:
|
|
193
|
+
import csv
|
|
194
|
+
|
|
195
|
+
with open(history_file) as f:
|
|
196
|
+
reader = csv.DictReader(f)
|
|
197
|
+
val_losses = []
|
|
198
|
+
for row in reader:
|
|
199
|
+
if "val_loss" in row:
|
|
200
|
+
try:
|
|
201
|
+
val_losses.append(float(row["val_loss"]))
|
|
202
|
+
except (ValueError, TypeError):
|
|
203
|
+
pass
|
|
204
|
+
if val_losses:
|
|
205
|
+
val_loss = min(val_losses) # Best (minimum) val_loss
|
|
206
|
+
except Exception as e:
|
|
207
|
+
print(f"Trial {trial.number}: Error reading history: {e}")
|
|
208
|
+
|
|
209
|
+
if val_loss is None:
|
|
210
|
+
# Fallback: parse stdout for training log format
|
|
211
|
+
# Pattern: "epoch | train_loss | val_loss | ..."
|
|
212
|
+
# Use regex to avoid false positives from unrelated lines
|
|
213
|
+
import re
|
|
214
|
+
|
|
215
|
+
# Match lines like: " 42 | 0.0123 | 0.0156 | ..."
|
|
216
|
+
log_pattern = re.compile(
|
|
217
|
+
r"^\s*\d+\s*\|\s*[\d.]+\s*\|\s*([\d.]+)\s*\|"
|
|
218
|
+
)
|
|
219
|
+
val_losses_stdout = []
|
|
220
|
+
for line in result.stdout.split("\n"):
|
|
221
|
+
match = log_pattern.match(line)
|
|
222
|
+
if match:
|
|
223
|
+
try:
|
|
224
|
+
val_losses_stdout.append(float(match.group(1)))
|
|
225
|
+
except ValueError:
|
|
226
|
+
continue
|
|
227
|
+
if val_losses_stdout:
|
|
228
|
+
val_loss = min(val_losses_stdout)
|
|
229
|
+
|
|
230
|
+
if val_loss is None:
|
|
231
|
+
# Training failed or no loss found
|
|
232
|
+
print(f"Trial {trial.number}: Training failed (no val_loss found)")
|
|
233
|
+
if result.returncode != 0:
|
|
234
|
+
# Show last few lines of stderr for debugging
|
|
235
|
+
stderr_lines = result.stderr.strip().split("\n")[-3:]
|
|
236
|
+
for line in stderr_lines:
|
|
237
|
+
print(f" stderr: {line}")
|
|
238
|
+
return float("inf")
|
|
239
|
+
|
|
240
|
+
print(f"Trial {trial.number}: val_loss={val_loss:.6f}")
|
|
241
|
+
return val_loss
|
|
242
|
+
|
|
243
|
+
except subprocess.TimeoutExpired:
|
|
244
|
+
print(f"Trial {trial.number}: Timeout after {args.timeout}s")
|
|
245
|
+
return float("inf")
|
|
246
|
+
except Exception as e:
|
|
247
|
+
print(f"Trial {trial.number}: Error - {e}")
|
|
248
|
+
return float("inf")
|
|
249
|
+
|
|
250
|
+
return objective
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
# =============================================================================
|
|
254
|
+
# MAIN FUNCTION
|
|
255
|
+
# =============================================================================
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def main():
|
|
259
|
+
parser = argparse.ArgumentParser(
|
|
260
|
+
description="WaveDL Hyperparameter Optimization with Optuna",
|
|
261
|
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
262
|
+
epilog="""
|
|
263
|
+
Examples:
|
|
264
|
+
wavedl-hpo --data_path train.npz --n_trials 50
|
|
265
|
+
wavedl-hpo --data_path train.npz --n_trials 30 --quick
|
|
266
|
+
wavedl-hpo --data_path train.npz --n_trials 100 --models cnn resnet18
|
|
267
|
+
""",
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
# Required
|
|
271
|
+
parser.add_argument(
|
|
272
|
+
"--data_path", type=str, required=True, help="Path to training data"
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
# HPO settings
|
|
276
|
+
parser.add_argument(
|
|
277
|
+
"--n_trials", type=int, default=50, help="Number of HPO trials (default: 50)"
|
|
278
|
+
)
|
|
279
|
+
parser.add_argument(
|
|
280
|
+
"--n_jobs",
|
|
281
|
+
type=int,
|
|
282
|
+
default=-1,
|
|
283
|
+
help="Parallel trials (-1 = auto-detect GPUs, default: -1)",
|
|
284
|
+
)
|
|
285
|
+
parser.add_argument(
|
|
286
|
+
"--quick",
|
|
287
|
+
action="store_true",
|
|
288
|
+
help="Quick mode: search fewer parameters",
|
|
289
|
+
)
|
|
290
|
+
parser.add_argument(
|
|
291
|
+
"--timeout",
|
|
292
|
+
type=int,
|
|
293
|
+
default=3600,
|
|
294
|
+
help="Timeout per trial in seconds (default: 3600)",
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
# Search space customization
|
|
298
|
+
parser.add_argument(
|
|
299
|
+
"--models",
|
|
300
|
+
nargs="+",
|
|
301
|
+
default=None,
|
|
302
|
+
help=f"Models to search (default: {DEFAULT_MODELS})",
|
|
303
|
+
)
|
|
304
|
+
parser.add_argument(
|
|
305
|
+
"--optimizers",
|
|
306
|
+
nargs="+",
|
|
307
|
+
default=None,
|
|
308
|
+
help=f"Optimizers to search (default: {DEFAULT_OPTIMIZERS})",
|
|
309
|
+
)
|
|
310
|
+
parser.add_argument(
|
|
311
|
+
"--schedulers",
|
|
312
|
+
nargs="+",
|
|
313
|
+
default=None,
|
|
314
|
+
help=f"Schedulers to search (default: {DEFAULT_SCHEDULERS})",
|
|
315
|
+
)
|
|
316
|
+
parser.add_argument(
|
|
317
|
+
"--losses",
|
|
318
|
+
nargs="+",
|
|
319
|
+
default=None,
|
|
320
|
+
help=f"Losses to search (default: {DEFAULT_LOSSES})",
|
|
321
|
+
)
|
|
322
|
+
parser.add_argument(
|
|
323
|
+
"--batch_sizes",
|
|
324
|
+
type=int,
|
|
325
|
+
nargs="+",
|
|
326
|
+
default=None,
|
|
327
|
+
help="Batch sizes to search (default: 16 32 64 128)",
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
# Training settings for each trial
|
|
331
|
+
parser.add_argument(
|
|
332
|
+
"--max_epochs",
|
|
333
|
+
type=int,
|
|
334
|
+
default=50,
|
|
335
|
+
help="Max epochs per trial (default: 50, use early stopping)",
|
|
336
|
+
)
|
|
337
|
+
parser.add_argument(
|
|
338
|
+
"--seed", type=int, default=2025, help="Random seed (default: 2025)"
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
# Output
|
|
342
|
+
parser.add_argument(
|
|
343
|
+
"--output",
|
|
344
|
+
type=str,
|
|
345
|
+
default="hpo_results.json",
|
|
346
|
+
help="Output file for best params (default: hpo_results.json)",
|
|
347
|
+
)
|
|
348
|
+
parser.add_argument(
|
|
349
|
+
"--study_name",
|
|
350
|
+
type=str,
|
|
351
|
+
default="wavedl_hpo",
|
|
352
|
+
help="Optuna study name (default: wavedl_hpo)",
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
args = parser.parse_args()
|
|
356
|
+
|
|
357
|
+
# Convert to absolute path (child processes may run in different cwd)
|
|
358
|
+
args.data_path = str(Path(args.data_path).resolve())
|
|
359
|
+
|
|
360
|
+
# Validate data path
|
|
361
|
+
if not Path(args.data_path).exists():
|
|
362
|
+
print(f"Error: Data file not found: {args.data_path}")
|
|
363
|
+
sys.exit(1)
|
|
364
|
+
|
|
365
|
+
# Auto-detect GPUs for n_jobs if not specified
|
|
366
|
+
if args.n_jobs == -1:
|
|
367
|
+
try:
|
|
368
|
+
result_gpu = subprocess.run(
|
|
369
|
+
["nvidia-smi", "--list-gpus"],
|
|
370
|
+
capture_output=True,
|
|
371
|
+
text=True,
|
|
372
|
+
)
|
|
373
|
+
if result_gpu.returncode == 0:
|
|
374
|
+
args.n_jobs = max(1, len(result_gpu.stdout.strip().split("\n")))
|
|
375
|
+
else:
|
|
376
|
+
args.n_jobs = 1
|
|
377
|
+
except Exception:
|
|
378
|
+
args.n_jobs = 1
|
|
379
|
+
print(f"Auto-detected {args.n_jobs} GPU(s) for parallel trials")
|
|
380
|
+
|
|
381
|
+
# Create study
|
|
382
|
+
print("=" * 60)
|
|
383
|
+
print("WaveDL Hyperparameter Optimization")
|
|
384
|
+
print("=" * 60)
|
|
385
|
+
print(f"Data: {args.data_path}")
|
|
386
|
+
print(f"Trials: {args.n_trials}")
|
|
387
|
+
print(f"Mode: {'Quick' if args.quick else 'Full'}")
|
|
388
|
+
print(f"Parallel jobs: {args.n_jobs}")
|
|
389
|
+
print("=" * 60)
|
|
390
|
+
|
|
391
|
+
study = optuna.create_study(
|
|
392
|
+
study_name=args.study_name,
|
|
393
|
+
direction="minimize",
|
|
394
|
+
pruner=optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=10),
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
# Run optimization
|
|
398
|
+
objective = create_objective(args)
|
|
399
|
+
study.optimize(
|
|
400
|
+
objective,
|
|
401
|
+
n_trials=args.n_trials,
|
|
402
|
+
n_jobs=args.n_jobs,
|
|
403
|
+
show_progress_bar=True,
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
# Results
|
|
407
|
+
print("\n" + "=" * 60)
|
|
408
|
+
print("OPTIMIZATION COMPLETE")
|
|
409
|
+
print("=" * 60)
|
|
410
|
+
|
|
411
|
+
# Filter completed trials
|
|
412
|
+
completed_trials = [t for t in study.trials if t.state == TrialState.COMPLETE]
|
|
413
|
+
|
|
414
|
+
if not completed_trials:
|
|
415
|
+
print("No trials completed successfully.")
|
|
416
|
+
sys.exit(1)
|
|
417
|
+
|
|
418
|
+
print(f"\nCompleted trials: {len(completed_trials)}/{args.n_trials}")
|
|
419
|
+
print(f"Best trial: #{study.best_trial.number}")
|
|
420
|
+
print(f"Best val_loss: {study.best_value:.6f}")
|
|
421
|
+
|
|
422
|
+
print("\nBest hyperparameters:")
|
|
423
|
+
for key, value in study.best_params.items():
|
|
424
|
+
print(f" {key}: {value}")
|
|
425
|
+
|
|
426
|
+
# Save results
|
|
427
|
+
results = {
|
|
428
|
+
"best_value": study.best_value,
|
|
429
|
+
"best_params": study.best_params,
|
|
430
|
+
"n_trials": len(completed_trials),
|
|
431
|
+
"study_name": args.study_name,
|
|
432
|
+
}
|
|
433
|
+
|
|
434
|
+
with open(args.output, "w") as f:
|
|
435
|
+
json.dump(results, f, indent=2)
|
|
436
|
+
|
|
437
|
+
print(f"\nResults saved to: {args.output}")
|
|
438
|
+
|
|
439
|
+
# Print command to train with best params
|
|
440
|
+
print("\n" + "=" * 60)
|
|
441
|
+
print("TO TRAIN WITH BEST PARAMETERS:")
|
|
442
|
+
print("=" * 60)
|
|
443
|
+
cmd_parts = ["accelerate launch -m wavedl.train"]
|
|
444
|
+
cmd_parts.append(f"--data_path {args.data_path}")
|
|
445
|
+
for key, value in study.best_params.items():
|
|
446
|
+
cmd_parts.append(f"--{key} {value}")
|
|
447
|
+
print(" \\\n ".join(cmd_parts))
|
|
448
|
+
|
|
449
|
+
|
|
450
|
+
if __name__ == "__main__":
|
|
451
|
+
main()
|