wavedl 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +43 -0
- wavedl/hpo.py +366 -0
- wavedl/models/__init__.py +86 -0
- wavedl/models/_template.py +157 -0
- wavedl/models/base.py +173 -0
- wavedl/models/cnn.py +249 -0
- wavedl/models/convnext.py +425 -0
- wavedl/models/densenet.py +406 -0
- wavedl/models/efficientnet.py +236 -0
- wavedl/models/registry.py +104 -0
- wavedl/models/resnet.py +555 -0
- wavedl/models/unet.py +304 -0
- wavedl/models/vit.py +372 -0
- wavedl/test.py +1069 -0
- wavedl/train.py +1079 -0
- wavedl/utils/__init__.py +151 -0
- wavedl/utils/config.py +269 -0
- wavedl/utils/cross_validation.py +509 -0
- wavedl/utils/data.py +1220 -0
- wavedl/utils/distributed.py +138 -0
- wavedl/utils/losses.py +216 -0
- wavedl/utils/metrics.py +1236 -0
- wavedl/utils/optimizers.py +216 -0
- wavedl/utils/schedulers.py +251 -0
- wavedl-1.2.0.dist-info/LICENSE +21 -0
- wavedl-1.2.0.dist-info/METADATA +991 -0
- wavedl-1.2.0.dist-info/RECORD +30 -0
- wavedl-1.2.0.dist-info/WHEEL +5 -0
- wavedl-1.2.0.dist-info/entry_points.txt +4 -0
- wavedl-1.2.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,509 @@
|
|
|
1
|
+
"""
|
|
2
|
+
WaveDL - Cross-Validation Utilities
|
|
3
|
+
====================================
|
|
4
|
+
|
|
5
|
+
Internal module for K-fold cross-validation. Called by train.py when --cv flag is used.
|
|
6
|
+
|
|
7
|
+
This module provides:
|
|
8
|
+
- SimpleDataset: In-memory dataset for CV
|
|
9
|
+
- train_fold: Single fold training function
|
|
10
|
+
- run_cross_validation: Main CV orchestration
|
|
11
|
+
|
|
12
|
+
Author: Ductho Le (ductho.le@outlook.com)
|
|
13
|
+
Version: 1.0.0
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
import json
|
|
17
|
+
import logging
|
|
18
|
+
import os
|
|
19
|
+
import pickle
|
|
20
|
+
from datetime import datetime
|
|
21
|
+
from typing import Any
|
|
22
|
+
|
|
23
|
+
import numpy as np
|
|
24
|
+
import pandas as pd
|
|
25
|
+
import torch
|
|
26
|
+
import torch.nn as nn
|
|
27
|
+
from sklearn.metrics import mean_absolute_error, r2_score
|
|
28
|
+
from sklearn.model_selection import KFold, StratifiedKFold
|
|
29
|
+
from sklearn.preprocessing import StandardScaler
|
|
30
|
+
from torch.utils.data import DataLoader
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
# ==============================================================================
|
|
34
|
+
# SIMPLE DATASET
|
|
35
|
+
# ==============================================================================
|
|
36
|
+
class CVDataset(torch.utils.data.Dataset):
|
|
37
|
+
"""Simple in-memory dataset for cross-validation."""
|
|
38
|
+
|
|
39
|
+
def __init__(self, X: np.ndarray, y: np.ndarray, expected_spatial_ndim: int = None):
|
|
40
|
+
"""
|
|
41
|
+
Initialize CV dataset with explicit channel dimension handling.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
X: Input data with shape (N, *spatial_dims) or (N, C, *spatial_dims)
|
|
45
|
+
y: Target data (N, T)
|
|
46
|
+
expected_spatial_ndim: Expected number of spatial dimensions (1, 2, or 3).
|
|
47
|
+
If provided, uses explicit logic instead of heuristics.
|
|
48
|
+
If None, falls back to ndim-based inference (legacy behavior).
|
|
49
|
+
|
|
50
|
+
Channel Dimension Logic:
|
|
51
|
+
- If X.ndim == expected_spatial_ndim + 1: Add channel dim (N, *spatial) -> (N, 1, *spatial)
|
|
52
|
+
- If X.ndim == expected_spatial_ndim + 2: Already has channel (N, C, *spatial)
|
|
53
|
+
- If expected_spatial_ndim is None: Use legacy ndim-based inference
|
|
54
|
+
"""
|
|
55
|
+
if expected_spatial_ndim is not None:
|
|
56
|
+
# Explicit mode: use expected_spatial_ndim to determine if channel exists
|
|
57
|
+
if X.ndim == expected_spatial_ndim + 1:
|
|
58
|
+
# Shape is (N, *spatial) - needs channel dimension
|
|
59
|
+
X = np.expand_dims(X, axis=1)
|
|
60
|
+
elif X.ndim == expected_spatial_ndim + 2:
|
|
61
|
+
# Shape is (N, C, *spatial) - already has channel
|
|
62
|
+
pass
|
|
63
|
+
else:
|
|
64
|
+
raise ValueError(
|
|
65
|
+
f"Input shape {X.shape} incompatible with expected_spatial_ndim={expected_spatial_ndim}. "
|
|
66
|
+
f"Expected ndim={expected_spatial_ndim + 1} or {expected_spatial_ndim + 2}, got {X.ndim}."
|
|
67
|
+
)
|
|
68
|
+
else:
|
|
69
|
+
# Legacy mode: infer from ndim (for backwards compatibility)
|
|
70
|
+
# Assumes single-channel data without explicit channel dimension
|
|
71
|
+
if X.ndim == 2: # 1D signals: (N, L) -> (N, 1, L)
|
|
72
|
+
X = X[:, np.newaxis, :]
|
|
73
|
+
elif X.ndim == 3: # 2D images: (N, H, W) -> (N, 1, H, W)
|
|
74
|
+
X = X[:, np.newaxis, :, :]
|
|
75
|
+
elif X.ndim == 4: # 3D volumes: (N, D, H, W) -> (N, 1, D, H, W)
|
|
76
|
+
X = X[:, np.newaxis, :, :, :]
|
|
77
|
+
# ndim >= 5 assumed to already have channel dimension
|
|
78
|
+
|
|
79
|
+
self.X = torch.tensor(X, dtype=torch.float32)
|
|
80
|
+
self.y = torch.tensor(y, dtype=torch.float32)
|
|
81
|
+
|
|
82
|
+
def __len__(self) -> int:
|
|
83
|
+
return len(self.X)
|
|
84
|
+
|
|
85
|
+
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
|
|
86
|
+
return self.X[idx], self.y[idx]
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
# ==============================================================================
|
|
90
|
+
# SINGLE FOLD TRAINING
|
|
91
|
+
# ==============================================================================
|
|
92
|
+
def train_fold(
|
|
93
|
+
fold: int,
|
|
94
|
+
model: nn.Module,
|
|
95
|
+
train_loader: DataLoader,
|
|
96
|
+
val_loader: DataLoader,
|
|
97
|
+
criterion: nn.Module,
|
|
98
|
+
optimizer: torch.optim.Optimizer,
|
|
99
|
+
scheduler,
|
|
100
|
+
device: torch.device,
|
|
101
|
+
epochs: int,
|
|
102
|
+
patience: int,
|
|
103
|
+
scaler: StandardScaler,
|
|
104
|
+
logger: logging.Logger,
|
|
105
|
+
) -> dict[str, Any]:
|
|
106
|
+
"""
|
|
107
|
+
Train and evaluate a single CV fold.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
fold: Fold index (0-based)
|
|
111
|
+
model: PyTorch model
|
|
112
|
+
train_loader: Training data loader
|
|
113
|
+
val_loader: Validation data loader
|
|
114
|
+
criterion: Loss function
|
|
115
|
+
optimizer: Optimizer
|
|
116
|
+
scheduler: LR scheduler
|
|
117
|
+
device: Torch device
|
|
118
|
+
epochs: Max epochs
|
|
119
|
+
patience: Early stopping patience
|
|
120
|
+
scaler: Target scaler (for physical units)
|
|
121
|
+
logger: Logger instance
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
Dictionary with fold results and metrics
|
|
125
|
+
"""
|
|
126
|
+
best_val_loss = float("inf")
|
|
127
|
+
patience_ctr = 0
|
|
128
|
+
best_state = None
|
|
129
|
+
history = []
|
|
130
|
+
|
|
131
|
+
for epoch in range(epochs):
|
|
132
|
+
# Training
|
|
133
|
+
model.train()
|
|
134
|
+
train_loss = 0.0
|
|
135
|
+
train_samples = 0
|
|
136
|
+
|
|
137
|
+
for x, y in train_loader:
|
|
138
|
+
x, y = x.to(device), y.to(device)
|
|
139
|
+
|
|
140
|
+
optimizer.zero_grad()
|
|
141
|
+
pred = model(x)
|
|
142
|
+
loss = criterion(pred, y)
|
|
143
|
+
loss.backward()
|
|
144
|
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
|
145
|
+
optimizer.step()
|
|
146
|
+
|
|
147
|
+
train_loss += loss.item() * x.size(0)
|
|
148
|
+
train_samples += x.size(0)
|
|
149
|
+
|
|
150
|
+
avg_train_loss = train_loss / train_samples
|
|
151
|
+
|
|
152
|
+
# Validation
|
|
153
|
+
model.eval()
|
|
154
|
+
val_loss = 0.0
|
|
155
|
+
val_samples = 0
|
|
156
|
+
all_preds = []
|
|
157
|
+
all_targets = []
|
|
158
|
+
|
|
159
|
+
with torch.inference_mode():
|
|
160
|
+
for x, y in val_loader:
|
|
161
|
+
x, y = x.to(device), y.to(device)
|
|
162
|
+
pred = model(x)
|
|
163
|
+
loss = criterion(pred, y)
|
|
164
|
+
|
|
165
|
+
val_loss += loss.item() * x.size(0)
|
|
166
|
+
val_samples += x.size(0)
|
|
167
|
+
|
|
168
|
+
all_preds.append(pred.cpu())
|
|
169
|
+
all_targets.append(y.cpu())
|
|
170
|
+
|
|
171
|
+
avg_val_loss = val_loss / val_samples
|
|
172
|
+
|
|
173
|
+
# Compute metrics (guard for tiny datasets)
|
|
174
|
+
y_pred = torch.cat(all_preds).numpy()
|
|
175
|
+
y_true = torch.cat(all_targets).numpy()
|
|
176
|
+
r2 = r2_score(y_true, y_pred) if len(y_true) >= 2 else float("nan")
|
|
177
|
+
mae = np.abs((y_pred - y_true) * scaler.scale_).mean()
|
|
178
|
+
|
|
179
|
+
history.append(
|
|
180
|
+
{
|
|
181
|
+
"epoch": epoch + 1,
|
|
182
|
+
"train_loss": avg_train_loss,
|
|
183
|
+
"val_loss": avg_val_loss,
|
|
184
|
+
"r2": r2,
|
|
185
|
+
"mae": mae,
|
|
186
|
+
}
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
# LR scheduling
|
|
190
|
+
if hasattr(scheduler, "step"):
|
|
191
|
+
if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
|
|
192
|
+
scheduler.step(avg_val_loss)
|
|
193
|
+
else:
|
|
194
|
+
scheduler.step()
|
|
195
|
+
|
|
196
|
+
# Early stopping
|
|
197
|
+
if avg_val_loss < best_val_loss:
|
|
198
|
+
best_val_loss = avg_val_loss
|
|
199
|
+
patience_ctr = 0
|
|
200
|
+
best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
|
|
201
|
+
else:
|
|
202
|
+
patience_ctr += 1
|
|
203
|
+
|
|
204
|
+
if patience_ctr >= patience:
|
|
205
|
+
logger.info(f" Fold {fold + 1}: Early stopping at epoch {epoch + 1}")
|
|
206
|
+
break
|
|
207
|
+
|
|
208
|
+
# Restore best model and compute final metrics
|
|
209
|
+
if best_state:
|
|
210
|
+
model.load_state_dict(best_state)
|
|
211
|
+
|
|
212
|
+
model.eval()
|
|
213
|
+
all_preds = []
|
|
214
|
+
all_targets = []
|
|
215
|
+
|
|
216
|
+
with torch.inference_mode():
|
|
217
|
+
for x, y in val_loader:
|
|
218
|
+
x, y = x.to(device), y.to(device)
|
|
219
|
+
pred = model(x)
|
|
220
|
+
all_preds.append(pred.cpu())
|
|
221
|
+
all_targets.append(y.cpu())
|
|
222
|
+
|
|
223
|
+
y_pred = torch.cat(all_preds).numpy()
|
|
224
|
+
y_true = torch.cat(all_targets).numpy()
|
|
225
|
+
|
|
226
|
+
# Inverse transform for physical units
|
|
227
|
+
y_pred_phys = scaler.inverse_transform(y_pred)
|
|
228
|
+
y_true_phys = scaler.inverse_transform(y_true)
|
|
229
|
+
|
|
230
|
+
results = {
|
|
231
|
+
"fold": fold + 1,
|
|
232
|
+
"best_val_loss": best_val_loss,
|
|
233
|
+
"r2": r2_score(y_true, y_pred) if len(y_true) >= 2 else float("nan"),
|
|
234
|
+
"mae_normalized": mean_absolute_error(y_true, y_pred),
|
|
235
|
+
"mae_physical": mean_absolute_error(y_true_phys, y_pred_phys),
|
|
236
|
+
"epochs_trained": len(history),
|
|
237
|
+
"history": history,
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
# Per-target metrics (guard for tiny folds)
|
|
241
|
+
for i in range(y_true.shape[1]):
|
|
242
|
+
if len(y_true) >= 2:
|
|
243
|
+
results[f"r2_target_{i}"] = r2_score(y_true[:, i], y_pred[:, i])
|
|
244
|
+
else:
|
|
245
|
+
results[f"r2_target_{i}"] = float("nan")
|
|
246
|
+
results[f"mae_target_{i}"] = mean_absolute_error(
|
|
247
|
+
y_true_phys[:, i], y_pred_phys[:, i]
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
return results
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
# ==============================================================================
|
|
254
|
+
# MAIN CV ORCHESTRATION
|
|
255
|
+
# ==============================================================================
|
|
256
|
+
def run_cross_validation(
|
|
257
|
+
# Data
|
|
258
|
+
X: np.ndarray,
|
|
259
|
+
y: np.ndarray,
|
|
260
|
+
# Model
|
|
261
|
+
model_name: str,
|
|
262
|
+
in_shape: tuple[int, ...],
|
|
263
|
+
out_size: int,
|
|
264
|
+
# CV settings
|
|
265
|
+
folds: int = 5,
|
|
266
|
+
stratify: bool = False,
|
|
267
|
+
stratify_bins: int = 10,
|
|
268
|
+
# Training settings
|
|
269
|
+
batch_size: int = 128,
|
|
270
|
+
lr: float = 1e-3,
|
|
271
|
+
epochs: int = 100,
|
|
272
|
+
patience: int = 20,
|
|
273
|
+
weight_decay: float = 1e-4,
|
|
274
|
+
# Components
|
|
275
|
+
loss_name: str = "mse",
|
|
276
|
+
optimizer_name: str = "adamw",
|
|
277
|
+
scheduler_name: str = "plateau",
|
|
278
|
+
# Output
|
|
279
|
+
output_dir: str = "./cv_results",
|
|
280
|
+
workers: int = 4,
|
|
281
|
+
seed: int = 2025,
|
|
282
|
+
logger: logging.Logger | None = None,
|
|
283
|
+
) -> dict[str, Any]:
|
|
284
|
+
"""
|
|
285
|
+
Run K-fold cross-validation.
|
|
286
|
+
|
|
287
|
+
Args:
|
|
288
|
+
X: Input data
|
|
289
|
+
y: Target data
|
|
290
|
+
model_name: Model architecture name
|
|
291
|
+
in_shape: Input shape (excluding batch and channel)
|
|
292
|
+
out_size: Number of output targets
|
|
293
|
+
folds: Number of CV folds
|
|
294
|
+
stratify: Use stratified splitting
|
|
295
|
+
stratify_bins: Number of bins for stratification
|
|
296
|
+
batch_size: Batch size
|
|
297
|
+
lr: Learning rate
|
|
298
|
+
epochs: Max epochs per fold
|
|
299
|
+
patience: Early stopping patience
|
|
300
|
+
weight_decay: Weight decay
|
|
301
|
+
loss_name: Loss function name
|
|
302
|
+
optimizer_name: Optimizer name
|
|
303
|
+
scheduler_name: Scheduler name
|
|
304
|
+
output_dir: Output directory
|
|
305
|
+
workers: DataLoader workers
|
|
306
|
+
seed: Random seed
|
|
307
|
+
logger: Logger instance
|
|
308
|
+
|
|
309
|
+
Returns:
|
|
310
|
+
Summary dictionary with aggregated results
|
|
311
|
+
"""
|
|
312
|
+
# Setup
|
|
313
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
314
|
+
|
|
315
|
+
if logger is None:
|
|
316
|
+
logging.basicConfig(
|
|
317
|
+
level=logging.INFO,
|
|
318
|
+
format="%(asctime)s | %(levelname)s | %(message)s",
|
|
319
|
+
datefmt="%H:%M:%S",
|
|
320
|
+
)
|
|
321
|
+
logger = logging.getLogger("CV-Trainer")
|
|
322
|
+
|
|
323
|
+
# Set seeds
|
|
324
|
+
np.random.seed(seed)
|
|
325
|
+
torch.manual_seed(seed)
|
|
326
|
+
if torch.cuda.is_available():
|
|
327
|
+
torch.cuda.manual_seed_all(seed)
|
|
328
|
+
|
|
329
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
330
|
+
logger.info(f"🚀 K-Fold Cross-Validation ({folds} folds)")
|
|
331
|
+
logger.info(f" Model: {model_name} | Device: {device}")
|
|
332
|
+
logger.info(
|
|
333
|
+
f" Loss: {loss_name} | Optimizer: {optimizer_name} | Scheduler: {scheduler_name}"
|
|
334
|
+
)
|
|
335
|
+
logger.info(f" Data shape: X={X.shape}, y={y.shape}")
|
|
336
|
+
|
|
337
|
+
# Setup cross-validation
|
|
338
|
+
if stratify:
|
|
339
|
+
# Bin targets for stratification (regression)
|
|
340
|
+
y_binned = np.digitize(
|
|
341
|
+
y[:, 0], np.percentile(y[:, 0], np.linspace(0, 100, stratify_bins + 1))
|
|
342
|
+
)
|
|
343
|
+
kfold = StratifiedKFold(n_splits=folds, shuffle=True, random_state=seed)
|
|
344
|
+
splits = list(kfold.split(X, y_binned))
|
|
345
|
+
else:
|
|
346
|
+
kfold = KFold(n_splits=folds, shuffle=True, random_state=seed)
|
|
347
|
+
splits = list(kfold.split(X))
|
|
348
|
+
|
|
349
|
+
# Import factories
|
|
350
|
+
from wavedl.models import build_model
|
|
351
|
+
from wavedl.utils import get_loss, get_optimizer, get_scheduler
|
|
352
|
+
|
|
353
|
+
# Run folds
|
|
354
|
+
fold_results = []
|
|
355
|
+
|
|
356
|
+
for fold, (train_idx, val_idx) in enumerate(splits):
|
|
357
|
+
logger.info(f"\n{'=' * 60}")
|
|
358
|
+
logger.info(f"📊 Fold {fold + 1}/{folds}")
|
|
359
|
+
logger.info(f" Train: {len(train_idx)} samples, Val: {len(val_idx)} samples")
|
|
360
|
+
|
|
361
|
+
# Split data
|
|
362
|
+
X_train, X_val = X[train_idx], X[val_idx]
|
|
363
|
+
y_train, y_val = y[train_idx], y[val_idx]
|
|
364
|
+
|
|
365
|
+
# Fit scaler on training data only
|
|
366
|
+
scaler = StandardScaler()
|
|
367
|
+
y_train_scaled = scaler.fit_transform(y_train)
|
|
368
|
+
y_val_scaled = scaler.transform(y_val)
|
|
369
|
+
|
|
370
|
+
# Create datasets and loaders with explicit spatial dimensionality
|
|
371
|
+
spatial_ndim = len(in_shape)
|
|
372
|
+
train_ds = CVDataset(
|
|
373
|
+
X_train.astype(np.float32),
|
|
374
|
+
y_train_scaled.astype(np.float32),
|
|
375
|
+
expected_spatial_ndim=spatial_ndim,
|
|
376
|
+
)
|
|
377
|
+
val_ds = CVDataset(
|
|
378
|
+
X_val.astype(np.float32),
|
|
379
|
+
y_val_scaled.astype(np.float32),
|
|
380
|
+
expected_spatial_ndim=spatial_ndim,
|
|
381
|
+
)
|
|
382
|
+
|
|
383
|
+
train_loader = DataLoader(
|
|
384
|
+
train_ds,
|
|
385
|
+
batch_size=batch_size,
|
|
386
|
+
shuffle=True,
|
|
387
|
+
num_workers=workers,
|
|
388
|
+
pin_memory=True,
|
|
389
|
+
)
|
|
390
|
+
val_loader = DataLoader(
|
|
391
|
+
val_ds,
|
|
392
|
+
batch_size=batch_size,
|
|
393
|
+
shuffle=False,
|
|
394
|
+
num_workers=workers,
|
|
395
|
+
pin_memory=True,
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
# Build model
|
|
399
|
+
model = build_model(model_name, in_shape=in_shape, out_size=out_size)
|
|
400
|
+
model = model.to(device)
|
|
401
|
+
|
|
402
|
+
# Setup training components
|
|
403
|
+
criterion = get_loss(loss_name)
|
|
404
|
+
optimizer = get_optimizer(
|
|
405
|
+
optimizer_name, model.parameters(), lr=lr, weight_decay=weight_decay
|
|
406
|
+
)
|
|
407
|
+
scheduler = get_scheduler(
|
|
408
|
+
scheduler_name,
|
|
409
|
+
optimizer,
|
|
410
|
+
epochs=epochs,
|
|
411
|
+
steps_per_epoch=len(train_loader) if scheduler_name == "onecycle" else None,
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
# Train fold
|
|
415
|
+
results = train_fold(
|
|
416
|
+
fold=fold,
|
|
417
|
+
model=model,
|
|
418
|
+
train_loader=train_loader,
|
|
419
|
+
val_loader=val_loader,
|
|
420
|
+
criterion=criterion,
|
|
421
|
+
optimizer=optimizer,
|
|
422
|
+
scheduler=scheduler,
|
|
423
|
+
device=device,
|
|
424
|
+
epochs=epochs,
|
|
425
|
+
patience=patience,
|
|
426
|
+
scaler=scaler,
|
|
427
|
+
logger=logger,
|
|
428
|
+
)
|
|
429
|
+
|
|
430
|
+
fold_results.append(results)
|
|
431
|
+
|
|
432
|
+
logger.info(
|
|
433
|
+
f" Fold {fold + 1} Results: R²={results['r2']:.4f}, MAE={results['mae_physical']:.4f}"
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
# Save fold model
|
|
437
|
+
fold_dir = os.path.join(output_dir, f"fold_{fold + 1}")
|
|
438
|
+
os.makedirs(fold_dir, exist_ok=True)
|
|
439
|
+
torch.save(model.state_dict(), os.path.join(fold_dir, "model.pth"))
|
|
440
|
+
with open(os.path.join(fold_dir, "scaler.pkl"), "wb") as f:
|
|
441
|
+
pickle.dump(scaler, f)
|
|
442
|
+
|
|
443
|
+
# ==============================================================================
|
|
444
|
+
# AGGREGATE RESULTS
|
|
445
|
+
# ==============================================================================
|
|
446
|
+
logger.info(f"\n{'=' * 60}")
|
|
447
|
+
logger.info("📈 Cross-Validation Summary")
|
|
448
|
+
logger.info("=" * 60)
|
|
449
|
+
|
|
450
|
+
r2_scores = [r["r2"] for r in fold_results]
|
|
451
|
+
mae_scores = [r["mae_physical"] for r in fold_results]
|
|
452
|
+
val_losses = [r["best_val_loss"] for r in fold_results]
|
|
453
|
+
|
|
454
|
+
summary = {
|
|
455
|
+
"config": {
|
|
456
|
+
"model": model_name,
|
|
457
|
+
"folds": folds,
|
|
458
|
+
"stratify": stratify,
|
|
459
|
+
"stratify_bins": stratify_bins,
|
|
460
|
+
"batch_size": batch_size,
|
|
461
|
+
"lr": lr,
|
|
462
|
+
"epochs": epochs,
|
|
463
|
+
"patience": patience,
|
|
464
|
+
"loss": loss_name,
|
|
465
|
+
"optimizer": optimizer_name,
|
|
466
|
+
"scheduler": scheduler_name,
|
|
467
|
+
},
|
|
468
|
+
"timestamp": datetime.now().isoformat(),
|
|
469
|
+
"folds": folds,
|
|
470
|
+
"r2_mean": float(np.mean(r2_scores)),
|
|
471
|
+
"r2_std": float(np.std(r2_scores)),
|
|
472
|
+
"mae_mean": float(np.mean(mae_scores)),
|
|
473
|
+
"mae_std": float(np.std(mae_scores)),
|
|
474
|
+
"val_loss_mean": float(np.mean(val_losses)),
|
|
475
|
+
"val_loss_std": float(np.std(val_losses)),
|
|
476
|
+
"fold_results": fold_results,
|
|
477
|
+
}
|
|
478
|
+
|
|
479
|
+
logger.info(f" R² Score: {summary['r2_mean']:.4f} ± {summary['r2_std']:.4f}")
|
|
480
|
+
logger.info(f" MAE (phys): {summary['mae_mean']:.4f} ± {summary['mae_std']:.4f}")
|
|
481
|
+
logger.info(
|
|
482
|
+
f" Val Loss: {summary['val_loss_mean']:.6f} ± {summary['val_loss_std']:.6f}"
|
|
483
|
+
)
|
|
484
|
+
|
|
485
|
+
# Per-target summary
|
|
486
|
+
for i in range(out_size):
|
|
487
|
+
r2_target = [r.get(f"r2_target_{i}", np.nan) for r in fold_results]
|
|
488
|
+
mae_target = [r.get(f"mae_target_{i}", np.nan) for r in fold_results]
|
|
489
|
+
logger.info(
|
|
490
|
+
f" Target {i}: R²={np.mean(r2_target):.4f}±{np.std(r2_target):.4f}, "
|
|
491
|
+
f"MAE={np.mean(mae_target):.4f}±{np.std(mae_target):.4f}"
|
|
492
|
+
)
|
|
493
|
+
|
|
494
|
+
# Save summary
|
|
495
|
+
with open(os.path.join(output_dir, "cv_summary.json"), "w") as f:
|
|
496
|
+
summary_save = summary.copy()
|
|
497
|
+
for r in summary_save["fold_results"]:
|
|
498
|
+
r["history"] = None # Too large
|
|
499
|
+
json.dump(summary_save, f, indent=2)
|
|
500
|
+
|
|
501
|
+
# Save detailed results as CSV
|
|
502
|
+
results_df = pd.DataFrame(
|
|
503
|
+
[{k: v for k, v in r.items() if k != "history"} for r in fold_results]
|
|
504
|
+
)
|
|
505
|
+
results_df.to_csv(os.path.join(output_dir, "cv_results.csv"), index=False)
|
|
506
|
+
|
|
507
|
+
logger.info(f"\n✅ Results saved to: {output_dir}")
|
|
508
|
+
|
|
509
|
+
return summary
|