wavedl 1.5.7__py3-none-any.whl → 1.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. wavedl/__init__.py +1 -1
  2. wavedl/hpo.py +451 -451
  3. wavedl/models/__init__.py +80 -4
  4. wavedl/models/_pretrained_utils.py +366 -0
  5. wavedl/models/base.py +48 -0
  6. wavedl/models/caformer.py +270 -0
  7. wavedl/models/cnn.py +2 -27
  8. wavedl/models/convnext.py +113 -51
  9. wavedl/models/convnext_v2.py +488 -0
  10. wavedl/models/densenet.py +10 -23
  11. wavedl/models/efficientnet.py +6 -6
  12. wavedl/models/efficientnetv2.py +315 -315
  13. wavedl/models/efficientvit.py +398 -0
  14. wavedl/models/fastvit.py +252 -0
  15. wavedl/models/mamba.py +555 -0
  16. wavedl/models/maxvit.py +254 -0
  17. wavedl/models/mobilenetv3.py +295 -295
  18. wavedl/models/regnet.py +406 -406
  19. wavedl/models/resnet.py +19 -61
  20. wavedl/models/resnet3d.py +258 -258
  21. wavedl/models/swin.py +443 -443
  22. wavedl/models/tcn.py +393 -409
  23. wavedl/models/unet.py +2 -6
  24. wavedl/models/unireplknet.py +491 -0
  25. wavedl/models/vit.py +9 -9
  26. wavedl/train.py +1430 -1425
  27. wavedl/utils/config.py +367 -367
  28. wavedl/utils/cross_validation.py +530 -530
  29. wavedl/utils/data.py +39 -6
  30. wavedl/utils/losses.py +216 -216
  31. wavedl/utils/optimizers.py +216 -216
  32. wavedl/utils/schedulers.py +251 -251
  33. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/METADATA +150 -82
  34. wavedl-1.6.1.dist-info/RECORD +46 -0
  35. wavedl-1.5.7.dist-info/RECORD +0 -38
  36. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/LICENSE +0 -0
  37. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/WHEEL +0 -0
  38. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/entry_points.txt +0 -0
  39. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/top_level.txt +0 -0
