wavedl 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,509 @@
1
+ """
2
+ WaveDL - Cross-Validation Utilities
3
+ ====================================
4
+
5
+ Internal module for K-fold cross-validation. Called by train.py when --cv flag is used.
6
+
7
+ This module provides:
8
+ - SimpleDataset: In-memory dataset for CV
9
+ - train_fold: Single fold training function
10
+ - run_cross_validation: Main CV orchestration
11
+
12
+ Author: Ductho Le (ductho.le@outlook.com)
13
+ Version: 1.0.0
14
+ """
15
+
16
+ import json
17
+ import logging
18
+ import os
19
+ import pickle
20
+ from datetime import datetime
21
+ from typing import Any
22
+
23
+ import numpy as np
24
+ import pandas as pd
25
+ import torch
26
+ import torch.nn as nn
27
+ from sklearn.metrics import mean_absolute_error, r2_score
28
+ from sklearn.model_selection import KFold, StratifiedKFold
29
+ from sklearn.preprocessing import StandardScaler
30
+ from torch.utils.data import DataLoader
31
+
32
+
33
+ # ==============================================================================
34
+ # SIMPLE DATASET
35
+ # ==============================================================================
36
+ class CVDataset(torch.utils.data.Dataset):
37
+ """Simple in-memory dataset for cross-validation."""
38
+
39
+ def __init__(self, X: np.ndarray, y: np.ndarray, expected_spatial_ndim: int = None):
40
+ """
41
+ Initialize CV dataset with explicit channel dimension handling.
42
+
43
+ Args:
44
+ X: Input data with shape (N, *spatial_dims) or (N, C, *spatial_dims)
45
+ y: Target data (N, T)
46
+ expected_spatial_ndim: Expected number of spatial dimensions (1, 2, or 3).
47
+ If provided, uses explicit logic instead of heuristics.
48
+ If None, falls back to ndim-based inference (legacy behavior).
49
+
50
+ Channel Dimension Logic:
51
+ - If X.ndim == expected_spatial_ndim + 1: Add channel dim (N, *spatial) -> (N, 1, *spatial)
52
+ - If X.ndim == expected_spatial_ndim + 2: Already has channel (N, C, *spatial)
53
+ - If expected_spatial_ndim is None: Use legacy ndim-based inference
54
+ """
55
+ if expected_spatial_ndim is not None:
56
+ # Explicit mode: use expected_spatial_ndim to determine if channel exists
57
+ if X.ndim == expected_spatial_ndim + 1:
58
+ # Shape is (N, *spatial) - needs channel dimension
59
+ X = np.expand_dims(X, axis=1)
60
+ elif X.ndim == expected_spatial_ndim + 2:
61
+ # Shape is (N, C, *spatial) - already has channel
62
+ pass
63
+ else:
64
+ raise ValueError(
65
+ f"Input shape {X.shape} incompatible with expected_spatial_ndim={expected_spatial_ndim}. "
66
+ f"Expected ndim={expected_spatial_ndim + 1} or {expected_spatial_ndim + 2}, got {X.ndim}."
67
+ )
68
+ else:
69
+ # Legacy mode: infer from ndim (for backwards compatibility)
70
+ # Assumes single-channel data without explicit channel dimension
71
+ if X.ndim == 2: # 1D signals: (N, L) -> (N, 1, L)
72
+ X = X[:, np.newaxis, :]
73
+ elif X.ndim == 3: # 2D images: (N, H, W) -> (N, 1, H, W)
74
+ X = X[:, np.newaxis, :, :]
75
+ elif X.ndim == 4: # 3D volumes: (N, D, H, W) -> (N, 1, D, H, W)
76
+ X = X[:, np.newaxis, :, :, :]
77
+ # ndim >= 5 assumed to already have channel dimension
78
+
79
+ self.X = torch.tensor(X, dtype=torch.float32)
80
+ self.y = torch.tensor(y, dtype=torch.float32)
81
+
82
+ def __len__(self) -> int:
83
+ return len(self.X)
84
+
85
+ def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
86
+ return self.X[idx], self.y[idx]
87
+
88
+
89
+ # ==============================================================================
90
+ # SINGLE FOLD TRAINING
91
+ # ==============================================================================
92
+ def train_fold(
93
+ fold: int,
94
+ model: nn.Module,
95
+ train_loader: DataLoader,
96
+ val_loader: DataLoader,
97
+ criterion: nn.Module,
98
+ optimizer: torch.optim.Optimizer,
99
+ scheduler,
100
+ device: torch.device,
101
+ epochs: int,
102
+ patience: int,
103
+ scaler: StandardScaler,
104
+ logger: logging.Logger,
105
+ ) -> dict[str, Any]:
106
+ """
107
+ Train and evaluate a single CV fold.
108
+
109
+ Args:
110
+ fold: Fold index (0-based)
111
+ model: PyTorch model
112
+ train_loader: Training data loader
113
+ val_loader: Validation data loader
114
+ criterion: Loss function
115
+ optimizer: Optimizer
116
+ scheduler: LR scheduler
117
+ device: Torch device
118
+ epochs: Max epochs
119
+ patience: Early stopping patience
120
+ scaler: Target scaler (for physical units)
121
+ logger: Logger instance
122
+
123
+ Returns:
124
+ Dictionary with fold results and metrics
125
+ """
126
+ best_val_loss = float("inf")
127
+ patience_ctr = 0
128
+ best_state = None
129
+ history = []
130
+
131
+ for epoch in range(epochs):
132
+ # Training
133
+ model.train()
134
+ train_loss = 0.0
135
+ train_samples = 0
136
+
137
+ for x, y in train_loader:
138
+ x, y = x.to(device), y.to(device)
139
+
140
+ optimizer.zero_grad()
141
+ pred = model(x)
142
+ loss = criterion(pred, y)
143
+ loss.backward()
144
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
145
+ optimizer.step()
146
+
147
+ train_loss += loss.item() * x.size(0)
148
+ train_samples += x.size(0)
149
+
150
+ avg_train_loss = train_loss / train_samples
151
+
152
+ # Validation
153
+ model.eval()
154
+ val_loss = 0.0
155
+ val_samples = 0
156
+ all_preds = []
157
+ all_targets = []
158
+
159
+ with torch.inference_mode():
160
+ for x, y in val_loader:
161
+ x, y = x.to(device), y.to(device)
162
+ pred = model(x)
163
+ loss = criterion(pred, y)
164
+
165
+ val_loss += loss.item() * x.size(0)
166
+ val_samples += x.size(0)
167
+
168
+ all_preds.append(pred.cpu())
169
+ all_targets.append(y.cpu())
170
+
171
+ avg_val_loss = val_loss / val_samples
172
+
173
+ # Compute metrics (guard for tiny datasets)
174
+ y_pred = torch.cat(all_preds).numpy()
175
+ y_true = torch.cat(all_targets).numpy()
176
+ r2 = r2_score(y_true, y_pred) if len(y_true) >= 2 else float("nan")
177
+ mae = np.abs((y_pred - y_true) * scaler.scale_).mean()
178
+
179
+ history.append(
180
+ {
181
+ "epoch": epoch + 1,
182
+ "train_loss": avg_train_loss,
183
+ "val_loss": avg_val_loss,
184
+ "r2": r2,
185
+ "mae": mae,
186
+ }
187
+ )
188
+
189
+ # LR scheduling
190
+ if hasattr(scheduler, "step"):
191
+ if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
192
+ scheduler.step(avg_val_loss)
193
+ else:
194
+ scheduler.step()
195
+
196
+ # Early stopping
197
+ if avg_val_loss < best_val_loss:
198
+ best_val_loss = avg_val_loss
199
+ patience_ctr = 0
200
+ best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
201
+ else:
202
+ patience_ctr += 1
203
+
204
+ if patience_ctr >= patience:
205
+ logger.info(f" Fold {fold + 1}: Early stopping at epoch {epoch + 1}")
206
+ break
207
+
208
+ # Restore best model and compute final metrics
209
+ if best_state:
210
+ model.load_state_dict(best_state)
211
+
212
+ model.eval()
213
+ all_preds = []
214
+ all_targets = []
215
+
216
+ with torch.inference_mode():
217
+ for x, y in val_loader:
218
+ x, y = x.to(device), y.to(device)
219
+ pred = model(x)
220
+ all_preds.append(pred.cpu())
221
+ all_targets.append(y.cpu())
222
+
223
+ y_pred = torch.cat(all_preds).numpy()
224
+ y_true = torch.cat(all_targets).numpy()
225
+
226
+ # Inverse transform for physical units
227
+ y_pred_phys = scaler.inverse_transform(y_pred)
228
+ y_true_phys = scaler.inverse_transform(y_true)
229
+
230
+ results = {
231
+ "fold": fold + 1,
232
+ "best_val_loss": best_val_loss,
233
+ "r2": r2_score(y_true, y_pred) if len(y_true) >= 2 else float("nan"),
234
+ "mae_normalized": mean_absolute_error(y_true, y_pred),
235
+ "mae_physical": mean_absolute_error(y_true_phys, y_pred_phys),
236
+ "epochs_trained": len(history),
237
+ "history": history,
238
+ }
239
+
240
+ # Per-target metrics (guard for tiny folds)
241
+ for i in range(y_true.shape[1]):
242
+ if len(y_true) >= 2:
243
+ results[f"r2_target_{i}"] = r2_score(y_true[:, i], y_pred[:, i])
244
+ else:
245
+ results[f"r2_target_{i}"] = float("nan")
246
+ results[f"mae_target_{i}"] = mean_absolute_error(
247
+ y_true_phys[:, i], y_pred_phys[:, i]
248
+ )
249
+
250
+ return results
251
+
252
+
253
+ # ==============================================================================
254
+ # MAIN CV ORCHESTRATION
255
+ # ==============================================================================
256
+ def run_cross_validation(
257
+ # Data
258
+ X: np.ndarray,
259
+ y: np.ndarray,
260
+ # Model
261
+ model_name: str,
262
+ in_shape: tuple[int, ...],
263
+ out_size: int,
264
+ # CV settings
265
+ folds: int = 5,
266
+ stratify: bool = False,
267
+ stratify_bins: int = 10,
268
+ # Training settings
269
+ batch_size: int = 128,
270
+ lr: float = 1e-3,
271
+ epochs: int = 100,
272
+ patience: int = 20,
273
+ weight_decay: float = 1e-4,
274
+ # Components
275
+ loss_name: str = "mse",
276
+ optimizer_name: str = "adamw",
277
+ scheduler_name: str = "plateau",
278
+ # Output
279
+ output_dir: str = "./cv_results",
280
+ workers: int = 4,
281
+ seed: int = 2025,
282
+ logger: logging.Logger | None = None,
283
+ ) -> dict[str, Any]:
284
+ """
285
+ Run K-fold cross-validation.
286
+
287
+ Args:
288
+ X: Input data
289
+ y: Target data
290
+ model_name: Model architecture name
291
+ in_shape: Input shape (excluding batch and channel)
292
+ out_size: Number of output targets
293
+ folds: Number of CV folds
294
+ stratify: Use stratified splitting
295
+ stratify_bins: Number of bins for stratification
296
+ batch_size: Batch size
297
+ lr: Learning rate
298
+ epochs: Max epochs per fold
299
+ patience: Early stopping patience
300
+ weight_decay: Weight decay
301
+ loss_name: Loss function name
302
+ optimizer_name: Optimizer name
303
+ scheduler_name: Scheduler name
304
+ output_dir: Output directory
305
+ workers: DataLoader workers
306
+ seed: Random seed
307
+ logger: Logger instance
308
+
309
+ Returns:
310
+ Summary dictionary with aggregated results
311
+ """
312
+ # Setup
313
+ os.makedirs(output_dir, exist_ok=True)
314
+
315
+ if logger is None:
316
+ logging.basicConfig(
317
+ level=logging.INFO,
318
+ format="%(asctime)s | %(levelname)s | %(message)s",
319
+ datefmt="%H:%M:%S",
320
+ )
321
+ logger = logging.getLogger("CV-Trainer")
322
+
323
+ # Set seeds
324
+ np.random.seed(seed)
325
+ torch.manual_seed(seed)
326
+ if torch.cuda.is_available():
327
+ torch.cuda.manual_seed_all(seed)
328
+
329
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
330
+ logger.info(f"🚀 K-Fold Cross-Validation ({folds} folds)")
331
+ logger.info(f" Model: {model_name} | Device: {device}")
332
+ logger.info(
333
+ f" Loss: {loss_name} | Optimizer: {optimizer_name} | Scheduler: {scheduler_name}"
334
+ )
335
+ logger.info(f" Data shape: X={X.shape}, y={y.shape}")
336
+
337
+ # Setup cross-validation
338
+ if stratify:
339
+ # Bin targets for stratification (regression)
340
+ y_binned = np.digitize(
341
+ y[:, 0], np.percentile(y[:, 0], np.linspace(0, 100, stratify_bins + 1))
342
+ )
343
+ kfold = StratifiedKFold(n_splits=folds, shuffle=True, random_state=seed)
344
+ splits = list(kfold.split(X, y_binned))
345
+ else:
346
+ kfold = KFold(n_splits=folds, shuffle=True, random_state=seed)
347
+ splits = list(kfold.split(X))
348
+
349
+ # Import factories
350
+ from wavedl.models import build_model
351
+ from wavedl.utils import get_loss, get_optimizer, get_scheduler
352
+
353
+ # Run folds
354
+ fold_results = []
355
+
356
+ for fold, (train_idx, val_idx) in enumerate(splits):
357
+ logger.info(f"\n{'=' * 60}")
358
+ logger.info(f"📊 Fold {fold + 1}/{folds}")
359
+ logger.info(f" Train: {len(train_idx)} samples, Val: {len(val_idx)} samples")
360
+
361
+ # Split data
362
+ X_train, X_val = X[train_idx], X[val_idx]
363
+ y_train, y_val = y[train_idx], y[val_idx]
364
+
365
+ # Fit scaler on training data only
366
+ scaler = StandardScaler()
367
+ y_train_scaled = scaler.fit_transform(y_train)
368
+ y_val_scaled = scaler.transform(y_val)
369
+
370
+ # Create datasets and loaders with explicit spatial dimensionality
371
+ spatial_ndim = len(in_shape)
372
+ train_ds = CVDataset(
373
+ X_train.astype(np.float32),
374
+ y_train_scaled.astype(np.float32),
375
+ expected_spatial_ndim=spatial_ndim,
376
+ )
377
+ val_ds = CVDataset(
378
+ X_val.astype(np.float32),
379
+ y_val_scaled.astype(np.float32),
380
+ expected_spatial_ndim=spatial_ndim,
381
+ )
382
+
383
+ train_loader = DataLoader(
384
+ train_ds,
385
+ batch_size=batch_size,
386
+ shuffle=True,
387
+ num_workers=workers,
388
+ pin_memory=True,
389
+ )
390
+ val_loader = DataLoader(
391
+ val_ds,
392
+ batch_size=batch_size,
393
+ shuffle=False,
394
+ num_workers=workers,
395
+ pin_memory=True,
396
+ )
397
+
398
+ # Build model
399
+ model = build_model(model_name, in_shape=in_shape, out_size=out_size)
400
+ model = model.to(device)
401
+
402
+ # Setup training components
403
+ criterion = get_loss(loss_name)
404
+ optimizer = get_optimizer(
405
+ optimizer_name, model.parameters(), lr=lr, weight_decay=weight_decay
406
+ )
407
+ scheduler = get_scheduler(
408
+ scheduler_name,
409
+ optimizer,
410
+ epochs=epochs,
411
+ steps_per_epoch=len(train_loader) if scheduler_name == "onecycle" else None,
412
+ )
413
+
414
+ # Train fold
415
+ results = train_fold(
416
+ fold=fold,
417
+ model=model,
418
+ train_loader=train_loader,
419
+ val_loader=val_loader,
420
+ criterion=criterion,
421
+ optimizer=optimizer,
422
+ scheduler=scheduler,
423
+ device=device,
424
+ epochs=epochs,
425
+ patience=patience,
426
+ scaler=scaler,
427
+ logger=logger,
428
+ )
429
+
430
+ fold_results.append(results)
431
+
432
+ logger.info(
433
+ f" Fold {fold + 1} Results: R²={results['r2']:.4f}, MAE={results['mae_physical']:.4f}"
434
+ )
435
+
436
+ # Save fold model
437
+ fold_dir = os.path.join(output_dir, f"fold_{fold + 1}")
438
+ os.makedirs(fold_dir, exist_ok=True)
439
+ torch.save(model.state_dict(), os.path.join(fold_dir, "model.pth"))
440
+ with open(os.path.join(fold_dir, "scaler.pkl"), "wb") as f:
441
+ pickle.dump(scaler, f)
442
+
443
+ # ==============================================================================
444
+ # AGGREGATE RESULTS
445
+ # ==============================================================================
446
+ logger.info(f"\n{'=' * 60}")
447
+ logger.info("📈 Cross-Validation Summary")
448
+ logger.info("=" * 60)
449
+
450
+ r2_scores = [r["r2"] for r in fold_results]
451
+ mae_scores = [r["mae_physical"] for r in fold_results]
452
+ val_losses = [r["best_val_loss"] for r in fold_results]
453
+
454
+ summary = {
455
+ "config": {
456
+ "model": model_name,
457
+ "folds": folds,
458
+ "stratify": stratify,
459
+ "stratify_bins": stratify_bins,
460
+ "batch_size": batch_size,
461
+ "lr": lr,
462
+ "epochs": epochs,
463
+ "patience": patience,
464
+ "loss": loss_name,
465
+ "optimizer": optimizer_name,
466
+ "scheduler": scheduler_name,
467
+ },
468
+ "timestamp": datetime.now().isoformat(),
469
+ "folds": folds,
470
+ "r2_mean": float(np.mean(r2_scores)),
471
+ "r2_std": float(np.std(r2_scores)),
472
+ "mae_mean": float(np.mean(mae_scores)),
473
+ "mae_std": float(np.std(mae_scores)),
474
+ "val_loss_mean": float(np.mean(val_losses)),
475
+ "val_loss_std": float(np.std(val_losses)),
476
+ "fold_results": fold_results,
477
+ }
478
+
479
+ logger.info(f" R² Score: {summary['r2_mean']:.4f} ± {summary['r2_std']:.4f}")
480
+ logger.info(f" MAE (phys): {summary['mae_mean']:.4f} ± {summary['mae_std']:.4f}")
481
+ logger.info(
482
+ f" Val Loss: {summary['val_loss_mean']:.6f} ± {summary['val_loss_std']:.6f}"
483
+ )
484
+
485
+ # Per-target summary
486
+ for i in range(out_size):
487
+ r2_target = [r.get(f"r2_target_{i}", np.nan) for r in fold_results]
488
+ mae_target = [r.get(f"mae_target_{i}", np.nan) for r in fold_results]
489
+ logger.info(
490
+ f" Target {i}: R²={np.mean(r2_target):.4f}±{np.std(r2_target):.4f}, "
491
+ f"MAE={np.mean(mae_target):.4f}±{np.std(mae_target):.4f}"
492
+ )
493
+
494
+ # Save summary
495
+ with open(os.path.join(output_dir, "cv_summary.json"), "w") as f:
496
+ summary_save = summary.copy()
497
+ for r in summary_save["fold_results"]:
498
+ r["history"] = None # Too large
499
+ json.dump(summary_save, f, indent=2)
500
+
501
+ # Save detailed results as CSV
502
+ results_df = pd.DataFrame(
503
+ [{k: v for k, v in r.items() if k != "history"} for r in fold_results]
504
+ )
505
+ results_df.to_csv(os.path.join(output_dir, "cv_results.csv"), index=False)
506
+
507
+ logger.info(f"\n✅ Results saved to: {output_dir}")
508
+
509
+ return summary