wavedl 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,151 @@
1
+ """
2
+ Utility Functions and Classes
3
+ =============================
4
+
5
+ Centralized exports for all utility modules.
6
+
7
+ Author: Ductho Le (ductho.le@outlook.com)
8
+ Version: 1.0.0
9
+ """
10
+
11
+ from .config import (
12
+ create_default_config,
13
+ load_config,
14
+ merge_config_with_args,
15
+ save_config,
16
+ validate_config,
17
+ )
18
+ from .cross_validation import (
19
+ CVDataset,
20
+ run_cross_validation,
21
+ train_fold,
22
+ )
23
+ from .data import (
24
+ # Multi-format data loading
25
+ DataSource,
26
+ HDF5Source,
27
+ MATSource,
28
+ MemmapDataset,
29
+ NPZSource,
30
+ get_data_source,
31
+ load_outputs_only,
32
+ load_test_data,
33
+ load_training_data,
34
+ memmap_worker_init_fn,
35
+ prepare_data,
36
+ )
37
+ from .distributed import (
38
+ broadcast_early_stop,
39
+ broadcast_value,
40
+ sync_tensor,
41
+ )
42
+ from .losses import (
43
+ LogCoshLoss,
44
+ WeightedMSELoss,
45
+ get_loss,
46
+ list_losses,
47
+ )
48
+ from .metrics import (
49
+ COLORS,
50
+ FIGURE_DPI,
51
+ FIGURE_WIDTH_CM,
52
+ # Style constants
53
+ FIGURE_WIDTH_INCH,
54
+ FONT_SIZE_TEXT,
55
+ FONT_SIZE_TICKS,
56
+ MetricTracker,
57
+ calc_pearson,
58
+ calc_per_target_r2,
59
+ configure_matplotlib_style,
60
+ create_training_curves,
61
+ get_lr,
62
+ plot_bland_altman,
63
+ plot_correlation_heatmap,
64
+ plot_error_boxplot,
65
+ plot_error_cdf,
66
+ plot_error_histogram,
67
+ plot_prediction_vs_index,
68
+ plot_qq,
69
+ plot_relative_error,
70
+ plot_residuals,
71
+ plot_scientific_scatter,
72
+ )
73
+ from .optimizers import (
74
+ get_optimizer,
75
+ get_optimizer_with_param_groups,
76
+ list_optimizers,
77
+ )
78
+ from .schedulers import (
79
+ get_scheduler,
80
+ get_scheduler_with_warmup,
81
+ is_epoch_based,
82
+ list_schedulers,
83
+ )
84
+
85
+
86
+ __all__ = [
87
+ "COLORS",
88
+ "FIGURE_DPI",
89
+ "FIGURE_WIDTH_CM",
90
+ # Style constants
91
+ "FIGURE_WIDTH_INCH",
92
+ "FONT_SIZE_TEXT",
93
+ "FONT_SIZE_TICKS",
94
+ "CVDataset",
95
+ "DataSource",
96
+ "HDF5Source",
97
+ "LogCoshLoss",
98
+ "MATSource",
99
+ # Data
100
+ "MemmapDataset",
101
+ # Metrics
102
+ "MetricTracker",
103
+ "NPZSource",
104
+ "WeightedMSELoss",
105
+ # Distributed
106
+ "broadcast_early_stop",
107
+ "broadcast_value",
108
+ "calc_pearson",
109
+ "calc_per_target_r2",
110
+ "configure_matplotlib_style",
111
+ "create_default_config",
112
+ "create_training_curves",
113
+ "get_data_source",
114
+ # Losses
115
+ "get_loss",
116
+ "get_lr",
117
+ # Optimizers
118
+ "get_optimizer",
119
+ "get_optimizer_with_param_groups",
120
+ # Schedulers
121
+ "get_scheduler",
122
+ "get_scheduler_with_warmup",
123
+ "is_epoch_based",
124
+ "list_losses",
125
+ "list_optimizers",
126
+ "list_schedulers",
127
+ # Config
128
+ "load_config",
129
+ "load_outputs_only",
130
+ "load_test_data",
131
+ "load_training_data",
132
+ "memmap_worker_init_fn",
133
+ "merge_config_with_args",
134
+ "plot_bland_altman",
135
+ "plot_correlation_heatmap",
136
+ "plot_error_boxplot",
137
+ "plot_error_cdf",
138
+ "plot_error_histogram",
139
+ "plot_prediction_vs_index",
140
+ "plot_qq",
141
+ "plot_relative_error",
142
+ "plot_residuals",
143
+ "plot_scientific_scatter",
144
+ "prepare_data",
145
+ # Cross-Validation
146
+ "run_cross_validation",
147
+ "save_config",
148
+ "sync_tensor",
149
+ "train_fold",
150
+ "validate_config",
151
+ ]
wavedl/utils/config.py ADDED
@@ -0,0 +1,269 @@
1
+ """
2
+ WaveDL - Configuration Management
3
+ ==================================
4
+
5
+ YAML configuration file support for reproducible experiments.
6
+
7
+ Features:
8
+ - Load experiment configs from YAML files
9
+ - Merge configs with CLI arguments (CLI takes precedence)
10
+ - Validate config values against known options
11
+ - Save effective config for reproducibility
12
+
13
+ Usage:
14
+ # Load config and merge with CLI args
15
+ config = load_config("experiment.yaml")
16
+ args = merge_config_with_args(config, args)
17
+
18
+ # Save effective config
19
+ save_config(args, "output/config.yaml")
20
+
21
+ Author: Ductho Le (ductho.le@outlook.com)
22
+ Version: 1.0.0
23
+ """
24
+
25
+ import argparse
26
+ import logging
27
+ from datetime import datetime
28
+ from pathlib import Path
29
+ from typing import Any
30
+
31
+ import yaml
32
+
33
+
34
+ def load_config(config_path: str) -> dict[str, Any]:
35
+ """
36
+ Load configuration from a YAML file.
37
+
38
+ Args:
39
+ config_path: Path to YAML configuration file
40
+
41
+ Returns:
42
+ Dictionary of configuration values
43
+
44
+ Raises:
45
+ FileNotFoundError: If config file doesn't exist
46
+ yaml.YAMLError: If config file is invalid YAML
47
+
48
+ Example:
49
+ >>> config = load_config("configs/experiment.yaml")
50
+ >>> print(config["model"])
51
+ 'cnn'
52
+ """
53
+ config_path = Path(config_path)
54
+
55
+ if not config_path.exists():
56
+ raise FileNotFoundError(f"Config file not found: {config_path}")
57
+
58
+ with open(config_path, encoding="utf-8") as f:
59
+ config = yaml.safe_load(f)
60
+
61
+ if config is None:
62
+ config = {}
63
+
64
+ # Handle nested configs (e.g., optimizer.lr -> optimizer_lr)
65
+ config = _flatten_config(config)
66
+
67
+ return config
68
+
69
+
70
+ def _flatten_config(
71
+ config: dict[str, Any], parent_key: str = "", sep: str = "_"
72
+ ) -> dict[str, Any]:
73
+ """
74
+ Flatten nested dictionaries for argparse compatibility.
75
+
76
+ Recursively flattens nested dicts, preserving the full key path.
77
+
78
+ Example:
79
+ {'optimizer': {'lr': 1e-3}} -> {'optimizer_lr': 1e-3}
80
+ {'optimizer': {'params': {'beta1': 0.9}}} -> {'optimizer_params_beta1': 0.9}
81
+ {'lr': 1e-3} -> {'lr': 1e-3}
82
+ """
83
+ items = []
84
+ for key, value in config.items():
85
+ new_key = f"{parent_key}{sep}{key}" if parent_key else key
86
+ if isinstance(value, dict):
87
+ # Recursively flatten, passing full accumulated key path
88
+ items.extend(_flatten_config(value, new_key, sep).items())
89
+ else:
90
+ items.append((new_key, value))
91
+ return dict(items)
92
+
93
+
94
+ def merge_config_with_args(
95
+ config: dict[str, Any],
96
+ args: argparse.Namespace,
97
+ parser: argparse.ArgumentParser | None = None,
98
+ ignore_unknown: bool = True,
99
+ ) -> argparse.Namespace:
100
+ """
101
+ Merge YAML config with CLI arguments. CLI args take precedence.
102
+
103
+ Args:
104
+ config: Dictionary from load_config()
105
+ args: Parsed argparse Namespace
106
+ parser: Optional ArgumentParser to detect defaults (if not provided,
107
+ uses heuristic comparison with common default values)
108
+ ignore_unknown: If True, skip config keys not in args
109
+
110
+ Returns:
111
+ Updated argparse Namespace
112
+
113
+ Note:
114
+ CLI arguments (non-default values) always override config values.
115
+ This allows: `--config base.yaml --lr 5e-4` to use config but override LR.
116
+ """
117
+ # Get parser defaults to detect which args were explicitly set by user
118
+ if parser is not None:
119
+ defaults = vars(parser.parse_args([]))
120
+ else:
121
+ # Fallback: reconstruct defaults from known patterns
122
+ # This works because argparse stores actual values, and we compare
123
+ defaults = {}
124
+
125
+ # Track which args were explicitly set on CLI (differ from defaults)
126
+ cli_overrides = set()
127
+ for key, value in vars(args).items():
128
+ if parser is not None:
129
+ if key in defaults and value != defaults[key]:
130
+ cli_overrides.add(key)
131
+ # Without parser, we can't reliably detect CLI overrides
132
+ # So we apply all config values (legacy behavior)
133
+
134
+ # Apply config values only where CLI didn't override
135
+ for key, value in config.items():
136
+ if hasattr(args, key):
137
+ # Skip if user explicitly set this via CLI
138
+ if key in cli_overrides:
139
+ logging.debug(f"Config key '{key}' skipped: CLI override detected")
140
+ continue
141
+ setattr(args, key, value)
142
+ elif not ignore_unknown:
143
+ logging.warning(f"Unknown config key: {key}")
144
+
145
+ return args
146
+
147
+
148
+ def save_config(
149
+ args: argparse.Namespace, output_path: str, exclude_keys: list[str] | None = None
150
+ ) -> str:
151
+ """
152
+ Save effective configuration to YAML for reproducibility.
153
+
154
+ Args:
155
+ args: Parsed argparse Namespace
156
+ output_path: Path to save YAML file
157
+ exclude_keys: Keys to exclude from saved config
158
+
159
+ Returns:
160
+ Path to saved config file
161
+
162
+ Example:
163
+ >>> save_config(args, "output/effective_config.yaml")
164
+ """
165
+ if exclude_keys is None:
166
+ exclude_keys = ["list_models", "fresh", "resume"]
167
+
168
+ config = {}
169
+ for key, value in vars(args).items():
170
+ if key not in exclude_keys:
171
+ # Convert Path objects to strings
172
+ if isinstance(value, Path):
173
+ value = str(value)
174
+ config[key] = value
175
+
176
+ # Add metadata
177
+ config["_metadata"] = {
178
+ "saved_at": datetime.now().isoformat(),
179
+ "wavedl_version": "1.0.0",
180
+ }
181
+
182
+ output_path = Path(output_path)
183
+ output_path.parent.mkdir(parents=True, exist_ok=True)
184
+
185
+ with open(output_path, "w", encoding="utf-8") as f:
186
+ yaml.dump(config, f, default_flow_style=False, sort_keys=False)
187
+
188
+ return str(output_path)
189
+
190
+
191
+ def validate_config(config: dict[str, Any]) -> list[str]:
192
+ """
193
+ Validate configuration values against known options.
194
+
195
+ Args:
196
+ config: Configuration dictionary
197
+
198
+ Returns:
199
+ List of warning messages (empty if valid)
200
+ """
201
+ warnings = []
202
+
203
+ # Known valid options
204
+ from wavedl.models import list_models
205
+ from wavedl.utils import list_losses, list_optimizers, list_schedulers
206
+
207
+ valid_options = {
208
+ "model": list_models(),
209
+ "loss": list_losses(),
210
+ "optimizer": list_optimizers(),
211
+ "scheduler": list_schedulers(),
212
+ }
213
+
214
+ for key, valid_values in valid_options.items():
215
+ if key in config and config[key] not in valid_values:
216
+ warnings.append(
217
+ f"Invalid {key}='{config[key]}'. Valid options: {valid_values}"
218
+ )
219
+
220
+ # Validate numeric ranges
221
+ numeric_checks = {
222
+ "lr": (0, 1, "Learning rate should be between 0 and 1"),
223
+ "epochs": (1, 100000, "Epochs should be positive"),
224
+ "batch_size": (1, 10000, "Batch size should be positive"),
225
+ "patience": (1, 1000, "Patience should be positive"),
226
+ "cv": (0, 100, "CV folds should be 0-100"),
227
+ }
228
+
229
+ for key, (min_val, max_val, msg) in numeric_checks.items():
230
+ if key in config:
231
+ val = config[key]
232
+ if not (min_val <= val <= max_val):
233
+ warnings.append(f"{msg}: got {val}")
234
+
235
+ return warnings
236
+
237
+
238
+ def create_default_config() -> dict[str, Any]:
239
+ """
240
+ Create a default configuration dictionary.
241
+
242
+ Returns:
243
+ Dictionary with default training configuration
244
+ """
245
+ return {
246
+ # Model
247
+ "model": "cnn",
248
+ # Hyperparameters
249
+ "batch_size": 128,
250
+ "lr": 1e-3,
251
+ "epochs": 1000,
252
+ "patience": 20,
253
+ "weight_decay": 1e-4,
254
+ "grad_clip": 1.0,
255
+ # Training components
256
+ "loss": "mse",
257
+ "optimizer": "adamw",
258
+ "scheduler": "plateau",
259
+ # Cross-validation
260
+ "cv": 0,
261
+ "cv_stratify": False,
262
+ "cv_bins": 10,
263
+ # Performance
264
+ "precision": "bf16",
265
+ "compile": False,
266
+ # Output
267
+ "seed": 2025,
268
+ "workers": 8,
269
+ }