wavedl 1.6.0__py3-none-any.whl → 1.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/hpo.py +451 -451
- wavedl/models/__init__.py +28 -0
- wavedl/models/{_timm_utils.py → _pretrained_utils.py} +128 -0
- wavedl/models/base.py +48 -0
- wavedl/models/caformer.py +1 -1
- wavedl/models/cnn.py +2 -27
- wavedl/models/convnext.py +5 -18
- wavedl/models/convnext_v2.py +6 -22
- wavedl/models/densenet.py +5 -18
- wavedl/models/efficientnetv2.py +315 -315
- wavedl/models/efficientvit.py +398 -0
- wavedl/models/fastvit.py +6 -39
- wavedl/models/mamba.py +44 -24
- wavedl/models/maxvit.py +51 -48
- wavedl/models/mobilenetv3.py +295 -295
- wavedl/models/regnet.py +406 -406
- wavedl/models/resnet.py +14 -56
- wavedl/models/resnet3d.py +258 -258
- wavedl/models/swin.py +443 -443
- wavedl/models/tcn.py +393 -409
- wavedl/models/unet.py +1 -5
- wavedl/models/unireplknet.py +491 -0
- wavedl/models/vit.py +3 -3
- wavedl/train.py +1430 -1430
- wavedl/utils/config.py +367 -367
- wavedl/utils/cross_validation.py +530 -530
- wavedl/utils/losses.py +216 -216
- wavedl/utils/optimizers.py +216 -216
- wavedl/utils/schedulers.py +251 -251
- {wavedl-1.6.0.dist-info → wavedl-1.6.1.dist-info}/METADATA +93 -53
- wavedl-1.6.1.dist-info/RECORD +46 -0
- wavedl-1.6.0.dist-info/RECORD +0 -44
- {wavedl-1.6.0.dist-info → wavedl-1.6.1.dist-info}/LICENSE +0 -0
- {wavedl-1.6.0.dist-info → wavedl-1.6.1.dist-info}/WHEEL +0 -0
- {wavedl-1.6.0.dist-info → wavedl-1.6.1.dist-info}/entry_points.txt +0 -0
- {wavedl-1.6.0.dist-info → wavedl-1.6.1.dist-info}/top_level.txt +0 -0
wavedl/train.py
CHANGED
|
@@ -1,1430 +1,1430 @@
|
|
|
1
|
-
"""
|
|
2
|
-
WaveDL - Deep Learning for Wave-based Inverse Problems
|
|
3
|
-
=======================================================
|
|
4
|
-
Target Environment: NVIDIA HPC GPUs (Multi-GPU DDP) | PyTorch 2.x | Python 3.11+
|
|
5
|
-
|
|
6
|
-
A modular training framework for wave-based inverse problems and regression:
|
|
7
|
-
1. HPC-Grade DDP Training: BF16/FP16 mixed precision with torch.compile support
|
|
8
|
-
2. Dynamic Model Selection: Use --model flag to select any registered architecture
|
|
9
|
-
3. Zero-Copy Data Engine: Memmap-backed datasets for large-scale training
|
|
10
|
-
4. Physics-Aware Metrics: Real-time physical MAE with proper unscaling
|
|
11
|
-
5. Robust Checkpointing: Resume training, periodic saves, and training curves
|
|
12
|
-
6. Deep Observability: WandB integration with scatter analysis
|
|
13
|
-
|
|
14
|
-
Usage:
|
|
15
|
-
# Recommended: Using the HPC launcher (handles accelerate configuration)
|
|
16
|
-
wavedl-hpc --model cnn --batch_size 128 --mixed_precision bf16 --wandb
|
|
17
|
-
|
|
18
|
-
# Or direct training module (use --precision, not --mixed_precision)
|
|
19
|
-
accelerate launch -m wavedl.train --model cnn --batch_size 128 --precision bf16
|
|
20
|
-
|
|
21
|
-
# Multi-GPU with explicit config
|
|
22
|
-
wavedl-hpc --num_gpus 4 --mixed_precision bf16 --model cnn --wandb
|
|
23
|
-
|
|
24
|
-
# Resume from checkpoint
|
|
25
|
-
accelerate launch -m wavedl.train --model cnn --resume best_checkpoint --wandb
|
|
26
|
-
|
|
27
|
-
# List available models
|
|
28
|
-
wavedl-train --list_models
|
|
29
|
-
|
|
30
|
-
Note:
|
|
31
|
-
- wavedl-hpc: Uses --mixed_precision (passed to accelerate launch)
|
|
32
|
-
- wavedl.train: Uses --precision (internal module flag)
|
|
33
|
-
Both control the same behavior; use the appropriate flag for your entry point.
|
|
34
|
-
|
|
35
|
-
Author: Ductho Le (ductho.le@outlook.com)
|
|
36
|
-
"""
|
|
37
|
-
|
|
38
|
-
from __future__ import annotations
|
|
39
|
-
|
|
40
|
-
# =============================================================================
|
|
41
|
-
# HPC Environment Setup (MUST be before any library imports)
|
|
42
|
-
# =============================================================================
|
|
43
|
-
# Auto-configure writable cache directories when home is not writable.
|
|
44
|
-
# Uses current working directory as fallback - works on HPC and local machines.
|
|
45
|
-
import os
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
def _setup_cache_dir(env_var: str, subdir: str) -> None:
|
|
49
|
-
"""Set cache directory to CWD if home is not writable."""
|
|
50
|
-
if env_var in os.environ:
|
|
51
|
-
return # User already set, respect their choice
|
|
52
|
-
|
|
53
|
-
# Check if home is writable
|
|
54
|
-
home = os.path.expanduser("~")
|
|
55
|
-
if os.access(home, os.W_OK):
|
|
56
|
-
return # Home is writable, let library use defaults
|
|
57
|
-
|
|
58
|
-
# Home not writable - use current working directory
|
|
59
|
-
cache_path = os.path.join(os.getcwd(), f".{subdir}")
|
|
60
|
-
os.makedirs(cache_path, exist_ok=True)
|
|
61
|
-
os.environ[env_var] = cache_path
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
# Configure cache directories (before any library imports)
|
|
65
|
-
_setup_cache_dir("TORCH_HOME", "torch_cache")
|
|
66
|
-
_setup_cache_dir("MPLCONFIGDIR", "matplotlib")
|
|
67
|
-
_setup_cache_dir("FONTCONFIG_CACHE", "fontconfig")
|
|
68
|
-
_setup_cache_dir("XDG_DATA_HOME", "local/share")
|
|
69
|
-
_setup_cache_dir("XDG_STATE_HOME", "local/state")
|
|
70
|
-
_setup_cache_dir("XDG_CACHE_HOME", "cache")
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
def _setup_per_rank_compile_cache() -> None:
|
|
74
|
-
"""Set per-GPU Triton/Inductor cache to prevent multi-process race warnings.
|
|
75
|
-
|
|
76
|
-
When using torch.compile with multiple GPUs, all processes try to write to
|
|
77
|
-
the same cache directory, causing 'Directory is not empty - skipping!' warnings.
|
|
78
|
-
This gives each GPU rank its own isolated cache subdirectory.
|
|
79
|
-
"""
|
|
80
|
-
# Get local rank from environment (set by accelerate/torchrun)
|
|
81
|
-
local_rank = os.environ.get("LOCAL_RANK", "0")
|
|
82
|
-
|
|
83
|
-
# Get cache base from environment or use CWD
|
|
84
|
-
cache_base = os.environ.get(
|
|
85
|
-
"TRITON_CACHE_DIR", os.path.join(os.getcwd(), ".triton_cache")
|
|
86
|
-
)
|
|
87
|
-
|
|
88
|
-
# Set per-rank cache directories
|
|
89
|
-
os.environ["TRITON_CACHE_DIR"] = os.path.join(cache_base, f"rank_{local_rank}")
|
|
90
|
-
os.environ["TORCHINDUCTOR_CACHE_DIR"] = os.path.join(
|
|
91
|
-
os.environ.get(
|
|
92
|
-
"TORCHINDUCTOR_CACHE_DIR", os.path.join(os.getcwd(), ".inductor_cache")
|
|
93
|
-
),
|
|
94
|
-
f"rank_{local_rank}",
|
|
95
|
-
)
|
|
96
|
-
|
|
97
|
-
# Create directories
|
|
98
|
-
os.makedirs(os.environ["TRITON_CACHE_DIR"], exist_ok=True)
|
|
99
|
-
os.makedirs(os.environ["TORCHINDUCTOR_CACHE_DIR"], exist_ok=True)
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
# Setup per-rank compile caches (before torch imports)
|
|
103
|
-
_setup_per_rank_compile_cache()
|
|
104
|
-
|
|
105
|
-
# =============================================================================
|
|
106
|
-
# Standard imports (after environment setup)
|
|
107
|
-
# =============================================================================
|
|
108
|
-
import argparse
|
|
109
|
-
import logging
|
|
110
|
-
import pickle
|
|
111
|
-
import shutil
|
|
112
|
-
import sys
|
|
113
|
-
import time
|
|
114
|
-
import warnings
|
|
115
|
-
from typing import Any
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
# Suppress Pydantic warnings from accelerate's internal Field() usage
|
|
119
|
-
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic")
|
|
120
|
-
|
|
121
|
-
import matplotlib.pyplot as plt
|
|
122
|
-
import numpy as np
|
|
123
|
-
import pandas as pd
|
|
124
|
-
import torch
|
|
125
|
-
import torch.distributed as dist
|
|
126
|
-
from accelerate import Accelerator
|
|
127
|
-
from accelerate.utils import set_seed
|
|
128
|
-
from sklearn.metrics import r2_score
|
|
129
|
-
from tqdm.auto import tqdm
|
|
130
|
-
|
|
131
|
-
from wavedl.models import build_model, get_model, list_models
|
|
132
|
-
from wavedl.utils import (
|
|
133
|
-
FIGURE_DPI,
|
|
134
|
-
MetricTracker,
|
|
135
|
-
broadcast_early_stop,
|
|
136
|
-
calc_pearson,
|
|
137
|
-
create_training_curves,
|
|
138
|
-
get_loss,
|
|
139
|
-
get_lr,
|
|
140
|
-
get_optimizer,
|
|
141
|
-
get_scheduler,
|
|
142
|
-
is_epoch_based,
|
|
143
|
-
list_losses,
|
|
144
|
-
list_optimizers,
|
|
145
|
-
list_schedulers,
|
|
146
|
-
plot_scientific_scatter,
|
|
147
|
-
prepare_data,
|
|
148
|
-
)
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
try:
|
|
152
|
-
import wandb
|
|
153
|
-
|
|
154
|
-
WANDB_AVAILABLE = True
|
|
155
|
-
except ImportError:
|
|
156
|
-
WANDB_AVAILABLE = False
|
|
157
|
-
|
|
158
|
-
# ==============================================================================
|
|
159
|
-
# RUNTIME CONFIGURATION (post-import)
|
|
160
|
-
# ==============================================================================
|
|
161
|
-
# Configure matplotlib paths for HPC systems without writable home directories
|
|
162
|
-
os.environ.setdefault("MPLCONFIGDIR", os.getenv("TMPDIR", "/tmp") + "/matplotlib")
|
|
163
|
-
os.environ.setdefault("FONTCONFIG_PATH", "/etc/fonts")
|
|
164
|
-
|
|
165
|
-
# Suppress warnings from known-noisy libraries, but preserve legitimate warnings
|
|
166
|
-
# from torch/numpy about NaN, dtype, and numerical issues.
|
|
167
|
-
warnings.filterwarnings("ignore", category=FutureWarning)
|
|
168
|
-
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
|
169
|
-
# Pydantic v1/v2 compatibility warnings
|
|
170
|
-
warnings.filterwarnings("ignore", module="pydantic")
|
|
171
|
-
warnings.filterwarnings("ignore", message=".*UnsupportedFieldAttributeWarning.*")
|
|
172
|
-
# Transformer library warnings (loading configs, etc.)
|
|
173
|
-
warnings.filterwarnings("ignore", module="transformers")
|
|
174
|
-
# Accelerate verbose messages
|
|
175
|
-
warnings.filterwarnings("ignore", module="accelerate")
|
|
176
|
-
# torch.compile backend selection warnings
|
|
177
|
-
warnings.filterwarnings("ignore", message=".*TorchDynamo.*")
|
|
178
|
-
warnings.filterwarnings("ignore", message=".*Dynamo is not supported.*")
|
|
179
|
-
# Note: UserWarning from torch/numpy core is NOT suppressed to preserve
|
|
180
|
-
# legitimate warnings about NaN values, dtype mismatches, etc.
|
|
181
|
-
|
|
182
|
-
# ==============================================================================
|
|
183
|
-
# GPU PERFORMANCE OPTIMIZATIONS (Ampere/Hopper: A100, H100)
|
|
184
|
-
# ==============================================================================
|
|
185
|
-
# Enable TF32 for faster matmul (safe precision for training, ~2x speedup)
|
|
186
|
-
torch.backends.cuda.matmul.allow_tf32 = True
|
|
187
|
-
torch.backends.cudnn.allow_tf32 = True
|
|
188
|
-
torch.set_float32_matmul_precision("high") # Use TF32 for float32 ops
|
|
189
|
-
|
|
190
|
-
# Enable cuDNN autotuning for fixed-size inputs (CNN-like models benefit most)
|
|
191
|
-
# Note: First few batches may be slower due to benchmarking
|
|
192
|
-
torch.backends.cudnn.benchmark = True
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
# ==============================================================================
|
|
196
|
-
# LOGGING UTILITIES
|
|
197
|
-
# ==============================================================================
|
|
198
|
-
from contextlib import contextmanager
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
@contextmanager
|
|
202
|
-
def suppress_accelerate_logging():
|
|
203
|
-
"""Temporarily suppress accelerate's verbose checkpoint save messages."""
|
|
204
|
-
accelerate_logger = logging.getLogger("accelerate.checkpointing")
|
|
205
|
-
original_level = accelerate_logger.level
|
|
206
|
-
accelerate_logger.setLevel(logging.WARNING)
|
|
207
|
-
try:
|
|
208
|
-
yield
|
|
209
|
-
finally:
|
|
210
|
-
accelerate_logger.setLevel(original_level)
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
# ==============================================================================
|
|
214
|
-
# ARGUMENT PARSING
|
|
215
|
-
# ==============================================================================
|
|
216
|
-
def parse_args() -> argparse.Namespace:
|
|
217
|
-
"""Parse command-line arguments with comprehensive options."""
|
|
218
|
-
parser = argparse.ArgumentParser(
|
|
219
|
-
description="Universal DDP Training Pipeline",
|
|
220
|
-
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
221
|
-
)
|
|
222
|
-
|
|
223
|
-
# Model Selection
|
|
224
|
-
parser.add_argument(
|
|
225
|
-
"--model",
|
|
226
|
-
type=str,
|
|
227
|
-
default="cnn",
|
|
228
|
-
help=f"Model architecture to train. Available: {list_models()}",
|
|
229
|
-
)
|
|
230
|
-
parser.add_argument(
|
|
231
|
-
"--list_models", action="store_true", help="List all available models and exit"
|
|
232
|
-
)
|
|
233
|
-
parser.add_argument(
|
|
234
|
-
"--import",
|
|
235
|
-
dest="import_modules",
|
|
236
|
-
type=str,
|
|
237
|
-
nargs="+",
|
|
238
|
-
default=[],
|
|
239
|
-
help="Python modules to import before training (for custom models)",
|
|
240
|
-
)
|
|
241
|
-
parser.add_argument(
|
|
242
|
-
"--no_pretrained",
|
|
243
|
-
dest="pretrained",
|
|
244
|
-
action="store_false",
|
|
245
|
-
help="Train from scratch without pretrained weights (default: use pretrained)",
|
|
246
|
-
)
|
|
247
|
-
parser.set_defaults(pretrained=True)
|
|
248
|
-
|
|
249
|
-
# Configuration File
|
|
250
|
-
parser.add_argument(
|
|
251
|
-
"--config",
|
|
252
|
-
type=str,
|
|
253
|
-
default=None,
|
|
254
|
-
help="Path to YAML config file. CLI args override config values.",
|
|
255
|
-
)
|
|
256
|
-
|
|
257
|
-
# Hyperparameters
|
|
258
|
-
parser.add_argument(
|
|
259
|
-
"--batch_size", type=int, default=128, help="Batch size per GPU"
|
|
260
|
-
)
|
|
261
|
-
parser.add_argument("--lr", type=float, default=1e-3, help="Initial learning rate")
|
|
262
|
-
parser.add_argument(
|
|
263
|
-
"--epochs", type=int, default=1000, help="Maximum training epochs"
|
|
264
|
-
)
|
|
265
|
-
parser.add_argument(
|
|
266
|
-
"--patience", type=int, default=20, help="Early stopping patience"
|
|
267
|
-
)
|
|
268
|
-
parser.add_argument("--weight_decay", type=float, default=1e-4, help="Weight decay")
|
|
269
|
-
parser.add_argument(
|
|
270
|
-
"--grad_clip", type=float, default=1.0, help="Gradient clipping norm"
|
|
271
|
-
)
|
|
272
|
-
|
|
273
|
-
# Loss Function
|
|
274
|
-
parser.add_argument(
|
|
275
|
-
"--loss",
|
|
276
|
-
type=str,
|
|
277
|
-
default="mse",
|
|
278
|
-
choices=["mse", "mae", "huber", "smooth_l1", "log_cosh", "weighted_mse"],
|
|
279
|
-
help=f"Loss function for training. Available: {list_losses()}",
|
|
280
|
-
)
|
|
281
|
-
parser.add_argument(
|
|
282
|
-
"--huber_delta", type=float, default=1.0, help="Delta for Huber loss"
|
|
283
|
-
)
|
|
284
|
-
parser.add_argument(
|
|
285
|
-
"--loss_weights",
|
|
286
|
-
type=str,
|
|
287
|
-
default=None,
|
|
288
|
-
help="Comma-separated weights for weighted_mse (e.g., '1.0,2.0,1.0')",
|
|
289
|
-
)
|
|
290
|
-
|
|
291
|
-
# Optimizer
|
|
292
|
-
parser.add_argument(
|
|
293
|
-
"--optimizer",
|
|
294
|
-
type=str,
|
|
295
|
-
default="adamw",
|
|
296
|
-
choices=["adamw", "adam", "sgd", "nadam", "radam", "rmsprop"],
|
|
297
|
-
help=f"Optimizer for training. Available: {list_optimizers()}",
|
|
298
|
-
)
|
|
299
|
-
parser.add_argument(
|
|
300
|
-
"--momentum", type=float, default=0.9, help="Momentum for SGD/RMSprop"
|
|
301
|
-
)
|
|
302
|
-
parser.add_argument(
|
|
303
|
-
"--nesterov", action="store_true", help="Use Nesterov momentum (SGD)"
|
|
304
|
-
)
|
|
305
|
-
parser.add_argument(
|
|
306
|
-
"--betas",
|
|
307
|
-
type=str,
|
|
308
|
-
default="0.9,0.999",
|
|
309
|
-
help="Betas for Adam variants (comma-separated)",
|
|
310
|
-
)
|
|
311
|
-
|
|
312
|
-
# Learning Rate Scheduler
|
|
313
|
-
parser.add_argument(
|
|
314
|
-
"--scheduler",
|
|
315
|
-
type=str,
|
|
316
|
-
default="plateau",
|
|
317
|
-
choices=[
|
|
318
|
-
"plateau",
|
|
319
|
-
"cosine",
|
|
320
|
-
"cosine_restarts",
|
|
321
|
-
"onecycle",
|
|
322
|
-
"step",
|
|
323
|
-
"multistep",
|
|
324
|
-
"exponential",
|
|
325
|
-
"linear_warmup",
|
|
326
|
-
],
|
|
327
|
-
help=f"LR scheduler. Available: {list_schedulers()}",
|
|
328
|
-
)
|
|
329
|
-
parser.add_argument(
|
|
330
|
-
"--scheduler_patience",
|
|
331
|
-
type=int,
|
|
332
|
-
default=10,
|
|
333
|
-
help="Patience for ReduceLROnPlateau",
|
|
334
|
-
)
|
|
335
|
-
parser.add_argument(
|
|
336
|
-
"--min_lr", type=float, default=1e-6, help="Minimum learning rate"
|
|
337
|
-
)
|
|
338
|
-
parser.add_argument(
|
|
339
|
-
"--scheduler_factor", type=float, default=0.5, help="LR reduction factor"
|
|
340
|
-
)
|
|
341
|
-
parser.add_argument(
|
|
342
|
-
"--warmup_epochs", type=int, default=5, help="Warmup epochs for linear_warmup"
|
|
343
|
-
)
|
|
344
|
-
parser.add_argument(
|
|
345
|
-
"--step_size", type=int, default=30, help="Step size for StepLR"
|
|
346
|
-
)
|
|
347
|
-
parser.add_argument(
|
|
348
|
-
"--milestones",
|
|
349
|
-
type=str,
|
|
350
|
-
default=None,
|
|
351
|
-
help="Comma-separated epochs for MultiStepLR (e.g., '30,60,90')",
|
|
352
|
-
)
|
|
353
|
-
|
|
354
|
-
# Data
|
|
355
|
-
parser.add_argument(
|
|
356
|
-
"--data_path", type=str, default="train_data.npz", help="Path to NPZ dataset"
|
|
357
|
-
)
|
|
358
|
-
parser.add_argument(
|
|
359
|
-
"--workers",
|
|
360
|
-
type=int,
|
|
361
|
-
default=-1,
|
|
362
|
-
help="DataLoader workers per GPU (-1=auto-detect based on CPU cores)",
|
|
363
|
-
)
|
|
364
|
-
parser.add_argument("--seed", type=int, default=2025, help="Random seed")
|
|
365
|
-
parser.add_argument(
|
|
366
|
-
"--deterministic",
|
|
367
|
-
action="store_true",
|
|
368
|
-
help="Enable deterministic mode for reproducibility (slower, disables TF32/cuDNN benchmark)",
|
|
369
|
-
)
|
|
370
|
-
parser.add_argument(
|
|
371
|
-
"--cache_validate",
|
|
372
|
-
type=str,
|
|
373
|
-
default="sha256",
|
|
374
|
-
choices=["sha256", "fast", "size"],
|
|
375
|
-
help="Cache validation mode: sha256 (full hash), fast (partial), size (quick)",
|
|
376
|
-
)
|
|
377
|
-
parser.add_argument(
|
|
378
|
-
"--single_channel",
|
|
379
|
-
action="store_true",
|
|
380
|
-
help="Confirm data is single-channel (suppress ambiguous shape warnings for shallow 3D volumes)",
|
|
381
|
-
)
|
|
382
|
-
|
|
383
|
-
# Cross-Validation
|
|
384
|
-
parser.add_argument(
|
|
385
|
-
"--cv",
|
|
386
|
-
type=int,
|
|
387
|
-
default=0,
|
|
388
|
-
help="Enable K-fold cross-validation with K folds (0=disabled)",
|
|
389
|
-
)
|
|
390
|
-
parser.add_argument(
|
|
391
|
-
"--cv_stratify",
|
|
392
|
-
action="store_true",
|
|
393
|
-
help="Use stratified splitting for cross-validation",
|
|
394
|
-
)
|
|
395
|
-
parser.add_argument(
|
|
396
|
-
"--cv_bins",
|
|
397
|
-
type=int,
|
|
398
|
-
default=10,
|
|
399
|
-
help="Number of bins for stratified CV (only with --cv_stratify)",
|
|
400
|
-
)
|
|
401
|
-
|
|
402
|
-
# Checkpointing & Resume
|
|
403
|
-
parser.add_argument(
|
|
404
|
-
"--resume", type=str, default=None, help="Checkpoint directory to resume from"
|
|
405
|
-
)
|
|
406
|
-
parser.add_argument(
|
|
407
|
-
"--save_every",
|
|
408
|
-
type=int,
|
|
409
|
-
default=50,
|
|
410
|
-
help="Save checkpoint every N epochs (0=disable)",
|
|
411
|
-
)
|
|
412
|
-
parser.add_argument(
|
|
413
|
-
"--output_dir", type=str, default=".", help="Output directory for checkpoints"
|
|
414
|
-
)
|
|
415
|
-
parser.add_argument(
|
|
416
|
-
"--fresh",
|
|
417
|
-
action="store_true",
|
|
418
|
-
help="Force fresh training, ignore existing checkpoints",
|
|
419
|
-
)
|
|
420
|
-
|
|
421
|
-
# Performance
|
|
422
|
-
parser.add_argument(
|
|
423
|
-
"--compile", action="store_true", help="Enable torch.compile (PyTorch 2.x)"
|
|
424
|
-
)
|
|
425
|
-
parser.add_argument(
|
|
426
|
-
"--precision",
|
|
427
|
-
type=str,
|
|
428
|
-
default="bf16",
|
|
429
|
-
choices=["bf16", "fp16", "no"],
|
|
430
|
-
help="Mixed precision mode",
|
|
431
|
-
)
|
|
432
|
-
# Alias for consistency with wavedl-hpc (--mixed_precision)
|
|
433
|
-
parser.add_argument(
|
|
434
|
-
"--mixed_precision",
|
|
435
|
-
dest="precision",
|
|
436
|
-
type=str,
|
|
437
|
-
choices=["bf16", "fp16", "no"],
|
|
438
|
-
help=argparse.SUPPRESS, # Hidden: use --precision instead
|
|
439
|
-
)
|
|
440
|
-
|
|
441
|
-
# Physical Constraints
|
|
442
|
-
parser.add_argument(
|
|
443
|
-
"--constraint",
|
|
444
|
-
type=str,
|
|
445
|
-
nargs="+",
|
|
446
|
-
default=[],
|
|
447
|
-
help="Soft constraint expressions: 'y0 - y1*y2' (penalize violations)",
|
|
448
|
-
)
|
|
449
|
-
|
|
450
|
-
parser.add_argument(
|
|
451
|
-
"--constraint_file",
|
|
452
|
-
type=str,
|
|
453
|
-
default=None,
|
|
454
|
-
help="Python file with constraint(pred, inputs) function",
|
|
455
|
-
)
|
|
456
|
-
parser.add_argument(
|
|
457
|
-
"--constraint_weight",
|
|
458
|
-
type=float,
|
|
459
|
-
nargs="+",
|
|
460
|
-
default=[0.1],
|
|
461
|
-
help="Weight(s) for soft constraints (one per constraint, or single shared weight)",
|
|
462
|
-
)
|
|
463
|
-
parser.add_argument(
|
|
464
|
-
"--constraint_reduction",
|
|
465
|
-
type=str,
|
|
466
|
-
default="mse",
|
|
467
|
-
choices=["mse", "mae"],
|
|
468
|
-
help="Reduction mode for constraint penalties",
|
|
469
|
-
)
|
|
470
|
-
|
|
471
|
-
# Logging
|
|
472
|
-
parser.add_argument(
|
|
473
|
-
"--wandb", action="store_true", help="Enable Weights & Biases logging"
|
|
474
|
-
)
|
|
475
|
-
parser.add_argument(
|
|
476
|
-
"--wandb_watch",
|
|
477
|
-
action="store_true",
|
|
478
|
-
help="Enable WandB gradient watching (adds overhead, useful for debugging)",
|
|
479
|
-
)
|
|
480
|
-
parser.add_argument(
|
|
481
|
-
"--project_name", type=str, default="DL-Training", help="WandB project name"
|
|
482
|
-
)
|
|
483
|
-
parser.add_argument("--run_name", type=str, default=None, help="WandB run name")
|
|
484
|
-
|
|
485
|
-
args = parser.parse_args()
|
|
486
|
-
return args, parser # Returns (Namespace, ArgumentParser)
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
# ==============================================================================
|
|
490
|
-
# MAIN TRAINING FUNCTION
|
|
491
|
-
# ==============================================================================
|
|
492
|
-
def main():
|
|
493
|
-
args, parser = parse_args()
|
|
494
|
-
|
|
495
|
-
# Import custom model modules if specified
|
|
496
|
-
if args.import_modules:
|
|
497
|
-
import importlib
|
|
498
|
-
|
|
499
|
-
for module_name in args.import_modules:
|
|
500
|
-
try:
|
|
501
|
-
# Handle both module names (my_model) and file paths (./my_model.py)
|
|
502
|
-
if module_name.endswith(".py"):
|
|
503
|
-
# Import from file path with unique module name
|
|
504
|
-
import importlib.util
|
|
505
|
-
|
|
506
|
-
# Derive unique module name from filename to avoid collisions
|
|
507
|
-
base_name = os.path.splitext(os.path.basename(module_name))[0]
|
|
508
|
-
unique_name = f"wavedl_custom_{base_name}"
|
|
509
|
-
|
|
510
|
-
spec = importlib.util.spec_from_file_location(
|
|
511
|
-
unique_name, module_name
|
|
512
|
-
)
|
|
513
|
-
if spec and spec.loader:
|
|
514
|
-
module = importlib.util.module_from_spec(spec)
|
|
515
|
-
sys.modules[unique_name] = module
|
|
516
|
-
spec.loader.exec_module(module)
|
|
517
|
-
print(f"✓ Imported custom module from: {module_name}")
|
|
518
|
-
else:
|
|
519
|
-
# Import as regular module
|
|
520
|
-
importlib.import_module(module_name)
|
|
521
|
-
print(f"✓ Imported module: {module_name}")
|
|
522
|
-
except (ImportError, FileNotFoundError, SyntaxError, PermissionError) as e:
|
|
523
|
-
print(f"✗ Failed to import '{module_name}': {e}", file=sys.stderr)
|
|
524
|
-
if isinstance(e, FileNotFoundError):
|
|
525
|
-
print(" File does not exist. Check the path.", file=sys.stderr)
|
|
526
|
-
elif isinstance(e, SyntaxError):
|
|
527
|
-
print(
|
|
528
|
-
f" Syntax error at line {e.lineno}: {e.msg}", file=sys.stderr
|
|
529
|
-
)
|
|
530
|
-
elif isinstance(e, PermissionError):
|
|
531
|
-
print(
|
|
532
|
-
" Permission denied. Check file permissions.", file=sys.stderr
|
|
533
|
-
)
|
|
534
|
-
else:
|
|
535
|
-
print(
|
|
536
|
-
" Make sure the module is in your Python path or current directory.",
|
|
537
|
-
file=sys.stderr,
|
|
538
|
-
)
|
|
539
|
-
sys.exit(1)
|
|
540
|
-
|
|
541
|
-
# Handle --list_models flag
|
|
542
|
-
if args.list_models:
|
|
543
|
-
print("Available models:")
|
|
544
|
-
for name in list_models():
|
|
545
|
-
ModelClass = get_model(name)
|
|
546
|
-
# Get first non-empty docstring line
|
|
547
|
-
if ModelClass.__doc__:
|
|
548
|
-
lines = [
|
|
549
|
-
l.strip() for l in ModelClass.__doc__.splitlines() if l.strip()
|
|
550
|
-
]
|
|
551
|
-
doc_first_line = lines[0] if lines else "No description"
|
|
552
|
-
else:
|
|
553
|
-
doc_first_line = "No description"
|
|
554
|
-
print(f" - {name}: {doc_first_line}")
|
|
555
|
-
sys.exit(0)
|
|
556
|
-
|
|
557
|
-
# Load and merge config file if provided
|
|
558
|
-
if args.config:
|
|
559
|
-
from wavedl.utils.config import (
|
|
560
|
-
load_config,
|
|
561
|
-
merge_config_with_args,
|
|
562
|
-
validate_config,
|
|
563
|
-
)
|
|
564
|
-
|
|
565
|
-
print(f"📄 Loading config from: {args.config}")
|
|
566
|
-
config = load_config(args.config)
|
|
567
|
-
|
|
568
|
-
# Validate config values
|
|
569
|
-
warnings_list = validate_config(config)
|
|
570
|
-
for w in warnings_list:
|
|
571
|
-
print(f" ⚠ {w}")
|
|
572
|
-
|
|
573
|
-
# Merge config with CLI args (CLI takes precedence via parser defaults detection)
|
|
574
|
-
args = merge_config_with_args(config, args, parser=parser)
|
|
575
|
-
|
|
576
|
-
# Handle --cv flag (cross-validation mode)
|
|
577
|
-
if args.cv > 0:
|
|
578
|
-
print(f"🔄 Cross-Validation Mode: {args.cv} folds")
|
|
579
|
-
from wavedl.utils.cross_validation import run_cross_validation
|
|
580
|
-
|
|
581
|
-
# Load data for CV using memory-efficient loader
|
|
582
|
-
from wavedl.utils.data import DataSource, get_data_source
|
|
583
|
-
|
|
584
|
-
data_format = DataSource.detect_format(args.data_path)
|
|
585
|
-
source = get_data_source(data_format)
|
|
586
|
-
|
|
587
|
-
# Use memory-mapped loading when available (now returns LazyDataHandle for all formats)
|
|
588
|
-
_cv_handle = None
|
|
589
|
-
if hasattr(source, "load_mmap"):
|
|
590
|
-
_cv_handle = source.load_mmap(args.data_path)
|
|
591
|
-
X, y = _cv_handle.inputs, _cv_handle.outputs
|
|
592
|
-
else:
|
|
593
|
-
X, y = source.load(args.data_path)
|
|
594
|
-
|
|
595
|
-
# Handle sparse matrices (must materialize for CV shuffling)
|
|
596
|
-
if hasattr(X, "__getitem__") and len(X) > 0 and hasattr(X[0], "toarray"):
|
|
597
|
-
X = np.stack([x.toarray() for x in X])
|
|
598
|
-
|
|
599
|
-
# Normalize target shape: (N,) -> (N, 1) for consistency
|
|
600
|
-
if y.ndim == 1:
|
|
601
|
-
y = y.reshape(-1, 1)
|
|
602
|
-
|
|
603
|
-
# Validate and determine input shape (consistent with prepare_data)
|
|
604
|
-
# Check for ambiguous shapes that could be multi-channel or shallow 3D volume
|
|
605
|
-
sample_shape = X.shape[1:] # Per-sample shape
|
|
606
|
-
|
|
607
|
-
# Same heuristic as prepare_data: detect ambiguous 3D shapes
|
|
608
|
-
is_ambiguous_shape = (
|
|
609
|
-
len(sample_shape) == 3 # Exactly 3D: could be (C, H, W) or (D, H, W)
|
|
610
|
-
and sample_shape[0] <= 16 # First dim looks like channels
|
|
611
|
-
and sample_shape[1] > 16
|
|
612
|
-
and sample_shape[2] > 16 # Both spatial dims are large
|
|
613
|
-
)
|
|
614
|
-
|
|
615
|
-
if is_ambiguous_shape and not args.single_channel:
|
|
616
|
-
raise ValueError(
|
|
617
|
-
f"Ambiguous input shape detected: sample shape {sample_shape}. "
|
|
618
|
-
f"This could be either:\n"
|
|
619
|
-
f" - Multi-channel 2D data (C={sample_shape[0]}, H={sample_shape[1]}, W={sample_shape[2]})\n"
|
|
620
|
-
f" - Single-channel 3D volume (D={sample_shape[0]}, H={sample_shape[1]}, W={sample_shape[2]})\n\n"
|
|
621
|
-
f"If this is single-channel 3D/shallow volume data, use --single_channel flag.\n"
|
|
622
|
-
f"If this is multi-channel 2D data, reshape to (N*C, H, W) with adjusted targets."
|
|
623
|
-
)
|
|
624
|
-
|
|
625
|
-
# in_shape = spatial dimensions for model registry (channel added during training)
|
|
626
|
-
in_shape = sample_shape
|
|
627
|
-
|
|
628
|
-
# Run cross-validation
|
|
629
|
-
try:
|
|
630
|
-
run_cross_validation(
|
|
631
|
-
X=X,
|
|
632
|
-
y=y,
|
|
633
|
-
model_name=args.model,
|
|
634
|
-
in_shape=in_shape,
|
|
635
|
-
out_size=y.shape[1],
|
|
636
|
-
folds=args.cv,
|
|
637
|
-
stratify=args.cv_stratify,
|
|
638
|
-
stratify_bins=args.cv_bins,
|
|
639
|
-
batch_size=args.batch_size,
|
|
640
|
-
lr=args.lr,
|
|
641
|
-
epochs=args.epochs,
|
|
642
|
-
patience=args.patience,
|
|
643
|
-
weight_decay=args.weight_decay,
|
|
644
|
-
loss_name=args.loss,
|
|
645
|
-
optimizer_name=args.optimizer,
|
|
646
|
-
scheduler_name=args.scheduler,
|
|
647
|
-
output_dir=args.output_dir,
|
|
648
|
-
workers=args.workers,
|
|
649
|
-
seed=args.seed,
|
|
650
|
-
)
|
|
651
|
-
finally:
|
|
652
|
-
# Clean up file handle if HDF5/MAT
|
|
653
|
-
if _cv_handle is not None and hasattr(_cv_handle, "close"):
|
|
654
|
-
try:
|
|
655
|
-
_cv_handle.close()
|
|
656
|
-
except Exception as e:
|
|
657
|
-
logging.debug(f"Failed to close CV data handle: {e}")
|
|
658
|
-
return
|
|
659
|
-
|
|
660
|
-
# ==========================================================================
|
|
661
|
-
# SYSTEM INITIALIZATION
|
|
662
|
-
# ==========================================================================
|
|
663
|
-
# Initialize Accelerator for DDP and mixed precision
|
|
664
|
-
accelerator = Accelerator(
|
|
665
|
-
mixed_precision=args.precision,
|
|
666
|
-
log_with="wandb" if args.wandb and WANDB_AVAILABLE else None,
|
|
667
|
-
)
|
|
668
|
-
set_seed(args.seed)
|
|
669
|
-
|
|
670
|
-
# Deterministic mode for scientific reproducibility
|
|
671
|
-
# Disables TF32 and cuDNN benchmark for exact reproducibility (slower)
|
|
672
|
-
if args.deterministic:
|
|
673
|
-
torch.backends.cudnn.benchmark = False
|
|
674
|
-
torch.backends.cudnn.deterministic = True
|
|
675
|
-
torch.backends.cuda.matmul.allow_tf32 = False
|
|
676
|
-
torch.backends.cudnn.allow_tf32 = False
|
|
677
|
-
torch.use_deterministic_algorithms(True, warn_only=True)
|
|
678
|
-
if accelerator.is_main_process:
|
|
679
|
-
print("🔒 Deterministic mode enabled (slower but reproducible)")
|
|
680
|
-
|
|
681
|
-
# Configure logging (rank 0 only prints to console)
|
|
682
|
-
logging.basicConfig(
|
|
683
|
-
level=logging.INFO if accelerator.is_main_process else logging.ERROR,
|
|
684
|
-
format="%(asctime)s | %(levelname)s | %(message)s",
|
|
685
|
-
datefmt="%H:%M:%S",
|
|
686
|
-
)
|
|
687
|
-
logger = logging.getLogger("Trainer")
|
|
688
|
-
|
|
689
|
-
# Ensure output directory exists (critical for cache files, checkpoints, etc.)
|
|
690
|
-
os.makedirs(args.output_dir, exist_ok=True)
|
|
691
|
-
|
|
692
|
-
# Auto-detect optimal DataLoader workers if not specified
|
|
693
|
-
if args.workers < 0:
|
|
694
|
-
cpu_count = os.cpu_count() or 4
|
|
695
|
-
num_gpus = accelerator.num_processes
|
|
696
|
-
# Heuristic: 4-16 workers per GPU, bounded by available CPU cores
|
|
697
|
-
# Increased cap from 8 to 16 for high-throughput GPUs (H100, A100)
|
|
698
|
-
args.workers = min(16, max(2, (cpu_count - 2) // num_gpus))
|
|
699
|
-
if accelerator.is_main_process:
|
|
700
|
-
logger.info(
|
|
701
|
-
f"⚙️ Auto-detected workers: {args.workers} per GPU "
|
|
702
|
-
f"(CPUs: {cpu_count}, GPUs: {num_gpus})"
|
|
703
|
-
)
|
|
704
|
-
|
|
705
|
-
if accelerator.is_main_process:
|
|
706
|
-
logger.info(f"🚀 Cluster Status: {accelerator.num_processes}x GPUs detected")
|
|
707
|
-
logger.info(
|
|
708
|
-
f" Model: {args.model} | Precision: {args.precision} | Compile: {args.compile}"
|
|
709
|
-
)
|
|
710
|
-
logger.info(
|
|
711
|
-
f" Loss: {args.loss} | Optimizer: {args.optimizer} | Scheduler: {args.scheduler}"
|
|
712
|
-
)
|
|
713
|
-
logger.info(f" Early Stopping Patience: {args.patience} epochs")
|
|
714
|
-
if args.save_every > 0:
|
|
715
|
-
logger.info(f" Periodic Checkpointing: Every {args.save_every} epochs")
|
|
716
|
-
if args.resume:
|
|
717
|
-
logger.info(f" 📂 Resuming from: {args.resume}")
|
|
718
|
-
|
|
719
|
-
# Initialize WandB
|
|
720
|
-
if args.wandb and WANDB_AVAILABLE:
|
|
721
|
-
accelerator.init_trackers(
|
|
722
|
-
project_name=args.project_name,
|
|
723
|
-
config=vars(args),
|
|
724
|
-
init_kwargs={"wandb": {"name": args.run_name or f"{args.model}_run"}},
|
|
725
|
-
)
|
|
726
|
-
|
|
727
|
-
# ==========================================================================
|
|
728
|
-
# DATA & MODEL LOADING
|
|
729
|
-
# ==========================================================================
|
|
730
|
-
train_dl, val_dl, scaler, in_shape, out_dim = prepare_data(
|
|
731
|
-
args, logger, accelerator, cache_dir=args.output_dir
|
|
732
|
-
)
|
|
733
|
-
|
|
734
|
-
# Build model using registry
|
|
735
|
-
model = build_model(
|
|
736
|
-
args.model, in_shape=in_shape, out_size=out_dim, pretrained=args.pretrained
|
|
737
|
-
)
|
|
738
|
-
|
|
739
|
-
if accelerator.is_main_process:
|
|
740
|
-
param_info = model.parameter_summary()
|
|
741
|
-
logger.info(
|
|
742
|
-
f" Model Parameters: {param_info['trainable_parameters']:,} trainable"
|
|
743
|
-
)
|
|
744
|
-
logger.info(f" Model Size: {param_info['total_mb']:.2f} MB")
|
|
745
|
-
|
|
746
|
-
# Optional WandB model watching (opt-in due to overhead on large models)
|
|
747
|
-
if (
|
|
748
|
-
args.wandb
|
|
749
|
-
and args.wandb_watch
|
|
750
|
-
and WANDB_AVAILABLE
|
|
751
|
-
and accelerator.is_main_process
|
|
752
|
-
):
|
|
753
|
-
wandb.watch(model, log="gradients", log_freq=100)
|
|
754
|
-
logger.info(" 📊 WandB gradient watching enabled")
|
|
755
|
-
|
|
756
|
-
# Torch 2.0 compilation (requires compatible Triton on GPU)
|
|
757
|
-
if args.compile:
|
|
758
|
-
try:
|
|
759
|
-
# Test if Triton is available - just import the package
|
|
760
|
-
# Different Triton versions have different internal APIs, so just check base import
|
|
761
|
-
import triton
|
|
762
|
-
|
|
763
|
-
model = torch.compile(model)
|
|
764
|
-
if accelerator.is_main_process:
|
|
765
|
-
logger.info(" ✔ torch.compile enabled (Triton backend)")
|
|
766
|
-
except ImportError as e:
|
|
767
|
-
if accelerator.is_main_process:
|
|
768
|
-
if "triton" in str(e).lower():
|
|
769
|
-
logger.warning(
|
|
770
|
-
" ⚠ Triton not installed or incompatible version - torch.compile disabled. "
|
|
771
|
-
"Training will proceed without compilation."
|
|
772
|
-
)
|
|
773
|
-
else:
|
|
774
|
-
logger.warning(
|
|
775
|
-
f" ⚠ torch.compile setup failed: {e}. Continuing without compilation."
|
|
776
|
-
)
|
|
777
|
-
except Exception as e:
|
|
778
|
-
if accelerator.is_main_process:
|
|
779
|
-
logger.warning(
|
|
780
|
-
f" ⚠ torch.compile failed: {e}. Continuing without compilation."
|
|
781
|
-
)
|
|
782
|
-
|
|
783
|
-
# ==========================================================================
|
|
784
|
-
# OPTIMIZER, SCHEDULER & LOSS CONFIGURATION
|
|
785
|
-
# ==========================================================================
|
|
786
|
-
# Parse comma-separated arguments with validation
|
|
787
|
-
try:
|
|
788
|
-
betas_list = [float(x.strip()) for x in args.betas.split(",")]
|
|
789
|
-
if len(betas_list) != 2:
|
|
790
|
-
raise ValueError(
|
|
791
|
-
f"--betas must have exactly 2 values, got {len(betas_list)}"
|
|
792
|
-
)
|
|
793
|
-
if not all(0.0 <= b < 1.0 for b in betas_list):
|
|
794
|
-
raise ValueError(f"--betas values must be in [0, 1), got {betas_list}")
|
|
795
|
-
betas = tuple(betas_list)
|
|
796
|
-
except ValueError as e:
|
|
797
|
-
raise ValueError(
|
|
798
|
-
f"Invalid --betas format '{args.betas}': {e}. Expected format: '0.9,0.999'"
|
|
799
|
-
)
|
|
800
|
-
|
|
801
|
-
loss_weights = None
|
|
802
|
-
if args.loss_weights:
|
|
803
|
-
loss_weights = [float(x.strip()) for x in args.loss_weights.split(",")]
|
|
804
|
-
milestones = None
|
|
805
|
-
if args.milestones:
|
|
806
|
-
milestones = [int(x.strip()) for x in args.milestones.split(",")]
|
|
807
|
-
|
|
808
|
-
# Create optimizer using factory
|
|
809
|
-
optimizer = get_optimizer(
|
|
810
|
-
name=args.optimizer,
|
|
811
|
-
params=model.get_optimizer_groups(args.lr, args.weight_decay),
|
|
812
|
-
lr=args.lr,
|
|
813
|
-
weight_decay=args.weight_decay,
|
|
814
|
-
momentum=args.momentum,
|
|
815
|
-
nesterov=args.nesterov,
|
|
816
|
-
betas=betas,
|
|
817
|
-
)
|
|
818
|
-
|
|
819
|
-
# Create loss function using factory
|
|
820
|
-
criterion = get_loss(
|
|
821
|
-
name=args.loss,
|
|
822
|
-
weights=loss_weights,
|
|
823
|
-
delta=args.huber_delta,
|
|
824
|
-
)
|
|
825
|
-
# Move criterion to device (important for WeightedMSELoss buffer)
|
|
826
|
-
criterion = criterion.to(accelerator.device)
|
|
827
|
-
|
|
828
|
-
# ==========================================================================
|
|
829
|
-
# PHYSICAL CONSTRAINTS INTEGRATION
|
|
830
|
-
# ==========================================================================
|
|
831
|
-
from wavedl.utils.constraints import (
|
|
832
|
-
PhysicsConstrainedLoss,
|
|
833
|
-
build_constraints,
|
|
834
|
-
)
|
|
835
|
-
|
|
836
|
-
# Build soft constraints
|
|
837
|
-
soft_constraints = build_constraints(
|
|
838
|
-
expressions=args.constraint,
|
|
839
|
-
file_path=args.constraint_file,
|
|
840
|
-
reduction=args.constraint_reduction,
|
|
841
|
-
)
|
|
842
|
-
|
|
843
|
-
# Wrap criterion with PhysicsConstrainedLoss if we have soft constraints
|
|
844
|
-
if soft_constraints:
|
|
845
|
-
# Pass output scaler so constraints can be evaluated in physical space
|
|
846
|
-
output_mean = scaler.mean_ if hasattr(scaler, "mean_") else None
|
|
847
|
-
output_std = scaler.scale_ if hasattr(scaler, "scale_") else None
|
|
848
|
-
criterion = PhysicsConstrainedLoss(
|
|
849
|
-
criterion,
|
|
850
|
-
soft_constraints,
|
|
851
|
-
weights=args.constraint_weight,
|
|
852
|
-
output_mean=output_mean,
|
|
853
|
-
output_std=output_std,
|
|
854
|
-
)
|
|
855
|
-
if accelerator.is_main_process:
|
|
856
|
-
logger.info(
|
|
857
|
-
f" 🔬 Physical constraints: {len(soft_constraints)} constraint(s) "
|
|
858
|
-
f"with weight(s) {args.constraint_weight}"
|
|
859
|
-
)
|
|
860
|
-
if output_mean is not None:
|
|
861
|
-
logger.info(
|
|
862
|
-
" 📐 Constraints evaluated in physical space (denormalized)"
|
|
863
|
-
)
|
|
864
|
-
|
|
865
|
-
# Track if scheduler should step per batch (OneCycleLR) or per epoch
|
|
866
|
-
scheduler_step_per_batch = not is_epoch_based(args.scheduler)
|
|
867
|
-
|
|
868
|
-
# ==========================================================================
|
|
869
|
-
# DDP Preparation Strategy:
|
|
870
|
-
# - For batch-based schedulers (OneCycleLR): prepare DataLoaders first to get
|
|
871
|
-
# the correct sharded batch count, then create scheduler
|
|
872
|
-
# - For epoch-based schedulers: create scheduler before prepare (no issue)
|
|
873
|
-
# ==========================================================================
|
|
874
|
-
if scheduler_step_per_batch:
|
|
875
|
-
# BATCH-BASED SCHEDULER (e.g., OneCycleLR)
|
|
876
|
-
# Prepare model, optimizer, dataloaders FIRST to get sharded loader length
|
|
877
|
-
model, optimizer, train_dl, val_dl = accelerator.prepare(
|
|
878
|
-
model, optimizer, train_dl, val_dl
|
|
879
|
-
)
|
|
880
|
-
|
|
881
|
-
# Now create scheduler with the CORRECT sharded steps_per_epoch
|
|
882
|
-
steps_per_epoch = len(train_dl) # Post-DDP sharded length
|
|
883
|
-
scheduler = get_scheduler(
|
|
884
|
-
name=args.scheduler,
|
|
885
|
-
optimizer=optimizer,
|
|
886
|
-
epochs=args.epochs,
|
|
887
|
-
steps_per_epoch=steps_per_epoch,
|
|
888
|
-
min_lr=args.min_lr,
|
|
889
|
-
patience=args.scheduler_patience,
|
|
890
|
-
factor=args.scheduler_factor,
|
|
891
|
-
gamma=args.scheduler_factor, # For Step/MultiStep/Exponential schedulers
|
|
892
|
-
step_size=args.step_size,
|
|
893
|
-
milestones=milestones,
|
|
894
|
-
warmup_epochs=args.warmup_epochs,
|
|
895
|
-
)
|
|
896
|
-
# Prepare scheduler separately (Accelerator wraps it for state saving)
|
|
897
|
-
scheduler = accelerator.prepare(scheduler)
|
|
898
|
-
else:
|
|
899
|
-
# EPOCH-BASED SCHEDULER (plateau, cosine, step, etc.)
|
|
900
|
-
# No batch count dependency - create scheduler before prepare
|
|
901
|
-
scheduler = get_scheduler(
|
|
902
|
-
name=args.scheduler,
|
|
903
|
-
optimizer=optimizer,
|
|
904
|
-
epochs=args.epochs,
|
|
905
|
-
steps_per_epoch=None,
|
|
906
|
-
min_lr=args.min_lr,
|
|
907
|
-
patience=args.scheduler_patience,
|
|
908
|
-
factor=args.scheduler_factor,
|
|
909
|
-
gamma=args.scheduler_factor, # For Step/MultiStep/Exponential schedulers
|
|
910
|
-
step_size=args.step_size,
|
|
911
|
-
milestones=milestones,
|
|
912
|
-
warmup_epochs=args.warmup_epochs,
|
|
913
|
-
)
|
|
914
|
-
|
|
915
|
-
# For ReduceLROnPlateau: DON'T include scheduler in accelerator.prepare()
|
|
916
|
-
# because accelerator wraps scheduler.step() to sync across processes,
|
|
917
|
-
# which defeats our rank-0-only stepping for correct patience counting.
|
|
918
|
-
# Other schedulers are safe to prepare (no internal state affected by multi-call).
|
|
919
|
-
if args.scheduler == "plateau":
|
|
920
|
-
model, optimizer, train_dl, val_dl = accelerator.prepare(
|
|
921
|
-
model, optimizer, train_dl, val_dl
|
|
922
|
-
)
|
|
923
|
-
# Scheduler stays unwrapped - we handle sync manually in training loop
|
|
924
|
-
# But register it for checkpointing so state is saved/loaded on resume
|
|
925
|
-
accelerator.register_for_checkpointing(scheduler)
|
|
926
|
-
else:
|
|
927
|
-
model, optimizer, train_dl, val_dl, scheduler = accelerator.prepare(
|
|
928
|
-
model, optimizer, train_dl, val_dl, scheduler
|
|
929
|
-
)
|
|
930
|
-
|
|
931
|
-
# ==========================================================================
|
|
932
|
-
# AUTO-RESUME / RESUME FROM CHECKPOINT
|
|
933
|
-
# ==========================================================================
|
|
934
|
-
start_epoch = 0
|
|
935
|
-
best_val_loss = float("inf")
|
|
936
|
-
patience_ctr = 0
|
|
937
|
-
history: list[dict[str, Any]] = []
|
|
938
|
-
|
|
939
|
-
# Define checkpoint paths
|
|
940
|
-
best_ckpt_path = os.path.join(args.output_dir, "best_checkpoint")
|
|
941
|
-
complete_flag_path = os.path.join(args.output_dir, "training_complete.flag")
|
|
942
|
-
|
|
943
|
-
# Auto-resume logic (if not --fresh and no explicit --resume)
|
|
944
|
-
if not args.fresh and args.resume is None:
|
|
945
|
-
if os.path.exists(complete_flag_path):
|
|
946
|
-
# Training already completed
|
|
947
|
-
if accelerator.is_main_process:
|
|
948
|
-
logger.info(
|
|
949
|
-
"✅ Training already completed (early stopping). Use --fresh to retrain."
|
|
950
|
-
)
|
|
951
|
-
return # Exit gracefully
|
|
952
|
-
elif os.path.exists(best_ckpt_path):
|
|
953
|
-
# Incomplete training found - auto-resume
|
|
954
|
-
args.resume = best_ckpt_path
|
|
955
|
-
if accelerator.is_main_process:
|
|
956
|
-
logger.info(f"🔄 Auto-resuming from: {best_ckpt_path}")
|
|
957
|
-
|
|
958
|
-
if args.resume:
|
|
959
|
-
if os.path.exists(args.resume):
|
|
960
|
-
logger.info(f"🔄 Loading checkpoint from: {args.resume}")
|
|
961
|
-
accelerator.load_state(args.resume)
|
|
962
|
-
|
|
963
|
-
# Restore training metadata
|
|
964
|
-
meta_path = os.path.join(args.resume, "training_meta.pkl")
|
|
965
|
-
if os.path.exists(meta_path):
|
|
966
|
-
with open(meta_path, "rb") as f:
|
|
967
|
-
meta = pickle.load(f)
|
|
968
|
-
start_epoch = meta.get("epoch", 0)
|
|
969
|
-
best_val_loss = meta.get("best_val_loss", float("inf"))
|
|
970
|
-
patience_ctr = meta.get("patience_ctr", 0)
|
|
971
|
-
logger.info(
|
|
972
|
-
f" ✅ Restored: Epoch {start_epoch}, Best Loss: {best_val_loss:.6f}"
|
|
973
|
-
)
|
|
974
|
-
else:
|
|
975
|
-
logger.warning(
|
|
976
|
-
" ⚠️ training_meta.pkl not found, starting from epoch 0"
|
|
977
|
-
)
|
|
978
|
-
|
|
979
|
-
# Restore history
|
|
980
|
-
history_path = os.path.join(args.output_dir, "training_history.csv")
|
|
981
|
-
if os.path.exists(history_path):
|
|
982
|
-
history = pd.read_csv(history_path).to_dict("records")
|
|
983
|
-
logger.info(f" ✅ Loaded {len(history)} epochs from history")
|
|
984
|
-
else:
|
|
985
|
-
raise FileNotFoundError(f"Checkpoint not found: {args.resume}")
|
|
986
|
-
|
|
987
|
-
# ==========================================================================
|
|
988
|
-
# PHYSICAL METRIC SETUP
|
|
989
|
-
# ==========================================================================
|
|
990
|
-
# Physical MAE = normalized MAE * scaler.scale_
|
|
991
|
-
phys_scale = torch.tensor(
|
|
992
|
-
scaler.scale_, device=accelerator.device, dtype=torch.float32
|
|
993
|
-
)
|
|
994
|
-
|
|
995
|
-
# ==========================================================================
|
|
996
|
-
# TRAINING LOOP
|
|
997
|
-
# ==========================================================================
|
|
998
|
-
# Dynamic console header
|
|
999
|
-
if accelerator.is_main_process:
|
|
1000
|
-
base_cols = ["Ep", "TrnLoss", "ValLoss", "R2", "PCC", "GradN", "LR", "MAE_Avg"]
|
|
1001
|
-
param_cols = [f"MAE_P{i}" for i in range(out_dim)]
|
|
1002
|
-
header = "{:<4} | {:<8} | {:<8} | {:<6} | {:<6} | {:<6} | {:<8} | {:<8}".format(
|
|
1003
|
-
*base_cols
|
|
1004
|
-
)
|
|
1005
|
-
header += " | " + " | ".join([f"{c:<8}" for c in param_cols])
|
|
1006
|
-
logger.info("=" * len(header))
|
|
1007
|
-
logger.info(header)
|
|
1008
|
-
logger.info("=" * len(header))
|
|
1009
|
-
|
|
1010
|
-
try:
|
|
1011
|
-
total_training_time = 0.0
|
|
1012
|
-
|
|
1013
|
-
for epoch in range(start_epoch, args.epochs):
|
|
1014
|
-
epoch_start_time = time.time()
|
|
1015
|
-
|
|
1016
|
-
# ==================== TRAINING PHASE ====================
|
|
1017
|
-
model.train()
|
|
1018
|
-
# Use GPU tensor for loss accumulation to avoid .item() sync per batch
|
|
1019
|
-
train_loss_sum = torch.tensor(0.0, device=accelerator.device)
|
|
1020
|
-
train_samples = 0
|
|
1021
|
-
grad_norm_tracker = MetricTracker()
|
|
1022
|
-
|
|
1023
|
-
pbar = tqdm(
|
|
1024
|
-
train_dl,
|
|
1025
|
-
disable=not accelerator.is_main_process,
|
|
1026
|
-
leave=False,
|
|
1027
|
-
desc=f"Epoch {epoch + 1}",
|
|
1028
|
-
)
|
|
1029
|
-
|
|
1030
|
-
for x, y in pbar:
|
|
1031
|
-
with accelerator.accumulate(model):
|
|
1032
|
-
# Use mixed precision for forward pass (respects --precision flag)
|
|
1033
|
-
with accelerator.autocast():
|
|
1034
|
-
pred = model(x)
|
|
1035
|
-
# Pass inputs for input-dependent constraints (x_mean, x[...], etc.)
|
|
1036
|
-
if isinstance(criterion, PhysicsConstrainedLoss):
|
|
1037
|
-
loss = criterion(pred, y, x)
|
|
1038
|
-
else:
|
|
1039
|
-
loss = criterion(pred, y)
|
|
1040
|
-
|
|
1041
|
-
accelerator.backward(loss)
|
|
1042
|
-
|
|
1043
|
-
if accelerator.sync_gradients:
|
|
1044
|
-
grad_norm = accelerator.clip_grad_norm_(
|
|
1045
|
-
model.parameters(), args.grad_clip
|
|
1046
|
-
)
|
|
1047
|
-
if grad_norm is not None:
|
|
1048
|
-
grad_norm_tracker.update(grad_norm.item())
|
|
1049
|
-
|
|
1050
|
-
optimizer.step()
|
|
1051
|
-
optimizer.zero_grad(set_to_none=True) # Faster than zero_grad()
|
|
1052
|
-
|
|
1053
|
-
# Per-batch LR scheduling (e.g., OneCycleLR)
|
|
1054
|
-
if scheduler_step_per_batch:
|
|
1055
|
-
scheduler.step()
|
|
1056
|
-
|
|
1057
|
-
# Accumulate as tensors to avoid .item() sync per batch
|
|
1058
|
-
train_loss_sum += loss.detach() * x.size(0)
|
|
1059
|
-
train_samples += x.size(0)
|
|
1060
|
-
|
|
1061
|
-
# Single .item() call at end of epoch (reduces GPU sync overhead)
|
|
1062
|
-
train_loss_scalar = train_loss_sum.item()
|
|
1063
|
-
|
|
1064
|
-
# Synchronize training metrics across GPUs
|
|
1065
|
-
global_loss = accelerator.reduce(
|
|
1066
|
-
torch.tensor([train_loss_scalar], device=accelerator.device),
|
|
1067
|
-
reduction="sum",
|
|
1068
|
-
).item()
|
|
1069
|
-
global_samples = accelerator.reduce(
|
|
1070
|
-
torch.tensor([train_samples], device=accelerator.device),
|
|
1071
|
-
reduction="sum",
|
|
1072
|
-
).item()
|
|
1073
|
-
avg_train_loss = global_loss / global_samples
|
|
1074
|
-
|
|
1075
|
-
# ==================== VALIDATION PHASE ====================
|
|
1076
|
-
model.eval()
|
|
1077
|
-
# Use GPU tensor for loss accumulation (consistent with training phase)
|
|
1078
|
-
val_loss_sum = torch.tensor(0.0, device=accelerator.device)
|
|
1079
|
-
val_mae_sum = torch.zeros(out_dim, device=accelerator.device)
|
|
1080
|
-
val_samples = 0
|
|
1081
|
-
|
|
1082
|
-
# Accumulate predictions locally ON CPU to prevent GPU OOM
|
|
1083
|
-
local_preds = []
|
|
1084
|
-
local_targets = []
|
|
1085
|
-
|
|
1086
|
-
with torch.inference_mode():
|
|
1087
|
-
for x, y in val_dl:
|
|
1088
|
-
# Use mixed precision for validation (consistent with training)
|
|
1089
|
-
with accelerator.autocast():
|
|
1090
|
-
pred = model(x)
|
|
1091
|
-
# Pass inputs for input-dependent constraints
|
|
1092
|
-
if isinstance(criterion, PhysicsConstrainedLoss):
|
|
1093
|
-
loss = criterion(pred, y, x)
|
|
1094
|
-
else:
|
|
1095
|
-
loss = criterion(pred, y)
|
|
1096
|
-
|
|
1097
|
-
val_loss_sum += loss.detach() * x.size(0)
|
|
1098
|
-
val_samples += x.size(0)
|
|
1099
|
-
|
|
1100
|
-
# Physical MAE
|
|
1101
|
-
mae_batch = torch.abs((pred - y) * phys_scale).sum(dim=0)
|
|
1102
|
-
val_mae_sum += mae_batch
|
|
1103
|
-
|
|
1104
|
-
# Store on CPU (critical for large val sets)
|
|
1105
|
-
local_preds.append(pred.detach().cpu())
|
|
1106
|
-
local_targets.append(y.detach().cpu())
|
|
1107
|
-
|
|
1108
|
-
# Concatenate locally (keep on GPU for gather_for_metrics compatibility)
|
|
1109
|
-
local_preds_cat = torch.cat(local_preds)
|
|
1110
|
-
local_targets_cat = torch.cat(local_targets)
|
|
1111
|
-
|
|
1112
|
-
# Gather predictions and targets using Accelerate's CPU-efficient utility
|
|
1113
|
-
# gather_for_metrics handles:
|
|
1114
|
-
# - DDP padding removal (no need to trim manually)
|
|
1115
|
-
# - Efficient cross-rank gathering without GPU memory spike
|
|
1116
|
-
# - Returns concatenated tensors on CPU for metric computation
|
|
1117
|
-
if accelerator.num_processes > 1:
|
|
1118
|
-
# Move to GPU for gather (required by NCCL), then back to CPU
|
|
1119
|
-
# gather_for_metrics is more memory-efficient than manual gather
|
|
1120
|
-
# as it processes in chunks internally
|
|
1121
|
-
gathered_preds = accelerator.gather_for_metrics(
|
|
1122
|
-
local_preds_cat.to(accelerator.device)
|
|
1123
|
-
).cpu()
|
|
1124
|
-
gathered_targets = accelerator.gather_for_metrics(
|
|
1125
|
-
local_targets_cat.to(accelerator.device)
|
|
1126
|
-
).cpu()
|
|
1127
|
-
else:
|
|
1128
|
-
# Single-GPU mode: no gathering needed
|
|
1129
|
-
gathered_preds = local_preds_cat
|
|
1130
|
-
gathered_targets = local_targets_cat
|
|
1131
|
-
|
|
1132
|
-
# Synchronize validation metrics (scalars only - efficient)
|
|
1133
|
-
val_loss_scalar = val_loss_sum.item()
|
|
1134
|
-
val_metrics = torch.cat(
|
|
1135
|
-
[
|
|
1136
|
-
torch.tensor([val_loss_scalar], device=accelerator.device),
|
|
1137
|
-
val_mae_sum,
|
|
1138
|
-
]
|
|
1139
|
-
)
|
|
1140
|
-
val_metrics_sync = accelerator.reduce(val_metrics, reduction="sum")
|
|
1141
|
-
|
|
1142
|
-
total_val_samples = accelerator.reduce(
|
|
1143
|
-
torch.tensor([val_samples], device=accelerator.device), reduction="sum"
|
|
1144
|
-
).item()
|
|
1145
|
-
|
|
1146
|
-
avg_val_loss = val_metrics_sync[0].item() / total_val_samples
|
|
1147
|
-
# Cast to float32 before numpy (bf16 tensors can't convert directly)
|
|
1148
|
-
avg_mae_per_param = (
|
|
1149
|
-
(val_metrics_sync[1:] / total_val_samples).float().cpu().numpy()
|
|
1150
|
-
)
|
|
1151
|
-
avg_mae = avg_mae_per_param.mean()
|
|
1152
|
-
|
|
1153
|
-
# ==================== LOGGING & CHECKPOINTING ====================
|
|
1154
|
-
if accelerator.is_main_process:
|
|
1155
|
-
# Scientific metrics - cast to float32 before numpy
|
|
1156
|
-
# gather_for_metrics already handles DDP padding removal
|
|
1157
|
-
y_pred = gathered_preds.float().numpy()
|
|
1158
|
-
y_true = gathered_targets.float().numpy()
|
|
1159
|
-
|
|
1160
|
-
# Guard against tiny validation sets (R² undefined for <2 samples)
|
|
1161
|
-
if len(y_true) >= 2:
|
|
1162
|
-
r2 = r2_score(y_true, y_pred)
|
|
1163
|
-
else:
|
|
1164
|
-
r2 = float("nan")
|
|
1165
|
-
pcc = calc_pearson(y_true, y_pred)
|
|
1166
|
-
current_lr = get_lr(optimizer)
|
|
1167
|
-
|
|
1168
|
-
# Update history
|
|
1169
|
-
epoch_end_time = time.time()
|
|
1170
|
-
epoch_time = epoch_end_time - epoch_start_time
|
|
1171
|
-
total_training_time += epoch_time
|
|
1172
|
-
|
|
1173
|
-
epoch_stats = {
|
|
1174
|
-
"epoch": epoch + 1,
|
|
1175
|
-
"train_loss": avg_train_loss,
|
|
1176
|
-
"val_loss": avg_val_loss,
|
|
1177
|
-
"val_r2": r2,
|
|
1178
|
-
"val_pearson": pcc,
|
|
1179
|
-
"val_mae_avg": avg_mae,
|
|
1180
|
-
"grad_norm": grad_norm_tracker.avg,
|
|
1181
|
-
"lr": current_lr,
|
|
1182
|
-
"epoch_time": round(epoch_time, 2),
|
|
1183
|
-
"total_time": round(total_training_time, 2),
|
|
1184
|
-
}
|
|
1185
|
-
for i, mae in enumerate(avg_mae_per_param):
|
|
1186
|
-
epoch_stats[f"MAE_Phys_P{i}"] = mae
|
|
1187
|
-
|
|
1188
|
-
history.append(epoch_stats)
|
|
1189
|
-
|
|
1190
|
-
# Console display
|
|
1191
|
-
base_str = f"{epoch + 1:<4} | {avg_train_loss:<8.4f} | {avg_val_loss:<8.4f} | {r2:<6.4f} | {pcc:<6.4f} | {grad_norm_tracker.avg:<6.4f} | {current_lr:<8.2e} | {avg_mae:<8.4f}"
|
|
1192
|
-
param_str = " | ".join([f"{m:<8.4f}" for m in avg_mae_per_param])
|
|
1193
|
-
logger.info(f"{base_str} | {param_str}")
|
|
1194
|
-
|
|
1195
|
-
# WandB logging
|
|
1196
|
-
if args.wandb and WANDB_AVAILABLE:
|
|
1197
|
-
log_dict = {
|
|
1198
|
-
"main/train_loss": avg_train_loss,
|
|
1199
|
-
"main/val_loss": avg_val_loss,
|
|
1200
|
-
"metrics/r2_score": r2,
|
|
1201
|
-
"metrics/pearson_corr": pcc,
|
|
1202
|
-
"metrics/mae_avg": avg_mae,
|
|
1203
|
-
"system/grad_norm": grad_norm_tracker.avg,
|
|
1204
|
-
"hyper/lr": current_lr,
|
|
1205
|
-
}
|
|
1206
|
-
for i, mae in enumerate(avg_mae_per_param):
|
|
1207
|
-
log_dict[f"mae_detailed/P{i}"] = mae
|
|
1208
|
-
|
|
1209
|
-
# Periodic scatter plots
|
|
1210
|
-
if (epoch % 5 == 0) or (avg_val_loss < best_val_loss):
|
|
1211
|
-
real_true = scaler.inverse_transform(y_true)
|
|
1212
|
-
real_pred = scaler.inverse_transform(y_pred)
|
|
1213
|
-
fig = plot_scientific_scatter(real_true, real_pred)
|
|
1214
|
-
log_dict["plots/scatter_analysis"] = wandb.Image(fig)
|
|
1215
|
-
plt.close(fig)
|
|
1216
|
-
|
|
1217
|
-
accelerator.log(log_dict)
|
|
1218
|
-
|
|
1219
|
-
# ==========================================================================
|
|
1220
|
-
# DDP-SAFE CHECKPOINT LOGIC
|
|
1221
|
-
# ==========================================================================
|
|
1222
|
-
# Step 1: Determine if this is the best epoch (BEFORE updating best_val_loss)
|
|
1223
|
-
is_best_epoch = False
|
|
1224
|
-
if accelerator.is_main_process:
|
|
1225
|
-
if avg_val_loss < best_val_loss:
|
|
1226
|
-
is_best_epoch = True
|
|
1227
|
-
|
|
1228
|
-
# Step 2: Broadcast decision to all ranks (required for save_state)
|
|
1229
|
-
is_best_epoch = broadcast_early_stop(is_best_epoch, accelerator)
|
|
1230
|
-
|
|
1231
|
-
# Step 3: Save checkpoint with all ranks participating
|
|
1232
|
-
if is_best_epoch:
|
|
1233
|
-
ckpt_dir = os.path.join(args.output_dir, "best_checkpoint")
|
|
1234
|
-
with suppress_accelerate_logging():
|
|
1235
|
-
accelerator.save_state(ckpt_dir, safe_serialization=False)
|
|
1236
|
-
|
|
1237
|
-
# Step 4: Rank 0 handles metadata and updates tracking variables
|
|
1238
|
-
if accelerator.is_main_process:
|
|
1239
|
-
best_val_loss = avg_val_loss # Update AFTER checkpoint saved
|
|
1240
|
-
patience_ctr = 0
|
|
1241
|
-
|
|
1242
|
-
with open(os.path.join(ckpt_dir, "training_meta.pkl"), "wb") as f:
|
|
1243
|
-
pickle.dump(
|
|
1244
|
-
{
|
|
1245
|
-
"epoch": epoch + 1,
|
|
1246
|
-
"best_val_loss": best_val_loss,
|
|
1247
|
-
"patience_ctr": patience_ctr,
|
|
1248
|
-
# Model info for auto-detection during inference
|
|
1249
|
-
"model_name": args.model,
|
|
1250
|
-
"in_shape": in_shape,
|
|
1251
|
-
"out_dim": out_dim,
|
|
1252
|
-
},
|
|
1253
|
-
f,
|
|
1254
|
-
)
|
|
1255
|
-
|
|
1256
|
-
# Unwrap model for saving (handle torch.compile compatibility)
|
|
1257
|
-
try:
|
|
1258
|
-
unwrapped = accelerator.unwrap_model(model)
|
|
1259
|
-
except KeyError:
|
|
1260
|
-
# torch.compile model may not have _orig_mod in expected location
|
|
1261
|
-
# Fall back to getting the module directly
|
|
1262
|
-
unwrapped = model.module if hasattr(model, "module") else model
|
|
1263
|
-
# If still compiled, try to get the underlying model
|
|
1264
|
-
if hasattr(unwrapped, "_orig_mod"):
|
|
1265
|
-
unwrapped = unwrapped._orig_mod
|
|
1266
|
-
|
|
1267
|
-
torch.save(
|
|
1268
|
-
unwrapped.state_dict(),
|
|
1269
|
-
os.path.join(args.output_dir, "best_model_weights.pth"),
|
|
1270
|
-
)
|
|
1271
|
-
|
|
1272
|
-
# Copy scaler to checkpoint for portability
|
|
1273
|
-
scaler_src = os.path.join(args.output_dir, "scaler.pkl")
|
|
1274
|
-
scaler_dst = os.path.join(ckpt_dir, "scaler.pkl")
|
|
1275
|
-
if os.path.exists(scaler_src) and not os.path.exists(scaler_dst):
|
|
1276
|
-
shutil.copy2(scaler_src, scaler_dst)
|
|
1277
|
-
|
|
1278
|
-
logger.info(
|
|
1279
|
-
f" 💾 Best model saved (val_loss: {best_val_loss:.6f})"
|
|
1280
|
-
)
|
|
1281
|
-
|
|
1282
|
-
# Also save CSV on best model (ensures progress is saved)
|
|
1283
|
-
pd.DataFrame(history).to_csv(
|
|
1284
|
-
os.path.join(args.output_dir, "training_history.csv"),
|
|
1285
|
-
index=False,
|
|
1286
|
-
)
|
|
1287
|
-
else:
|
|
1288
|
-
if accelerator.is_main_process:
|
|
1289
|
-
patience_ctr += 1
|
|
1290
|
-
|
|
1291
|
-
# Periodic checkpoint (all ranks participate in save_state)
|
|
1292
|
-
periodic_checkpoint_needed = (
|
|
1293
|
-
args.save_every > 0 and (epoch + 1) % args.save_every == 0
|
|
1294
|
-
)
|
|
1295
|
-
if periodic_checkpoint_needed:
|
|
1296
|
-
ckpt_name = f"epoch_{epoch + 1}_checkpoint"
|
|
1297
|
-
ckpt_dir = os.path.join(args.output_dir, ckpt_name)
|
|
1298
|
-
with suppress_accelerate_logging():
|
|
1299
|
-
accelerator.save_state(ckpt_dir, safe_serialization=False)
|
|
1300
|
-
|
|
1301
|
-
if accelerator.is_main_process:
|
|
1302
|
-
with open(os.path.join(ckpt_dir, "training_meta.pkl"), "wb") as f:
|
|
1303
|
-
pickle.dump(
|
|
1304
|
-
{
|
|
1305
|
-
"epoch": epoch + 1,
|
|
1306
|
-
"best_val_loss": best_val_loss,
|
|
1307
|
-
"patience_ctr": patience_ctr,
|
|
1308
|
-
# Model info for auto-detection during inference
|
|
1309
|
-
"model_name": args.model,
|
|
1310
|
-
"in_shape": in_shape,
|
|
1311
|
-
"out_dim": out_dim,
|
|
1312
|
-
},
|
|
1313
|
-
f,
|
|
1314
|
-
)
|
|
1315
|
-
logger.info(f" 📁 Periodic checkpoint: {ckpt_name}")
|
|
1316
|
-
|
|
1317
|
-
# Save CSV with each checkpoint (keeps logs in sync with model state)
|
|
1318
|
-
pd.DataFrame(history).to_csv(
|
|
1319
|
-
os.path.join(args.output_dir, "training_history.csv"),
|
|
1320
|
-
index=False,
|
|
1321
|
-
)
|
|
1322
|
-
|
|
1323
|
-
# Learning rate scheduling (epoch-based schedulers only)
|
|
1324
|
-
# NOTE: For ReduceLROnPlateau with DDP, we must step only on main process
|
|
1325
|
-
# to avoid patience counter being incremented by all GPU processes.
|
|
1326
|
-
# Then we sync the new LR to all processes to keep them consistent.
|
|
1327
|
-
if not scheduler_step_per_batch:
|
|
1328
|
-
if args.scheduler == "plateau":
|
|
1329
|
-
# Step only on main process to avoid multi-GPU patience bug
|
|
1330
|
-
if accelerator.is_main_process:
|
|
1331
|
-
scheduler.step(avg_val_loss)
|
|
1332
|
-
|
|
1333
|
-
# Sync LR across all processes after main process updates it
|
|
1334
|
-
accelerator.wait_for_everyone()
|
|
1335
|
-
|
|
1336
|
-
# Broadcast new LR from rank 0 to all processes
|
|
1337
|
-
if dist.is_initialized():
|
|
1338
|
-
if accelerator.is_main_process:
|
|
1339
|
-
new_lr = optimizer.param_groups[0]["lr"]
|
|
1340
|
-
else:
|
|
1341
|
-
new_lr = 0.0
|
|
1342
|
-
new_lr_tensor = torch.tensor(
|
|
1343
|
-
new_lr, device=accelerator.device, dtype=torch.float32
|
|
1344
|
-
)
|
|
1345
|
-
dist.broadcast(new_lr_tensor, src=0)
|
|
1346
|
-
# Update LR on non-main processes
|
|
1347
|
-
if not accelerator.is_main_process:
|
|
1348
|
-
for param_group in optimizer.param_groups:
|
|
1349
|
-
param_group["lr"] = new_lr_tensor.item()
|
|
1350
|
-
else:
|
|
1351
|
-
scheduler.step()
|
|
1352
|
-
|
|
1353
|
-
# DDP-safe early stopping
|
|
1354
|
-
should_stop = (
|
|
1355
|
-
patience_ctr >= args.patience if accelerator.is_main_process else False
|
|
1356
|
-
)
|
|
1357
|
-
if broadcast_early_stop(should_stop, accelerator):
|
|
1358
|
-
if accelerator.is_main_process:
|
|
1359
|
-
logger.info(
|
|
1360
|
-
f"🛑 Early stopping at epoch {epoch + 1} (patience={args.patience})"
|
|
1361
|
-
)
|
|
1362
|
-
# Create completion flag to prevent auto-resume
|
|
1363
|
-
with open(
|
|
1364
|
-
os.path.join(args.output_dir, "training_complete.flag"), "w"
|
|
1365
|
-
) as f:
|
|
1366
|
-
f.write(
|
|
1367
|
-
f"Training completed via early stopping at epoch {epoch + 1}\n"
|
|
1368
|
-
)
|
|
1369
|
-
break
|
|
1370
|
-
|
|
1371
|
-
except KeyboardInterrupt:
|
|
1372
|
-
logger.warning("Training interrupted. Saving emergency checkpoint...")
|
|
1373
|
-
with suppress_accelerate_logging():
|
|
1374
|
-
accelerator.save_state(
|
|
1375
|
-
os.path.join(args.output_dir, "interrupted_checkpoint"),
|
|
1376
|
-
safe_serialization=False,
|
|
1377
|
-
)
|
|
1378
|
-
|
|
1379
|
-
except Exception as e:
|
|
1380
|
-
logger.error(f"Critical error: {e}", exc_info=True)
|
|
1381
|
-
raise
|
|
1382
|
-
|
|
1383
|
-
else:
|
|
1384
|
-
# Training completed normally (reached max epochs without early stopping)
|
|
1385
|
-
# Create completion flag to prevent auto-resume on re-run
|
|
1386
|
-
if accelerator.is_main_process:
|
|
1387
|
-
if not os.path.exists(complete_flag_path):
|
|
1388
|
-
with open(complete_flag_path, "w") as f:
|
|
1389
|
-
f.write(f"Training completed normally after {args.epochs} epochs\n")
|
|
1390
|
-
logger.info(f"✅ Training completed after {args.epochs} epochs")
|
|
1391
|
-
|
|
1392
|
-
finally:
|
|
1393
|
-
# Final CSV write to capture all epochs (handles non-multiple-of-10 endings)
|
|
1394
|
-
if accelerator.is_main_process and len(history) > 0:
|
|
1395
|
-
pd.DataFrame(history).to_csv(
|
|
1396
|
-
os.path.join(args.output_dir, "training_history.csv"),
|
|
1397
|
-
index=False,
|
|
1398
|
-
)
|
|
1399
|
-
|
|
1400
|
-
# Generate training curves plot (PNG + SVG)
|
|
1401
|
-
if accelerator.is_main_process and len(history) > 0:
|
|
1402
|
-
try:
|
|
1403
|
-
fig = create_training_curves(history, show_lr=True)
|
|
1404
|
-
for fmt in ["png", "svg"]:
|
|
1405
|
-
fig.savefig(
|
|
1406
|
-
os.path.join(args.output_dir, f"training_curves.{fmt}"),
|
|
1407
|
-
dpi=FIGURE_DPI,
|
|
1408
|
-
bbox_inches="tight",
|
|
1409
|
-
)
|
|
1410
|
-
plt.close(fig)
|
|
1411
|
-
logger.info("✔ Saved: training_curves.png, training_curves.svg")
|
|
1412
|
-
except Exception as e:
|
|
1413
|
-
logger.warning(f"Could not generate training curves: {e}")
|
|
1414
|
-
|
|
1415
|
-
if args.wandb and WANDB_AVAILABLE:
|
|
1416
|
-
accelerator.end_training()
|
|
1417
|
-
|
|
1418
|
-
# Clean up distributed process group to prevent resource leak warning
|
|
1419
|
-
if torch.distributed.is_initialized():
|
|
1420
|
-
torch.distributed.destroy_process_group()
|
|
1421
|
-
|
|
1422
|
-
logger.info("Training completed.")
|
|
1423
|
-
|
|
1424
|
-
|
|
1425
|
-
if __name__ == "__main__":
|
|
1426
|
-
try:
|
|
1427
|
-
torch.multiprocessing.set_start_method("spawn")
|
|
1428
|
-
except RuntimeError:
|
|
1429
|
-
pass
|
|
1430
|
-
main()
|
|
1
|
+
"""
|
|
2
|
+
WaveDL - Deep Learning for Wave-based Inverse Problems
|
|
3
|
+
=======================================================
|
|
4
|
+
Target Environment: NVIDIA HPC GPUs (Multi-GPU DDP) | PyTorch 2.x | Python 3.11+
|
|
5
|
+
|
|
6
|
+
A modular training framework for wave-based inverse problems and regression:
|
|
7
|
+
1. HPC-Grade DDP Training: BF16/FP16 mixed precision with torch.compile support
|
|
8
|
+
2. Dynamic Model Selection: Use --model flag to select any registered architecture
|
|
9
|
+
3. Zero-Copy Data Engine: Memmap-backed datasets for large-scale training
|
|
10
|
+
4. Physics-Aware Metrics: Real-time physical MAE with proper unscaling
|
|
11
|
+
5. Robust Checkpointing: Resume training, periodic saves, and training curves
|
|
12
|
+
6. Deep Observability: WandB integration with scatter analysis
|
|
13
|
+
|
|
14
|
+
Usage:
|
|
15
|
+
# Recommended: Using the HPC launcher (handles accelerate configuration)
|
|
16
|
+
wavedl-hpc --model cnn --batch_size 128 --mixed_precision bf16 --wandb
|
|
17
|
+
|
|
18
|
+
# Or direct training module (use --precision, not --mixed_precision)
|
|
19
|
+
accelerate launch -m wavedl.train --model cnn --batch_size 128 --precision bf16
|
|
20
|
+
|
|
21
|
+
# Multi-GPU with explicit config
|
|
22
|
+
wavedl-hpc --num_gpus 4 --mixed_precision bf16 --model cnn --wandb
|
|
23
|
+
|
|
24
|
+
# Resume from checkpoint
|
|
25
|
+
accelerate launch -m wavedl.train --model cnn --resume best_checkpoint --wandb
|
|
26
|
+
|
|
27
|
+
# List available models
|
|
28
|
+
wavedl-train --list_models
|
|
29
|
+
|
|
30
|
+
Note:
|
|
31
|
+
- wavedl-hpc: Uses --mixed_precision (passed to accelerate launch)
|
|
32
|
+
- wavedl.train: Uses --precision (internal module flag)
|
|
33
|
+
Both control the same behavior; use the appropriate flag for your entry point.
|
|
34
|
+
|
|
35
|
+
Author: Ductho Le (ductho.le@outlook.com)
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
from __future__ import annotations
|
|
39
|
+
|
|
40
|
+
# =============================================================================
|
|
41
|
+
# HPC Environment Setup (MUST be before any library imports)
|
|
42
|
+
# =============================================================================
|
|
43
|
+
# Auto-configure writable cache directories when home is not writable.
|
|
44
|
+
# Uses current working directory as fallback - works on HPC and local machines.
|
|
45
|
+
import os
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _setup_cache_dir(env_var: str, subdir: str) -> None:
|
|
49
|
+
"""Set cache directory to CWD if home is not writable."""
|
|
50
|
+
if env_var in os.environ:
|
|
51
|
+
return # User already set, respect their choice
|
|
52
|
+
|
|
53
|
+
# Check if home is writable
|
|
54
|
+
home = os.path.expanduser("~")
|
|
55
|
+
if os.access(home, os.W_OK):
|
|
56
|
+
return # Home is writable, let library use defaults
|
|
57
|
+
|
|
58
|
+
# Home not writable - use current working directory
|
|
59
|
+
cache_path = os.path.join(os.getcwd(), f".{subdir}")
|
|
60
|
+
os.makedirs(cache_path, exist_ok=True)
|
|
61
|
+
os.environ[env_var] = cache_path
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
# Configure cache directories (before any library imports)
|
|
65
|
+
_setup_cache_dir("TORCH_HOME", "torch_cache")
|
|
66
|
+
_setup_cache_dir("MPLCONFIGDIR", "matplotlib")
|
|
67
|
+
_setup_cache_dir("FONTCONFIG_CACHE", "fontconfig")
|
|
68
|
+
_setup_cache_dir("XDG_DATA_HOME", "local/share")
|
|
69
|
+
_setup_cache_dir("XDG_STATE_HOME", "local/state")
|
|
70
|
+
_setup_cache_dir("XDG_CACHE_HOME", "cache")
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _setup_per_rank_compile_cache() -> None:
|
|
74
|
+
"""Set per-GPU Triton/Inductor cache to prevent multi-process race warnings.
|
|
75
|
+
|
|
76
|
+
When using torch.compile with multiple GPUs, all processes try to write to
|
|
77
|
+
the same cache directory, causing 'Directory is not empty - skipping!' warnings.
|
|
78
|
+
This gives each GPU rank its own isolated cache subdirectory.
|
|
79
|
+
"""
|
|
80
|
+
# Get local rank from environment (set by accelerate/torchrun)
|
|
81
|
+
local_rank = os.environ.get("LOCAL_RANK", "0")
|
|
82
|
+
|
|
83
|
+
# Get cache base from environment or use CWD
|
|
84
|
+
cache_base = os.environ.get(
|
|
85
|
+
"TRITON_CACHE_DIR", os.path.join(os.getcwd(), ".triton_cache")
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
# Set per-rank cache directories
|
|
89
|
+
os.environ["TRITON_CACHE_DIR"] = os.path.join(cache_base, f"rank_{local_rank}")
|
|
90
|
+
os.environ["TORCHINDUCTOR_CACHE_DIR"] = os.path.join(
|
|
91
|
+
os.environ.get(
|
|
92
|
+
"TORCHINDUCTOR_CACHE_DIR", os.path.join(os.getcwd(), ".inductor_cache")
|
|
93
|
+
),
|
|
94
|
+
f"rank_{local_rank}",
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
# Create directories
|
|
98
|
+
os.makedirs(os.environ["TRITON_CACHE_DIR"], exist_ok=True)
|
|
99
|
+
os.makedirs(os.environ["TORCHINDUCTOR_CACHE_DIR"], exist_ok=True)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
# Setup per-rank compile caches (before torch imports)
|
|
103
|
+
_setup_per_rank_compile_cache()
|
|
104
|
+
|
|
105
|
+
# =============================================================================
|
|
106
|
+
# Standard imports (after environment setup)
|
|
107
|
+
# =============================================================================
|
|
108
|
+
import argparse
|
|
109
|
+
import logging
|
|
110
|
+
import pickle
|
|
111
|
+
import shutil
|
|
112
|
+
import sys
|
|
113
|
+
import time
|
|
114
|
+
import warnings
|
|
115
|
+
from typing import Any
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
# Suppress Pydantic warnings from accelerate's internal Field() usage
|
|
119
|
+
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic")
|
|
120
|
+
|
|
121
|
+
import matplotlib.pyplot as plt
|
|
122
|
+
import numpy as np
|
|
123
|
+
import pandas as pd
|
|
124
|
+
import torch
|
|
125
|
+
import torch.distributed as dist
|
|
126
|
+
from accelerate import Accelerator
|
|
127
|
+
from accelerate.utils import set_seed
|
|
128
|
+
from sklearn.metrics import r2_score
|
|
129
|
+
from tqdm.auto import tqdm
|
|
130
|
+
|
|
131
|
+
from wavedl.models import build_model, get_model, list_models
|
|
132
|
+
from wavedl.utils import (
|
|
133
|
+
FIGURE_DPI,
|
|
134
|
+
MetricTracker,
|
|
135
|
+
broadcast_early_stop,
|
|
136
|
+
calc_pearson,
|
|
137
|
+
create_training_curves,
|
|
138
|
+
get_loss,
|
|
139
|
+
get_lr,
|
|
140
|
+
get_optimizer,
|
|
141
|
+
get_scheduler,
|
|
142
|
+
is_epoch_based,
|
|
143
|
+
list_losses,
|
|
144
|
+
list_optimizers,
|
|
145
|
+
list_schedulers,
|
|
146
|
+
plot_scientific_scatter,
|
|
147
|
+
prepare_data,
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
try:
|
|
152
|
+
import wandb
|
|
153
|
+
|
|
154
|
+
WANDB_AVAILABLE = True
|
|
155
|
+
except ImportError:
|
|
156
|
+
WANDB_AVAILABLE = False
|
|
157
|
+
|
|
158
|
+
# ==============================================================================
|
|
159
|
+
# RUNTIME CONFIGURATION (post-import)
|
|
160
|
+
# ==============================================================================
|
|
161
|
+
# Configure matplotlib paths for HPC systems without writable home directories
|
|
162
|
+
os.environ.setdefault("MPLCONFIGDIR", os.getenv("TMPDIR", "/tmp") + "/matplotlib")
|
|
163
|
+
os.environ.setdefault("FONTCONFIG_PATH", "/etc/fonts")
|
|
164
|
+
|
|
165
|
+
# Suppress warnings from known-noisy libraries, but preserve legitimate warnings
|
|
166
|
+
# from torch/numpy about NaN, dtype, and numerical issues.
|
|
167
|
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
|
168
|
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
|
169
|
+
# Pydantic v1/v2 compatibility warnings
|
|
170
|
+
warnings.filterwarnings("ignore", module="pydantic")
|
|
171
|
+
warnings.filterwarnings("ignore", message=".*UnsupportedFieldAttributeWarning.*")
|
|
172
|
+
# Transformer library warnings (loading configs, etc.)
|
|
173
|
+
warnings.filterwarnings("ignore", module="transformers")
|
|
174
|
+
# Accelerate verbose messages
|
|
175
|
+
warnings.filterwarnings("ignore", module="accelerate")
|
|
176
|
+
# torch.compile backend selection warnings
|
|
177
|
+
warnings.filterwarnings("ignore", message=".*TorchDynamo.*")
|
|
178
|
+
warnings.filterwarnings("ignore", message=".*Dynamo is not supported.*")
|
|
179
|
+
# Note: UserWarning from torch/numpy core is NOT suppressed to preserve
|
|
180
|
+
# legitimate warnings about NaN values, dtype mismatches, etc.
|
|
181
|
+
|
|
182
|
+
# ==============================================================================
|
|
183
|
+
# GPU PERFORMANCE OPTIMIZATIONS (Ampere/Hopper: A100, H100)
|
|
184
|
+
# ==============================================================================
|
|
185
|
+
# Enable TF32 for faster matmul (safe precision for training, ~2x speedup)
|
|
186
|
+
torch.backends.cuda.matmul.allow_tf32 = True
|
|
187
|
+
torch.backends.cudnn.allow_tf32 = True
|
|
188
|
+
torch.set_float32_matmul_precision("high") # Use TF32 for float32 ops
|
|
189
|
+
|
|
190
|
+
# Enable cuDNN autotuning for fixed-size inputs (CNN-like models benefit most)
|
|
191
|
+
# Note: First few batches may be slower due to benchmarking
|
|
192
|
+
torch.backends.cudnn.benchmark = True
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
# ==============================================================================
|
|
196
|
+
# LOGGING UTILITIES
|
|
197
|
+
# ==============================================================================
|
|
198
|
+
from contextlib import contextmanager
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
@contextmanager
|
|
202
|
+
def suppress_accelerate_logging():
|
|
203
|
+
"""Temporarily suppress accelerate's verbose checkpoint save messages."""
|
|
204
|
+
accelerate_logger = logging.getLogger("accelerate.checkpointing")
|
|
205
|
+
original_level = accelerate_logger.level
|
|
206
|
+
accelerate_logger.setLevel(logging.WARNING)
|
|
207
|
+
try:
|
|
208
|
+
yield
|
|
209
|
+
finally:
|
|
210
|
+
accelerate_logger.setLevel(original_level)
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
# ==============================================================================
|
|
214
|
+
# ARGUMENT PARSING
|
|
215
|
+
# ==============================================================================
|
|
216
|
+
def parse_args() -> argparse.Namespace:
|
|
217
|
+
"""Parse command-line arguments with comprehensive options."""
|
|
218
|
+
parser = argparse.ArgumentParser(
|
|
219
|
+
description="Universal DDP Training Pipeline",
|
|
220
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
# Model Selection
|
|
224
|
+
parser.add_argument(
|
|
225
|
+
"--model",
|
|
226
|
+
type=str,
|
|
227
|
+
default="cnn",
|
|
228
|
+
help=f"Model architecture to train. Available: {list_models()}",
|
|
229
|
+
)
|
|
230
|
+
parser.add_argument(
|
|
231
|
+
"--list_models", action="store_true", help="List all available models and exit"
|
|
232
|
+
)
|
|
233
|
+
parser.add_argument(
|
|
234
|
+
"--import",
|
|
235
|
+
dest="import_modules",
|
|
236
|
+
type=str,
|
|
237
|
+
nargs="+",
|
|
238
|
+
default=[],
|
|
239
|
+
help="Python modules to import before training (for custom models)",
|
|
240
|
+
)
|
|
241
|
+
parser.add_argument(
|
|
242
|
+
"--no_pretrained",
|
|
243
|
+
dest="pretrained",
|
|
244
|
+
action="store_false",
|
|
245
|
+
help="Train from scratch without pretrained weights (default: use pretrained)",
|
|
246
|
+
)
|
|
247
|
+
parser.set_defaults(pretrained=True)
|
|
248
|
+
|
|
249
|
+
# Configuration File
|
|
250
|
+
parser.add_argument(
|
|
251
|
+
"--config",
|
|
252
|
+
type=str,
|
|
253
|
+
default=None,
|
|
254
|
+
help="Path to YAML config file. CLI args override config values.",
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
# Hyperparameters
|
|
258
|
+
parser.add_argument(
|
|
259
|
+
"--batch_size", type=int, default=128, help="Batch size per GPU"
|
|
260
|
+
)
|
|
261
|
+
parser.add_argument("--lr", type=float, default=1e-3, help="Initial learning rate")
|
|
262
|
+
parser.add_argument(
|
|
263
|
+
"--epochs", type=int, default=1000, help="Maximum training epochs"
|
|
264
|
+
)
|
|
265
|
+
parser.add_argument(
|
|
266
|
+
"--patience", type=int, default=20, help="Early stopping patience"
|
|
267
|
+
)
|
|
268
|
+
parser.add_argument("--weight_decay", type=float, default=1e-4, help="Weight decay")
|
|
269
|
+
parser.add_argument(
|
|
270
|
+
"--grad_clip", type=float, default=1.0, help="Gradient clipping norm"
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
# Loss Function
|
|
274
|
+
parser.add_argument(
|
|
275
|
+
"--loss",
|
|
276
|
+
type=str,
|
|
277
|
+
default="mse",
|
|
278
|
+
choices=["mse", "mae", "huber", "smooth_l1", "log_cosh", "weighted_mse"],
|
|
279
|
+
help=f"Loss function for training. Available: {list_losses()}",
|
|
280
|
+
)
|
|
281
|
+
parser.add_argument(
|
|
282
|
+
"--huber_delta", type=float, default=1.0, help="Delta for Huber loss"
|
|
283
|
+
)
|
|
284
|
+
parser.add_argument(
|
|
285
|
+
"--loss_weights",
|
|
286
|
+
type=str,
|
|
287
|
+
default=None,
|
|
288
|
+
help="Comma-separated weights for weighted_mse (e.g., '1.0,2.0,1.0')",
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
# Optimizer
|
|
292
|
+
parser.add_argument(
|
|
293
|
+
"--optimizer",
|
|
294
|
+
type=str,
|
|
295
|
+
default="adamw",
|
|
296
|
+
choices=["adamw", "adam", "sgd", "nadam", "radam", "rmsprop"],
|
|
297
|
+
help=f"Optimizer for training. Available: {list_optimizers()}",
|
|
298
|
+
)
|
|
299
|
+
parser.add_argument(
|
|
300
|
+
"--momentum", type=float, default=0.9, help="Momentum for SGD/RMSprop"
|
|
301
|
+
)
|
|
302
|
+
parser.add_argument(
|
|
303
|
+
"--nesterov", action="store_true", help="Use Nesterov momentum (SGD)"
|
|
304
|
+
)
|
|
305
|
+
parser.add_argument(
|
|
306
|
+
"--betas",
|
|
307
|
+
type=str,
|
|
308
|
+
default="0.9,0.999",
|
|
309
|
+
help="Betas for Adam variants (comma-separated)",
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
# Learning Rate Scheduler
|
|
313
|
+
parser.add_argument(
|
|
314
|
+
"--scheduler",
|
|
315
|
+
type=str,
|
|
316
|
+
default="plateau",
|
|
317
|
+
choices=[
|
|
318
|
+
"plateau",
|
|
319
|
+
"cosine",
|
|
320
|
+
"cosine_restarts",
|
|
321
|
+
"onecycle",
|
|
322
|
+
"step",
|
|
323
|
+
"multistep",
|
|
324
|
+
"exponential",
|
|
325
|
+
"linear_warmup",
|
|
326
|
+
],
|
|
327
|
+
help=f"LR scheduler. Available: {list_schedulers()}",
|
|
328
|
+
)
|
|
329
|
+
parser.add_argument(
|
|
330
|
+
"--scheduler_patience",
|
|
331
|
+
type=int,
|
|
332
|
+
default=10,
|
|
333
|
+
help="Patience for ReduceLROnPlateau",
|
|
334
|
+
)
|
|
335
|
+
parser.add_argument(
|
|
336
|
+
"--min_lr", type=float, default=1e-6, help="Minimum learning rate"
|
|
337
|
+
)
|
|
338
|
+
parser.add_argument(
|
|
339
|
+
"--scheduler_factor", type=float, default=0.5, help="LR reduction factor"
|
|
340
|
+
)
|
|
341
|
+
parser.add_argument(
|
|
342
|
+
"--warmup_epochs", type=int, default=5, help="Warmup epochs for linear_warmup"
|
|
343
|
+
)
|
|
344
|
+
parser.add_argument(
|
|
345
|
+
"--step_size", type=int, default=30, help="Step size for StepLR"
|
|
346
|
+
)
|
|
347
|
+
parser.add_argument(
|
|
348
|
+
"--milestones",
|
|
349
|
+
type=str,
|
|
350
|
+
default=None,
|
|
351
|
+
help="Comma-separated epochs for MultiStepLR (e.g., '30,60,90')",
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
# Data
|
|
355
|
+
parser.add_argument(
|
|
356
|
+
"--data_path", type=str, default="train_data.npz", help="Path to NPZ dataset"
|
|
357
|
+
)
|
|
358
|
+
parser.add_argument(
|
|
359
|
+
"--workers",
|
|
360
|
+
type=int,
|
|
361
|
+
default=-1,
|
|
362
|
+
help="DataLoader workers per GPU (-1=auto-detect based on CPU cores)",
|
|
363
|
+
)
|
|
364
|
+
parser.add_argument("--seed", type=int, default=2025, help="Random seed")
|
|
365
|
+
parser.add_argument(
|
|
366
|
+
"--deterministic",
|
|
367
|
+
action="store_true",
|
|
368
|
+
help="Enable deterministic mode for reproducibility (slower, disables TF32/cuDNN benchmark)",
|
|
369
|
+
)
|
|
370
|
+
parser.add_argument(
|
|
371
|
+
"--cache_validate",
|
|
372
|
+
type=str,
|
|
373
|
+
default="sha256",
|
|
374
|
+
choices=["sha256", "fast", "size"],
|
|
375
|
+
help="Cache validation mode: sha256 (full hash), fast (partial), size (quick)",
|
|
376
|
+
)
|
|
377
|
+
parser.add_argument(
|
|
378
|
+
"--single_channel",
|
|
379
|
+
action="store_true",
|
|
380
|
+
help="Confirm data is single-channel (suppress ambiguous shape warnings for shallow 3D volumes)",
|
|
381
|
+
)
|
|
382
|
+
|
|
383
|
+
# Cross-Validation
|
|
384
|
+
parser.add_argument(
|
|
385
|
+
"--cv",
|
|
386
|
+
type=int,
|
|
387
|
+
default=0,
|
|
388
|
+
help="Enable K-fold cross-validation with K folds (0=disabled)",
|
|
389
|
+
)
|
|
390
|
+
parser.add_argument(
|
|
391
|
+
"--cv_stratify",
|
|
392
|
+
action="store_true",
|
|
393
|
+
help="Use stratified splitting for cross-validation",
|
|
394
|
+
)
|
|
395
|
+
parser.add_argument(
|
|
396
|
+
"--cv_bins",
|
|
397
|
+
type=int,
|
|
398
|
+
default=10,
|
|
399
|
+
help="Number of bins for stratified CV (only with --cv_stratify)",
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
# Checkpointing & Resume
|
|
403
|
+
parser.add_argument(
|
|
404
|
+
"--resume", type=str, default=None, help="Checkpoint directory to resume from"
|
|
405
|
+
)
|
|
406
|
+
parser.add_argument(
|
|
407
|
+
"--save_every",
|
|
408
|
+
type=int,
|
|
409
|
+
default=50,
|
|
410
|
+
help="Save checkpoint every N epochs (0=disable)",
|
|
411
|
+
)
|
|
412
|
+
parser.add_argument(
|
|
413
|
+
"--output_dir", type=str, default=".", help="Output directory for checkpoints"
|
|
414
|
+
)
|
|
415
|
+
parser.add_argument(
|
|
416
|
+
"--fresh",
|
|
417
|
+
action="store_true",
|
|
418
|
+
help="Force fresh training, ignore existing checkpoints",
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
# Performance
|
|
422
|
+
parser.add_argument(
|
|
423
|
+
"--compile", action="store_true", help="Enable torch.compile (PyTorch 2.x)"
|
|
424
|
+
)
|
|
425
|
+
parser.add_argument(
|
|
426
|
+
"--precision",
|
|
427
|
+
type=str,
|
|
428
|
+
default="bf16",
|
|
429
|
+
choices=["bf16", "fp16", "no"],
|
|
430
|
+
help="Mixed precision mode",
|
|
431
|
+
)
|
|
432
|
+
# Alias for consistency with wavedl-hpc (--mixed_precision)
|
|
433
|
+
parser.add_argument(
|
|
434
|
+
"--mixed_precision",
|
|
435
|
+
dest="precision",
|
|
436
|
+
type=str,
|
|
437
|
+
choices=["bf16", "fp16", "no"],
|
|
438
|
+
help=argparse.SUPPRESS, # Hidden: use --precision instead
|
|
439
|
+
)
|
|
440
|
+
|
|
441
|
+
# Physical Constraints
|
|
442
|
+
parser.add_argument(
|
|
443
|
+
"--constraint",
|
|
444
|
+
type=str,
|
|
445
|
+
nargs="+",
|
|
446
|
+
default=[],
|
|
447
|
+
help="Soft constraint expressions: 'y0 - y1*y2' (penalize violations)",
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
parser.add_argument(
|
|
451
|
+
"--constraint_file",
|
|
452
|
+
type=str,
|
|
453
|
+
default=None,
|
|
454
|
+
help="Python file with constraint(pred, inputs) function",
|
|
455
|
+
)
|
|
456
|
+
parser.add_argument(
|
|
457
|
+
"--constraint_weight",
|
|
458
|
+
type=float,
|
|
459
|
+
nargs="+",
|
|
460
|
+
default=[0.1],
|
|
461
|
+
help="Weight(s) for soft constraints (one per constraint, or single shared weight)",
|
|
462
|
+
)
|
|
463
|
+
parser.add_argument(
|
|
464
|
+
"--constraint_reduction",
|
|
465
|
+
type=str,
|
|
466
|
+
default="mse",
|
|
467
|
+
choices=["mse", "mae"],
|
|
468
|
+
help="Reduction mode for constraint penalties",
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
# Logging
|
|
472
|
+
parser.add_argument(
|
|
473
|
+
"--wandb", action="store_true", help="Enable Weights & Biases logging"
|
|
474
|
+
)
|
|
475
|
+
parser.add_argument(
|
|
476
|
+
"--wandb_watch",
|
|
477
|
+
action="store_true",
|
|
478
|
+
help="Enable WandB gradient watching (adds overhead, useful for debugging)",
|
|
479
|
+
)
|
|
480
|
+
parser.add_argument(
|
|
481
|
+
"--project_name", type=str, default="DL-Training", help="WandB project name"
|
|
482
|
+
)
|
|
483
|
+
parser.add_argument("--run_name", type=str, default=None, help="WandB run name")
|
|
484
|
+
|
|
485
|
+
args = parser.parse_args()
|
|
486
|
+
return args, parser # Returns (Namespace, ArgumentParser)
|
|
487
|
+
|
|
488
|
+
|
|
489
|
+
# ==============================================================================
|
|
490
|
+
# MAIN TRAINING FUNCTION
|
|
491
|
+
# ==============================================================================
|
|
492
|
+
def main():
|
|
493
|
+
args, parser = parse_args()
|
|
494
|
+
|
|
495
|
+
# Import custom model modules if specified
|
|
496
|
+
if args.import_modules:
|
|
497
|
+
import importlib
|
|
498
|
+
|
|
499
|
+
for module_name in args.import_modules:
|
|
500
|
+
try:
|
|
501
|
+
# Handle both module names (my_model) and file paths (./my_model.py)
|
|
502
|
+
if module_name.endswith(".py"):
|
|
503
|
+
# Import from file path with unique module name
|
|
504
|
+
import importlib.util
|
|
505
|
+
|
|
506
|
+
# Derive unique module name from filename to avoid collisions
|
|
507
|
+
base_name = os.path.splitext(os.path.basename(module_name))[0]
|
|
508
|
+
unique_name = f"wavedl_custom_{base_name}"
|
|
509
|
+
|
|
510
|
+
spec = importlib.util.spec_from_file_location(
|
|
511
|
+
unique_name, module_name
|
|
512
|
+
)
|
|
513
|
+
if spec and spec.loader:
|
|
514
|
+
module = importlib.util.module_from_spec(spec)
|
|
515
|
+
sys.modules[unique_name] = module
|
|
516
|
+
spec.loader.exec_module(module)
|
|
517
|
+
print(f"✓ Imported custom module from: {module_name}")
|
|
518
|
+
else:
|
|
519
|
+
# Import as regular module
|
|
520
|
+
importlib.import_module(module_name)
|
|
521
|
+
print(f"✓ Imported module: {module_name}")
|
|
522
|
+
except (ImportError, FileNotFoundError, SyntaxError, PermissionError) as e:
|
|
523
|
+
print(f"✗ Failed to import '{module_name}': {e}", file=sys.stderr)
|
|
524
|
+
if isinstance(e, FileNotFoundError):
|
|
525
|
+
print(" File does not exist. Check the path.", file=sys.stderr)
|
|
526
|
+
elif isinstance(e, SyntaxError):
|
|
527
|
+
print(
|
|
528
|
+
f" Syntax error at line {e.lineno}: {e.msg}", file=sys.stderr
|
|
529
|
+
)
|
|
530
|
+
elif isinstance(e, PermissionError):
|
|
531
|
+
print(
|
|
532
|
+
" Permission denied. Check file permissions.", file=sys.stderr
|
|
533
|
+
)
|
|
534
|
+
else:
|
|
535
|
+
print(
|
|
536
|
+
" Make sure the module is in your Python path or current directory.",
|
|
537
|
+
file=sys.stderr,
|
|
538
|
+
)
|
|
539
|
+
sys.exit(1)
|
|
540
|
+
|
|
541
|
+
# Handle --list_models flag
|
|
542
|
+
if args.list_models:
|
|
543
|
+
print("Available models:")
|
|
544
|
+
for name in list_models():
|
|
545
|
+
ModelClass = get_model(name)
|
|
546
|
+
# Get first non-empty docstring line
|
|
547
|
+
if ModelClass.__doc__:
|
|
548
|
+
lines = [
|
|
549
|
+
l.strip() for l in ModelClass.__doc__.splitlines() if l.strip()
|
|
550
|
+
]
|
|
551
|
+
doc_first_line = lines[0] if lines else "No description"
|
|
552
|
+
else:
|
|
553
|
+
doc_first_line = "No description"
|
|
554
|
+
print(f" - {name}: {doc_first_line}")
|
|
555
|
+
sys.exit(0)
|
|
556
|
+
|
|
557
|
+
# Load and merge config file if provided
|
|
558
|
+
if args.config:
|
|
559
|
+
from wavedl.utils.config import (
|
|
560
|
+
load_config,
|
|
561
|
+
merge_config_with_args,
|
|
562
|
+
validate_config,
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
print(f"📄 Loading config from: {args.config}")
|
|
566
|
+
config = load_config(args.config)
|
|
567
|
+
|
|
568
|
+
# Validate config values
|
|
569
|
+
warnings_list = validate_config(config)
|
|
570
|
+
for w in warnings_list:
|
|
571
|
+
print(f" ⚠ {w}")
|
|
572
|
+
|
|
573
|
+
# Merge config with CLI args (CLI takes precedence via parser defaults detection)
|
|
574
|
+
args = merge_config_with_args(config, args, parser=parser)
|
|
575
|
+
|
|
576
|
+
# Handle --cv flag (cross-validation mode)
|
|
577
|
+
if args.cv > 0:
|
|
578
|
+
print(f"🔄 Cross-Validation Mode: {args.cv} folds")
|
|
579
|
+
from wavedl.utils.cross_validation import run_cross_validation
|
|
580
|
+
|
|
581
|
+
# Load data for CV using memory-efficient loader
|
|
582
|
+
from wavedl.utils.data import DataSource, get_data_source
|
|
583
|
+
|
|
584
|
+
data_format = DataSource.detect_format(args.data_path)
|
|
585
|
+
source = get_data_source(data_format)
|
|
586
|
+
|
|
587
|
+
# Use memory-mapped loading when available (now returns LazyDataHandle for all formats)
|
|
588
|
+
_cv_handle = None
|
|
589
|
+
if hasattr(source, "load_mmap"):
|
|
590
|
+
_cv_handle = source.load_mmap(args.data_path)
|
|
591
|
+
X, y = _cv_handle.inputs, _cv_handle.outputs
|
|
592
|
+
else:
|
|
593
|
+
X, y = source.load(args.data_path)
|
|
594
|
+
|
|
595
|
+
# Handle sparse matrices (must materialize for CV shuffling)
|
|
596
|
+
if hasattr(X, "__getitem__") and len(X) > 0 and hasattr(X[0], "toarray"):
|
|
597
|
+
X = np.stack([x.toarray() for x in X])
|
|
598
|
+
|
|
599
|
+
# Normalize target shape: (N,) -> (N, 1) for consistency
|
|
600
|
+
if y.ndim == 1:
|
|
601
|
+
y = y.reshape(-1, 1)
|
|
602
|
+
|
|
603
|
+
# Validate and determine input shape (consistent with prepare_data)
|
|
604
|
+
# Check for ambiguous shapes that could be multi-channel or shallow 3D volume
|
|
605
|
+
sample_shape = X.shape[1:] # Per-sample shape
|
|
606
|
+
|
|
607
|
+
# Same heuristic as prepare_data: detect ambiguous 3D shapes
|
|
608
|
+
is_ambiguous_shape = (
|
|
609
|
+
len(sample_shape) == 3 # Exactly 3D: could be (C, H, W) or (D, H, W)
|
|
610
|
+
and sample_shape[0] <= 16 # First dim looks like channels
|
|
611
|
+
and sample_shape[1] > 16
|
|
612
|
+
and sample_shape[2] > 16 # Both spatial dims are large
|
|
613
|
+
)
|
|
614
|
+
|
|
615
|
+
if is_ambiguous_shape and not args.single_channel:
|
|
616
|
+
raise ValueError(
|
|
617
|
+
f"Ambiguous input shape detected: sample shape {sample_shape}. "
|
|
618
|
+
f"This could be either:\n"
|
|
619
|
+
f" - Multi-channel 2D data (C={sample_shape[0]}, H={sample_shape[1]}, W={sample_shape[2]})\n"
|
|
620
|
+
f" - Single-channel 3D volume (D={sample_shape[0]}, H={sample_shape[1]}, W={sample_shape[2]})\n\n"
|
|
621
|
+
f"If this is single-channel 3D/shallow volume data, use --single_channel flag.\n"
|
|
622
|
+
f"If this is multi-channel 2D data, reshape to (N*C, H, W) with adjusted targets."
|
|
623
|
+
)
|
|
624
|
+
|
|
625
|
+
# in_shape = spatial dimensions for model registry (channel added during training)
|
|
626
|
+
in_shape = sample_shape
|
|
627
|
+
|
|
628
|
+
# Run cross-validation
|
|
629
|
+
try:
|
|
630
|
+
run_cross_validation(
|
|
631
|
+
X=X,
|
|
632
|
+
y=y,
|
|
633
|
+
model_name=args.model,
|
|
634
|
+
in_shape=in_shape,
|
|
635
|
+
out_size=y.shape[1],
|
|
636
|
+
folds=args.cv,
|
|
637
|
+
stratify=args.cv_stratify,
|
|
638
|
+
stratify_bins=args.cv_bins,
|
|
639
|
+
batch_size=args.batch_size,
|
|
640
|
+
lr=args.lr,
|
|
641
|
+
epochs=args.epochs,
|
|
642
|
+
patience=args.patience,
|
|
643
|
+
weight_decay=args.weight_decay,
|
|
644
|
+
loss_name=args.loss,
|
|
645
|
+
optimizer_name=args.optimizer,
|
|
646
|
+
scheduler_name=args.scheduler,
|
|
647
|
+
output_dir=args.output_dir,
|
|
648
|
+
workers=args.workers,
|
|
649
|
+
seed=args.seed,
|
|
650
|
+
)
|
|
651
|
+
finally:
|
|
652
|
+
# Clean up file handle if HDF5/MAT
|
|
653
|
+
if _cv_handle is not None and hasattr(_cv_handle, "close"):
|
|
654
|
+
try:
|
|
655
|
+
_cv_handle.close()
|
|
656
|
+
except Exception as e:
|
|
657
|
+
logging.debug(f"Failed to close CV data handle: {e}")
|
|
658
|
+
return
|
|
659
|
+
|
|
660
|
+
# ==========================================================================
|
|
661
|
+
# SYSTEM INITIALIZATION
|
|
662
|
+
# ==========================================================================
|
|
663
|
+
# Initialize Accelerator for DDP and mixed precision
|
|
664
|
+
accelerator = Accelerator(
|
|
665
|
+
mixed_precision=args.precision,
|
|
666
|
+
log_with="wandb" if args.wandb and WANDB_AVAILABLE else None,
|
|
667
|
+
)
|
|
668
|
+
set_seed(args.seed)
|
|
669
|
+
|
|
670
|
+
# Deterministic mode for scientific reproducibility
|
|
671
|
+
# Disables TF32 and cuDNN benchmark for exact reproducibility (slower)
|
|
672
|
+
if args.deterministic:
|
|
673
|
+
torch.backends.cudnn.benchmark = False
|
|
674
|
+
torch.backends.cudnn.deterministic = True
|
|
675
|
+
torch.backends.cuda.matmul.allow_tf32 = False
|
|
676
|
+
torch.backends.cudnn.allow_tf32 = False
|
|
677
|
+
torch.use_deterministic_algorithms(True, warn_only=True)
|
|
678
|
+
if accelerator.is_main_process:
|
|
679
|
+
print("🔒 Deterministic mode enabled (slower but reproducible)")
|
|
680
|
+
|
|
681
|
+
# Configure logging (rank 0 only prints to console)
|
|
682
|
+
logging.basicConfig(
|
|
683
|
+
level=logging.INFO if accelerator.is_main_process else logging.ERROR,
|
|
684
|
+
format="%(asctime)s | %(levelname)s | %(message)s",
|
|
685
|
+
datefmt="%H:%M:%S",
|
|
686
|
+
)
|
|
687
|
+
logger = logging.getLogger("Trainer")
|
|
688
|
+
|
|
689
|
+
# Ensure output directory exists (critical for cache files, checkpoints, etc.)
|
|
690
|
+
os.makedirs(args.output_dir, exist_ok=True)
|
|
691
|
+
|
|
692
|
+
# Auto-detect optimal DataLoader workers if not specified
|
|
693
|
+
if args.workers < 0:
|
|
694
|
+
cpu_count = os.cpu_count() or 4
|
|
695
|
+
num_gpus = accelerator.num_processes
|
|
696
|
+
# Heuristic: 4-16 workers per GPU, bounded by available CPU cores
|
|
697
|
+
# Increased cap from 8 to 16 for high-throughput GPUs (H100, A100)
|
|
698
|
+
args.workers = min(16, max(2, (cpu_count - 2) // num_gpus))
|
|
699
|
+
if accelerator.is_main_process:
|
|
700
|
+
logger.info(
|
|
701
|
+
f"⚙️ Auto-detected workers: {args.workers} per GPU "
|
|
702
|
+
f"(CPUs: {cpu_count}, GPUs: {num_gpus})"
|
|
703
|
+
)
|
|
704
|
+
|
|
705
|
+
if accelerator.is_main_process:
|
|
706
|
+
logger.info(f"🚀 Cluster Status: {accelerator.num_processes}x GPUs detected")
|
|
707
|
+
logger.info(
|
|
708
|
+
f" Model: {args.model} | Precision: {args.precision} | Compile: {args.compile}"
|
|
709
|
+
)
|
|
710
|
+
logger.info(
|
|
711
|
+
f" Loss: {args.loss} | Optimizer: {args.optimizer} | Scheduler: {args.scheduler}"
|
|
712
|
+
)
|
|
713
|
+
logger.info(f" Early Stopping Patience: {args.patience} epochs")
|
|
714
|
+
if args.save_every > 0:
|
|
715
|
+
logger.info(f" Periodic Checkpointing: Every {args.save_every} epochs")
|
|
716
|
+
if args.resume:
|
|
717
|
+
logger.info(f" 📂 Resuming from: {args.resume}")
|
|
718
|
+
|
|
719
|
+
# Initialize WandB
|
|
720
|
+
if args.wandb and WANDB_AVAILABLE:
|
|
721
|
+
accelerator.init_trackers(
|
|
722
|
+
project_name=args.project_name,
|
|
723
|
+
config=vars(args),
|
|
724
|
+
init_kwargs={"wandb": {"name": args.run_name or f"{args.model}_run"}},
|
|
725
|
+
)
|
|
726
|
+
|
|
727
|
+
# ==========================================================================
|
|
728
|
+
# DATA & MODEL LOADING
|
|
729
|
+
# ==========================================================================
|
|
730
|
+
train_dl, val_dl, scaler, in_shape, out_dim = prepare_data(
|
|
731
|
+
args, logger, accelerator, cache_dir=args.output_dir
|
|
732
|
+
)
|
|
733
|
+
|
|
734
|
+
# Build model using registry
|
|
735
|
+
model = build_model(
|
|
736
|
+
args.model, in_shape=in_shape, out_size=out_dim, pretrained=args.pretrained
|
|
737
|
+
)
|
|
738
|
+
|
|
739
|
+
if accelerator.is_main_process:
|
|
740
|
+
param_info = model.parameter_summary()
|
|
741
|
+
logger.info(
|
|
742
|
+
f" Model Parameters: {param_info['trainable_parameters']:,} trainable"
|
|
743
|
+
)
|
|
744
|
+
logger.info(f" Model Size: {param_info['total_mb']:.2f} MB")
|
|
745
|
+
|
|
746
|
+
# Optional WandB model watching (opt-in due to overhead on large models)
|
|
747
|
+
if (
|
|
748
|
+
args.wandb
|
|
749
|
+
and args.wandb_watch
|
|
750
|
+
and WANDB_AVAILABLE
|
|
751
|
+
and accelerator.is_main_process
|
|
752
|
+
):
|
|
753
|
+
wandb.watch(model, log="gradients", log_freq=100)
|
|
754
|
+
logger.info(" 📊 WandB gradient watching enabled")
|
|
755
|
+
|
|
756
|
+
# Torch 2.0 compilation (requires compatible Triton on GPU)
|
|
757
|
+
if args.compile:
|
|
758
|
+
try:
|
|
759
|
+
# Test if Triton is available - just import the package
|
|
760
|
+
# Different Triton versions have different internal APIs, so just check base import
|
|
761
|
+
import triton
|
|
762
|
+
|
|
763
|
+
model = torch.compile(model)
|
|
764
|
+
if accelerator.is_main_process:
|
|
765
|
+
logger.info(" ✔ torch.compile enabled (Triton backend)")
|
|
766
|
+
except ImportError as e:
|
|
767
|
+
if accelerator.is_main_process:
|
|
768
|
+
if "triton" in str(e).lower():
|
|
769
|
+
logger.warning(
|
|
770
|
+
" ⚠ Triton not installed or incompatible version - torch.compile disabled. "
|
|
771
|
+
"Training will proceed without compilation."
|
|
772
|
+
)
|
|
773
|
+
else:
|
|
774
|
+
logger.warning(
|
|
775
|
+
f" ⚠ torch.compile setup failed: {e}. Continuing without compilation."
|
|
776
|
+
)
|
|
777
|
+
except Exception as e:
|
|
778
|
+
if accelerator.is_main_process:
|
|
779
|
+
logger.warning(
|
|
780
|
+
f" ⚠ torch.compile failed: {e}. Continuing without compilation."
|
|
781
|
+
)
|
|
782
|
+
|
|
783
|
+
# ==========================================================================
|
|
784
|
+
# OPTIMIZER, SCHEDULER & LOSS CONFIGURATION
|
|
785
|
+
# ==========================================================================
|
|
786
|
+
# Parse comma-separated arguments with validation
|
|
787
|
+
try:
|
|
788
|
+
betas_list = [float(x.strip()) for x in args.betas.split(",")]
|
|
789
|
+
if len(betas_list) != 2:
|
|
790
|
+
raise ValueError(
|
|
791
|
+
f"--betas must have exactly 2 values, got {len(betas_list)}"
|
|
792
|
+
)
|
|
793
|
+
if not all(0.0 <= b < 1.0 for b in betas_list):
|
|
794
|
+
raise ValueError(f"--betas values must be in [0, 1), got {betas_list}")
|
|
795
|
+
betas = tuple(betas_list)
|
|
796
|
+
except ValueError as e:
|
|
797
|
+
raise ValueError(
|
|
798
|
+
f"Invalid --betas format '{args.betas}': {e}. Expected format: '0.9,0.999'"
|
|
799
|
+
)
|
|
800
|
+
|
|
801
|
+
loss_weights = None
|
|
802
|
+
if args.loss_weights:
|
|
803
|
+
loss_weights = [float(x.strip()) for x in args.loss_weights.split(",")]
|
|
804
|
+
milestones = None
|
|
805
|
+
if args.milestones:
|
|
806
|
+
milestones = [int(x.strip()) for x in args.milestones.split(",")]
|
|
807
|
+
|
|
808
|
+
# Create optimizer using factory
|
|
809
|
+
optimizer = get_optimizer(
|
|
810
|
+
name=args.optimizer,
|
|
811
|
+
params=model.get_optimizer_groups(args.lr, args.weight_decay),
|
|
812
|
+
lr=args.lr,
|
|
813
|
+
weight_decay=args.weight_decay,
|
|
814
|
+
momentum=args.momentum,
|
|
815
|
+
nesterov=args.nesterov,
|
|
816
|
+
betas=betas,
|
|
817
|
+
)
|
|
818
|
+
|
|
819
|
+
# Create loss function using factory
|
|
820
|
+
criterion = get_loss(
|
|
821
|
+
name=args.loss,
|
|
822
|
+
weights=loss_weights,
|
|
823
|
+
delta=args.huber_delta,
|
|
824
|
+
)
|
|
825
|
+
# Move criterion to device (important for WeightedMSELoss buffer)
|
|
826
|
+
criterion = criterion.to(accelerator.device)
|
|
827
|
+
|
|
828
|
+
# ==========================================================================
|
|
829
|
+
# PHYSICAL CONSTRAINTS INTEGRATION
|
|
830
|
+
# ==========================================================================
|
|
831
|
+
from wavedl.utils.constraints import (
|
|
832
|
+
PhysicsConstrainedLoss,
|
|
833
|
+
build_constraints,
|
|
834
|
+
)
|
|
835
|
+
|
|
836
|
+
# Build soft constraints
|
|
837
|
+
soft_constraints = build_constraints(
|
|
838
|
+
expressions=args.constraint,
|
|
839
|
+
file_path=args.constraint_file,
|
|
840
|
+
reduction=args.constraint_reduction,
|
|
841
|
+
)
|
|
842
|
+
|
|
843
|
+
# Wrap criterion with PhysicsConstrainedLoss if we have soft constraints
|
|
844
|
+
if soft_constraints:
|
|
845
|
+
# Pass output scaler so constraints can be evaluated in physical space
|
|
846
|
+
output_mean = scaler.mean_ if hasattr(scaler, "mean_") else None
|
|
847
|
+
output_std = scaler.scale_ if hasattr(scaler, "scale_") else None
|
|
848
|
+
criterion = PhysicsConstrainedLoss(
|
|
849
|
+
criterion,
|
|
850
|
+
soft_constraints,
|
|
851
|
+
weights=args.constraint_weight,
|
|
852
|
+
output_mean=output_mean,
|
|
853
|
+
output_std=output_std,
|
|
854
|
+
)
|
|
855
|
+
if accelerator.is_main_process:
|
|
856
|
+
logger.info(
|
|
857
|
+
f" 🔬 Physical constraints: {len(soft_constraints)} constraint(s) "
|
|
858
|
+
f"with weight(s) {args.constraint_weight}"
|
|
859
|
+
)
|
|
860
|
+
if output_mean is not None:
|
|
861
|
+
logger.info(
|
|
862
|
+
" 📐 Constraints evaluated in physical space (denormalized)"
|
|
863
|
+
)
|
|
864
|
+
|
|
865
|
+
# Track if scheduler should step per batch (OneCycleLR) or per epoch
|
|
866
|
+
scheduler_step_per_batch = not is_epoch_based(args.scheduler)
|
|
867
|
+
|
|
868
|
+
# ==========================================================================
|
|
869
|
+
# DDP Preparation Strategy:
|
|
870
|
+
# - For batch-based schedulers (OneCycleLR): prepare DataLoaders first to get
|
|
871
|
+
# the correct sharded batch count, then create scheduler
|
|
872
|
+
# - For epoch-based schedulers: create scheduler before prepare (no issue)
|
|
873
|
+
# ==========================================================================
|
|
874
|
+
if scheduler_step_per_batch:
|
|
875
|
+
# BATCH-BASED SCHEDULER (e.g., OneCycleLR)
|
|
876
|
+
# Prepare model, optimizer, dataloaders FIRST to get sharded loader length
|
|
877
|
+
model, optimizer, train_dl, val_dl = accelerator.prepare(
|
|
878
|
+
model, optimizer, train_dl, val_dl
|
|
879
|
+
)
|
|
880
|
+
|
|
881
|
+
# Now create scheduler with the CORRECT sharded steps_per_epoch
|
|
882
|
+
steps_per_epoch = len(train_dl) # Post-DDP sharded length
|
|
883
|
+
scheduler = get_scheduler(
|
|
884
|
+
name=args.scheduler,
|
|
885
|
+
optimizer=optimizer,
|
|
886
|
+
epochs=args.epochs,
|
|
887
|
+
steps_per_epoch=steps_per_epoch,
|
|
888
|
+
min_lr=args.min_lr,
|
|
889
|
+
patience=args.scheduler_patience,
|
|
890
|
+
factor=args.scheduler_factor,
|
|
891
|
+
gamma=args.scheduler_factor, # For Step/MultiStep/Exponential schedulers
|
|
892
|
+
step_size=args.step_size,
|
|
893
|
+
milestones=milestones,
|
|
894
|
+
warmup_epochs=args.warmup_epochs,
|
|
895
|
+
)
|
|
896
|
+
# Prepare scheduler separately (Accelerator wraps it for state saving)
|
|
897
|
+
scheduler = accelerator.prepare(scheduler)
|
|
898
|
+
else:
|
|
899
|
+
# EPOCH-BASED SCHEDULER (plateau, cosine, step, etc.)
|
|
900
|
+
# No batch count dependency - create scheduler before prepare
|
|
901
|
+
scheduler = get_scheduler(
|
|
902
|
+
name=args.scheduler,
|
|
903
|
+
optimizer=optimizer,
|
|
904
|
+
epochs=args.epochs,
|
|
905
|
+
steps_per_epoch=None,
|
|
906
|
+
min_lr=args.min_lr,
|
|
907
|
+
patience=args.scheduler_patience,
|
|
908
|
+
factor=args.scheduler_factor,
|
|
909
|
+
gamma=args.scheduler_factor, # For Step/MultiStep/Exponential schedulers
|
|
910
|
+
step_size=args.step_size,
|
|
911
|
+
milestones=milestones,
|
|
912
|
+
warmup_epochs=args.warmup_epochs,
|
|
913
|
+
)
|
|
914
|
+
|
|
915
|
+
# For ReduceLROnPlateau: DON'T include scheduler in accelerator.prepare()
|
|
916
|
+
# because accelerator wraps scheduler.step() to sync across processes,
|
|
917
|
+
# which defeats our rank-0-only stepping for correct patience counting.
|
|
918
|
+
# Other schedulers are safe to prepare (no internal state affected by multi-call).
|
|
919
|
+
if args.scheduler == "plateau":
|
|
920
|
+
model, optimizer, train_dl, val_dl = accelerator.prepare(
|
|
921
|
+
model, optimizer, train_dl, val_dl
|
|
922
|
+
)
|
|
923
|
+
# Scheduler stays unwrapped - we handle sync manually in training loop
|
|
924
|
+
# But register it for checkpointing so state is saved/loaded on resume
|
|
925
|
+
accelerator.register_for_checkpointing(scheduler)
|
|
926
|
+
else:
|
|
927
|
+
model, optimizer, train_dl, val_dl, scheduler = accelerator.prepare(
|
|
928
|
+
model, optimizer, train_dl, val_dl, scheduler
|
|
929
|
+
)
|
|
930
|
+
|
|
931
|
+
# ==========================================================================
|
|
932
|
+
# AUTO-RESUME / RESUME FROM CHECKPOINT
|
|
933
|
+
# ==========================================================================
|
|
934
|
+
start_epoch = 0
|
|
935
|
+
best_val_loss = float("inf")
|
|
936
|
+
patience_ctr = 0
|
|
937
|
+
history: list[dict[str, Any]] = []
|
|
938
|
+
|
|
939
|
+
# Define checkpoint paths
|
|
940
|
+
best_ckpt_path = os.path.join(args.output_dir, "best_checkpoint")
|
|
941
|
+
complete_flag_path = os.path.join(args.output_dir, "training_complete.flag")
|
|
942
|
+
|
|
943
|
+
# Auto-resume logic (if not --fresh and no explicit --resume)
|
|
944
|
+
if not args.fresh and args.resume is None:
|
|
945
|
+
if os.path.exists(complete_flag_path):
|
|
946
|
+
# Training already completed
|
|
947
|
+
if accelerator.is_main_process:
|
|
948
|
+
logger.info(
|
|
949
|
+
"✅ Training already completed (early stopping). Use --fresh to retrain."
|
|
950
|
+
)
|
|
951
|
+
return # Exit gracefully
|
|
952
|
+
elif os.path.exists(best_ckpt_path):
|
|
953
|
+
# Incomplete training found - auto-resume
|
|
954
|
+
args.resume = best_ckpt_path
|
|
955
|
+
if accelerator.is_main_process:
|
|
956
|
+
logger.info(f"🔄 Auto-resuming from: {best_ckpt_path}")
|
|
957
|
+
|
|
958
|
+
if args.resume:
|
|
959
|
+
if os.path.exists(args.resume):
|
|
960
|
+
logger.info(f"🔄 Loading checkpoint from: {args.resume}")
|
|
961
|
+
accelerator.load_state(args.resume)
|
|
962
|
+
|
|
963
|
+
# Restore training metadata
|
|
964
|
+
meta_path = os.path.join(args.resume, "training_meta.pkl")
|
|
965
|
+
if os.path.exists(meta_path):
|
|
966
|
+
with open(meta_path, "rb") as f:
|
|
967
|
+
meta = pickle.load(f)
|
|
968
|
+
start_epoch = meta.get("epoch", 0)
|
|
969
|
+
best_val_loss = meta.get("best_val_loss", float("inf"))
|
|
970
|
+
patience_ctr = meta.get("patience_ctr", 0)
|
|
971
|
+
logger.info(
|
|
972
|
+
f" ✅ Restored: Epoch {start_epoch}, Best Loss: {best_val_loss:.6f}"
|
|
973
|
+
)
|
|
974
|
+
else:
|
|
975
|
+
logger.warning(
|
|
976
|
+
" ⚠️ training_meta.pkl not found, starting from epoch 0"
|
|
977
|
+
)
|
|
978
|
+
|
|
979
|
+
# Restore history
|
|
980
|
+
history_path = os.path.join(args.output_dir, "training_history.csv")
|
|
981
|
+
if os.path.exists(history_path):
|
|
982
|
+
history = pd.read_csv(history_path).to_dict("records")
|
|
983
|
+
logger.info(f" ✅ Loaded {len(history)} epochs from history")
|
|
984
|
+
else:
|
|
985
|
+
raise FileNotFoundError(f"Checkpoint not found: {args.resume}")
|
|
986
|
+
|
|
987
|
+
# ==========================================================================
|
|
988
|
+
# PHYSICAL METRIC SETUP
|
|
989
|
+
# ==========================================================================
|
|
990
|
+
# Physical MAE = normalized MAE * scaler.scale_
|
|
991
|
+
phys_scale = torch.tensor(
|
|
992
|
+
scaler.scale_, device=accelerator.device, dtype=torch.float32
|
|
993
|
+
)
|
|
994
|
+
|
|
995
|
+
# ==========================================================================
|
|
996
|
+
# TRAINING LOOP
|
|
997
|
+
# ==========================================================================
|
|
998
|
+
# Dynamic console header
|
|
999
|
+
if accelerator.is_main_process:
|
|
1000
|
+
base_cols = ["Ep", "TrnLoss", "ValLoss", "R2", "PCC", "GradN", "LR", "MAE_Avg"]
|
|
1001
|
+
param_cols = [f"MAE_P{i}" for i in range(out_dim)]
|
|
1002
|
+
header = "{:<4} | {:<8} | {:<8} | {:<6} | {:<6} | {:<6} | {:<8} | {:<8}".format(
|
|
1003
|
+
*base_cols
|
|
1004
|
+
)
|
|
1005
|
+
header += " | " + " | ".join([f"{c:<8}" for c in param_cols])
|
|
1006
|
+
logger.info("=" * len(header))
|
|
1007
|
+
logger.info(header)
|
|
1008
|
+
logger.info("=" * len(header))
|
|
1009
|
+
|
|
1010
|
+
try:
|
|
1011
|
+
total_training_time = 0.0
|
|
1012
|
+
|
|
1013
|
+
for epoch in range(start_epoch, args.epochs):
|
|
1014
|
+
epoch_start_time = time.time()
|
|
1015
|
+
|
|
1016
|
+
# ==================== TRAINING PHASE ====================
|
|
1017
|
+
model.train()
|
|
1018
|
+
# Use GPU tensor for loss accumulation to avoid .item() sync per batch
|
|
1019
|
+
train_loss_sum = torch.tensor(0.0, device=accelerator.device)
|
|
1020
|
+
train_samples = 0
|
|
1021
|
+
grad_norm_tracker = MetricTracker()
|
|
1022
|
+
|
|
1023
|
+
pbar = tqdm(
|
|
1024
|
+
train_dl,
|
|
1025
|
+
disable=not accelerator.is_main_process,
|
|
1026
|
+
leave=False,
|
|
1027
|
+
desc=f"Epoch {epoch + 1}",
|
|
1028
|
+
)
|
|
1029
|
+
|
|
1030
|
+
for x, y in pbar:
|
|
1031
|
+
with accelerator.accumulate(model):
|
|
1032
|
+
# Use mixed precision for forward pass (respects --precision flag)
|
|
1033
|
+
with accelerator.autocast():
|
|
1034
|
+
pred = model(x)
|
|
1035
|
+
# Pass inputs for input-dependent constraints (x_mean, x[...], etc.)
|
|
1036
|
+
if isinstance(criterion, PhysicsConstrainedLoss):
|
|
1037
|
+
loss = criterion(pred, y, x)
|
|
1038
|
+
else:
|
|
1039
|
+
loss = criterion(pred, y)
|
|
1040
|
+
|
|
1041
|
+
accelerator.backward(loss)
|
|
1042
|
+
|
|
1043
|
+
if accelerator.sync_gradients:
|
|
1044
|
+
grad_norm = accelerator.clip_grad_norm_(
|
|
1045
|
+
model.parameters(), args.grad_clip
|
|
1046
|
+
)
|
|
1047
|
+
if grad_norm is not None:
|
|
1048
|
+
grad_norm_tracker.update(grad_norm.item())
|
|
1049
|
+
|
|
1050
|
+
optimizer.step()
|
|
1051
|
+
optimizer.zero_grad(set_to_none=True) # Faster than zero_grad()
|
|
1052
|
+
|
|
1053
|
+
# Per-batch LR scheduling (e.g., OneCycleLR)
|
|
1054
|
+
if scheduler_step_per_batch:
|
|
1055
|
+
scheduler.step()
|
|
1056
|
+
|
|
1057
|
+
# Accumulate as tensors to avoid .item() sync per batch
|
|
1058
|
+
train_loss_sum += loss.detach() * x.size(0)
|
|
1059
|
+
train_samples += x.size(0)
|
|
1060
|
+
|
|
1061
|
+
# Single .item() call at end of epoch (reduces GPU sync overhead)
|
|
1062
|
+
train_loss_scalar = train_loss_sum.item()
|
|
1063
|
+
|
|
1064
|
+
# Synchronize training metrics across GPUs
|
|
1065
|
+
global_loss = accelerator.reduce(
|
|
1066
|
+
torch.tensor([train_loss_scalar], device=accelerator.device),
|
|
1067
|
+
reduction="sum",
|
|
1068
|
+
).item()
|
|
1069
|
+
global_samples = accelerator.reduce(
|
|
1070
|
+
torch.tensor([train_samples], device=accelerator.device),
|
|
1071
|
+
reduction="sum",
|
|
1072
|
+
).item()
|
|
1073
|
+
avg_train_loss = global_loss / global_samples
|
|
1074
|
+
|
|
1075
|
+
# ==================== VALIDATION PHASE ====================
|
|
1076
|
+
model.eval()
|
|
1077
|
+
# Use GPU tensor for loss accumulation (consistent with training phase)
|
|
1078
|
+
val_loss_sum = torch.tensor(0.0, device=accelerator.device)
|
|
1079
|
+
val_mae_sum = torch.zeros(out_dim, device=accelerator.device)
|
|
1080
|
+
val_samples = 0
|
|
1081
|
+
|
|
1082
|
+
# Accumulate predictions locally ON CPU to prevent GPU OOM
|
|
1083
|
+
local_preds = []
|
|
1084
|
+
local_targets = []
|
|
1085
|
+
|
|
1086
|
+
with torch.inference_mode():
|
|
1087
|
+
for x, y in val_dl:
|
|
1088
|
+
# Use mixed precision for validation (consistent with training)
|
|
1089
|
+
with accelerator.autocast():
|
|
1090
|
+
pred = model(x)
|
|
1091
|
+
# Pass inputs for input-dependent constraints
|
|
1092
|
+
if isinstance(criterion, PhysicsConstrainedLoss):
|
|
1093
|
+
loss = criterion(pred, y, x)
|
|
1094
|
+
else:
|
|
1095
|
+
loss = criterion(pred, y)
|
|
1096
|
+
|
|
1097
|
+
val_loss_sum += loss.detach() * x.size(0)
|
|
1098
|
+
val_samples += x.size(0)
|
|
1099
|
+
|
|
1100
|
+
# Physical MAE
|
|
1101
|
+
mae_batch = torch.abs((pred - y) * phys_scale).sum(dim=0)
|
|
1102
|
+
val_mae_sum += mae_batch
|
|
1103
|
+
|
|
1104
|
+
# Store on CPU (critical for large val sets)
|
|
1105
|
+
local_preds.append(pred.detach().cpu())
|
|
1106
|
+
local_targets.append(y.detach().cpu())
|
|
1107
|
+
|
|
1108
|
+
# Concatenate locally (keep on GPU for gather_for_metrics compatibility)
|
|
1109
|
+
local_preds_cat = torch.cat(local_preds)
|
|
1110
|
+
local_targets_cat = torch.cat(local_targets)
|
|
1111
|
+
|
|
1112
|
+
# Gather predictions and targets using Accelerate's CPU-efficient utility
|
|
1113
|
+
# gather_for_metrics handles:
|
|
1114
|
+
# - DDP padding removal (no need to trim manually)
|
|
1115
|
+
# - Efficient cross-rank gathering without GPU memory spike
|
|
1116
|
+
# - Returns concatenated tensors on CPU for metric computation
|
|
1117
|
+
if accelerator.num_processes > 1:
|
|
1118
|
+
# Move to GPU for gather (required by NCCL), then back to CPU
|
|
1119
|
+
# gather_for_metrics is more memory-efficient than manual gather
|
|
1120
|
+
# as it processes in chunks internally
|
|
1121
|
+
gathered_preds = accelerator.gather_for_metrics(
|
|
1122
|
+
local_preds_cat.to(accelerator.device)
|
|
1123
|
+
).cpu()
|
|
1124
|
+
gathered_targets = accelerator.gather_for_metrics(
|
|
1125
|
+
local_targets_cat.to(accelerator.device)
|
|
1126
|
+
).cpu()
|
|
1127
|
+
else:
|
|
1128
|
+
# Single-GPU mode: no gathering needed
|
|
1129
|
+
gathered_preds = local_preds_cat
|
|
1130
|
+
gathered_targets = local_targets_cat
|
|
1131
|
+
|
|
1132
|
+
# Synchronize validation metrics (scalars only - efficient)
|
|
1133
|
+
val_loss_scalar = val_loss_sum.item()
|
|
1134
|
+
val_metrics = torch.cat(
|
|
1135
|
+
[
|
|
1136
|
+
torch.tensor([val_loss_scalar], device=accelerator.device),
|
|
1137
|
+
val_mae_sum,
|
|
1138
|
+
]
|
|
1139
|
+
)
|
|
1140
|
+
val_metrics_sync = accelerator.reduce(val_metrics, reduction="sum")
|
|
1141
|
+
|
|
1142
|
+
total_val_samples = accelerator.reduce(
|
|
1143
|
+
torch.tensor([val_samples], device=accelerator.device), reduction="sum"
|
|
1144
|
+
).item()
|
|
1145
|
+
|
|
1146
|
+
avg_val_loss = val_metrics_sync[0].item() / total_val_samples
|
|
1147
|
+
# Cast to float32 before numpy (bf16 tensors can't convert directly)
|
|
1148
|
+
avg_mae_per_param = (
|
|
1149
|
+
(val_metrics_sync[1:] / total_val_samples).float().cpu().numpy()
|
|
1150
|
+
)
|
|
1151
|
+
avg_mae = avg_mae_per_param.mean()
|
|
1152
|
+
|
|
1153
|
+
# ==================== LOGGING & CHECKPOINTING ====================
|
|
1154
|
+
if accelerator.is_main_process:
|
|
1155
|
+
# Scientific metrics - cast to float32 before numpy
|
|
1156
|
+
# gather_for_metrics already handles DDP padding removal
|
|
1157
|
+
y_pred = gathered_preds.float().numpy()
|
|
1158
|
+
y_true = gathered_targets.float().numpy()
|
|
1159
|
+
|
|
1160
|
+
# Guard against tiny validation sets (R² undefined for <2 samples)
|
|
1161
|
+
if len(y_true) >= 2:
|
|
1162
|
+
r2 = r2_score(y_true, y_pred)
|
|
1163
|
+
else:
|
|
1164
|
+
r2 = float("nan")
|
|
1165
|
+
pcc = calc_pearson(y_true, y_pred)
|
|
1166
|
+
current_lr = get_lr(optimizer)
|
|
1167
|
+
|
|
1168
|
+
# Update history
|
|
1169
|
+
epoch_end_time = time.time()
|
|
1170
|
+
epoch_time = epoch_end_time - epoch_start_time
|
|
1171
|
+
total_training_time += epoch_time
|
|
1172
|
+
|
|
1173
|
+
epoch_stats = {
|
|
1174
|
+
"epoch": epoch + 1,
|
|
1175
|
+
"train_loss": avg_train_loss,
|
|
1176
|
+
"val_loss": avg_val_loss,
|
|
1177
|
+
"val_r2": r2,
|
|
1178
|
+
"val_pearson": pcc,
|
|
1179
|
+
"val_mae_avg": avg_mae,
|
|
1180
|
+
"grad_norm": grad_norm_tracker.avg,
|
|
1181
|
+
"lr": current_lr,
|
|
1182
|
+
"epoch_time": round(epoch_time, 2),
|
|
1183
|
+
"total_time": round(total_training_time, 2),
|
|
1184
|
+
}
|
|
1185
|
+
for i, mae in enumerate(avg_mae_per_param):
|
|
1186
|
+
epoch_stats[f"MAE_Phys_P{i}"] = mae
|
|
1187
|
+
|
|
1188
|
+
history.append(epoch_stats)
|
|
1189
|
+
|
|
1190
|
+
# Console display
|
|
1191
|
+
base_str = f"{epoch + 1:<4} | {avg_train_loss:<8.4f} | {avg_val_loss:<8.4f} | {r2:<6.4f} | {pcc:<6.4f} | {grad_norm_tracker.avg:<6.4f} | {current_lr:<8.2e} | {avg_mae:<8.4f}"
|
|
1192
|
+
param_str = " | ".join([f"{m:<8.4f}" for m in avg_mae_per_param])
|
|
1193
|
+
logger.info(f"{base_str} | {param_str}")
|
|
1194
|
+
|
|
1195
|
+
# WandB logging
|
|
1196
|
+
if args.wandb and WANDB_AVAILABLE:
|
|
1197
|
+
log_dict = {
|
|
1198
|
+
"main/train_loss": avg_train_loss,
|
|
1199
|
+
"main/val_loss": avg_val_loss,
|
|
1200
|
+
"metrics/r2_score": r2,
|
|
1201
|
+
"metrics/pearson_corr": pcc,
|
|
1202
|
+
"metrics/mae_avg": avg_mae,
|
|
1203
|
+
"system/grad_norm": grad_norm_tracker.avg,
|
|
1204
|
+
"hyper/lr": current_lr,
|
|
1205
|
+
}
|
|
1206
|
+
for i, mae in enumerate(avg_mae_per_param):
|
|
1207
|
+
log_dict[f"mae_detailed/P{i}"] = mae
|
|
1208
|
+
|
|
1209
|
+
# Periodic scatter plots
|
|
1210
|
+
if (epoch % 5 == 0) or (avg_val_loss < best_val_loss):
|
|
1211
|
+
real_true = scaler.inverse_transform(y_true)
|
|
1212
|
+
real_pred = scaler.inverse_transform(y_pred)
|
|
1213
|
+
fig = plot_scientific_scatter(real_true, real_pred)
|
|
1214
|
+
log_dict["plots/scatter_analysis"] = wandb.Image(fig)
|
|
1215
|
+
plt.close(fig)
|
|
1216
|
+
|
|
1217
|
+
accelerator.log(log_dict)
|
|
1218
|
+
|
|
1219
|
+
# ==========================================================================
|
|
1220
|
+
# DDP-SAFE CHECKPOINT LOGIC
|
|
1221
|
+
# ==========================================================================
|
|
1222
|
+
# Step 1: Determine if this is the best epoch (BEFORE updating best_val_loss)
|
|
1223
|
+
is_best_epoch = False
|
|
1224
|
+
if accelerator.is_main_process:
|
|
1225
|
+
if avg_val_loss < best_val_loss:
|
|
1226
|
+
is_best_epoch = True
|
|
1227
|
+
|
|
1228
|
+
# Step 2: Broadcast decision to all ranks (required for save_state)
|
|
1229
|
+
is_best_epoch = broadcast_early_stop(is_best_epoch, accelerator)
|
|
1230
|
+
|
|
1231
|
+
# Step 3: Save checkpoint with all ranks participating
|
|
1232
|
+
if is_best_epoch:
|
|
1233
|
+
ckpt_dir = os.path.join(args.output_dir, "best_checkpoint")
|
|
1234
|
+
with suppress_accelerate_logging():
|
|
1235
|
+
accelerator.save_state(ckpt_dir, safe_serialization=False)
|
|
1236
|
+
|
|
1237
|
+
# Step 4: Rank 0 handles metadata and updates tracking variables
|
|
1238
|
+
if accelerator.is_main_process:
|
|
1239
|
+
best_val_loss = avg_val_loss # Update AFTER checkpoint saved
|
|
1240
|
+
patience_ctr = 0
|
|
1241
|
+
|
|
1242
|
+
with open(os.path.join(ckpt_dir, "training_meta.pkl"), "wb") as f:
|
|
1243
|
+
pickle.dump(
|
|
1244
|
+
{
|
|
1245
|
+
"epoch": epoch + 1,
|
|
1246
|
+
"best_val_loss": best_val_loss,
|
|
1247
|
+
"patience_ctr": patience_ctr,
|
|
1248
|
+
# Model info for auto-detection during inference
|
|
1249
|
+
"model_name": args.model,
|
|
1250
|
+
"in_shape": in_shape,
|
|
1251
|
+
"out_dim": out_dim,
|
|
1252
|
+
},
|
|
1253
|
+
f,
|
|
1254
|
+
)
|
|
1255
|
+
|
|
1256
|
+
# Unwrap model for saving (handle torch.compile compatibility)
|
|
1257
|
+
try:
|
|
1258
|
+
unwrapped = accelerator.unwrap_model(model)
|
|
1259
|
+
except KeyError:
|
|
1260
|
+
# torch.compile model may not have _orig_mod in expected location
|
|
1261
|
+
# Fall back to getting the module directly
|
|
1262
|
+
unwrapped = model.module if hasattr(model, "module") else model
|
|
1263
|
+
# If still compiled, try to get the underlying model
|
|
1264
|
+
if hasattr(unwrapped, "_orig_mod"):
|
|
1265
|
+
unwrapped = unwrapped._orig_mod
|
|
1266
|
+
|
|
1267
|
+
torch.save(
|
|
1268
|
+
unwrapped.state_dict(),
|
|
1269
|
+
os.path.join(args.output_dir, "best_model_weights.pth"),
|
|
1270
|
+
)
|
|
1271
|
+
|
|
1272
|
+
# Copy scaler to checkpoint for portability
|
|
1273
|
+
scaler_src = os.path.join(args.output_dir, "scaler.pkl")
|
|
1274
|
+
scaler_dst = os.path.join(ckpt_dir, "scaler.pkl")
|
|
1275
|
+
if os.path.exists(scaler_src) and not os.path.exists(scaler_dst):
|
|
1276
|
+
shutil.copy2(scaler_src, scaler_dst)
|
|
1277
|
+
|
|
1278
|
+
logger.info(
|
|
1279
|
+
f" 💾 Best model saved (val_loss: {best_val_loss:.6f})"
|
|
1280
|
+
)
|
|
1281
|
+
|
|
1282
|
+
# Also save CSV on best model (ensures progress is saved)
|
|
1283
|
+
pd.DataFrame(history).to_csv(
|
|
1284
|
+
os.path.join(args.output_dir, "training_history.csv"),
|
|
1285
|
+
index=False,
|
|
1286
|
+
)
|
|
1287
|
+
else:
|
|
1288
|
+
if accelerator.is_main_process:
|
|
1289
|
+
patience_ctr += 1
|
|
1290
|
+
|
|
1291
|
+
# Periodic checkpoint (all ranks participate in save_state)
|
|
1292
|
+
periodic_checkpoint_needed = (
|
|
1293
|
+
args.save_every > 0 and (epoch + 1) % args.save_every == 0
|
|
1294
|
+
)
|
|
1295
|
+
if periodic_checkpoint_needed:
|
|
1296
|
+
ckpt_name = f"epoch_{epoch + 1}_checkpoint"
|
|
1297
|
+
ckpt_dir = os.path.join(args.output_dir, ckpt_name)
|
|
1298
|
+
with suppress_accelerate_logging():
|
|
1299
|
+
accelerator.save_state(ckpt_dir, safe_serialization=False)
|
|
1300
|
+
|
|
1301
|
+
if accelerator.is_main_process:
|
|
1302
|
+
with open(os.path.join(ckpt_dir, "training_meta.pkl"), "wb") as f:
|
|
1303
|
+
pickle.dump(
|
|
1304
|
+
{
|
|
1305
|
+
"epoch": epoch + 1,
|
|
1306
|
+
"best_val_loss": best_val_loss,
|
|
1307
|
+
"patience_ctr": patience_ctr,
|
|
1308
|
+
# Model info for auto-detection during inference
|
|
1309
|
+
"model_name": args.model,
|
|
1310
|
+
"in_shape": in_shape,
|
|
1311
|
+
"out_dim": out_dim,
|
|
1312
|
+
},
|
|
1313
|
+
f,
|
|
1314
|
+
)
|
|
1315
|
+
logger.info(f" 📁 Periodic checkpoint: {ckpt_name}")
|
|
1316
|
+
|
|
1317
|
+
# Save CSV with each checkpoint (keeps logs in sync with model state)
|
|
1318
|
+
pd.DataFrame(history).to_csv(
|
|
1319
|
+
os.path.join(args.output_dir, "training_history.csv"),
|
|
1320
|
+
index=False,
|
|
1321
|
+
)
|
|
1322
|
+
|
|
1323
|
+
# Learning rate scheduling (epoch-based schedulers only)
|
|
1324
|
+
# NOTE: For ReduceLROnPlateau with DDP, we must step only on main process
|
|
1325
|
+
# to avoid patience counter being incremented by all GPU processes.
|
|
1326
|
+
# Then we sync the new LR to all processes to keep them consistent.
|
|
1327
|
+
if not scheduler_step_per_batch:
|
|
1328
|
+
if args.scheduler == "plateau":
|
|
1329
|
+
# Step only on main process to avoid multi-GPU patience bug
|
|
1330
|
+
if accelerator.is_main_process:
|
|
1331
|
+
scheduler.step(avg_val_loss)
|
|
1332
|
+
|
|
1333
|
+
# Sync LR across all processes after main process updates it
|
|
1334
|
+
accelerator.wait_for_everyone()
|
|
1335
|
+
|
|
1336
|
+
# Broadcast new LR from rank 0 to all processes
|
|
1337
|
+
if dist.is_initialized():
|
|
1338
|
+
if accelerator.is_main_process:
|
|
1339
|
+
new_lr = optimizer.param_groups[0]["lr"]
|
|
1340
|
+
else:
|
|
1341
|
+
new_lr = 0.0
|
|
1342
|
+
new_lr_tensor = torch.tensor(
|
|
1343
|
+
new_lr, device=accelerator.device, dtype=torch.float32
|
|
1344
|
+
)
|
|
1345
|
+
dist.broadcast(new_lr_tensor, src=0)
|
|
1346
|
+
# Update LR on non-main processes
|
|
1347
|
+
if not accelerator.is_main_process:
|
|
1348
|
+
for param_group in optimizer.param_groups:
|
|
1349
|
+
param_group["lr"] = new_lr_tensor.item()
|
|
1350
|
+
else:
|
|
1351
|
+
scheduler.step()
|
|
1352
|
+
|
|
1353
|
+
# DDP-safe early stopping
|
|
1354
|
+
should_stop = (
|
|
1355
|
+
patience_ctr >= args.patience if accelerator.is_main_process else False
|
|
1356
|
+
)
|
|
1357
|
+
if broadcast_early_stop(should_stop, accelerator):
|
|
1358
|
+
if accelerator.is_main_process:
|
|
1359
|
+
logger.info(
|
|
1360
|
+
f"🛑 Early stopping at epoch {epoch + 1} (patience={args.patience})"
|
|
1361
|
+
)
|
|
1362
|
+
# Create completion flag to prevent auto-resume
|
|
1363
|
+
with open(
|
|
1364
|
+
os.path.join(args.output_dir, "training_complete.flag"), "w"
|
|
1365
|
+
) as f:
|
|
1366
|
+
f.write(
|
|
1367
|
+
f"Training completed via early stopping at epoch {epoch + 1}\n"
|
|
1368
|
+
)
|
|
1369
|
+
break
|
|
1370
|
+
|
|
1371
|
+
except KeyboardInterrupt:
|
|
1372
|
+
logger.warning("Training interrupted. Saving emergency checkpoint...")
|
|
1373
|
+
with suppress_accelerate_logging():
|
|
1374
|
+
accelerator.save_state(
|
|
1375
|
+
os.path.join(args.output_dir, "interrupted_checkpoint"),
|
|
1376
|
+
safe_serialization=False,
|
|
1377
|
+
)
|
|
1378
|
+
|
|
1379
|
+
except Exception as e:
|
|
1380
|
+
logger.error(f"Critical error: {e}", exc_info=True)
|
|
1381
|
+
raise
|
|
1382
|
+
|
|
1383
|
+
else:
|
|
1384
|
+
# Training completed normally (reached max epochs without early stopping)
|
|
1385
|
+
# Create completion flag to prevent auto-resume on re-run
|
|
1386
|
+
if accelerator.is_main_process:
|
|
1387
|
+
if not os.path.exists(complete_flag_path):
|
|
1388
|
+
with open(complete_flag_path, "w") as f:
|
|
1389
|
+
f.write(f"Training completed normally after {args.epochs} epochs\n")
|
|
1390
|
+
logger.info(f"✅ Training completed after {args.epochs} epochs")
|
|
1391
|
+
|
|
1392
|
+
finally:
|
|
1393
|
+
# Final CSV write to capture all epochs (handles non-multiple-of-10 endings)
|
|
1394
|
+
if accelerator.is_main_process and len(history) > 0:
|
|
1395
|
+
pd.DataFrame(history).to_csv(
|
|
1396
|
+
os.path.join(args.output_dir, "training_history.csv"),
|
|
1397
|
+
index=False,
|
|
1398
|
+
)
|
|
1399
|
+
|
|
1400
|
+
# Generate training curves plot (PNG + SVG)
|
|
1401
|
+
if accelerator.is_main_process and len(history) > 0:
|
|
1402
|
+
try:
|
|
1403
|
+
fig = create_training_curves(history, show_lr=True)
|
|
1404
|
+
for fmt in ["png", "svg"]:
|
|
1405
|
+
fig.savefig(
|
|
1406
|
+
os.path.join(args.output_dir, f"training_curves.{fmt}"),
|
|
1407
|
+
dpi=FIGURE_DPI,
|
|
1408
|
+
bbox_inches="tight",
|
|
1409
|
+
)
|
|
1410
|
+
plt.close(fig)
|
|
1411
|
+
logger.info("✔ Saved: training_curves.png, training_curves.svg")
|
|
1412
|
+
except Exception as e:
|
|
1413
|
+
logger.warning(f"Could not generate training curves: {e}")
|
|
1414
|
+
|
|
1415
|
+
if args.wandb and WANDB_AVAILABLE:
|
|
1416
|
+
accelerator.end_training()
|
|
1417
|
+
|
|
1418
|
+
# Clean up distributed process group to prevent resource leak warning
|
|
1419
|
+
if torch.distributed.is_initialized():
|
|
1420
|
+
torch.distributed.destroy_process_group()
|
|
1421
|
+
|
|
1422
|
+
logger.info("Training completed.")
|
|
1423
|
+
|
|
1424
|
+
|
|
1425
|
+
if __name__ == "__main__":
|
|
1426
|
+
try:
|
|
1427
|
+
torch.multiprocessing.set_start_method("spawn")
|
|
1428
|
+
except RuntimeError:
|
|
1429
|
+
pass
|
|
1430
|
+
main()
|