wavedl 1.6.0__py3-none-any.whl → 1.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/hpo.py +451 -451
- wavedl/models/__init__.py +28 -0
- wavedl/models/{_timm_utils.py → _pretrained_utils.py} +128 -0
- wavedl/models/base.py +48 -0
- wavedl/models/caformer.py +1 -1
- wavedl/models/cnn.py +2 -27
- wavedl/models/convnext.py +5 -18
- wavedl/models/convnext_v2.py +6 -22
- wavedl/models/densenet.py +5 -18
- wavedl/models/efficientnetv2.py +315 -315
- wavedl/models/efficientvit.py +398 -0
- wavedl/models/fastvit.py +6 -39
- wavedl/models/mamba.py +44 -24
- wavedl/models/maxvit.py +51 -48
- wavedl/models/mobilenetv3.py +295 -295
- wavedl/models/regnet.py +406 -406
- wavedl/models/resnet.py +14 -56
- wavedl/models/resnet3d.py +258 -258
- wavedl/models/swin.py +443 -443
- wavedl/models/tcn.py +393 -409
- wavedl/models/unet.py +1 -5
- wavedl/models/unireplknet.py +491 -0
- wavedl/models/vit.py +3 -3
- wavedl/train.py +1430 -1430
- wavedl/utils/config.py +367 -367
- wavedl/utils/cross_validation.py +530 -530
- wavedl/utils/losses.py +216 -216
- wavedl/utils/optimizers.py +216 -216
- wavedl/utils/schedulers.py +251 -251
- {wavedl-1.6.0.dist-info → wavedl-1.6.1.dist-info}/METADATA +93 -53
- wavedl-1.6.1.dist-info/RECORD +46 -0
- wavedl-1.6.0.dist-info/RECORD +0 -44
- {wavedl-1.6.0.dist-info → wavedl-1.6.1.dist-info}/LICENSE +0 -0
- {wavedl-1.6.0.dist-info → wavedl-1.6.1.dist-info}/WHEEL +0 -0
- {wavedl-1.6.0.dist-info → wavedl-1.6.1.dist-info}/entry_points.txt +0 -0
- {wavedl-1.6.0.dist-info → wavedl-1.6.1.dist-info}/top_level.txt +0 -0
wavedl/utils/config.py
CHANGED
|
@@ -1,367 +1,367 @@
|
|
|
1
|
-
"""
|
|
2
|
-
WaveDL - Configuration Management
|
|
3
|
-
==================================
|
|
4
|
-
|
|
5
|
-
YAML configuration file support for reproducible experiments.
|
|
6
|
-
|
|
7
|
-
Features:
|
|
8
|
-
- Load experiment configs from YAML files
|
|
9
|
-
- Merge configs with CLI arguments (CLI takes precedence)
|
|
10
|
-
- Validate config values against known options
|
|
11
|
-
- Save effective config for reproducibility
|
|
12
|
-
|
|
13
|
-
Usage:
|
|
14
|
-
# Load config and merge with CLI args
|
|
15
|
-
config = load_config("experiment.yaml")
|
|
16
|
-
args = merge_config_with_args(config, args)
|
|
17
|
-
|
|
18
|
-
# Save effective config
|
|
19
|
-
save_config(args, "output/config.yaml")
|
|
20
|
-
|
|
21
|
-
Author: Ductho Le (ductho.le@outlook.com)
|
|
22
|
-
Version: 1.0.0
|
|
23
|
-
"""
|
|
24
|
-
|
|
25
|
-
import argparse
|
|
26
|
-
import logging
|
|
27
|
-
from datetime import datetime
|
|
28
|
-
from pathlib import Path
|
|
29
|
-
from typing import Any
|
|
30
|
-
|
|
31
|
-
import yaml
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
def load_config(config_path: str) -> dict[str, Any]:
|
|
35
|
-
"""
|
|
36
|
-
Load configuration from a YAML file.
|
|
37
|
-
|
|
38
|
-
Args:
|
|
39
|
-
config_path: Path to YAML configuration file
|
|
40
|
-
|
|
41
|
-
Returns:
|
|
42
|
-
Dictionary of configuration values
|
|
43
|
-
|
|
44
|
-
Raises:
|
|
45
|
-
FileNotFoundError: If config file doesn't exist
|
|
46
|
-
yaml.YAMLError: If config file is invalid YAML
|
|
47
|
-
|
|
48
|
-
Example:
|
|
49
|
-
>>> config = load_config("configs/experiment.yaml")
|
|
50
|
-
>>> print(config["model"])
|
|
51
|
-
'cnn'
|
|
52
|
-
"""
|
|
53
|
-
config_path = Path(config_path)
|
|
54
|
-
|
|
55
|
-
if not config_path.exists():
|
|
56
|
-
raise FileNotFoundError(f"Config file not found: {config_path}")
|
|
57
|
-
|
|
58
|
-
with open(config_path, encoding="utf-8") as f:
|
|
59
|
-
config = yaml.safe_load(f)
|
|
60
|
-
|
|
61
|
-
if config is None:
|
|
62
|
-
config = {}
|
|
63
|
-
|
|
64
|
-
# Handle nested configs (e.g., optimizer.lr -> optimizer_lr)
|
|
65
|
-
config = _flatten_config(config)
|
|
66
|
-
|
|
67
|
-
return config
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
def _flatten_config(
|
|
71
|
-
config: dict[str, Any], parent_key: str = "", sep: str = "_"
|
|
72
|
-
) -> dict[str, Any]:
|
|
73
|
-
"""
|
|
74
|
-
Flatten nested dictionaries for argparse compatibility.
|
|
75
|
-
|
|
76
|
-
Recursively flattens nested dicts, preserving the full key path.
|
|
77
|
-
|
|
78
|
-
Example:
|
|
79
|
-
{'optimizer': {'lr': 1e-3}} -> {'optimizer_lr': 1e-3}
|
|
80
|
-
{'optimizer': {'params': {'beta1': 0.9}}} -> {'optimizer_params_beta1': 0.9}
|
|
81
|
-
{'lr': 1e-3} -> {'lr': 1e-3}
|
|
82
|
-
"""
|
|
83
|
-
items = []
|
|
84
|
-
for key, value in config.items():
|
|
85
|
-
new_key = f"{parent_key}{sep}{key}" if parent_key else key
|
|
86
|
-
if isinstance(value, dict):
|
|
87
|
-
# Recursively flatten, passing full accumulated key path
|
|
88
|
-
items.extend(_flatten_config(value, new_key, sep).items())
|
|
89
|
-
else:
|
|
90
|
-
items.append((new_key, value))
|
|
91
|
-
return dict(items)
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
def merge_config_with_args(
|
|
95
|
-
config: dict[str, Any],
|
|
96
|
-
args: argparse.Namespace,
|
|
97
|
-
parser: argparse.ArgumentParser | None = None,
|
|
98
|
-
ignore_unknown: bool = True,
|
|
99
|
-
) -> argparse.Namespace:
|
|
100
|
-
"""
|
|
101
|
-
Merge YAML config with CLI arguments. CLI args take precedence.
|
|
102
|
-
|
|
103
|
-
Args:
|
|
104
|
-
config: Dictionary from load_config()
|
|
105
|
-
args: Parsed argparse Namespace
|
|
106
|
-
parser: Optional ArgumentParser to detect defaults (if not provided,
|
|
107
|
-
uses heuristic comparison with common default values)
|
|
108
|
-
ignore_unknown: If True, skip config keys not in args
|
|
109
|
-
|
|
110
|
-
Returns:
|
|
111
|
-
Updated argparse Namespace
|
|
112
|
-
|
|
113
|
-
Note:
|
|
114
|
-
CLI arguments (non-default values) always override config values.
|
|
115
|
-
This allows: `--config base.yaml --lr 5e-4` to use config but override LR.
|
|
116
|
-
"""
|
|
117
|
-
# Get parser defaults to detect which args were explicitly set by user
|
|
118
|
-
if parser is not None:
|
|
119
|
-
# Safe extraction: iterate actions instead of parse_args([])
|
|
120
|
-
# This avoids failures if required arguments are added later
|
|
121
|
-
defaults = {
|
|
122
|
-
action.dest: action.default
|
|
123
|
-
for action in parser._actions
|
|
124
|
-
if action.dest != "help"
|
|
125
|
-
}
|
|
126
|
-
else:
|
|
127
|
-
# Fallback: reconstruct defaults from known patterns
|
|
128
|
-
# This works because argparse stores actual values, and we compare
|
|
129
|
-
defaults = {}
|
|
130
|
-
|
|
131
|
-
# Track which args were explicitly set on CLI (differ from defaults)
|
|
132
|
-
cli_overrides = set()
|
|
133
|
-
for key, value in vars(args).items():
|
|
134
|
-
if parser is not None:
|
|
135
|
-
if key in defaults and value != defaults[key]:
|
|
136
|
-
cli_overrides.add(key)
|
|
137
|
-
# Without parser, we can't reliably detect CLI overrides
|
|
138
|
-
# So we apply all config values (legacy behavior)
|
|
139
|
-
|
|
140
|
-
# Apply config values only where CLI didn't override
|
|
141
|
-
for key, value in config.items():
|
|
142
|
-
if hasattr(args, key):
|
|
143
|
-
# Skip if user explicitly set this via CLI
|
|
144
|
-
if key in cli_overrides:
|
|
145
|
-
logging.debug(f"Config key '{key}' skipped: CLI override detected")
|
|
146
|
-
continue
|
|
147
|
-
setattr(args, key, value)
|
|
148
|
-
elif not ignore_unknown:
|
|
149
|
-
logging.warning(f"Unknown config key: {key}")
|
|
150
|
-
else:
|
|
151
|
-
# Even in ignore_unknown mode, log for discoverability
|
|
152
|
-
logging.debug(f"Config key '{key}' ignored: not a valid argument")
|
|
153
|
-
|
|
154
|
-
return args
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
def save_config(
|
|
158
|
-
args: argparse.Namespace, output_path: str, exclude_keys: list[str] | None = None
|
|
159
|
-
) -> str:
|
|
160
|
-
"""
|
|
161
|
-
Save effective configuration to YAML for reproducibility.
|
|
162
|
-
|
|
163
|
-
Args:
|
|
164
|
-
args: Parsed argparse Namespace
|
|
165
|
-
output_path: Path to save YAML file
|
|
166
|
-
exclude_keys: Keys to exclude from saved config
|
|
167
|
-
|
|
168
|
-
Returns:
|
|
169
|
-
Path to saved config file
|
|
170
|
-
|
|
171
|
-
Example:
|
|
172
|
-
>>> save_config(args, "output/effective_config.yaml")
|
|
173
|
-
"""
|
|
174
|
-
if exclude_keys is None:
|
|
175
|
-
exclude_keys = ["list_models", "fresh", "resume"]
|
|
176
|
-
|
|
177
|
-
config = {}
|
|
178
|
-
for key, value in vars(args).items():
|
|
179
|
-
if key not in exclude_keys:
|
|
180
|
-
# Convert Path objects to strings
|
|
181
|
-
if isinstance(value, Path):
|
|
182
|
-
value = str(value)
|
|
183
|
-
config[key] = value
|
|
184
|
-
|
|
185
|
-
# Add metadata
|
|
186
|
-
from wavedl import __version__
|
|
187
|
-
|
|
188
|
-
config["_metadata"] = {
|
|
189
|
-
"saved_at": datetime.now().isoformat(),
|
|
190
|
-
"wavedl_version": __version__,
|
|
191
|
-
}
|
|
192
|
-
|
|
193
|
-
output_path = Path(output_path)
|
|
194
|
-
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
195
|
-
|
|
196
|
-
with open(output_path, "w", encoding="utf-8") as f:
|
|
197
|
-
yaml.dump(config, f, default_flow_style=False, sort_keys=False)
|
|
198
|
-
|
|
199
|
-
return str(output_path)
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
def validate_config(
|
|
203
|
-
config: dict[str, Any], known_keys: list[str] | None = None
|
|
204
|
-
) -> list[str]:
|
|
205
|
-
"""
|
|
206
|
-
Validate configuration values against known options.
|
|
207
|
-
|
|
208
|
-
Args:
|
|
209
|
-
config: Configuration dictionary
|
|
210
|
-
known_keys: Optional list of valid keys (if None, uses defaults from parser args)
|
|
211
|
-
|
|
212
|
-
Returns:
|
|
213
|
-
List of warning messages (empty if valid)
|
|
214
|
-
"""
|
|
215
|
-
warnings = []
|
|
216
|
-
|
|
217
|
-
# Known valid options
|
|
218
|
-
from wavedl.models import list_models
|
|
219
|
-
from wavedl.utils import list_losses, list_optimizers, list_schedulers
|
|
220
|
-
|
|
221
|
-
valid_options = {
|
|
222
|
-
"model": list_models(),
|
|
223
|
-
"loss": list_losses(),
|
|
224
|
-
"optimizer": list_optimizers(),
|
|
225
|
-
"scheduler": list_schedulers(),
|
|
226
|
-
}
|
|
227
|
-
|
|
228
|
-
for key, valid_values in valid_options.items():
|
|
229
|
-
if key in config and config[key] not in valid_values:
|
|
230
|
-
warnings.append(
|
|
231
|
-
f"Invalid {key}='{config[key]}'. Valid options: {valid_values}"
|
|
232
|
-
)
|
|
233
|
-
|
|
234
|
-
# Validate numeric ranges
|
|
235
|
-
numeric_checks = {
|
|
236
|
-
"lr": (0, 1, "Learning rate should be between 0 and 1"),
|
|
237
|
-
"epochs": (1, 100000, "Epochs should be positive"),
|
|
238
|
-
"batch_size": (1, 10000, "Batch size should be positive"),
|
|
239
|
-
"patience": (1, 1000, "Patience should be positive"),
|
|
240
|
-
"cv": (0, 100, "CV folds should be 0-100"),
|
|
241
|
-
}
|
|
242
|
-
|
|
243
|
-
for key, (min_val, max_val, msg) in numeric_checks.items():
|
|
244
|
-
if key in config:
|
|
245
|
-
val = config[key]
|
|
246
|
-
# Type check: ensure value is numeric before comparison
|
|
247
|
-
if not isinstance(val, (int, float)):
|
|
248
|
-
warnings.append(
|
|
249
|
-
f"Invalid type for '{key}': expected number, got {type(val).__name__} ({val!r})"
|
|
250
|
-
)
|
|
251
|
-
continue
|
|
252
|
-
if not (min_val <= val <= max_val):
|
|
253
|
-
warnings.append(f"{msg}: got {val}")
|
|
254
|
-
|
|
255
|
-
# Check for unknown/unrecognized keys (helps catch typos)
|
|
256
|
-
# Default known keys based on common training arguments
|
|
257
|
-
default_known_keys = {
|
|
258
|
-
# Model
|
|
259
|
-
"model",
|
|
260
|
-
"import_modules",
|
|
261
|
-
# Hyperparameters
|
|
262
|
-
"batch_size",
|
|
263
|
-
"lr",
|
|
264
|
-
"epochs",
|
|
265
|
-
"patience",
|
|
266
|
-
"weight_decay",
|
|
267
|
-
"grad_clip",
|
|
268
|
-
# Loss
|
|
269
|
-
"loss",
|
|
270
|
-
"huber_delta",
|
|
271
|
-
"loss_weights",
|
|
272
|
-
# Optimizer
|
|
273
|
-
"optimizer",
|
|
274
|
-
"momentum",
|
|
275
|
-
"nesterov",
|
|
276
|
-
"betas",
|
|
277
|
-
# Scheduler
|
|
278
|
-
"scheduler",
|
|
279
|
-
"scheduler_patience",
|
|
280
|
-
"min_lr",
|
|
281
|
-
"scheduler_factor",
|
|
282
|
-
"warmup_epochs",
|
|
283
|
-
"step_size",
|
|
284
|
-
"milestones",
|
|
285
|
-
# Data
|
|
286
|
-
"data_path",
|
|
287
|
-
"workers",
|
|
288
|
-
"seed",
|
|
289
|
-
"single_channel",
|
|
290
|
-
# Cross-validation
|
|
291
|
-
"cv",
|
|
292
|
-
"cv_stratify",
|
|
293
|
-
"cv_bins",
|
|
294
|
-
# Checkpointing
|
|
295
|
-
"resume",
|
|
296
|
-
"save_every",
|
|
297
|
-
"output_dir",
|
|
298
|
-
"fresh",
|
|
299
|
-
# Performance
|
|
300
|
-
"compile",
|
|
301
|
-
"precision",
|
|
302
|
-
"mixed_precision",
|
|
303
|
-
# Logging
|
|
304
|
-
"wandb",
|
|
305
|
-
"wandb_watch",
|
|
306
|
-
"project_name",
|
|
307
|
-
"run_name",
|
|
308
|
-
# Config
|
|
309
|
-
"config",
|
|
310
|
-
"list_models",
|
|
311
|
-
# Physical Constraints
|
|
312
|
-
"constraint",
|
|
313
|
-
"bounds",
|
|
314
|
-
"constraint_file",
|
|
315
|
-
"constraint_weight",
|
|
316
|
-
"constraint_reduction",
|
|
317
|
-
"positive",
|
|
318
|
-
"output_bounds",
|
|
319
|
-
"output_transform",
|
|
320
|
-
"output_formula",
|
|
321
|
-
# Metadata (internal)
|
|
322
|
-
"_metadata",
|
|
323
|
-
}
|
|
324
|
-
|
|
325
|
-
check_keys = set(known_keys) if known_keys else default_known_keys
|
|
326
|
-
|
|
327
|
-
for key in config:
|
|
328
|
-
if key not in check_keys:
|
|
329
|
-
warnings.append(
|
|
330
|
-
f"Unknown config key: '{key}' - check for typos or see wavedl-train --help"
|
|
331
|
-
)
|
|
332
|
-
|
|
333
|
-
return warnings
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
def create_default_config() -> dict[str, Any]:
|
|
337
|
-
"""
|
|
338
|
-
Create a default configuration dictionary.
|
|
339
|
-
|
|
340
|
-
Returns:
|
|
341
|
-
Dictionary with default training configuration
|
|
342
|
-
"""
|
|
343
|
-
return {
|
|
344
|
-
# Model
|
|
345
|
-
"model": "cnn",
|
|
346
|
-
# Hyperparameters
|
|
347
|
-
"batch_size": 128,
|
|
348
|
-
"lr": 1e-3,
|
|
349
|
-
"epochs": 1000,
|
|
350
|
-
"patience": 20,
|
|
351
|
-
"weight_decay": 1e-4,
|
|
352
|
-
"grad_clip": 1.0,
|
|
353
|
-
# Training components
|
|
354
|
-
"loss": "mse",
|
|
355
|
-
"optimizer": "adamw",
|
|
356
|
-
"scheduler": "plateau",
|
|
357
|
-
# Cross-validation
|
|
358
|
-
"cv": 0,
|
|
359
|
-
"cv_stratify": False,
|
|
360
|
-
"cv_bins": 10,
|
|
361
|
-
# Performance
|
|
362
|
-
"precision": "bf16",
|
|
363
|
-
"compile": False,
|
|
364
|
-
# Output
|
|
365
|
-
"seed": 2025,
|
|
366
|
-
"workers": 8,
|
|
367
|
-
}
|
|
1
|
+
"""
|
|
2
|
+
WaveDL - Configuration Management
|
|
3
|
+
==================================
|
|
4
|
+
|
|
5
|
+
YAML configuration file support for reproducible experiments.
|
|
6
|
+
|
|
7
|
+
Features:
|
|
8
|
+
- Load experiment configs from YAML files
|
|
9
|
+
- Merge configs with CLI arguments (CLI takes precedence)
|
|
10
|
+
- Validate config values against known options
|
|
11
|
+
- Save effective config for reproducibility
|
|
12
|
+
|
|
13
|
+
Usage:
|
|
14
|
+
# Load config and merge with CLI args
|
|
15
|
+
config = load_config("experiment.yaml")
|
|
16
|
+
args = merge_config_with_args(config, args)
|
|
17
|
+
|
|
18
|
+
# Save effective config
|
|
19
|
+
save_config(args, "output/config.yaml")
|
|
20
|
+
|
|
21
|
+
Author: Ductho Le (ductho.le@outlook.com)
|
|
22
|
+
Version: 1.0.0
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
import argparse
|
|
26
|
+
import logging
|
|
27
|
+
from datetime import datetime
|
|
28
|
+
from pathlib import Path
|
|
29
|
+
from typing import Any
|
|
30
|
+
|
|
31
|
+
import yaml
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def load_config(config_path: str) -> dict[str, Any]:
|
|
35
|
+
"""
|
|
36
|
+
Load configuration from a YAML file.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
config_path: Path to YAML configuration file
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
Dictionary of configuration values
|
|
43
|
+
|
|
44
|
+
Raises:
|
|
45
|
+
FileNotFoundError: If config file doesn't exist
|
|
46
|
+
yaml.YAMLError: If config file is invalid YAML
|
|
47
|
+
|
|
48
|
+
Example:
|
|
49
|
+
>>> config = load_config("configs/experiment.yaml")
|
|
50
|
+
>>> print(config["model"])
|
|
51
|
+
'cnn'
|
|
52
|
+
"""
|
|
53
|
+
config_path = Path(config_path)
|
|
54
|
+
|
|
55
|
+
if not config_path.exists():
|
|
56
|
+
raise FileNotFoundError(f"Config file not found: {config_path}")
|
|
57
|
+
|
|
58
|
+
with open(config_path, encoding="utf-8") as f:
|
|
59
|
+
config = yaml.safe_load(f)
|
|
60
|
+
|
|
61
|
+
if config is None:
|
|
62
|
+
config = {}
|
|
63
|
+
|
|
64
|
+
# Handle nested configs (e.g., optimizer.lr -> optimizer_lr)
|
|
65
|
+
config = _flatten_config(config)
|
|
66
|
+
|
|
67
|
+
return config
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _flatten_config(
|
|
71
|
+
config: dict[str, Any], parent_key: str = "", sep: str = "_"
|
|
72
|
+
) -> dict[str, Any]:
|
|
73
|
+
"""
|
|
74
|
+
Flatten nested dictionaries for argparse compatibility.
|
|
75
|
+
|
|
76
|
+
Recursively flattens nested dicts, preserving the full key path.
|
|
77
|
+
|
|
78
|
+
Example:
|
|
79
|
+
{'optimizer': {'lr': 1e-3}} -> {'optimizer_lr': 1e-3}
|
|
80
|
+
{'optimizer': {'params': {'beta1': 0.9}}} -> {'optimizer_params_beta1': 0.9}
|
|
81
|
+
{'lr': 1e-3} -> {'lr': 1e-3}
|
|
82
|
+
"""
|
|
83
|
+
items = []
|
|
84
|
+
for key, value in config.items():
|
|
85
|
+
new_key = f"{parent_key}{sep}{key}" if parent_key else key
|
|
86
|
+
if isinstance(value, dict):
|
|
87
|
+
# Recursively flatten, passing full accumulated key path
|
|
88
|
+
items.extend(_flatten_config(value, new_key, sep).items())
|
|
89
|
+
else:
|
|
90
|
+
items.append((new_key, value))
|
|
91
|
+
return dict(items)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def merge_config_with_args(
|
|
95
|
+
config: dict[str, Any],
|
|
96
|
+
args: argparse.Namespace,
|
|
97
|
+
parser: argparse.ArgumentParser | None = None,
|
|
98
|
+
ignore_unknown: bool = True,
|
|
99
|
+
) -> argparse.Namespace:
|
|
100
|
+
"""
|
|
101
|
+
Merge YAML config with CLI arguments. CLI args take precedence.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
config: Dictionary from load_config()
|
|
105
|
+
args: Parsed argparse Namespace
|
|
106
|
+
parser: Optional ArgumentParser to detect defaults (if not provided,
|
|
107
|
+
uses heuristic comparison with common default values)
|
|
108
|
+
ignore_unknown: If True, skip config keys not in args
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
Updated argparse Namespace
|
|
112
|
+
|
|
113
|
+
Note:
|
|
114
|
+
CLI arguments (non-default values) always override config values.
|
|
115
|
+
This allows: `--config base.yaml --lr 5e-4` to use config but override LR.
|
|
116
|
+
"""
|
|
117
|
+
# Get parser defaults to detect which args were explicitly set by user
|
|
118
|
+
if parser is not None:
|
|
119
|
+
# Safe extraction: iterate actions instead of parse_args([])
|
|
120
|
+
# This avoids failures if required arguments are added later
|
|
121
|
+
defaults = {
|
|
122
|
+
action.dest: action.default
|
|
123
|
+
for action in parser._actions
|
|
124
|
+
if action.dest != "help"
|
|
125
|
+
}
|
|
126
|
+
else:
|
|
127
|
+
# Fallback: reconstruct defaults from known patterns
|
|
128
|
+
# This works because argparse stores actual values, and we compare
|
|
129
|
+
defaults = {}
|
|
130
|
+
|
|
131
|
+
# Track which args were explicitly set on CLI (differ from defaults)
|
|
132
|
+
cli_overrides = set()
|
|
133
|
+
for key, value in vars(args).items():
|
|
134
|
+
if parser is not None:
|
|
135
|
+
if key in defaults and value != defaults[key]:
|
|
136
|
+
cli_overrides.add(key)
|
|
137
|
+
# Without parser, we can't reliably detect CLI overrides
|
|
138
|
+
# So we apply all config values (legacy behavior)
|
|
139
|
+
|
|
140
|
+
# Apply config values only where CLI didn't override
|
|
141
|
+
for key, value in config.items():
|
|
142
|
+
if hasattr(args, key):
|
|
143
|
+
# Skip if user explicitly set this via CLI
|
|
144
|
+
if key in cli_overrides:
|
|
145
|
+
logging.debug(f"Config key '{key}' skipped: CLI override detected")
|
|
146
|
+
continue
|
|
147
|
+
setattr(args, key, value)
|
|
148
|
+
elif not ignore_unknown:
|
|
149
|
+
logging.warning(f"Unknown config key: {key}")
|
|
150
|
+
else:
|
|
151
|
+
# Even in ignore_unknown mode, log for discoverability
|
|
152
|
+
logging.debug(f"Config key '{key}' ignored: not a valid argument")
|
|
153
|
+
|
|
154
|
+
return args
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def save_config(
|
|
158
|
+
args: argparse.Namespace, output_path: str, exclude_keys: list[str] | None = None
|
|
159
|
+
) -> str:
|
|
160
|
+
"""
|
|
161
|
+
Save effective configuration to YAML for reproducibility.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
args: Parsed argparse Namespace
|
|
165
|
+
output_path: Path to save YAML file
|
|
166
|
+
exclude_keys: Keys to exclude from saved config
|
|
167
|
+
|
|
168
|
+
Returns:
|
|
169
|
+
Path to saved config file
|
|
170
|
+
|
|
171
|
+
Example:
|
|
172
|
+
>>> save_config(args, "output/effective_config.yaml")
|
|
173
|
+
"""
|
|
174
|
+
if exclude_keys is None:
|
|
175
|
+
exclude_keys = ["list_models", "fresh", "resume"]
|
|
176
|
+
|
|
177
|
+
config = {}
|
|
178
|
+
for key, value in vars(args).items():
|
|
179
|
+
if key not in exclude_keys:
|
|
180
|
+
# Convert Path objects to strings
|
|
181
|
+
if isinstance(value, Path):
|
|
182
|
+
value = str(value)
|
|
183
|
+
config[key] = value
|
|
184
|
+
|
|
185
|
+
# Add metadata
|
|
186
|
+
from wavedl import __version__
|
|
187
|
+
|
|
188
|
+
config["_metadata"] = {
|
|
189
|
+
"saved_at": datetime.now().isoformat(),
|
|
190
|
+
"wavedl_version": __version__,
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
output_path = Path(output_path)
|
|
194
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
195
|
+
|
|
196
|
+
with open(output_path, "w", encoding="utf-8") as f:
|
|
197
|
+
yaml.dump(config, f, default_flow_style=False, sort_keys=False)
|
|
198
|
+
|
|
199
|
+
return str(output_path)
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def validate_config(
|
|
203
|
+
config: dict[str, Any], known_keys: list[str] | None = None
|
|
204
|
+
) -> list[str]:
|
|
205
|
+
"""
|
|
206
|
+
Validate configuration values against known options.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
config: Configuration dictionary
|
|
210
|
+
known_keys: Optional list of valid keys (if None, uses defaults from parser args)
|
|
211
|
+
|
|
212
|
+
Returns:
|
|
213
|
+
List of warning messages (empty if valid)
|
|
214
|
+
"""
|
|
215
|
+
warnings = []
|
|
216
|
+
|
|
217
|
+
# Known valid options
|
|
218
|
+
from wavedl.models import list_models
|
|
219
|
+
from wavedl.utils import list_losses, list_optimizers, list_schedulers
|
|
220
|
+
|
|
221
|
+
valid_options = {
|
|
222
|
+
"model": list_models(),
|
|
223
|
+
"loss": list_losses(),
|
|
224
|
+
"optimizer": list_optimizers(),
|
|
225
|
+
"scheduler": list_schedulers(),
|
|
226
|
+
}
|
|
227
|
+
|
|
228
|
+
for key, valid_values in valid_options.items():
|
|
229
|
+
if key in config and config[key] not in valid_values:
|
|
230
|
+
warnings.append(
|
|
231
|
+
f"Invalid {key}='{config[key]}'. Valid options: {valid_values}"
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
# Validate numeric ranges
|
|
235
|
+
numeric_checks = {
|
|
236
|
+
"lr": (0, 1, "Learning rate should be between 0 and 1"),
|
|
237
|
+
"epochs": (1, 100000, "Epochs should be positive"),
|
|
238
|
+
"batch_size": (1, 10000, "Batch size should be positive"),
|
|
239
|
+
"patience": (1, 1000, "Patience should be positive"),
|
|
240
|
+
"cv": (0, 100, "CV folds should be 0-100"),
|
|
241
|
+
}
|
|
242
|
+
|
|
243
|
+
for key, (min_val, max_val, msg) in numeric_checks.items():
|
|
244
|
+
if key in config:
|
|
245
|
+
val = config[key]
|
|
246
|
+
# Type check: ensure value is numeric before comparison
|
|
247
|
+
if not isinstance(val, (int, float)):
|
|
248
|
+
warnings.append(
|
|
249
|
+
f"Invalid type for '{key}': expected number, got {type(val).__name__} ({val!r})"
|
|
250
|
+
)
|
|
251
|
+
continue
|
|
252
|
+
if not (min_val <= val <= max_val):
|
|
253
|
+
warnings.append(f"{msg}: got {val}")
|
|
254
|
+
|
|
255
|
+
# Check for unknown/unrecognized keys (helps catch typos)
|
|
256
|
+
# Default known keys based on common training arguments
|
|
257
|
+
default_known_keys = {
|
|
258
|
+
# Model
|
|
259
|
+
"model",
|
|
260
|
+
"import_modules",
|
|
261
|
+
# Hyperparameters
|
|
262
|
+
"batch_size",
|
|
263
|
+
"lr",
|
|
264
|
+
"epochs",
|
|
265
|
+
"patience",
|
|
266
|
+
"weight_decay",
|
|
267
|
+
"grad_clip",
|
|
268
|
+
# Loss
|
|
269
|
+
"loss",
|
|
270
|
+
"huber_delta",
|
|
271
|
+
"loss_weights",
|
|
272
|
+
# Optimizer
|
|
273
|
+
"optimizer",
|
|
274
|
+
"momentum",
|
|
275
|
+
"nesterov",
|
|
276
|
+
"betas",
|
|
277
|
+
# Scheduler
|
|
278
|
+
"scheduler",
|
|
279
|
+
"scheduler_patience",
|
|
280
|
+
"min_lr",
|
|
281
|
+
"scheduler_factor",
|
|
282
|
+
"warmup_epochs",
|
|
283
|
+
"step_size",
|
|
284
|
+
"milestones",
|
|
285
|
+
# Data
|
|
286
|
+
"data_path",
|
|
287
|
+
"workers",
|
|
288
|
+
"seed",
|
|
289
|
+
"single_channel",
|
|
290
|
+
# Cross-validation
|
|
291
|
+
"cv",
|
|
292
|
+
"cv_stratify",
|
|
293
|
+
"cv_bins",
|
|
294
|
+
# Checkpointing
|
|
295
|
+
"resume",
|
|
296
|
+
"save_every",
|
|
297
|
+
"output_dir",
|
|
298
|
+
"fresh",
|
|
299
|
+
# Performance
|
|
300
|
+
"compile",
|
|
301
|
+
"precision",
|
|
302
|
+
"mixed_precision",
|
|
303
|
+
# Logging
|
|
304
|
+
"wandb",
|
|
305
|
+
"wandb_watch",
|
|
306
|
+
"project_name",
|
|
307
|
+
"run_name",
|
|
308
|
+
# Config
|
|
309
|
+
"config",
|
|
310
|
+
"list_models",
|
|
311
|
+
# Physical Constraints
|
|
312
|
+
"constraint",
|
|
313
|
+
"bounds",
|
|
314
|
+
"constraint_file",
|
|
315
|
+
"constraint_weight",
|
|
316
|
+
"constraint_reduction",
|
|
317
|
+
"positive",
|
|
318
|
+
"output_bounds",
|
|
319
|
+
"output_transform",
|
|
320
|
+
"output_formula",
|
|
321
|
+
# Metadata (internal)
|
|
322
|
+
"_metadata",
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
check_keys = set(known_keys) if known_keys else default_known_keys
|
|
326
|
+
|
|
327
|
+
for key in config:
|
|
328
|
+
if key not in check_keys:
|
|
329
|
+
warnings.append(
|
|
330
|
+
f"Unknown config key: '{key}' - check for typos or see wavedl-train --help"
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
return warnings
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
def create_default_config() -> dict[str, Any]:
|
|
337
|
+
"""
|
|
338
|
+
Create a default configuration dictionary.
|
|
339
|
+
|
|
340
|
+
Returns:
|
|
341
|
+
Dictionary with default training configuration
|
|
342
|
+
"""
|
|
343
|
+
return {
|
|
344
|
+
# Model
|
|
345
|
+
"model": "cnn",
|
|
346
|
+
# Hyperparameters
|
|
347
|
+
"batch_size": 128,
|
|
348
|
+
"lr": 1e-3,
|
|
349
|
+
"epochs": 1000,
|
|
350
|
+
"patience": 20,
|
|
351
|
+
"weight_decay": 1e-4,
|
|
352
|
+
"grad_clip": 1.0,
|
|
353
|
+
# Training components
|
|
354
|
+
"loss": "mse",
|
|
355
|
+
"optimizer": "adamw",
|
|
356
|
+
"scheduler": "plateau",
|
|
357
|
+
# Cross-validation
|
|
358
|
+
"cv": 0,
|
|
359
|
+
"cv_stratify": False,
|
|
360
|
+
"cv_bins": 10,
|
|
361
|
+
# Performance
|
|
362
|
+
"precision": "bf16",
|
|
363
|
+
"compile": False,
|
|
364
|
+
# Output
|
|
365
|
+
"seed": 2025,
|
|
366
|
+
"workers": 8,
|
|
367
|
+
}
|