@@ -1,530 +1,530 @@
1
- """
2
- WaveDL - Cross-Validation Utilities
3
- ====================================
4
-
5
- Internal module for K-fold cross-validation. Called by train.py when --cv flag is used.
6
-
7
- This module provides:
8
- - SimpleDataset: In-memory dataset for CV
9
- - train_fold: Single fold training function
10
- - run_cross_validation: Main CV orchestration
11
-
12
- Author: Ductho Le (ductho.le@outlook.com)
13
- Version: 1.0.0
14
- """
15
-
16
- import json
17
- import logging
18
- import os
19
- import pickle
20
- from datetime import datetime
21
- from typing import Any
22
-
23
- import numpy as np
24
- import pandas as pd
25
- import torch
26
- import torch.nn as nn
27
- from sklearn.metrics import mean_absolute_error, r2_score
28
- from sklearn.model_selection import KFold, StratifiedKFold
29
- from sklearn.preprocessing import StandardScaler
30
- from torch.utils.data import DataLoader
31
-
32
-
33
- # ==============================================================================
34
- # SIMPLE DATASET
35
- # ==============================================================================
36
- class CVDataset(torch.utils.data.Dataset):
37
- """Simple in-memory dataset for cross-validation."""
38
-
39
- def __init__(self, X: np.ndarray, y: np.ndarray, expected_spatial_ndim: int = None):
40
- """
41
- Initialize CV dataset with explicit channel dimension handling.
42
-
43
- Args:
44
- X: Input data with shape (N, *spatial_dims) or (N, C, *spatial_dims)
45
- y: Target data (N, T)
46
- expected_spatial_ndim: Expected number of spatial dimensions (1, 2, or 3).
47
- If provided, uses explicit logic instead of heuristics.
48
- If None, falls back to ndim-based inference (legacy behavior).
49
-
50
- Channel Dimension Logic:
51
- - If X.ndim == expected_spatial_ndim + 1: Add channel dim (N, *spatial) -> (N, 1, *spatial)
52
- - If X.ndim == expected_spatial_ndim + 2: Already has channel (N, C, *spatial)
53
- - If expected_spatial_ndim is None: Use legacy ndim-based inference
54
- """
55
- if expected_spatial_ndim is not None:
56
- # Explicit mode: use expected_spatial_ndim to determine if channel exists
57
- if X.ndim == expected_spatial_ndim + 1:
58
- # Shape is (N, *spatial) - needs channel dimension
59
- X = np.expand_dims(X, axis=1)
60
- elif X.ndim == expected_spatial_ndim + 2:
61
- # Shape is (N, C, *spatial) - already has channel
62
- pass
63
- else:
64
- raise ValueError(
65
- f"Input shape {X.shape} incompatible with expected_spatial_ndim={expected_spatial_ndim}. "
66
- f"Expected ndim={expected_spatial_ndim + 1} or {expected_spatial_ndim + 2}, got {X.ndim}."
67
- )
68
- else:
69
- # Legacy mode: infer from ndim (for backwards compatibility)
70
- # Assumes single-channel data without explicit channel dimension
71
- if X.ndim == 2: # 1D signals: (N, L) -> (N, 1, L)
72
- X = X[:, np.newaxis, :]
73
- elif X.ndim == 3: # 2D images: (N, H, W) -> (N, 1, H, W)
74
- X = X[:, np.newaxis, :, :]
75
- elif X.ndim == 4: # 3D volumes: (N, D, H, W) -> (N, 1, D, H, W)
76
- X = X[:, np.newaxis, :, :, :]
77
- # ndim >= 5 assumed to already have channel dimension
78
-
79
- self.X = torch.tensor(X, dtype=torch.float32)
80
- self.y = torch.tensor(y, dtype=torch.float32)
81
-
82
- def __len__(self) -> int:
83
- return len(self.X)
84
-
85
- def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
86
- return self.X[idx], self.y[idx]
87
-
88
-
89
- # ==============================================================================
90
- # SINGLE FOLD TRAINING
91
- # ==============================================================================
92
- def train_fold(
93
- fold: int,
94
- model: nn.Module,
95
- train_loader: DataLoader,
96
- val_loader: DataLoader,
97
- criterion: nn.Module,
98
- optimizer: torch.optim.Optimizer,
99
- scheduler,
100
- device: torch.device,
101
- epochs: int,
102
- patience: int,
103
- scaler: StandardScaler,
104
- logger: logging.Logger,
105
- ) -> dict[str, Any]:
106
- """
107
- Train and evaluate a single CV fold.
108
-
109
- Args:
110
- fold: Fold index (0-based)
111
- model: PyTorch model
112
- train_loader: Training data loader
113
- val_loader: Validation data loader
114
- criterion: Loss function
115
- optimizer: Optimizer
116
- scheduler: LR scheduler
117
- device: Torch device
118
- epochs: Max epochs
119
- patience: Early stopping patience
120
- scaler: Target scaler (for physical units)
121
- logger: Logger instance
122
-
123
- Returns:
124
- Dictionary with fold results and metrics
125
- """
126
- best_val_loss = float("inf")
127
- patience_ctr = 0
128
- best_state = None
129
- history = []
130
-
131
- # Determine if scheduler steps per batch (OneCycleLR) or per epoch
132
- # Use isinstance check since class name 'OneCycleLR' != 'onecycle' string in is_epoch_based
133
- from torch.optim.lr_scheduler import OneCycleLR
134
-
135
- step_per_batch = isinstance(scheduler, OneCycleLR)
136
-
137
- for epoch in range(epochs):
138
- # Training
139
- model.train()
140
- train_loss = 0.0
141
- train_samples = 0
142
-
143
- for x, y in train_loader:
144
- x, y = x.to(device), y.to(device)
145
-
146
- optimizer.zero_grad()
147
- pred = model(x)
148
- loss = criterion(pred, y)
149
- loss.backward()
150
- torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
151
- optimizer.step()
152
-
153
- # Per-batch LR scheduling (OneCycleLR)
154
- if step_per_batch:
155
- scheduler.step()
156
-
157
- train_loss += loss.item() * x.size(0)
158
- train_samples += x.size(0)
159
-
160
- avg_train_loss = train_loss / train_samples
161
-
162
- # Validation
163
- model.eval()
164
- val_loss = 0.0
165
- val_samples = 0
166
- all_preds = []
167
- all_targets = []
168
-
169
- with torch.inference_mode():
170
- for x, y in val_loader:
171
- x, y = x.to(device), y.to(device)
172
- pred = model(x)
173
- loss = criterion(pred, y)
174
-
175
- val_loss += loss.item() * x.size(0)
176
- val_samples += x.size(0)
177
-
178
- all_preds.append(pred.cpu())
179
- all_targets.append(y.cpu())
180
-
181
- avg_val_loss = val_loss / val_samples
182
-
183
- # Compute metrics (guard for tiny datasets)
184
- y_pred = torch.cat(all_preds).numpy()
185
- y_true = torch.cat(all_targets).numpy()
186
- r2 = r2_score(y_true, y_pred) if len(y_true) >= 2 else float("nan")
187
- mae = np.abs((y_pred - y_true) * scaler.scale_).mean()
188
-
189
- history.append(
190
- {
191
- "epoch": epoch + 1,
192
- "train_loss": avg_train_loss,
193
- "val_loss": avg_val_loss,
194
- "r2": r2,
195
- "mae": mae,
196
- }
197
- )
198
-
199
- # LR scheduling (epoch-based only, not for per-batch schedulers)
200
- if not step_per_batch and hasattr(scheduler, "step"):
201
- if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
202
- scheduler.step(avg_val_loss)
203
- else:
204
- scheduler.step()
205
-
206
- # Early stopping
207
- if avg_val_loss < best_val_loss:
208
- best_val_loss = avg_val_loss
209
- patience_ctr = 0
210
- best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
211
- else:
212
- patience_ctr += 1
213
-
214
- if patience_ctr >= patience:
215
- logger.info(f" Fold {fold + 1}: Early stopping at epoch {epoch + 1}")
216
- break
217
-
218
- # Restore best model and compute final metrics
219
- if best_state:
220
- model.load_state_dict(best_state)
221
-
222
- model.eval()
223
- all_preds = []
224
- all_targets = []
225
-
226
- with torch.inference_mode():
227
- for x, y in val_loader:
228
- x, y = x.to(device), y.to(device)
229
- pred = model(x)
230
- all_preds.append(pred.cpu())
231
- all_targets.append(y.cpu())
232
-
233
- y_pred = torch.cat(all_preds).numpy()
234
- y_true = torch.cat(all_targets).numpy()
235
-
236
- # Inverse transform for physical units
237
- y_pred_phys = scaler.inverse_transform(y_pred)
238
- y_true_phys = scaler.inverse_transform(y_true)
239
-
240
- results = {
241
- "fold": fold + 1,
242
- "best_val_loss": best_val_loss,
243
- "r2": r2_score(y_true, y_pred) if len(y_true) >= 2 else float("nan"),
244
- "mae_normalized": mean_absolute_error(y_true, y_pred),
245
- "mae_physical": mean_absolute_error(y_true_phys, y_pred_phys),
246
- "epochs_trained": len(history),
247
- "history": history,
248
- }
249
-
250
- # Per-target metrics (guard for tiny folds)
251
- for i in range(y_true.shape[1]):
252
- if len(y_true) >= 2:
253
- results[f"r2_target_{i}"] = r2_score(y_true[:, i], y_pred[:, i])
254
- else:
255
- results[f"r2_target_{i}"] = float("nan")
256
- results[f"mae_target_{i}"] = mean_absolute_error(
257
- y_true_phys[:, i], y_pred_phys[:, i]
258
- )
259
-
260
- return results
261
-
262
-
263
- # ==============================================================================
264
- # MAIN CV ORCHESTRATION
265
- # ==============================================================================
266
- def run_cross_validation(
267
- # Data
268
- X: np.ndarray,
269
- y: np.ndarray,
270
- # Model
271
- model_name: str,
272
- in_shape: tuple[int, ...],
273
- out_size: int,
274
- # CV settings
275
- folds: int = 5,
276
- stratify: bool = False,
277
- stratify_bins: int = 10,
278
- # Training settings
279
- batch_size: int = 128,
280
- lr: float = 1e-3,
281
- epochs: int = 100,
282
- patience: int = 20,
283
- weight_decay: float = 1e-4,
284
- # Components
285
- loss_name: str = "mse",
286
- optimizer_name: str = "adamw",
287
- scheduler_name: str = "plateau",
288
- # Output
289
- output_dir: str = "./cv_results",
290
- workers: int = 4,
291
- seed: int = 2025,
292
- logger: logging.Logger | None = None,
293
- ) -> dict[str, Any]:
294
- """
295
- Run K-fold cross-validation.
296
-
297
- Args:
298
- X: Input data
299
- y: Target data
300
- model_name: Model architecture name
301
- in_shape: Input shape (excluding batch and channel)
302
- out_size: Number of output targets
303
- folds: Number of CV folds
304
- stratify: Use stratified splitting
305
- stratify_bins: Number of bins for stratification
306
- batch_size: Batch size
307
- lr: Learning rate
308
- epochs: Max epochs per fold
309
- patience: Early stopping patience
310
- weight_decay: Weight decay
311
- loss_name: Loss function name
312
- optimizer_name: Optimizer name
313
- scheduler_name: Scheduler name
314
- output_dir: Output directory
315
- workers: DataLoader workers
316
- seed: Random seed
317
- logger: Logger instance
318
-
319
- Returns:
320
- Summary dictionary with aggregated results
321
- """
322
- # Setup
323
- os.makedirs(output_dir, exist_ok=True)
324
-
325
- if logger is None:
326
- logging.basicConfig(
327
- level=logging.INFO,
328
- format="%(asctime)s | %(levelname)s | %(message)s",
329
- datefmt="%H:%M:%S",
330
- )
331
- logger = logging.getLogger("CV-Trainer")
332
-
333
- # Set seeds
334
- np.random.seed(seed)
335
- torch.manual_seed(seed)
336
- if torch.cuda.is_available():
337
- torch.cuda.manual_seed_all(seed)
338
-
339
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
340
-
341
- # Auto-detect optimal DataLoader workers if not specified (matches train.py behavior)
342
- if workers < 0:
343
- cpu_count = os.cpu_count() or 4
344
- num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1
345
- # Heuristic: 4-16 workers per GPU, bounded by available CPU cores
346
- workers = min(16, max(2, (cpu_count - 2) // max(1, num_gpus)))
347
- logger.info(
348
- f"⚙️ Auto-detected workers: {workers} (CPUs: {cpu_count}, GPUs: {num_gpus})"
349
- )
350
-
351
- logger.info(f"🚀 K-Fold Cross-Validation ({folds} folds)")
352
- logger.info(f" Model: {model_name} | Device: {device}")
353
- logger.info(
354
- f" Loss: {loss_name} | Optimizer: {optimizer_name} | Scheduler: {scheduler_name}"
355
- )
356
- logger.info(f" Data shape: X={X.shape}, y={y.shape}")
357
-
358
- # Setup cross-validation
359
- if stratify:
360
- # Bin targets for stratification (regression)
361
- y_binned = np.digitize(
362
- y[:, 0], np.percentile(y[:, 0], np.linspace(0, 100, stratify_bins + 1))
363
- )
364
- kfold = StratifiedKFold(n_splits=folds, shuffle=True, random_state=seed)
365
- splits = list(kfold.split(X, y_binned))
366
- else:
367
- kfold = KFold(n_splits=folds, shuffle=True, random_state=seed)
368
- splits = list(kfold.split(X))
369
-
370
- # Import factories
371
- from wavedl.models import build_model
372
- from wavedl.utils import get_loss, get_optimizer, get_scheduler
373
-
374
- # Run folds
375
- fold_results = []
376
-
377
- for fold, (train_idx, val_idx) in enumerate(splits):
378
- logger.info(f"\n{'=' * 60}")
379
- logger.info(f"📊 Fold {fold + 1}/{folds}")
380
- logger.info(f" Train: {len(train_idx)} samples, Val: {len(val_idx)} samples")
381
-
382
- # Split data
383
- X_train, X_val = X[train_idx], X[val_idx]
384
- y_train, y_val = y[train_idx], y[val_idx]
385
-
386
- # Fit scaler on training data only
387
- scaler = StandardScaler()
388
- y_train_scaled = scaler.fit_transform(y_train)
389
- y_val_scaled = scaler.transform(y_val)
390
-
391
- # Create datasets and loaders with explicit spatial dimensionality
392
- spatial_ndim = len(in_shape)
393
- train_ds = CVDataset(
394
- X_train.astype(np.float32),
395
- y_train_scaled.astype(np.float32),
396
- expected_spatial_ndim=spatial_ndim,
397
- )
398
- val_ds = CVDataset(
399
- X_val.astype(np.float32),
400
- y_val_scaled.astype(np.float32),
401
- expected_spatial_ndim=spatial_ndim,
402
- )
403
-
404
- train_loader = DataLoader(
405
- train_ds,
406
- batch_size=batch_size,
407
- shuffle=True,
408
- num_workers=workers,
409
- pin_memory=True,
410
- )
411
- val_loader = DataLoader(
412
- val_ds,
413
- batch_size=batch_size,
414
- shuffle=False,
415
- num_workers=workers,
416
- pin_memory=True,
417
- )
418
-
419
- # Build model
420
- model = build_model(model_name, in_shape=in_shape, out_size=out_size)
421
- model = model.to(device)
422
-
423
- # Setup training components
424
- criterion = get_loss(loss_name)
425
- optimizer = get_optimizer(
426
- optimizer_name, model.parameters(), lr=lr, weight_decay=weight_decay
427
- )
428
- scheduler = get_scheduler(
429
- scheduler_name,
430
- optimizer,
431
- epochs=epochs,
432
- steps_per_epoch=len(train_loader) if scheduler_name == "onecycle" else None,
433
- )
434
-
435
- # Train fold
436
- results = train_fold(
437
- fold=fold,
438
- model=model,
439
- train_loader=train_loader,
440
- val_loader=val_loader,
441
- criterion=criterion,
442
- optimizer=optimizer,
443
- scheduler=scheduler,
444
- device=device,
445
- epochs=epochs,
446
- patience=patience,
447
- scaler=scaler,
448
- logger=logger,
449
- )
450
-
451
- fold_results.append(results)
452
-
453
- logger.info(
454
- f" Fold {fold + 1} Results: R²={results['r2']:.4f}, MAE={results['mae_physical']:.4f}"
455
- )
456
-
457
- # Save fold model
458
- fold_dir = os.path.join(output_dir, f"fold_{fold + 1}")
459
- os.makedirs(fold_dir, exist_ok=True)
460
- torch.save(model.state_dict(), os.path.join(fold_dir, "model.pth"))
461
- with open(os.path.join(fold_dir, "scaler.pkl"), "wb") as f:
462
- pickle.dump(scaler, f)
463
-
464
- # ==============================================================================
465
- # AGGREGATE RESULTS
466
- # ==============================================================================
467
- logger.info(f"\n{'=' * 60}")
468
- logger.info("📈 Cross-Validation Summary")
469
- logger.info("=" * 60)
470
-
471
- r2_scores = [r["r2"] for r in fold_results]
472
- mae_scores = [r["mae_physical"] for r in fold_results]
473
- val_losses = [r["best_val_loss"] for r in fold_results]
474
-
475
- summary = {
476
- "config": {
477
- "model": model_name,
478
- "folds": folds,
479
- "stratify": stratify,
480
- "stratify_bins": stratify_bins,
481
- "batch_size": batch_size,
482
- "lr": lr,
483
- "epochs": epochs,
484
- "patience": patience,
485
- "loss": loss_name,
486
- "optimizer": optimizer_name,
487
- "scheduler": scheduler_name,
488
- },
489
- "timestamp": datetime.now().isoformat(),
490
- "folds": folds,
491
- "r2_mean": float(np.mean(r2_scores)),
492
- "r2_std": float(np.std(r2_scores)),
493
- "mae_mean": float(np.mean(mae_scores)),
494
- "mae_std": float(np.std(mae_scores)),
495
- "val_loss_mean": float(np.mean(val_losses)),
496
- "val_loss_std": float(np.std(val_losses)),
497
- "fold_results": fold_results,
498
- }
499
-
500
- logger.info(f" R² Score: {summary['r2_mean']:.4f} ± {summary['r2_std']:.4f}")
501
- logger.info(f" MAE (phys): {summary['mae_mean']:.4f} ± {summary['mae_std']:.4f}")
502
- logger.info(
503
- f" Val Loss: {summary['val_loss_mean']:.6f} ± {summary['val_loss_std']:.6f}"
504
- )
505
-
506
- # Per-target summary
507
- for i in range(out_size):
508
- r2_target = [r.get(f"r2_target_{i}", np.nan) for r in fold_results]
509
- mae_target = [r.get(f"mae_target_{i}", np.nan) for r in fold_results]
510
- logger.info(
511
- f" Target {i}: R²={np.mean(r2_target):.4f}±{np.std(r2_target):.4f}, "
512
- f"MAE={np.mean(mae_target):.4f}±{np.std(mae_target):.4f}"
513
- )
514
-
515
- # Save summary
516
- with open(os.path.join(output_dir, "cv_summary.json"), "w") as f:
517
- summary_save = summary.copy()
518
- for r in summary_save["fold_results"]:
519
- r["history"] = None # Too large
520
- json.dump(summary_save, f, indent=2)
521
-
522
- # Save detailed results as CSV
523
- results_df = pd.DataFrame(
524
- [{k: v for k, v in r.items() if k != "history"} for r in fold_results]
525
- )
526
- results_df.to_csv(os.path.join(output_dir, "cv_results.csv"), index=False)
527
-
528
- logger.info(f"\n✅ Results saved to: {output_dir}")
529
-
530
- return summary
1
+ """
2
+ WaveDL - Cross-Validation Utilities
3
+ ====================================
4
+
5
+ Internal module for K-fold cross-validation. Called by train.py when --cv flag is used.
6
+
7
+ This module provides:
8
+ - SimpleDataset: In-memory dataset for CV
9
+ - train_fold: Single fold training function
10
+ - run_cross_validation: Main CV orchestration
11
+
12
+ Author: Ductho Le (ductho.le@outlook.com)
13
+ Version: 1.0.0
14
+ """
15
+
16
+ import json
17
+ import logging
18
+ import os
19
+ import pickle
20
+ from datetime import datetime
21
+ from typing import Any
22
+
23
+ import numpy as np
24
+ import pandas as pd
25
+ import torch
26
+ import torch.nn as nn
27
+ from sklearn.metrics import mean_absolute_error, r2_score
28
+ from sklearn.model_selection import KFold, StratifiedKFold
29
+ from sklearn.preprocessing import StandardScaler
30
+ from torch.utils.data import DataLoader
31
+
32
+
33
+ # ==============================================================================
34
+ # SIMPLE DATASET
35
+ # ==============================================================================
36
+ class CVDataset(torch.utils.data.Dataset):
37
+ """Simple in-memory dataset for cross-validation."""
38
+
39
+ def __init__(self, X: np.ndarray, y: np.ndarray, expected_spatial_ndim: int = None):
40
+ """
41
+ Initialize CV dataset with explicit channel dimension handling.
42
+
43
+ Args:
44
+ X: Input data with shape (N, *spatial_dims) or (N, C, *spatial_dims)
45
+ y: Target data (N, T)
46
+ expected_spatial_ndim: Expected number of spatial dimensions (1, 2, or 3).
47
+ If provided, uses explicit logic instead of heuristics.
48
+ If None, falls back to ndim-based inference (legacy behavior).
49
+
50
+ Channel Dimension Logic:
51
+ - If X.ndim == expected_spatial_ndim + 1: Add channel dim (N, *spatial) -> (N, 1, *spatial)
52
+ - If X.ndim == expected_spatial_ndim + 2: Already has channel (N, C, *spatial)
53
+ - If expected_spatial_ndim is None: Use legacy ndim-based inference
54
+ """
55
+ if expected_spatial_ndim is not None:
56
+ # Explicit mode: use expected_spatial_ndim to determine if channel exists
57
+ if X.ndim == expected_spatial_ndim + 1:
58
+ # Shape is (N, *spatial) - needs channel dimension
59
+ X = np.expand_dims(X, axis=1)
60
+ elif X.ndim == expected_spatial_ndim + 2:
61
+ # Shape is (N, C, *spatial) - already has channel
62
+ pass
63
+ else:
64
+ raise ValueError(
65
+ f"Input shape {X.shape} incompatible with expected_spatial_ndim={expected_spatial_ndim}. "
66
+ f"Expected ndim={expected_spatial_ndim + 1} or {expected_spatial_ndim + 2}, got {X.ndim}."
67
+ )
68
+ else:
69
+ # Legacy mode: infer from ndim (for backwards compatibility)
70
+ # Assumes single-channel data without explicit channel dimension
71
+ if X.ndim == 2: # 1D signals: (N, L) -> (N, 1, L)
72
+ X = X[:, np.newaxis, :]
73
+ elif X.ndim == 3: # 2D images: (N, H, W) -> (N, 1, H, W)
74
+ X = X[:, np.newaxis, :, :]
75
+ elif X.ndim == 4: # 3D volumes: (N, D, H, W) -> (N, 1, D, H, W)
76
+ X = X[:, np.newaxis, :, :, :]
77
+ # ndim >= 5 assumed to already have channel dimension
78
+
79
+ self.X = torch.tensor(X, dtype=torch.float32)
80
+ self.y = torch.tensor(y, dtype=torch.float32)
81
+
82
+ def __len__(self) -> int:
83
+ return len(self.X)
84
+
85
+ def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
86
+ return self.X[idx], self.y[idx]
87
+
88
+
89
+ # ==============================================================================
90
+ # SINGLE FOLD TRAINING
91
+ # ==============================================================================
92
+ def train_fold(
93
+ fold: int,
94
+ model: nn.Module,
95
+ train_loader: DataLoader,
96
+ val_loader: DataLoader,
97
+ criterion: nn.Module,
98
+ optimizer: torch.optim.Optimizer,
99
+ scheduler,
100
+ device: torch.device,
101
+ epochs: int,
102
+ patience: int,
103
+ scaler: StandardScaler,
104
+ logger: logging.Logger,
105
+ ) -> dict[str, Any]:
106
+ """
107
+ Train and evaluate a single CV fold.
108
+
109
+ Args:
110
+ fold: Fold index (0-based)
111
+ model: PyTorch model
112
+ train_loader: Training data loader
113
+ val_loader: Validation data loader
114
+ criterion: Loss function
115
+ optimizer: Optimizer
116
+ scheduler: LR scheduler
117
+ device: Torch device
118
+ epochs: Max epochs
119
+ patience: Early stopping patience
120
+ scaler: Target scaler (for physical units)
121
+ logger: Logger instance
122
+
123
+ Returns:
124
+ Dictionary with fold results and metrics
125
+ """
126
+ best_val_loss = float("inf")
127
+ patience_ctr = 0
128
+ best_state = None
129
+ history = []
130
+
131
+ # Determine if scheduler steps per batch (OneCycleLR) or per epoch
132
+ # Use isinstance check since class name 'OneCycleLR' != 'onecycle' string in is_epoch_based
133
+ from torch.optim.lr_scheduler import OneCycleLR
134
+
135
+ step_per_batch = isinstance(scheduler, OneCycleLR)
136
+
137
+ for epoch in range(epochs):
138
+ # Training
139
+ model.train()
140
+ train_loss = 0.0
141
+ train_samples = 0
142
+
143
+ for x, y in train_loader:
144
+ x, y = x.to(device), y.to(device)
145
+
146
+ optimizer.zero_grad()
147
+ pred = model(x)
148
+ loss = criterion(pred, y)
149
+ loss.backward()
150
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
151
+ optimizer.step()
152
+
153
+ # Per-batch LR scheduling (OneCycleLR)
154
+ if step_per_batch:
155
+ scheduler.step()
156
+
157
+ train_loss += loss.item() * x.size(0)
158
+ train_samples += x.size(0)
159
+
160
+ avg_train_loss = train_loss / train_samples
161
+
162
+ # Validation
163
+ model.eval()
164
+ val_loss = 0.0
165
+ val_samples = 0
166
+ all_preds = []
167
+ all_targets = []
168
+
169
+ with torch.inference_mode():
170
+ for x, y in val_loader:
171
+ x, y = x.to(device), y.to(device)
172
+ pred = model(x)
173
+ loss = criterion(pred, y)
174
+
175
+ val_loss += loss.item() * x.size(0)
176
+ val_samples += x.size(0)
177
+
178
+ all_preds.append(pred.cpu())
179
+ all_targets.append(y.cpu())
180
+
181
+ avg_val_loss = val_loss / val_samples
182
+
183
+ # Compute metrics (guard for tiny datasets)
184
+ y_pred = torch.cat(all_preds).numpy()
185
+ y_true = torch.cat(all_targets).numpy()
186
+ r2 = r2_score(y_true, y_pred) if len(y_true) >= 2 else float("nan")
187
+ mae = np.abs((y_pred - y_true) * scaler.scale_).mean()
188
+
189
+ history.append(
190
+ {
191
+ "epoch": epoch + 1,
192
+ "train_loss": avg_train_loss,
193
+ "val_loss": avg_val_loss,
194
+ "r2": r2,
195
+ "mae": mae,
196
+ }
197
+ )
198
+
199
+ # LR scheduling (epoch-based only, not for per-batch schedulers)
200
+ if not step_per_batch and hasattr(scheduler, "step"):
201
+ if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
202
+ scheduler.step(avg_val_loss)
203
+ else:
204
+ scheduler.step()
205
+
206
+ # Early stopping
207
+ if avg_val_loss < best_val_loss:
208
+ best_val_loss = avg_val_loss
209
+ patience_ctr = 0
210
+ best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
211
+ else:
212
+ patience_ctr += 1
213
+
214
+ if patience_ctr >= patience:
215
+ logger.info(f" Fold {fold + 1}: Early stopping at epoch {epoch + 1}")
216
+ break
217
+
218
+ # Restore best model and compute final metrics
219
+ if best_state:
220
+ model.load_state_dict(best_state)
221
+
222
+ model.eval()
223
+ all_preds = []
224
+ all_targets = []
225
+
226
+ with torch.inference_mode():
227
+ for x, y in val_loader:
228
+ x, y = x.to(device), y.to(device)
229
+ pred = model(x)
230
+ all_preds.append(pred.cpu())
231
+ all_targets.append(y.cpu())
232
+
233
+ y_pred = torch.cat(all_preds).numpy()
234
+ y_true = torch.cat(all_targets).numpy()
235
+
236
+ # Inverse transform for physical units
237
+ y_pred_phys = scaler.inverse_transform(y_pred)
238
+ y_true_phys = scaler.inverse_transform(y_true)
239
+
240
+ results = {
241
+ "fold": fold + 1,
242
+ "best_val_loss": best_val_loss,
243
+ "r2": r2_score(y_true, y_pred) if len(y_true) >= 2 else float("nan"),
244
+ "mae_normalized": mean_absolute_error(y_true, y_pred),
245
+ "mae_physical": mean_absolute_error(y_true_phys, y_pred_phys),
246
+ "epochs_trained": len(history),
247
+ "history": history,
248
+ }
249
+
250
+ # Per-target metrics (guard for tiny folds)
251
+ for i in range(y_true.shape[1]):
252
+ if len(y_true) >= 2:
253
+ results[f"r2_target_{i}"] = r2_score(y_true[:, i], y_pred[:, i])
254
+ else:
255
+ results[f"r2_target_{i}"] = float("nan")
256
+ results[f"mae_target_{i}"] = mean_absolute_error(
257
+ y_true_phys[:, i], y_pred_phys[:, i]
258
+ )
259
+
260
+ return results
261
+
262
+
263
+ # ==============================================================================
264
+ # MAIN CV ORCHESTRATION
265
+ # ==============================================================================
266
+ def run_cross_validation(
267
+ # Data
268
+ X: np.ndarray,
269
+ y: np.ndarray,
270
+ # Model
271
+ model_name: str,
272
+ in_shape: tuple[int, ...],
273
+ out_size: int,
274
+ # CV settings
275
+ folds: int = 5,
276
+ stratify: bool = False,
277
+ stratify_bins: int = 10,
278
+ # Training settings
279
+ batch_size: int = 128,
280
+ lr: float = 1e-3,
281
+ epochs: int = 100,
282
+ patience: int = 20,
283
+ weight_decay: float = 1e-4,
284
+ # Components
285
+ loss_name: str = "mse",
286
+ optimizer_name: str = "adamw",
287
+ scheduler_name: str = "plateau",
288
+ # Output
289
+ output_dir: str = "./cv_results",
290
+ workers: int = 4,
291
+ seed: int = 2025,
292
+ logger: logging.Logger | None = None,
293
+ ) -> dict[str, Any]:
294
+ """
295
+ Run K-fold cross-validation.
296
+
297
+ Args:
298
+ X: Input data
299
+ y: Target data
300
+ model_name: Model architecture name
301
+ in_shape: Input shape (excluding batch and channel)
302
+ out_size: Number of output targets
303
+ folds: Number of CV folds
304
+ stratify: Use stratified splitting
305
+ stratify_bins: Number of bins for stratification
306
+ batch_size: Batch size
307
+ lr: Learning rate
308
+ epochs: Max epochs per fold
309
+ patience: Early stopping patience
310
+ weight_decay: Weight decay
311
+ loss_name: Loss function name
312
+ optimizer_name: Optimizer name
313
+ scheduler_name: Scheduler name
314
+ output_dir: Output directory
315
+ workers: DataLoader workers
316
+ seed: Random seed
317
+ logger: Logger instance
318
+
319
+ Returns:
320
+ Summary dictionary with aggregated results
321
+ """
322
+ # Setup
323
+ os.makedirs(output_dir, exist_ok=True)
324
+
325
+ if logger is None:
326
+ logging.basicConfig(
327
+ level=logging.INFO,
328
+ format="%(asctime)s | %(levelname)s | %(message)s",
329
+ datefmt="%H:%M:%S",
330
+ )
331
+ logger = logging.getLogger("CV-Trainer")
332
+
333
+ # Set seeds
334
+ np.random.seed(seed)
335
+ torch.manual_seed(seed)
336
+ if torch.cuda.is_available():
337
+ torch.cuda.manual_seed_all(seed)
338
+
339
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
340
+
341
+ # Auto-detect optimal DataLoader workers if not specified (matches train.py behavior)
342
+ if workers < 0:
343
+ cpu_count = os.cpu_count() or 4
344
+ num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1
345
+ # Heuristic: 4-16 workers per GPU, bounded by available CPU cores
346
+ workers = min(16, max(2, (cpu_count - 2) // max(1, num_gpus)))
347
+ logger.info(
348
+ f"⚙️ Auto-detected workers: {workers} (CPUs: {cpu_count}, GPUs: {num_gpus})"
349
+ )
350
+
351
+ logger.info(f"🚀 K-Fold Cross-Validation ({folds} folds)")
352
+ logger.info(f" Model: {model_name} | Device: {device}")
353
+ logger.info(
354
+ f" Loss: {loss_name} | Optimizer: {optimizer_name} | Scheduler: {scheduler_name}"
355
+ )
356
+ logger.info(f" Data shape: X={X.shape}, y={y.shape}")
357
+
358
+ # Setup cross-validation
359
+ if stratify:
360
+ # Bin targets for stratification (regression)
361
+ y_binned = np.digitize(
362
+ y[:, 0], np.percentile(y[:, 0], np.linspace(0, 100, stratify_bins + 1))
363
+ )
364
+ kfold = StratifiedKFold(n_splits=folds, shuffle=True, random_state=seed)
365
+ splits = list(kfold.split(X, y_binned))
366
+ else:
367
+ kfold = KFold(n_splits=folds, shuffle=True, random_state=seed)
368
+ splits = list(kfold.split(X))
369
+
370
+ # Import factories
371
+ from wavedl.models import build_model
372
+ from wavedl.utils import get_loss, get_optimizer, get_scheduler
373
+
374
+ # Run folds
375
+ fold_results = []
376
+
377
+ for fold, (train_idx, val_idx) in enumerate(splits):
378
+ logger.info(f"\n{'=' * 60}")
379
+ logger.info(f"📊 Fold {fold + 1}/{folds}")
380
+ logger.info(f" Train: {len(train_idx)} samples, Val: {len(val_idx)} samples")
381
+
382
+ # Split data
383
+ X_train, X_val = X[train_idx], X[val_idx]
384
+ y_train, y_val = y[train_idx], y[val_idx]
385
+
386
+ # Fit scaler on training data only
387
+ scaler = StandardScaler()
388
+ y_train_scaled = scaler.fit_transform(y_train)
389
+ y_val_scaled = scaler.transform(y_val)
390
+
391
+ # Create datasets and loaders with explicit spatial dimensionality
392
+ spatial_ndim = len(in_shape)
393
+ train_ds = CVDataset(
394
+ X_train.astype(np.float32),
395
+ y_train_scaled.astype(np.float32),
396
+ expected_spatial_ndim=spatial_ndim,
397
+ )
398
+ val_ds = CVDataset(
399
+ X_val.astype(np.float32),
400
+ y_val_scaled.astype(np.float32),
401
+ expected_spatial_ndim=spatial_ndim,
402
+ )
403
+
404
+ train_loader = DataLoader(
405
+ train_ds,
406
+ batch_size=batch_size,
407
+ shuffle=True,
408
+ num_workers=workers,
409
+ pin_memory=True,
410
+ )
411
+ val_loader = DataLoader(
412
+ val_ds,
413
+ batch_size=batch_size,
414
+ shuffle=False,
415
+ num_workers=workers,
416
+ pin_memory=True,
417
+ )
418
+
419
+ # Build model
420
+ model = build_model(model_name, in_shape=in_shape, out_size=out_size)
421
+ model = model.to(device)
422
+
423
+ # Setup training components
424
+ criterion = get_loss(loss_name)
425
+ optimizer = get_optimizer(
426
+ optimizer_name, model.parameters(), lr=lr, weight_decay=weight_decay
427
+ )
428
+ scheduler = get_scheduler(
429
+ scheduler_name,
430
+ optimizer,
431
+ epochs=epochs,
432
+ steps_per_epoch=len(train_loader) if scheduler_name == "onecycle" else None,
433
+ )
434
+
435
+ # Train fold
436
+ results = train_fold(
437
+ fold=fold,
438
+ model=model,
439
+ train_loader=train_loader,
440
+ val_loader=val_loader,
441
+ criterion=criterion,
442
+ optimizer=optimizer,
443
+ scheduler=scheduler,
444
+ device=device,
445
+ epochs=epochs,
446
+ patience=patience,
447
+ scaler=scaler,
448
+ logger=logger,
449
+ )
450
+
451
+ fold_results.append(results)
452
+
453
+ logger.info(
454
+ f" Fold {fold + 1} Results: R²={results['r2']:.4f}, MAE={results['mae_physical']:.4f}"
455
+ )
456
+
457
+ # Save fold model
458
+ fold_dir = os.path.join(output_dir, f"fold_{fold + 1}")
459
+ os.makedirs(fold_dir, exist_ok=True)
460
+ torch.save(model.state_dict(), os.path.join(fold_dir, "model.pth"))
461
+ with open(os.path.join(fold_dir, "scaler.pkl"), "wb") as f:
462
+ pickle.dump(scaler, f)
463
+
464
+ # ==============================================================================
465
+ # AGGREGATE RESULTS
466
+ # ==============================================================================
467
+ logger.info(f"\n{'=' * 60}")
468
+ logger.info("📈 Cross-Validation Summary")
469
+ logger.info("=" * 60)
470
+
471
+ r2_scores = [r["r2"] for r in fold_results]
472
+ mae_scores = [r["mae_physical"] for r in fold_results]
473
+ val_losses = [r["best_val_loss"] for r in fold_results]
474
+
475
+ summary = {
476
+ "config": {
477
+ "model": model_name,
478
+ "folds": folds,
479
+ "stratify": stratify,
480
+ "stratify_bins": stratify_bins,
481
+ "batch_size": batch_size,
482
+ "lr": lr,
483
+ "epochs": epochs,
484
+ "patience": patience,
485
+ "loss": loss_name,
486
+ "optimizer": optimizer_name,
487
+ "scheduler": scheduler_name,
488
+ },
489
+ "timestamp": datetime.now().isoformat(),
490
+ "folds": folds,
491
+ "r2_mean": float(np.mean(r2_scores)),
492
+ "r2_std": float(np.std(r2_scores)),
493
+ "mae_mean": float(np.mean(mae_scores)),
494
+ "mae_std": float(np.std(mae_scores)),
495
+ "val_loss_mean": float(np.mean(val_losses)),
496
+ "val_loss_std": float(np.std(val_losses)),
497
+ "fold_results": fold_results,
498
+ }
499
+
500
+ logger.info(f" R² Score: {summary['r2_mean']:.4f} ± {summary['r2_std']:.4f}")
501
+ logger.info(f" MAE (phys): {summary['mae_mean']:.4f} ± {summary['mae_std']:.4f}")
502
+ logger.info(
503
+ f" Val Loss: {summary['val_loss_mean']:.6f} ± {summary['val_loss_std']:.6f}"
504
+ )
505
+
506
+ # Per-target summary
507
+ for i in range(out_size):
508
+ r2_target = [r.get(f"r2_target_{i}", np.nan) for r in fold_results]
509
+ mae_target = [r.get(f"mae_target_{i}", np.nan) for r in fold_results]
510
+ logger.info(
511
+ f" Target {i}: R²={np.mean(r2_target):.4f}±{np.std(r2_target):.4f}, "
512
+ f"MAE={np.mean(mae_target):.4f}±{np.std(mae_target):.4f}"
513
+ )
514
+
515
+ # Save summary
516
+ with open(os.path.join(output_dir, "cv_summary.json"), "w") as f:
517
+ summary_save = summary.copy()
518
+ for r in summary_save["fold_results"]:
519
+ r["history"] = None # Too large
520
+ json.dump(summary_save, f, indent=2)
521
+
522
+ # Save detailed results as CSV
523
+ results_df = pd.DataFrame(
524
+ [{k: v for k, v in r.items() if k != "history"} for r in fold_results]
525
+ )
526
+ results_df.to_csv(os.path.join(output_dir, "cv_results.csv"), index=False)
527
+
528
+ logger.info(f"\n✅ Results saved to: {output_dir}")
529
+
530
+ return summary