wavedl 1.6.0__py3-none-any.whl → 1.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wavedl/hpo.py CHANGED
@@ -1,451 +1,451 @@
1
- """
2
- WaveDL - Hyperparameter Optimization with Optuna
3
- =================================================
4
- Automated hyperparameter search for finding optimal training configurations.
5
-
6
- Usage:
7
- # Basic HPO (50 trials)
8
- wavedl-hpo --data_path train.npz --n_trials 50
9
-
10
- # Quick search (fewer parameters)
11
- wavedl-hpo --data_path train.npz --n_trials 30 --quick
12
-
13
- # Full search with specific models
14
- wavedl-hpo --data_path train.npz --n_trials 100 --models cnn resnet18 efficientnet_b0
15
-
16
- # Parallel trials on multiple GPUs
17
- wavedl-hpo --data_path train.npz --n_trials 100 --n_jobs 4
18
-
19
- Author: Ductho Le (ductho.le@outlook.com)
20
- """
21
-
22
- import argparse
23
- import json
24
- import subprocess
25
- import sys
26
- import tempfile
27
- from pathlib import Path
28
-
29
-
30
- try:
31
- import optuna
32
- from optuna.trial import TrialState
33
- except ImportError:
34
- print("Error: Optuna not installed. Run: pip install wavedl")
35
- sys.exit(1)
36
-
37
-
38
- # =============================================================================
39
- # DEFAULT SEARCH SPACES
40
- # =============================================================================
41
-
42
- DEFAULT_MODELS = ["cnn", "resnet18", "resnet34"]
43
- QUICK_MODELS = ["cnn"]
44
-
45
- # All 6 optimizers
46
- DEFAULT_OPTIMIZERS = ["adamw", "adam", "sgd", "nadam", "radam", "rmsprop"]
47
- QUICK_OPTIMIZERS = ["adamw"]
48
-
49
- # All 8 schedulers
50
- DEFAULT_SCHEDULERS = [
51
- "plateau",
52
- "cosine",
53
- "cosine_restarts",
54
- "onecycle",
55
- "step",
56
- "multistep",
57
- "exponential",
58
- "linear_warmup",
59
- ]
60
- QUICK_SCHEDULERS = ["plateau"]
61
-
62
- # All 6 losses
63
- DEFAULT_LOSSES = ["mse", "mae", "huber", "smooth_l1", "log_cosh", "weighted_mse"]
64
- QUICK_LOSSES = ["mse"]
65
-
66
-
67
- # =============================================================================
68
- # OBJECTIVE FUNCTION
69
- # =============================================================================
70
-
71
-
72
- def create_objective(args):
73
- """Create Optuna objective function with configurable search space."""
74
-
75
- def objective(trial):
76
- # Select search space based on mode
77
- # CLI arguments always take precedence over defaults
78
- if args.quick:
79
- models = args.models or QUICK_MODELS
80
- optimizers = args.optimizers or QUICK_OPTIMIZERS
81
- schedulers = args.schedulers or QUICK_SCHEDULERS
82
- losses = args.losses or QUICK_LOSSES
83
- else:
84
- models = args.models or DEFAULT_MODELS
85
- optimizers = args.optimizers or DEFAULT_OPTIMIZERS
86
- schedulers = args.schedulers or DEFAULT_SCHEDULERS
87
- losses = args.losses or DEFAULT_LOSSES
88
-
89
- # Suggest hyperparameters
90
- model = trial.suggest_categorical("model", models)
91
- lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True)
92
- batch_sizes = args.batch_sizes or [16, 32, 64, 128]
93
- batch_size = trial.suggest_categorical("batch_size", batch_sizes)
94
- optimizer = trial.suggest_categorical("optimizer", optimizers)
95
- scheduler = trial.suggest_categorical("scheduler", schedulers)
96
- loss = trial.suggest_categorical("loss", losses)
97
- weight_decay = trial.suggest_float("weight_decay", 1e-6, 1e-3, log=True)
98
- patience = trial.suggest_int("patience", 10, 30, step=5)
99
-
100
- # Conditional hyperparameters
101
- if loss == "huber":
102
- huber_delta = trial.suggest_float("huber_delta", 0.1, 2.0)
103
- else:
104
- huber_delta = None
105
-
106
- if optimizer == "sgd":
107
- momentum = trial.suggest_float("momentum", 0.8, 0.99)
108
- else:
109
- momentum = None
110
-
111
- # Build command
112
- cmd = [
113
- sys.executable,
114
- "-m",
115
- "wavedl.train",
116
- "--data_path",
117
- str(args.data_path),
118
- "--model",
119
- model,
120
- "--lr",
121
- str(lr),
122
- "--batch_size",
123
- str(batch_size),
124
- "--optimizer",
125
- optimizer,
126
- "--scheduler",
127
- scheduler,
128
- "--loss",
129
- loss,
130
- "--weight_decay",
131
- str(weight_decay),
132
- "--patience",
133
- str(patience),
134
- "--epochs",
135
- str(args.max_epochs),
136
- "--seed",
137
- str(args.seed),
138
- ]
139
-
140
- # Add conditional args
141
- if huber_delta:
142
- cmd.extend(["--huber_delta", str(huber_delta)])
143
- if momentum:
144
- cmd.extend(["--momentum", str(momentum)])
145
-
146
- # Use temporary directory for trial output
147
- with tempfile.TemporaryDirectory() as tmpdir:
148
- cmd.extend(["--output_dir", tmpdir])
149
- history_file = Path(tmpdir) / "training_history.csv"
150
-
151
- # GPU isolation for parallel trials: assign each trial to a specific GPU
152
- # This prevents multiple trials from competing for all GPUs
153
- env = None
154
- if args.n_jobs > 1:
155
- import os
156
-
157
- # Detect available GPUs
158
- n_gpus = 1
159
- try:
160
- import subprocess as sp
161
-
162
- result_gpu = sp.run(
163
- ["nvidia-smi", "--list-gpus"],
164
- capture_output=True,
165
- text=True,
166
- )
167
- if result_gpu.returncode == 0:
168
- n_gpus = len(result_gpu.stdout.strip().split("\n"))
169
- except Exception:
170
- pass
171
-
172
- # Assign trial to a specific GPU (round-robin)
173
- gpu_id = trial.number % n_gpus
174
- env = os.environ.copy()
175
- env["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
176
-
177
- # Run training
178
- # Note: We inherit the user's cwd instead of setting cwd=Path(__file__).parent
179
- # because site-packages may be read-only and train.py creates cache directories
180
- try:
181
- result = subprocess.run(
182
- cmd,
183
- capture_output=True,
184
- text=True,
185
- timeout=args.timeout,
186
- env=env,
187
- )
188
-
189
- # Read best val_loss from training_history.csv (reliable machine-readable)
190
- val_loss = None
191
- if history_file.exists():
192
- try:
193
- import csv
194
-
195
- with open(history_file) as f:
196
- reader = csv.DictReader(f)
197
- val_losses = []
198
- for row in reader:
199
- if "val_loss" in row:
200
- try:
201
- val_losses.append(float(row["val_loss"]))
202
- except (ValueError, TypeError):
203
- pass
204
- if val_losses:
205
- val_loss = min(val_losses) # Best (minimum) val_loss
206
- except Exception as e:
207
- print(f"Trial {trial.number}: Error reading history: {e}")
208
-
209
- if val_loss is None:
210
- # Fallback: parse stdout for training log format
211
- # Pattern: "epoch | train_loss | val_loss | ..."
212
- # Use regex to avoid false positives from unrelated lines
213
- import re
214
-
215
- # Match lines like: " 42 | 0.0123 | 0.0156 | ..."
216
- log_pattern = re.compile(
217
- r"^\s*\d+\s*\|\s*[\d.]+\s*\|\s*([\d.]+)\s*\|"
218
- )
219
- val_losses_stdout = []
220
- for line in result.stdout.split("\n"):
221
- match = log_pattern.match(line)
222
- if match:
223
- try:
224
- val_losses_stdout.append(float(match.group(1)))
225
- except ValueError:
226
- continue
227
- if val_losses_stdout:
228
- val_loss = min(val_losses_stdout)
229
-
230
- if val_loss is None:
231
- # Training failed or no loss found
232
- print(f"Trial {trial.number}: Training failed (no val_loss found)")
233
- if result.returncode != 0:
234
- # Show last few lines of stderr for debugging
235
- stderr_lines = result.stderr.strip().split("\n")[-3:]
236
- for line in stderr_lines:
237
- print(f" stderr: {line}")
238
- return float("inf")
239
-
240
- print(f"Trial {trial.number}: val_loss={val_loss:.6f}")
241
- return val_loss
242
-
243
- except subprocess.TimeoutExpired:
244
- print(f"Trial {trial.number}: Timeout after {args.timeout}s")
245
- return float("inf")
246
- except Exception as e:
247
- print(f"Trial {trial.number}: Error - {e}")
248
- return float("inf")
249
-
250
- return objective
251
-
252
-
253
- # =============================================================================
254
- # MAIN FUNCTION
255
- # =============================================================================
256
-
257
-
258
- def main():
259
- parser = argparse.ArgumentParser(
260
- description="WaveDL Hyperparameter Optimization with Optuna",
261
- formatter_class=argparse.RawDescriptionHelpFormatter,
262
- epilog="""
263
- Examples:
264
- wavedl-hpo --data_path train.npz --n_trials 50
265
- wavedl-hpo --data_path train.npz --n_trials 30 --quick
266
- wavedl-hpo --data_path train.npz --n_trials 100 --models cnn resnet18
267
- """,
268
- )
269
-
270
- # Required
271
- parser.add_argument(
272
- "--data_path", type=str, required=True, help="Path to training data"
273
- )
274
-
275
- # HPO settings
276
- parser.add_argument(
277
- "--n_trials", type=int, default=50, help="Number of HPO trials (default: 50)"
278
- )
279
- parser.add_argument(
280
- "--n_jobs",
281
- type=int,
282
- default=-1,
283
- help="Parallel trials (-1 = auto-detect GPUs, default: -1)",
284
- )
285
- parser.add_argument(
286
- "--quick",
287
- action="store_true",
288
- help="Quick mode: search fewer parameters",
289
- )
290
- parser.add_argument(
291
- "--timeout",
292
- type=int,
293
- default=3600,
294
- help="Timeout per trial in seconds (default: 3600)",
295
- )
296
-
297
- # Search space customization
298
- parser.add_argument(
299
- "--models",
300
- nargs="+",
301
- default=None,
302
- help=f"Models to search (default: {DEFAULT_MODELS})",
303
- )
304
- parser.add_argument(
305
- "--optimizers",
306
- nargs="+",
307
- default=None,
308
- help=f"Optimizers to search (default: {DEFAULT_OPTIMIZERS})",
309
- )
310
- parser.add_argument(
311
- "--schedulers",
312
- nargs="+",
313
- default=None,
314
- help=f"Schedulers to search (default: {DEFAULT_SCHEDULERS})",
315
- )
316
- parser.add_argument(
317
- "--losses",
318
- nargs="+",
319
- default=None,
320
- help=f"Losses to search (default: {DEFAULT_LOSSES})",
321
- )
322
- parser.add_argument(
323
- "--batch_sizes",
324
- type=int,
325
- nargs="+",
326
- default=None,
327
- help="Batch sizes to search (default: 16 32 64 128)",
328
- )
329
-
330
- # Training settings for each trial
331
- parser.add_argument(
332
- "--max_epochs",
333
- type=int,
334
- default=50,
335
- help="Max epochs per trial (default: 50, use early stopping)",
336
- )
337
- parser.add_argument(
338
- "--seed", type=int, default=2025, help="Random seed (default: 2025)"
339
- )
340
-
341
- # Output
342
- parser.add_argument(
343
- "--output",
344
- type=str,
345
- default="hpo_results.json",
346
- help="Output file for best params (default: hpo_results.json)",
347
- )
348
- parser.add_argument(
349
- "--study_name",
350
- type=str,
351
- default="wavedl_hpo",
352
- help="Optuna study name (default: wavedl_hpo)",
353
- )
354
-
355
- args = parser.parse_args()
356
-
357
- # Convert to absolute path (child processes may run in different cwd)
358
- args.data_path = str(Path(args.data_path).resolve())
359
-
360
- # Validate data path
361
- if not Path(args.data_path).exists():
362
- print(f"Error: Data file not found: {args.data_path}")
363
- sys.exit(1)
364
-
365
- # Auto-detect GPUs for n_jobs if not specified
366
- if args.n_jobs == -1:
367
- try:
368
- result_gpu = subprocess.run(
369
- ["nvidia-smi", "--list-gpus"],
370
- capture_output=True,
371
- text=True,
372
- )
373
- if result_gpu.returncode == 0:
374
- args.n_jobs = max(1, len(result_gpu.stdout.strip().split("\n")))
375
- else:
376
- args.n_jobs = 1
377
- except Exception:
378
- args.n_jobs = 1
379
- print(f"Auto-detected {args.n_jobs} GPU(s) for parallel trials")
380
-
381
- # Create study
382
- print("=" * 60)
383
- print("WaveDL Hyperparameter Optimization")
384
- print("=" * 60)
385
- print(f"Data: {args.data_path}")
386
- print(f"Trials: {args.n_trials}")
387
- print(f"Mode: {'Quick' if args.quick else 'Full'}")
388
- print(f"Parallel jobs: {args.n_jobs}")
389
- print("=" * 60)
390
-
391
- study = optuna.create_study(
392
- study_name=args.study_name,
393
- direction="minimize",
394
- pruner=optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=10),
395
- )
396
-
397
- # Run optimization
398
- objective = create_objective(args)
399
- study.optimize(
400
- objective,
401
- n_trials=args.n_trials,
402
- n_jobs=args.n_jobs,
403
- show_progress_bar=True,
404
- )
405
-
406
- # Results
407
- print("\n" + "=" * 60)
408
- print("OPTIMIZATION COMPLETE")
409
- print("=" * 60)
410
-
411
- # Filter completed trials
412
- completed_trials = [t for t in study.trials if t.state == TrialState.COMPLETE]
413
-
414
- if not completed_trials:
415
- print("No trials completed successfully.")
416
- sys.exit(1)
417
-
418
- print(f"\nCompleted trials: {len(completed_trials)}/{args.n_trials}")
419
- print(f"Best trial: #{study.best_trial.number}")
420
- print(f"Best val_loss: {study.best_value:.6f}")
421
-
422
- print("\nBest hyperparameters:")
423
- for key, value in study.best_params.items():
424
- print(f" {key}: {value}")
425
-
426
- # Save results
427
- results = {
428
- "best_value": study.best_value,
429
- "best_params": study.best_params,
430
- "n_trials": len(completed_trials),
431
- "study_name": args.study_name,
432
- }
433
-
434
- with open(args.output, "w") as f:
435
- json.dump(results, f, indent=2)
436
-
437
- print(f"\nResults saved to: {args.output}")
438
-
439
- # Print command to train with best params
440
- print("\n" + "=" * 60)
441
- print("TO TRAIN WITH BEST PARAMETERS:")
442
- print("=" * 60)
443
- cmd_parts = ["accelerate launch -m wavedl.train"]
444
- cmd_parts.append(f"--data_path {args.data_path}")
445
- for key, value in study.best_params.items():
446
- cmd_parts.append(f"--{key} {value}")
447
- print(" \\\n ".join(cmd_parts))
448
-
449
-
450
- if __name__ == "__main__":
451
- main()
1
+ """
2
+ WaveDL - Hyperparameter Optimization with Optuna
3
+ =================================================
4
+ Automated hyperparameter search for finding optimal training configurations.
5
+
6
+ Usage:
7
+ # Basic HPO (50 trials)
8
+ wavedl-hpo --data_path train.npz --n_trials 50
9
+
10
+ # Quick search (fewer parameters)
11
+ wavedl-hpo --data_path train.npz --n_trials 30 --quick
12
+
13
+ # Full search with specific models
14
+ wavedl-hpo --data_path train.npz --n_trials 100 --models cnn resnet18 efficientnet_b0
15
+
16
+ # Parallel trials on multiple GPUs
17
+ wavedl-hpo --data_path train.npz --n_trials 100 --n_jobs 4
18
+
19
+ Author: Ductho Le (ductho.le@outlook.com)
20
+ """
21
+
22
+ import argparse
23
+ import json
24
+ import subprocess
25
+ import sys
26
+ import tempfile
27
+ from pathlib import Path
28
+
29
+
30
+ try:
31
+ import optuna
32
+ from optuna.trial import TrialState
33
+ except ImportError:
34
+ print("Error: Optuna not installed. Run: pip install wavedl")
35
+ sys.exit(1)
36
+
37
+
38
+ # =============================================================================
39
+ # DEFAULT SEARCH SPACES
40
+ # =============================================================================
41
+
42
+ DEFAULT_MODELS = ["cnn", "resnet18", "resnet34"]
43
+ QUICK_MODELS = ["cnn"]
44
+
45
+ # All 6 optimizers
46
+ DEFAULT_OPTIMIZERS = ["adamw", "adam", "sgd", "nadam", "radam", "rmsprop"]
47
+ QUICK_OPTIMIZERS = ["adamw"]
48
+
49
+ # All 8 schedulers
50
+ DEFAULT_SCHEDULERS = [
51
+ "plateau",
52
+ "cosine",
53
+ "cosine_restarts",
54
+ "onecycle",
55
+ "step",
56
+ "multistep",
57
+ "exponential",
58
+ "linear_warmup",
59
+ ]
60
+ QUICK_SCHEDULERS = ["plateau"]
61
+
62
+ # All 6 losses
63
+ DEFAULT_LOSSES = ["mse", "mae", "huber", "smooth_l1", "log_cosh", "weighted_mse"]
64
+ QUICK_LOSSES = ["mse"]
65
+
66
+
67
+ # =============================================================================
68
+ # OBJECTIVE FUNCTION
69
+ # =============================================================================
70
+
71
+
72
+ def create_objective(args):
73
+ """Create Optuna objective function with configurable search space."""
74
+
75
+ def objective(trial):
76
+ # Select search space based on mode
77
+ # CLI arguments always take precedence over defaults
78
+ if args.quick:
79
+ models = args.models or QUICK_MODELS
80
+ optimizers = args.optimizers or QUICK_OPTIMIZERS
81
+ schedulers = args.schedulers or QUICK_SCHEDULERS
82
+ losses = args.losses or QUICK_LOSSES
83
+ else:
84
+ models = args.models or DEFAULT_MODELS
85
+ optimizers = args.optimizers or DEFAULT_OPTIMIZERS
86
+ schedulers = args.schedulers or DEFAULT_SCHEDULERS
87
+ losses = args.losses or DEFAULT_LOSSES
88
+
89
+ # Suggest hyperparameters
90
+ model = trial.suggest_categorical("model", models)
91
+ lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True)
92
+ batch_sizes = args.batch_sizes or [16, 32, 64, 128]
93
+ batch_size = trial.suggest_categorical("batch_size", batch_sizes)
94
+ optimizer = trial.suggest_categorical("optimizer", optimizers)
95
+ scheduler = trial.suggest_categorical("scheduler", schedulers)
96
+ loss = trial.suggest_categorical("loss", losses)
97
+ weight_decay = trial.suggest_float("weight_decay", 1e-6, 1e-3, log=True)
98
+ patience = trial.suggest_int("patience", 10, 30, step=5)
99
+
100
+ # Conditional hyperparameters
101
+ if loss == "huber":
102
+ huber_delta = trial.suggest_float("huber_delta", 0.1, 2.0)
103
+ else:
104
+ huber_delta = None
105
+
106
+ if optimizer == "sgd":
107
+ momentum = trial.suggest_float("momentum", 0.8, 0.99)
108
+ else:
109
+ momentum = None
110
+
111
+ # Build command
112
+ cmd = [
113
+ sys.executable,
114
+ "-m",
115
+ "wavedl.train",
116
+ "--data_path",
117
+ str(args.data_path),
118
+ "--model",
119
+ model,
120
+ "--lr",
121
+ str(lr),
122
+ "--batch_size",
123
+ str(batch_size),
124
+ "--optimizer",
125
+ optimizer,
126
+ "--scheduler",
127
+ scheduler,
128
+ "--loss",
129
+ loss,
130
+ "--weight_decay",
131
+ str(weight_decay),
132
+ "--patience",
133
+ str(patience),
134
+ "--epochs",
135
+ str(args.max_epochs),
136
+ "--seed",
137
+ str(args.seed),
138
+ ]
139
+
140
+ # Add conditional args
141
+ if huber_delta:
142
+ cmd.extend(["--huber_delta", str(huber_delta)])
143
+ if momentum:
144
+ cmd.extend(["--momentum", str(momentum)])
145
+
146
+ # Use temporary directory for trial output
147
+ with tempfile.TemporaryDirectory() as tmpdir:
148
+ cmd.extend(["--output_dir", tmpdir])
149
+ history_file = Path(tmpdir) / "training_history.csv"
150
+
151
+ # GPU isolation for parallel trials: assign each trial to a specific GPU
152
+ # This prevents multiple trials from competing for all GPUs
153
+ env = None
154
+ if args.n_jobs > 1:
155
+ import os
156
+
157
+ # Detect available GPUs
158
+ n_gpus = 1
159
+ try:
160
+ import subprocess as sp
161
+
162
+ result_gpu = sp.run(
163
+ ["nvidia-smi", "--list-gpus"],
164
+ capture_output=True,
165
+ text=True,
166
+ )
167
+ if result_gpu.returncode == 0:
168
+ n_gpus = len(result_gpu.stdout.strip().split("\n"))
169
+ except Exception:
170
+ pass
171
+
172
+ # Assign trial to a specific GPU (round-robin)
173
+ gpu_id = trial.number % n_gpus
174
+ env = os.environ.copy()
175
+ env["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
176
+
177
+ # Run training
178
+ # Note: We inherit the user's cwd instead of setting cwd=Path(__file__).parent
179
+ # because site-packages may be read-only and train.py creates cache directories
180
+ try:
181
+ result = subprocess.run(
182
+ cmd,
183
+ capture_output=True,
184
+ text=True,
185
+ timeout=args.timeout,
186
+ env=env,
187
+ )
188
+
189
+ # Read best val_loss from training_history.csv (reliable machine-readable)
190
+ val_loss = None
191
+ if history_file.exists():
192
+ try:
193
+ import csv
194
+
195
+ with open(history_file) as f:
196
+ reader = csv.DictReader(f)
197
+ val_losses = []
198
+ for row in reader:
199
+ if "val_loss" in row:
200
+ try:
201
+ val_losses.append(float(row["val_loss"]))
202
+ except (ValueError, TypeError):
203
+ pass
204
+ if val_losses:
205
+ val_loss = min(val_losses) # Best (minimum) val_loss
206
+ except Exception as e:
207
+ print(f"Trial {trial.number}: Error reading history: {e}")
208
+
209
+ if val_loss is None:
210
+ # Fallback: parse stdout for training log format
211
+ # Pattern: "epoch | train_loss | val_loss | ..."
212
+ # Use regex to avoid false positives from unrelated lines
213
+ import re
214
+
215
+ # Match lines like: " 42 | 0.0123 | 0.0156 | ..."
216
+ log_pattern = re.compile(
217
+ r"^\s*\d+\s*\|\s*[\d.]+\s*\|\s*([\d.]+)\s*\|"
218
+ )
219
+ val_losses_stdout = []
220
+ for line in result.stdout.split("\n"):
221
+ match = log_pattern.match(line)
222
+ if match:
223
+ try:
224
+ val_losses_stdout.append(float(match.group(1)))
225
+ except ValueError:
226
+ continue
227
+ if val_losses_stdout:
228
+ val_loss = min(val_losses_stdout)
229
+
230
+ if val_loss is None:
231
+ # Training failed or no loss found
232
+ print(f"Trial {trial.number}: Training failed (no val_loss found)")
233
+ if result.returncode != 0:
234
+ # Show last few lines of stderr for debugging
235
+ stderr_lines = result.stderr.strip().split("\n")[-3:]
236
+ for line in stderr_lines:
237
+ print(f" stderr: {line}")
238
+ return float("inf")
239
+
240
+ print(f"Trial {trial.number}: val_loss={val_loss:.6f}")
241
+ return val_loss
242
+
243
+ except subprocess.TimeoutExpired:
244
+ print(f"Trial {trial.number}: Timeout after {args.timeout}s")
245
+ return float("inf")
246
+ except Exception as e:
247
+ print(f"Trial {trial.number}: Error - {e}")
248
+ return float("inf")
249
+
250
+ return objective
251
+
252
+
253
+ # =============================================================================
254
+ # MAIN FUNCTION
255
+ # =============================================================================
256
+
257
+
258
+ def main():
259
+ parser = argparse.ArgumentParser(
260
+ description="WaveDL Hyperparameter Optimization with Optuna",
261
+ formatter_class=argparse.RawDescriptionHelpFormatter,
262
+ epilog="""
263
+ Examples:
264
+ wavedl-hpo --data_path train.npz --n_trials 50
265
+ wavedl-hpo --data_path train.npz --n_trials 30 --quick
266
+ wavedl-hpo --data_path train.npz --n_trials 100 --models cnn resnet18
267
+ """,
268
+ )
269
+
270
+ # Required
271
+ parser.add_argument(
272
+ "--data_path", type=str, required=True, help="Path to training data"
273
+ )
274
+
275
+ # HPO settings
276
+ parser.add_argument(
277
+ "--n_trials", type=int, default=50, help="Number of HPO trials (default: 50)"
278
+ )
279
+ parser.add_argument(
280
+ "--n_jobs",
281
+ type=int,
282
+ default=-1,
283
+ help="Parallel trials (-1 = auto-detect GPUs, default: -1)",
284
+ )
285
+ parser.add_argument(
286
+ "--quick",
287
+ action="store_true",
288
+ help="Quick mode: search fewer parameters",
289
+ )
290
+ parser.add_argument(
291
+ "--timeout",
292
+ type=int,
293
+ default=3600,
294
+ help="Timeout per trial in seconds (default: 3600)",
295
+ )
296
+
297
+ # Search space customization
298
+ parser.add_argument(
299
+ "--models",
300
+ nargs="+",
301
+ default=None,
302
+ help=f"Models to search (default: {DEFAULT_MODELS})",
303
+ )
304
+ parser.add_argument(
305
+ "--optimizers",
306
+ nargs="+",
307
+ default=None,
308
+ help=f"Optimizers to search (default: {DEFAULT_OPTIMIZERS})",
309
+ )
310
+ parser.add_argument(
311
+ "--schedulers",
312
+ nargs="+",
313
+ default=None,
314
+ help=f"Schedulers to search (default: {DEFAULT_SCHEDULERS})",
315
+ )
316
+ parser.add_argument(
317
+ "--losses",
318
+ nargs="+",
319
+ default=None,
320
+ help=f"Losses to search (default: {DEFAULT_LOSSES})",
321
+ )
322
+ parser.add_argument(
323
+ "--batch_sizes",
324
+ type=int,
325
+ nargs="+",
326
+ default=None,
327
+ help="Batch sizes to search (default: 16 32 64 128)",
328
+ )
329
+
330
+ # Training settings for each trial
331
+ parser.add_argument(
332
+ "--max_epochs",
333
+ type=int,
334
+ default=50,
335
+ help="Max epochs per trial (default: 50, use early stopping)",
336
+ )
337
+ parser.add_argument(
338
+ "--seed", type=int, default=2025, help="Random seed (default: 2025)"
339
+ )
340
+
341
+ # Output
342
+ parser.add_argument(
343
+ "--output",
344
+ type=str,
345
+ default="hpo_results.json",
346
+ help="Output file for best params (default: hpo_results.json)",
347
+ )
348
+ parser.add_argument(
349
+ "--study_name",
350
+ type=str,
351
+ default="wavedl_hpo",
352
+ help="Optuna study name (default: wavedl_hpo)",
353
+ )
354
+
355
+ args = parser.parse_args()
356
+
357
+ # Convert to absolute path (child processes may run in different cwd)
358
+ args.data_path = str(Path(args.data_path).resolve())
359
+
360
+ # Validate data path
361
+ if not Path(args.data_path).exists():
362
+ print(f"Error: Data file not found: {args.data_path}")
363
+ sys.exit(1)
364
+
365
+ # Auto-detect GPUs for n_jobs if not specified
366
+ if args.n_jobs == -1:
367
+ try:
368
+ result_gpu = subprocess.run(
369
+ ["nvidia-smi", "--list-gpus"],
370
+ capture_output=True,
371
+ text=True,
372
+ )
373
+ if result_gpu.returncode == 0:
374
+ args.n_jobs = max(1, len(result_gpu.stdout.strip().split("\n")))
375
+ else:
376
+ args.n_jobs = 1
377
+ except Exception:
378
+ args.n_jobs = 1
379
+ print(f"Auto-detected {args.n_jobs} GPU(s) for parallel trials")
380
+
381
+ # Create study
382
+ print("=" * 60)
383
+ print("WaveDL Hyperparameter Optimization")
384
+ print("=" * 60)
385
+ print(f"Data: {args.data_path}")
386
+ print(f"Trials: {args.n_trials}")
387
+ print(f"Mode: {'Quick' if args.quick else 'Full'}")
388
+ print(f"Parallel jobs: {args.n_jobs}")
389
+ print("=" * 60)
390
+
391
+ study = optuna.create_study(
392
+ study_name=args.study_name,
393
+ direction="minimize",
394
+ pruner=optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=10),
395
+ )
396
+
397
+ # Run optimization
398
+ objective = create_objective(args)
399
+ study.optimize(
400
+ objective,
401
+ n_trials=args.n_trials,
402
+ n_jobs=args.n_jobs,
403
+ show_progress_bar=True,
404
+ )
405
+
406
+ # Results
407
+ print("\n" + "=" * 60)
408
+ print("OPTIMIZATION COMPLETE")
409
+ print("=" * 60)
410
+
411
+ # Filter completed trials
412
+ completed_trials = [t for t in study.trials if t.state == TrialState.COMPLETE]
413
+
414
+ if not completed_trials:
415
+ print("No trials completed successfully.")
416
+ sys.exit(1)
417
+
418
+ print(f"\nCompleted trials: {len(completed_trials)}/{args.n_trials}")
419
+ print(f"Best trial: #{study.best_trial.number}")
420
+ print(f"Best val_loss: {study.best_value:.6f}")
421
+
422
+ print("\nBest hyperparameters:")
423
+ for key, value in study.best_params.items():
424
+ print(f" {key}: {value}")
425
+
426
+ # Save results
427
+ results = {
428
+ "best_value": study.best_value,
429
+ "best_params": study.best_params,
430
+ "n_trials": len(completed_trials),
431
+ "study_name": args.study_name,
432
+ }
433
+
434
+ with open(args.output, "w") as f:
435
+ json.dump(results, f, indent=2)
436
+
437
+ print(f"\nResults saved to: {args.output}")
438
+
439
+ # Print command to train with best params
440
+ print("\n" + "=" * 60)
441
+ print("TO TRAIN WITH BEST PARAMETERS:")
442
+ print("=" * 60)
443
+ cmd_parts = ["accelerate launch -m wavedl.train"]
444
+ cmd_parts.append(f"--data_path {args.data_path}")
445
+ for key, value in study.best_params.items():
446
+ cmd_parts.append(f"--{key} {value}")
447
+ print(" \\\n ".join(cmd_parts))
448
+
449
+
450
+ if __name__ == "__main__":
451
+ main()