wavedl 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/hpc.py +28 -26
- wavedl/models/__init__.py +33 -7
- wavedl/models/_template.py +0 -1
- wavedl/models/base.py +0 -1
- wavedl/models/cnn.py +0 -1
- wavedl/models/convnext.py +4 -1
- wavedl/models/densenet.py +4 -1
- wavedl/models/efficientnet.py +9 -5
- wavedl/models/efficientnetv2.py +292 -0
- wavedl/models/mobilenetv3.py +272 -0
- wavedl/models/registry.py +0 -1
- wavedl/models/regnet.py +383 -0
- wavedl/models/resnet.py +7 -4
- wavedl/models/resnet3d.py +258 -0
- wavedl/models/swin.py +390 -0
- wavedl/models/tcn.py +389 -0
- wavedl/models/unet.py +44 -110
- wavedl/models/vit.py +8 -4
- wavedl/train.py +1113 -1117
- {wavedl-1.3.0.dist-info → wavedl-1.4.0.dist-info}/METADATA +111 -93
- wavedl-1.4.0.dist-info/RECORD +37 -0
- wavedl-1.3.0.dist-info/RECORD +0 -31
- {wavedl-1.3.0.dist-info → wavedl-1.4.0.dist-info}/LICENSE +0 -0
- {wavedl-1.3.0.dist-info → wavedl-1.4.0.dist-info}/WHEEL +0 -0
- {wavedl-1.3.0.dist-info → wavedl-1.4.0.dist-info}/entry_points.txt +0 -0
- {wavedl-1.3.0.dist-info → wavedl-1.4.0.dist-info}/top_level.txt +0 -0
wavedl/train.py
CHANGED
|
@@ -1,1117 +1,1113 @@
|
|
|
1
|
-
"""
|
|
2
|
-
WaveDL - Deep Learning for Wave-based Inverse Problems
|
|
3
|
-
=======================================================
|
|
4
|
-
Target Environment: NVIDIA HPC GPUs (Multi-GPU DDP) | PyTorch 2.x | Python 3.11+
|
|
5
|
-
|
|
6
|
-
A modular training framework for wave-based inverse problems and regression:
|
|
7
|
-
1. HPC-Grade DDP Training: BF16/FP16 mixed precision with torch.compile support
|
|
8
|
-
2. Dynamic Model Selection: Use --model flag to select any registered architecture
|
|
9
|
-
3. Zero-Copy Data Engine: Memmap-backed datasets for large-scale training
|
|
10
|
-
4. Physics-Aware Metrics: Real-time physical MAE with proper unscaling
|
|
11
|
-
5. Robust Checkpointing: Resume training, periodic saves, and training curves
|
|
12
|
-
6. Deep Observability: WandB integration with scatter analysis
|
|
13
|
-
|
|
14
|
-
Usage:
|
|
15
|
-
# Recommended: Using the HPC launcher
|
|
16
|
-
wavedl-hpc --model cnn --batch_size 128 --wandb
|
|
17
|
-
|
|
18
|
-
# Or with direct accelerate launch
|
|
19
|
-
accelerate launch -m wavedl.train --model cnn --batch_size 128 --wandb
|
|
20
|
-
|
|
21
|
-
# Multi-GPU with explicit config
|
|
22
|
-
wavedl-hpc --num_gpus 4 --mixed_precision bf16 --model cnn --wandb
|
|
23
|
-
|
|
24
|
-
# Resume from checkpoint
|
|
25
|
-
accelerate launch -m wavedl.train --model cnn --resume best_checkpoint --wandb
|
|
26
|
-
|
|
27
|
-
# List available models
|
|
28
|
-
wavedl-train --list_models
|
|
29
|
-
|
|
30
|
-
Note:
|
|
31
|
-
For HPC clusters (Compute Canada, etc.), use wavedl-hpc which handles
|
|
32
|
-
environment configuration automatically. Mixed precision is controlled via
|
|
33
|
-
--mixed_precision flag (default: bf16).
|
|
34
|
-
|
|
35
|
-
Author: Ductho Le (ductho.le@outlook.com)
|
|
36
|
-
"""
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
import os
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
import
|
|
49
|
-
|
|
50
|
-
import
|
|
51
|
-
import
|
|
52
|
-
import
|
|
53
|
-
import
|
|
54
|
-
import
|
|
55
|
-
from
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
import
|
|
60
|
-
import
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
#
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
warnings.filterwarnings("ignore",
|
|
98
|
-
warnings.filterwarnings("ignore",
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
"
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
help=
|
|
120
|
-
)
|
|
121
|
-
parser.add_argument(
|
|
122
|
-
"--
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
"
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
"
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
parser.add_argument(
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
parser.add_argument(
|
|
147
|
-
"--
|
|
148
|
-
)
|
|
149
|
-
parser.add_argument(
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
"
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
help=
|
|
164
|
-
)
|
|
165
|
-
parser.add_argument(
|
|
166
|
-
"--
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
"
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
"
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
help=
|
|
182
|
-
)
|
|
183
|
-
parser.add_argument(
|
|
184
|
-
"--
|
|
185
|
-
)
|
|
186
|
-
parser.add_argument(
|
|
187
|
-
"--
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
"
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
"
|
|
203
|
-
"
|
|
204
|
-
"
|
|
205
|
-
"
|
|
206
|
-
"
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
"
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
help="
|
|
218
|
-
)
|
|
219
|
-
parser.add_argument(
|
|
220
|
-
"--
|
|
221
|
-
)
|
|
222
|
-
parser.add_argument(
|
|
223
|
-
"--
|
|
224
|
-
)
|
|
225
|
-
parser.add_argument(
|
|
226
|
-
"--
|
|
227
|
-
)
|
|
228
|
-
parser.add_argument(
|
|
229
|
-
"--
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
"
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
parser.add_argument(
|
|
240
|
-
"--
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
"
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
"
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
"--
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
parser.add_argument(
|
|
276
|
-
"--
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
"
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
help="
|
|
283
|
-
)
|
|
284
|
-
parser.add_argument(
|
|
285
|
-
"--
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
parser.add_argument(
|
|
295
|
-
"--
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
"
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
parser.add_argument(
|
|
307
|
-
"--
|
|
308
|
-
)
|
|
309
|
-
parser.add_argument(
|
|
310
|
-
|
|
311
|
-
)
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
args
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
"custom_module"
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
source
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
X
|
|
412
|
-
|
|
413
|
-
#
|
|
414
|
-
if
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
#
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
#
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
f"
|
|
436
|
-
f"
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
)
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
logger.info(
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
)
|
|
517
|
-
|
|
518
|
-
f"
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
if args.
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
param_info
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
#
|
|
552
|
-
if args.
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
"
|
|
571
|
-
)
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
if
|
|
589
|
-
raise ValueError(
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
)
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
)
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
#
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
#
|
|
631
|
-
#
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
)
|
|
658
|
-
#
|
|
659
|
-
scheduler =
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
)
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
#
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
logger.
|
|
722
|
-
|
|
723
|
-
)
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
)
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
#
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
torch.tensor([
|
|
812
|
-
reduction="sum",
|
|
813
|
-
).item()
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
#
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
#
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
|
|
897
|
-
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
"
|
|
901
|
-
"
|
|
902
|
-
"
|
|
903
|
-
"
|
|
904
|
-
"
|
|
905
|
-
"
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
"
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
"
|
|
925
|
-
"
|
|
926
|
-
"
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
"
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
|
|
939
|
-
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
|
|
951
|
-
|
|
952
|
-
|
|
953
|
-
|
|
954
|
-
|
|
955
|
-
|
|
956
|
-
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
|
|
960
|
-
|
|
961
|
-
|
|
962
|
-
|
|
963
|
-
|
|
964
|
-
|
|
965
|
-
|
|
966
|
-
|
|
967
|
-
|
|
968
|
-
|
|
969
|
-
|
|
970
|
-
"
|
|
971
|
-
"
|
|
972
|
-
"
|
|
973
|
-
|
|
974
|
-
|
|
975
|
-
|
|
976
|
-
|
|
977
|
-
|
|
978
|
-
|
|
979
|
-
)
|
|
980
|
-
|
|
981
|
-
|
|
982
|
-
|
|
983
|
-
|
|
984
|
-
|
|
985
|
-
)
|
|
986
|
-
|
|
987
|
-
|
|
988
|
-
|
|
989
|
-
|
|
990
|
-
|
|
991
|
-
|
|
992
|
-
|
|
993
|
-
|
|
994
|
-
|
|
995
|
-
|
|
996
|
-
|
|
997
|
-
|
|
998
|
-
|
|
999
|
-
|
|
1000
|
-
|
|
1001
|
-
|
|
1002
|
-
|
|
1003
|
-
|
|
1004
|
-
|
|
1005
|
-
|
|
1006
|
-
|
|
1007
|
-
|
|
1008
|
-
|
|
1009
|
-
|
|
1010
|
-
|
|
1011
|
-
|
|
1012
|
-
|
|
1013
|
-
|
|
1014
|
-
|
|
1015
|
-
|
|
1016
|
-
|
|
1017
|
-
|
|
1018
|
-
|
|
1019
|
-
"
|
|
1020
|
-
"
|
|
1021
|
-
"
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
|
|
1025
|
-
|
|
1026
|
-
|
|
1027
|
-
|
|
1028
|
-
|
|
1029
|
-
|
|
1030
|
-
|
|
1031
|
-
|
|
1032
|
-
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
1040
|
-
|
|
1041
|
-
|
|
1042
|
-
|
|
1043
|
-
|
|
1044
|
-
|
|
1045
|
-
|
|
1046
|
-
|
|
1047
|
-
|
|
1048
|
-
|
|
1049
|
-
|
|
1050
|
-
|
|
1051
|
-
|
|
1052
|
-
)
|
|
1053
|
-
|
|
1054
|
-
|
|
1055
|
-
|
|
1056
|
-
|
|
1057
|
-
|
|
1058
|
-
|
|
1059
|
-
|
|
1060
|
-
|
|
1061
|
-
|
|
1062
|
-
except
|
|
1063
|
-
logger.
|
|
1064
|
-
|
|
1065
|
-
|
|
1066
|
-
|
|
1067
|
-
|
|
1068
|
-
|
|
1069
|
-
|
|
1070
|
-
|
|
1071
|
-
|
|
1072
|
-
|
|
1073
|
-
|
|
1074
|
-
|
|
1075
|
-
|
|
1076
|
-
|
|
1077
|
-
|
|
1078
|
-
|
|
1079
|
-
|
|
1080
|
-
|
|
1081
|
-
|
|
1082
|
-
|
|
1083
|
-
|
|
1084
|
-
|
|
1085
|
-
|
|
1086
|
-
|
|
1087
|
-
|
|
1088
|
-
|
|
1089
|
-
|
|
1090
|
-
|
|
1091
|
-
|
|
1092
|
-
|
|
1093
|
-
|
|
1094
|
-
|
|
1095
|
-
|
|
1096
|
-
|
|
1097
|
-
|
|
1098
|
-
|
|
1099
|
-
|
|
1100
|
-
|
|
1101
|
-
|
|
1102
|
-
if
|
|
1103
|
-
|
|
1104
|
-
|
|
1105
|
-
|
|
1106
|
-
|
|
1107
|
-
|
|
1108
|
-
|
|
1109
|
-
|
|
1110
|
-
|
|
1111
|
-
|
|
1112
|
-
|
|
1113
|
-
|
|
1114
|
-
torch.multiprocessing.set_start_method("spawn")
|
|
1115
|
-
except RuntimeError:
|
|
1116
|
-
pass
|
|
1117
|
-
main()
|
|
1
|
+
"""
|
|
2
|
+
WaveDL - Deep Learning for Wave-based Inverse Problems
|
|
3
|
+
=======================================================
|
|
4
|
+
Target Environment: NVIDIA HPC GPUs (Multi-GPU DDP) | PyTorch 2.x | Python 3.11+
|
|
5
|
+
|
|
6
|
+
A modular training framework for wave-based inverse problems and regression:
|
|
7
|
+
1. HPC-Grade DDP Training: BF16/FP16 mixed precision with torch.compile support
|
|
8
|
+
2. Dynamic Model Selection: Use --model flag to select any registered architecture
|
|
9
|
+
3. Zero-Copy Data Engine: Memmap-backed datasets for large-scale training
|
|
10
|
+
4. Physics-Aware Metrics: Real-time physical MAE with proper unscaling
|
|
11
|
+
5. Robust Checkpointing: Resume training, periodic saves, and training curves
|
|
12
|
+
6. Deep Observability: WandB integration with scatter analysis
|
|
13
|
+
|
|
14
|
+
Usage:
|
|
15
|
+
# Recommended: Using the HPC launcher
|
|
16
|
+
wavedl-hpc --model cnn --batch_size 128 --wandb
|
|
17
|
+
|
|
18
|
+
# Or with direct accelerate launch
|
|
19
|
+
accelerate launch -m wavedl.train --model cnn --batch_size 128 --wandb
|
|
20
|
+
|
|
21
|
+
# Multi-GPU with explicit config
|
|
22
|
+
wavedl-hpc --num_gpus 4 --mixed_precision bf16 --model cnn --wandb
|
|
23
|
+
|
|
24
|
+
# Resume from checkpoint
|
|
25
|
+
accelerate launch -m wavedl.train --model cnn --resume best_checkpoint --wandb
|
|
26
|
+
|
|
27
|
+
# List available models
|
|
28
|
+
wavedl-train --list_models
|
|
29
|
+
|
|
30
|
+
Note:
|
|
31
|
+
For HPC clusters (Compute Canada, etc.), use wavedl-hpc which handles
|
|
32
|
+
environment configuration automatically. Mixed precision is controlled via
|
|
33
|
+
--mixed_precision flag (default: bf16).
|
|
34
|
+
|
|
35
|
+
Author: Ductho Le (ductho.le@outlook.com)
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
from __future__ import annotations
|
|
39
|
+
|
|
40
|
+
import argparse
|
|
41
|
+
import logging
|
|
42
|
+
import os
|
|
43
|
+
import pickle
|
|
44
|
+
import shutil
|
|
45
|
+
import sys
|
|
46
|
+
import time
|
|
47
|
+
import warnings
|
|
48
|
+
from typing import Any
|
|
49
|
+
|
|
50
|
+
import matplotlib.pyplot as plt
|
|
51
|
+
import numpy as np
|
|
52
|
+
import pandas as pd
|
|
53
|
+
import torch
|
|
54
|
+
from accelerate import Accelerator
|
|
55
|
+
from accelerate.utils import set_seed
|
|
56
|
+
from sklearn.metrics import r2_score
|
|
57
|
+
from tqdm.auto import tqdm
|
|
58
|
+
|
|
59
|
+
from wavedl.models import build_model, get_model, list_models
|
|
60
|
+
from wavedl.utils import (
|
|
61
|
+
FIGURE_DPI,
|
|
62
|
+
MetricTracker,
|
|
63
|
+
broadcast_early_stop,
|
|
64
|
+
calc_pearson,
|
|
65
|
+
create_training_curves,
|
|
66
|
+
get_loss,
|
|
67
|
+
get_lr,
|
|
68
|
+
get_optimizer,
|
|
69
|
+
get_scheduler,
|
|
70
|
+
is_epoch_based,
|
|
71
|
+
list_losses,
|
|
72
|
+
list_optimizers,
|
|
73
|
+
list_schedulers,
|
|
74
|
+
plot_scientific_scatter,
|
|
75
|
+
prepare_data,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
try:
|
|
80
|
+
import wandb
|
|
81
|
+
|
|
82
|
+
WANDB_AVAILABLE = True
|
|
83
|
+
except ImportError:
|
|
84
|
+
WANDB_AVAILABLE = False
|
|
85
|
+
|
|
86
|
+
# ==============================================================================
|
|
87
|
+
# RUNTIME CONFIGURATION (post-import)
|
|
88
|
+
# ==============================================================================
|
|
89
|
+
# Configure matplotlib paths for HPC systems without writable home directories
|
|
90
|
+
os.environ.setdefault("MPLCONFIGDIR", os.getenv("TMPDIR", "/tmp") + "/matplotlib")
|
|
91
|
+
os.environ.setdefault("FONTCONFIG_PATH", "/etc/fonts")
|
|
92
|
+
|
|
93
|
+
# Suppress non-critical warnings for cleaner training logs
|
|
94
|
+
warnings.filterwarnings("ignore", category=UserWarning)
|
|
95
|
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
|
96
|
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
|
97
|
+
warnings.filterwarnings("ignore", module="pydantic")
|
|
98
|
+
warnings.filterwarnings("ignore", message=".*UnsupportedFieldAttributeWarning.*")
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
# ==============================================================================
|
|
102
|
+
# ARGUMENT PARSING
|
|
103
|
+
# ==============================================================================
|
|
104
|
+
def parse_args() -> argparse.Namespace:
|
|
105
|
+
"""Parse command-line arguments with comprehensive options."""
|
|
106
|
+
parser = argparse.ArgumentParser(
|
|
107
|
+
description="Universal DDP Training Pipeline",
|
|
108
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
# Model Selection
|
|
112
|
+
parser.add_argument(
|
|
113
|
+
"--model",
|
|
114
|
+
type=str,
|
|
115
|
+
default="cnn",
|
|
116
|
+
help=f"Model architecture to train. Available: {list_models()}",
|
|
117
|
+
)
|
|
118
|
+
parser.add_argument(
|
|
119
|
+
"--list_models", action="store_true", help="List all available models and exit"
|
|
120
|
+
)
|
|
121
|
+
parser.add_argument(
|
|
122
|
+
"--import",
|
|
123
|
+
dest="import_modules",
|
|
124
|
+
type=str,
|
|
125
|
+
nargs="+",
|
|
126
|
+
default=[],
|
|
127
|
+
help="Python modules to import before training (for custom models)",
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
# Configuration File
|
|
131
|
+
parser.add_argument(
|
|
132
|
+
"--config",
|
|
133
|
+
type=str,
|
|
134
|
+
default=None,
|
|
135
|
+
help="Path to YAML config file. CLI args override config values.",
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
# Hyperparameters
|
|
139
|
+
parser.add_argument(
|
|
140
|
+
"--batch_size", type=int, default=128, help="Batch size per GPU"
|
|
141
|
+
)
|
|
142
|
+
parser.add_argument("--lr", type=float, default=1e-3, help="Initial learning rate")
|
|
143
|
+
parser.add_argument(
|
|
144
|
+
"--epochs", type=int, default=1000, help="Maximum training epochs"
|
|
145
|
+
)
|
|
146
|
+
parser.add_argument(
|
|
147
|
+
"--patience", type=int, default=20, help="Early stopping patience"
|
|
148
|
+
)
|
|
149
|
+
parser.add_argument("--weight_decay", type=float, default=1e-4, help="Weight decay")
|
|
150
|
+
parser.add_argument(
|
|
151
|
+
"--grad_clip", type=float, default=1.0, help="Gradient clipping norm"
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
# Loss Function
|
|
155
|
+
parser.add_argument(
|
|
156
|
+
"--loss",
|
|
157
|
+
type=str,
|
|
158
|
+
default="mse",
|
|
159
|
+
choices=["mse", "mae", "huber", "smooth_l1", "log_cosh", "weighted_mse"],
|
|
160
|
+
help=f"Loss function for training. Available: {list_losses()}",
|
|
161
|
+
)
|
|
162
|
+
parser.add_argument(
|
|
163
|
+
"--huber_delta", type=float, default=1.0, help="Delta for Huber loss"
|
|
164
|
+
)
|
|
165
|
+
parser.add_argument(
|
|
166
|
+
"--loss_weights",
|
|
167
|
+
type=str,
|
|
168
|
+
default=None,
|
|
169
|
+
help="Comma-separated weights for weighted_mse (e.g., '1.0,2.0,1.0')",
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
# Optimizer
|
|
173
|
+
parser.add_argument(
|
|
174
|
+
"--optimizer",
|
|
175
|
+
type=str,
|
|
176
|
+
default="adamw",
|
|
177
|
+
choices=["adamw", "adam", "sgd", "nadam", "radam", "rmsprop"],
|
|
178
|
+
help=f"Optimizer for training. Available: {list_optimizers()}",
|
|
179
|
+
)
|
|
180
|
+
parser.add_argument(
|
|
181
|
+
"--momentum", type=float, default=0.9, help="Momentum for SGD/RMSprop"
|
|
182
|
+
)
|
|
183
|
+
parser.add_argument(
|
|
184
|
+
"--nesterov", action="store_true", help="Use Nesterov momentum (SGD)"
|
|
185
|
+
)
|
|
186
|
+
parser.add_argument(
|
|
187
|
+
"--betas",
|
|
188
|
+
type=str,
|
|
189
|
+
default="0.9,0.999",
|
|
190
|
+
help="Betas for Adam variants (comma-separated)",
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
# Learning Rate Scheduler
|
|
194
|
+
parser.add_argument(
|
|
195
|
+
"--scheduler",
|
|
196
|
+
type=str,
|
|
197
|
+
default="plateau",
|
|
198
|
+
choices=[
|
|
199
|
+
"plateau",
|
|
200
|
+
"cosine",
|
|
201
|
+
"cosine_restarts",
|
|
202
|
+
"onecycle",
|
|
203
|
+
"step",
|
|
204
|
+
"multistep",
|
|
205
|
+
"exponential",
|
|
206
|
+
"linear_warmup",
|
|
207
|
+
],
|
|
208
|
+
help=f"LR scheduler. Available: {list_schedulers()}",
|
|
209
|
+
)
|
|
210
|
+
parser.add_argument(
|
|
211
|
+
"--scheduler_patience",
|
|
212
|
+
type=int,
|
|
213
|
+
default=10,
|
|
214
|
+
help="Patience for ReduceLROnPlateau",
|
|
215
|
+
)
|
|
216
|
+
parser.add_argument(
|
|
217
|
+
"--min_lr", type=float, default=1e-6, help="Minimum learning rate"
|
|
218
|
+
)
|
|
219
|
+
parser.add_argument(
|
|
220
|
+
"--scheduler_factor", type=float, default=0.5, help="LR reduction factor"
|
|
221
|
+
)
|
|
222
|
+
parser.add_argument(
|
|
223
|
+
"--warmup_epochs", type=int, default=5, help="Warmup epochs for linear_warmup"
|
|
224
|
+
)
|
|
225
|
+
parser.add_argument(
|
|
226
|
+
"--step_size", type=int, default=30, help="Step size for StepLR"
|
|
227
|
+
)
|
|
228
|
+
parser.add_argument(
|
|
229
|
+
"--milestones",
|
|
230
|
+
type=str,
|
|
231
|
+
default=None,
|
|
232
|
+
help="Comma-separated epochs for MultiStepLR (e.g., '30,60,90')",
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
# Data
|
|
236
|
+
parser.add_argument(
|
|
237
|
+
"--data_path", type=str, default="train_data.npz", help="Path to NPZ dataset"
|
|
238
|
+
)
|
|
239
|
+
parser.add_argument(
|
|
240
|
+
"--workers",
|
|
241
|
+
type=int,
|
|
242
|
+
default=-1,
|
|
243
|
+
help="DataLoader workers per GPU (-1=auto-detect based on CPU cores)",
|
|
244
|
+
)
|
|
245
|
+
parser.add_argument("--seed", type=int, default=2025, help="Random seed")
|
|
246
|
+
parser.add_argument(
|
|
247
|
+
"--single_channel",
|
|
248
|
+
action="store_true",
|
|
249
|
+
help="Confirm data is single-channel (suppress ambiguous shape warnings for shallow 3D volumes)",
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
# Cross-Validation
|
|
253
|
+
parser.add_argument(
|
|
254
|
+
"--cv",
|
|
255
|
+
type=int,
|
|
256
|
+
default=0,
|
|
257
|
+
help="Enable K-fold cross-validation with K folds (0=disabled)",
|
|
258
|
+
)
|
|
259
|
+
parser.add_argument(
|
|
260
|
+
"--cv_stratify",
|
|
261
|
+
action="store_true",
|
|
262
|
+
help="Use stratified splitting for cross-validation",
|
|
263
|
+
)
|
|
264
|
+
parser.add_argument(
|
|
265
|
+
"--cv_bins",
|
|
266
|
+
type=int,
|
|
267
|
+
default=10,
|
|
268
|
+
help="Number of bins for stratified CV (only with --cv_stratify)",
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
# Checkpointing & Resume
|
|
272
|
+
parser.add_argument(
|
|
273
|
+
"--resume", type=str, default=None, help="Checkpoint directory to resume from"
|
|
274
|
+
)
|
|
275
|
+
parser.add_argument(
|
|
276
|
+
"--save_every",
|
|
277
|
+
type=int,
|
|
278
|
+
default=50,
|
|
279
|
+
help="Save checkpoint every N epochs (0=disable)",
|
|
280
|
+
)
|
|
281
|
+
parser.add_argument(
|
|
282
|
+
"--output_dir", type=str, default=".", help="Output directory for checkpoints"
|
|
283
|
+
)
|
|
284
|
+
parser.add_argument(
|
|
285
|
+
"--fresh",
|
|
286
|
+
action="store_true",
|
|
287
|
+
help="Force fresh training, ignore existing checkpoints",
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
# Performance
|
|
291
|
+
parser.add_argument(
|
|
292
|
+
"--compile", action="store_true", help="Enable torch.compile (PyTorch 2.x)"
|
|
293
|
+
)
|
|
294
|
+
parser.add_argument(
|
|
295
|
+
"--precision",
|
|
296
|
+
type=str,
|
|
297
|
+
default="bf16",
|
|
298
|
+
choices=["bf16", "fp16", "no"],
|
|
299
|
+
help="Mixed precision mode",
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
# Logging
|
|
303
|
+
parser.add_argument(
|
|
304
|
+
"--wandb", action="store_true", help="Enable Weights & Biases logging"
|
|
305
|
+
)
|
|
306
|
+
parser.add_argument(
|
|
307
|
+
"--project_name", type=str, default="DL-Training", help="WandB project name"
|
|
308
|
+
)
|
|
309
|
+
parser.add_argument("--run_name", type=str, default=None, help="WandB run name")
|
|
310
|
+
|
|
311
|
+
args = parser.parse_args()
|
|
312
|
+
return args, parser # Returns (Namespace, ArgumentParser)
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
# ==============================================================================
|
|
316
|
+
# MAIN TRAINING FUNCTION
|
|
317
|
+
# ==============================================================================
|
|
318
|
+
def main():
|
|
319
|
+
args, parser = parse_args()
|
|
320
|
+
|
|
321
|
+
# Import custom model modules if specified
|
|
322
|
+
if args.import_modules:
|
|
323
|
+
import importlib
|
|
324
|
+
|
|
325
|
+
for module_name in args.import_modules:
|
|
326
|
+
try:
|
|
327
|
+
# Handle both module names (my_model) and file paths (./my_model.py)
|
|
328
|
+
if module_name.endswith(".py"):
|
|
329
|
+
# Import from file path
|
|
330
|
+
import importlib.util
|
|
331
|
+
|
|
332
|
+
spec = importlib.util.spec_from_file_location(
|
|
333
|
+
"custom_module", module_name
|
|
334
|
+
)
|
|
335
|
+
if spec and spec.loader:
|
|
336
|
+
module = importlib.util.module_from_spec(spec)
|
|
337
|
+
sys.modules["custom_module"] = module
|
|
338
|
+
spec.loader.exec_module(module)
|
|
339
|
+
print(f"✓ Imported custom module from: {module_name}")
|
|
340
|
+
else:
|
|
341
|
+
# Import as regular module
|
|
342
|
+
importlib.import_module(module_name)
|
|
343
|
+
print(f"✓ Imported module: {module_name}")
|
|
344
|
+
except ImportError as e:
|
|
345
|
+
print(f"✗ Failed to import '{module_name}': {e}", file=sys.stderr)
|
|
346
|
+
print(
|
|
347
|
+
" Make sure the module is in your Python path or current directory."
|
|
348
|
+
)
|
|
349
|
+
sys.exit(1)
|
|
350
|
+
|
|
351
|
+
# Handle --list_models flag
|
|
352
|
+
if args.list_models:
|
|
353
|
+
print("Available models:")
|
|
354
|
+
for name in list_models():
|
|
355
|
+
ModelClass = get_model(name)
|
|
356
|
+
# Get first non-empty docstring line
|
|
357
|
+
if ModelClass.__doc__:
|
|
358
|
+
lines = [
|
|
359
|
+
l.strip() for l in ModelClass.__doc__.splitlines() if l.strip()
|
|
360
|
+
]
|
|
361
|
+
doc_first_line = lines[0] if lines else "No description"
|
|
362
|
+
else:
|
|
363
|
+
doc_first_line = "No description"
|
|
364
|
+
print(f" - {name}: {doc_first_line}")
|
|
365
|
+
sys.exit(0)
|
|
366
|
+
|
|
367
|
+
# Load and merge config file if provided
|
|
368
|
+
if args.config:
|
|
369
|
+
from wavedl.utils.config import (
|
|
370
|
+
load_config,
|
|
371
|
+
merge_config_with_args,
|
|
372
|
+
validate_config,
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
print(f"📄 Loading config from: {args.config}")
|
|
376
|
+
config = load_config(args.config)
|
|
377
|
+
|
|
378
|
+
# Validate config values
|
|
379
|
+
warnings_list = validate_config(config)
|
|
380
|
+
for w in warnings_list:
|
|
381
|
+
print(f" ⚠ {w}")
|
|
382
|
+
|
|
383
|
+
# Merge config with CLI args (CLI takes precedence via parser defaults detection)
|
|
384
|
+
args = merge_config_with_args(config, args, parser=parser)
|
|
385
|
+
|
|
386
|
+
# Handle --cv flag (cross-validation mode)
|
|
387
|
+
if args.cv > 0:
|
|
388
|
+
print(f"🔄 Cross-Validation Mode: {args.cv} folds")
|
|
389
|
+
from wavedl.utils.cross_validation import run_cross_validation
|
|
390
|
+
|
|
391
|
+
# Load data for CV using memory-efficient loader
|
|
392
|
+
from wavedl.utils.data import DataSource, get_data_source
|
|
393
|
+
|
|
394
|
+
data_format = DataSource.detect_format(args.data_path)
|
|
395
|
+
source = get_data_source(data_format)
|
|
396
|
+
|
|
397
|
+
# Use memory-mapped loading when available
|
|
398
|
+
_cv_handle = None
|
|
399
|
+
if hasattr(source, "load_mmap"):
|
|
400
|
+
result = source.load_mmap(args.data_path)
|
|
401
|
+
if hasattr(result, "inputs"):
|
|
402
|
+
_cv_handle = result
|
|
403
|
+
X, y = result.inputs, result.outputs
|
|
404
|
+
else:
|
|
405
|
+
X, y = result # NPZ returns tuple directly
|
|
406
|
+
else:
|
|
407
|
+
X, y = source.load(args.data_path)
|
|
408
|
+
|
|
409
|
+
# Handle sparse matrices (must materialize for CV shuffling)
|
|
410
|
+
if hasattr(X, "__getitem__") and len(X) > 0 and hasattr(X[0], "toarray"):
|
|
411
|
+
X = np.stack([x.toarray() for x in X])
|
|
412
|
+
|
|
413
|
+
# Normalize target shape: (N,) -> (N, 1) for consistency
|
|
414
|
+
if y.ndim == 1:
|
|
415
|
+
y = y.reshape(-1, 1)
|
|
416
|
+
|
|
417
|
+
# Validate and determine input shape (consistent with prepare_data)
|
|
418
|
+
# Check for ambiguous shapes that could be multi-channel or shallow 3D volume
|
|
419
|
+
sample_shape = X.shape[1:] # Per-sample shape
|
|
420
|
+
|
|
421
|
+
# Same heuristic as prepare_data: detect ambiguous 3D shapes
|
|
422
|
+
is_ambiguous_shape = (
|
|
423
|
+
len(sample_shape) == 3 # Exactly 3D: could be (C, H, W) or (D, H, W)
|
|
424
|
+
and sample_shape[0] <= 16 # First dim looks like channels
|
|
425
|
+
and sample_shape[1] > 16
|
|
426
|
+
and sample_shape[2] > 16 # Both spatial dims are large
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
if is_ambiguous_shape and not args.single_channel:
|
|
430
|
+
raise ValueError(
|
|
431
|
+
f"Ambiguous input shape detected: sample shape {sample_shape}. "
|
|
432
|
+
f"This could be either:\n"
|
|
433
|
+
f" - Multi-channel 2D data (C={sample_shape[0]}, H={sample_shape[1]}, W={sample_shape[2]})\n"
|
|
434
|
+
f" - Single-channel 3D volume (D={sample_shape[0]}, H={sample_shape[1]}, W={sample_shape[2]})\n\n"
|
|
435
|
+
f"If this is single-channel 3D/shallow volume data, use --single_channel flag.\n"
|
|
436
|
+
f"If this is multi-channel 2D data, reshape to (N*C, H, W) with adjusted targets."
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
# in_shape = spatial dimensions for model registry (channel added during training)
|
|
440
|
+
in_shape = sample_shape
|
|
441
|
+
|
|
442
|
+
# Run cross-validation
|
|
443
|
+
try:
|
|
444
|
+
run_cross_validation(
|
|
445
|
+
X=X,
|
|
446
|
+
y=y,
|
|
447
|
+
model_name=args.model,
|
|
448
|
+
in_shape=in_shape,
|
|
449
|
+
out_size=y.shape[1],
|
|
450
|
+
folds=args.cv,
|
|
451
|
+
stratify=args.cv_stratify,
|
|
452
|
+
stratify_bins=args.cv_bins,
|
|
453
|
+
batch_size=args.batch_size,
|
|
454
|
+
lr=args.lr,
|
|
455
|
+
epochs=args.epochs,
|
|
456
|
+
patience=args.patience,
|
|
457
|
+
weight_decay=args.weight_decay,
|
|
458
|
+
loss_name=args.loss,
|
|
459
|
+
optimizer_name=args.optimizer,
|
|
460
|
+
scheduler_name=args.scheduler,
|
|
461
|
+
output_dir=args.output_dir,
|
|
462
|
+
workers=args.workers,
|
|
463
|
+
seed=args.seed,
|
|
464
|
+
)
|
|
465
|
+
finally:
|
|
466
|
+
# Clean up file handle if HDF5/MAT
|
|
467
|
+
if _cv_handle is not None and hasattr(_cv_handle, "close"):
|
|
468
|
+
try:
|
|
469
|
+
_cv_handle.close()
|
|
470
|
+
except Exception:
|
|
471
|
+
pass
|
|
472
|
+
return
|
|
473
|
+
|
|
474
|
+
# ==========================================================================
|
|
475
|
+
# 1. SYSTEM INITIALIZATION
|
|
476
|
+
# ==========================================================================
|
|
477
|
+
# Initialize Accelerator for DDP and mixed precision
|
|
478
|
+
accelerator = Accelerator(
|
|
479
|
+
mixed_precision=args.precision,
|
|
480
|
+
log_with="wandb" if args.wandb and WANDB_AVAILABLE else None,
|
|
481
|
+
)
|
|
482
|
+
set_seed(args.seed)
|
|
483
|
+
|
|
484
|
+
# Configure logging (rank 0 only prints to console)
|
|
485
|
+
logging.basicConfig(
|
|
486
|
+
level=logging.INFO if accelerator.is_main_process else logging.ERROR,
|
|
487
|
+
format="%(asctime)s | %(levelname)s | %(message)s",
|
|
488
|
+
datefmt="%H:%M:%S",
|
|
489
|
+
)
|
|
490
|
+
logger = logging.getLogger("Trainer")
|
|
491
|
+
|
|
492
|
+
# Ensure output directory exists (critical for cache files, checkpoints, etc.)
|
|
493
|
+
os.makedirs(args.output_dir, exist_ok=True)
|
|
494
|
+
|
|
495
|
+
# Auto-detect optimal DataLoader workers if not specified
|
|
496
|
+
if args.workers < 0:
|
|
497
|
+
cpu_count = os.cpu_count() or 4
|
|
498
|
+
num_gpus = accelerator.num_processes
|
|
499
|
+
# Heuristic: 4-8 workers per GPU, bounded by available CPU cores
|
|
500
|
+
# Leave some cores for main process and system overhead
|
|
501
|
+
args.workers = min(8, max(2, (cpu_count - 2) // num_gpus))
|
|
502
|
+
if accelerator.is_main_process:
|
|
503
|
+
logger.info(
|
|
504
|
+
f"⚙️ Auto-detected workers: {args.workers} per GPU "
|
|
505
|
+
f"(CPUs: {cpu_count}, GPUs: {num_gpus})"
|
|
506
|
+
)
|
|
507
|
+
|
|
508
|
+
if accelerator.is_main_process:
|
|
509
|
+
logger.info(f"🚀 Cluster Status: {accelerator.num_processes}x GPUs detected")
|
|
510
|
+
logger.info(
|
|
511
|
+
f" Model: {args.model} | Precision: {args.precision} | Compile: {args.compile}"
|
|
512
|
+
)
|
|
513
|
+
logger.info(
|
|
514
|
+
f" Loss: {args.loss} | Optimizer: {args.optimizer} | Scheduler: {args.scheduler}"
|
|
515
|
+
)
|
|
516
|
+
logger.info(f" Early Stopping Patience: {args.patience} epochs")
|
|
517
|
+
if args.save_every > 0:
|
|
518
|
+
logger.info(f" Periodic Checkpointing: Every {args.save_every} epochs")
|
|
519
|
+
if args.resume:
|
|
520
|
+
logger.info(f" 📂 Resuming from: {args.resume}")
|
|
521
|
+
|
|
522
|
+
# Initialize WandB
|
|
523
|
+
if args.wandb and WANDB_AVAILABLE:
|
|
524
|
+
accelerator.init_trackers(
|
|
525
|
+
project_name=args.project_name,
|
|
526
|
+
config=vars(args),
|
|
527
|
+
init_kwargs={"wandb": {"name": args.run_name or f"{args.model}_run"}},
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
# ==========================================================================
|
|
531
|
+
# 2. DATA & MODEL LOADING
|
|
532
|
+
# ==========================================================================
|
|
533
|
+
train_dl, val_dl, scaler, in_shape, out_dim = prepare_data(
|
|
534
|
+
args, logger, accelerator, cache_dir=args.output_dir
|
|
535
|
+
)
|
|
536
|
+
|
|
537
|
+
# Build model using registry
|
|
538
|
+
model = build_model(args.model, in_shape=in_shape, out_size=out_dim)
|
|
539
|
+
|
|
540
|
+
if accelerator.is_main_process:
|
|
541
|
+
param_info = model.parameter_summary()
|
|
542
|
+
logger.info(
|
|
543
|
+
f" Model Parameters: {param_info['trainable_parameters']:,} trainable"
|
|
544
|
+
)
|
|
545
|
+
logger.info(f" Model Size: {param_info['total_mb']:.2f} MB")
|
|
546
|
+
|
|
547
|
+
# Optional WandB model watching
|
|
548
|
+
if args.wandb and WANDB_AVAILABLE and accelerator.is_main_process:
|
|
549
|
+
wandb.watch(model, log="gradients", log_freq=100)
|
|
550
|
+
|
|
551
|
+
# Torch 2.0 compilation (requires compatible Triton on GPU)
|
|
552
|
+
if args.compile:
|
|
553
|
+
try:
|
|
554
|
+
# Test if Triton is available AND compatible with this PyTorch version
|
|
555
|
+
# PyTorch needs triton_key from triton.compiler.compiler
|
|
556
|
+
from triton.compiler.compiler import triton_key
|
|
557
|
+
|
|
558
|
+
model = torch.compile(model)
|
|
559
|
+
if accelerator.is_main_process:
|
|
560
|
+
logger.info(" ✔ torch.compile enabled (Triton backend)")
|
|
561
|
+
except ImportError as e:
|
|
562
|
+
if accelerator.is_main_process:
|
|
563
|
+
if "triton" in str(e).lower():
|
|
564
|
+
logger.warning(
|
|
565
|
+
" ⚠ Triton not installed or incompatible version - torch.compile disabled. "
|
|
566
|
+
"Training will proceed without compilation."
|
|
567
|
+
)
|
|
568
|
+
else:
|
|
569
|
+
logger.warning(
|
|
570
|
+
f" ⚠ torch.compile setup failed: {e}. Continuing without compilation."
|
|
571
|
+
)
|
|
572
|
+
except Exception as e:
|
|
573
|
+
if accelerator.is_main_process:
|
|
574
|
+
logger.warning(
|
|
575
|
+
f" ⚠ torch.compile failed: {e}. Continuing without compilation."
|
|
576
|
+
)
|
|
577
|
+
|
|
578
|
+
# ==========================================================================
|
|
579
|
+
# 2.5. OPTIMIZER, SCHEDULER & LOSS CONFIGURATION
|
|
580
|
+
# ==========================================================================
|
|
581
|
+
# Parse comma-separated arguments with validation
|
|
582
|
+
try:
|
|
583
|
+
betas_list = [float(x.strip()) for x in args.betas.split(",")]
|
|
584
|
+
if len(betas_list) != 2:
|
|
585
|
+
raise ValueError(
|
|
586
|
+
f"--betas must have exactly 2 values, got {len(betas_list)}"
|
|
587
|
+
)
|
|
588
|
+
if not all(0.0 <= b < 1.0 for b in betas_list):
|
|
589
|
+
raise ValueError(f"--betas values must be in [0, 1), got {betas_list}")
|
|
590
|
+
betas = tuple(betas_list)
|
|
591
|
+
except ValueError as e:
|
|
592
|
+
raise ValueError(
|
|
593
|
+
f"Invalid --betas format '{args.betas}': {e}. Expected format: '0.9,0.999'"
|
|
594
|
+
)
|
|
595
|
+
|
|
596
|
+
loss_weights = None
|
|
597
|
+
if args.loss_weights:
|
|
598
|
+
loss_weights = [float(x.strip()) for x in args.loss_weights.split(",")]
|
|
599
|
+
milestones = None
|
|
600
|
+
if args.milestones:
|
|
601
|
+
milestones = [int(x.strip()) for x in args.milestones.split(",")]
|
|
602
|
+
|
|
603
|
+
# Create optimizer using factory
|
|
604
|
+
optimizer = get_optimizer(
|
|
605
|
+
name=args.optimizer,
|
|
606
|
+
params=model.get_optimizer_groups(args.lr, args.weight_decay),
|
|
607
|
+
lr=args.lr,
|
|
608
|
+
weight_decay=args.weight_decay,
|
|
609
|
+
momentum=args.momentum,
|
|
610
|
+
nesterov=args.nesterov,
|
|
611
|
+
betas=betas,
|
|
612
|
+
)
|
|
613
|
+
|
|
614
|
+
# Create loss function using factory
|
|
615
|
+
criterion = get_loss(
|
|
616
|
+
name=args.loss,
|
|
617
|
+
weights=loss_weights,
|
|
618
|
+
delta=args.huber_delta,
|
|
619
|
+
)
|
|
620
|
+
# Move criterion to device (important for WeightedMSELoss buffer)
|
|
621
|
+
criterion = criterion.to(accelerator.device)
|
|
622
|
+
|
|
623
|
+
# Track if scheduler should step per batch (OneCycleLR) or per epoch
|
|
624
|
+
scheduler_step_per_batch = not is_epoch_based(args.scheduler)
|
|
625
|
+
|
|
626
|
+
# ==========================================================================
|
|
627
|
+
# DDP Preparation Strategy:
|
|
628
|
+
# - For batch-based schedulers (OneCycleLR): prepare DataLoaders first to get
|
|
629
|
+
# the correct sharded batch count, then create scheduler
|
|
630
|
+
# - For epoch-based schedulers: create scheduler before prepare (no issue)
|
|
631
|
+
# ==========================================================================
|
|
632
|
+
if scheduler_step_per_batch:
|
|
633
|
+
# BATCH-BASED SCHEDULER (e.g., OneCycleLR)
|
|
634
|
+
# Prepare model, optimizer, dataloaders FIRST to get sharded loader length
|
|
635
|
+
model, optimizer, train_dl, val_dl = accelerator.prepare(
|
|
636
|
+
model, optimizer, train_dl, val_dl
|
|
637
|
+
)
|
|
638
|
+
|
|
639
|
+
# Now create scheduler with the CORRECT sharded steps_per_epoch
|
|
640
|
+
steps_per_epoch = len(train_dl) # Post-DDP sharded length
|
|
641
|
+
scheduler = get_scheduler(
|
|
642
|
+
name=args.scheduler,
|
|
643
|
+
optimizer=optimizer,
|
|
644
|
+
epochs=args.epochs,
|
|
645
|
+
steps_per_epoch=steps_per_epoch,
|
|
646
|
+
min_lr=args.min_lr,
|
|
647
|
+
patience=args.scheduler_patience,
|
|
648
|
+
factor=args.scheduler_factor,
|
|
649
|
+
gamma=args.scheduler_factor, # For Step/MultiStep/Exponential schedulers
|
|
650
|
+
step_size=args.step_size,
|
|
651
|
+
milestones=milestones,
|
|
652
|
+
warmup_epochs=args.warmup_epochs,
|
|
653
|
+
)
|
|
654
|
+
# Prepare scheduler separately (Accelerator wraps it for state saving)
|
|
655
|
+
scheduler = accelerator.prepare(scheduler)
|
|
656
|
+
else:
|
|
657
|
+
# EPOCH-BASED SCHEDULER (plateau, cosine, step, etc.)
|
|
658
|
+
# No batch count dependency - create scheduler before prepare
|
|
659
|
+
scheduler = get_scheduler(
|
|
660
|
+
name=args.scheduler,
|
|
661
|
+
optimizer=optimizer,
|
|
662
|
+
epochs=args.epochs,
|
|
663
|
+
steps_per_epoch=None,
|
|
664
|
+
min_lr=args.min_lr,
|
|
665
|
+
patience=args.scheduler_patience,
|
|
666
|
+
factor=args.scheduler_factor,
|
|
667
|
+
gamma=args.scheduler_factor, # For Step/MultiStep/Exponential schedulers
|
|
668
|
+
step_size=args.step_size,
|
|
669
|
+
milestones=milestones,
|
|
670
|
+
warmup_epochs=args.warmup_epochs,
|
|
671
|
+
)
|
|
672
|
+
# Prepare everything together
|
|
673
|
+
model, optimizer, train_dl, val_dl, scheduler = accelerator.prepare(
|
|
674
|
+
model, optimizer, train_dl, val_dl, scheduler
|
|
675
|
+
)
|
|
676
|
+
|
|
677
|
+
# ==========================================================================
|
|
678
|
+
# 3. AUTO-RESUME / RESUME FROM CHECKPOINT
|
|
679
|
+
# ==========================================================================
|
|
680
|
+
start_epoch = 0
|
|
681
|
+
best_val_loss = float("inf")
|
|
682
|
+
patience_ctr = 0
|
|
683
|
+
history: list[dict[str, Any]] = []
|
|
684
|
+
|
|
685
|
+
# Define checkpoint paths
|
|
686
|
+
best_ckpt_path = os.path.join(args.output_dir, "best_checkpoint")
|
|
687
|
+
complete_flag_path = os.path.join(args.output_dir, "training_complete.flag")
|
|
688
|
+
|
|
689
|
+
# Auto-resume logic (if not --fresh and no explicit --resume)
|
|
690
|
+
if not args.fresh and args.resume is None:
|
|
691
|
+
if os.path.exists(complete_flag_path):
|
|
692
|
+
# Training already completed
|
|
693
|
+
if accelerator.is_main_process:
|
|
694
|
+
logger.info(
|
|
695
|
+
"✅ Training already completed (early stopping). Use --fresh to retrain."
|
|
696
|
+
)
|
|
697
|
+
return # Exit gracefully
|
|
698
|
+
elif os.path.exists(best_ckpt_path):
|
|
699
|
+
# Incomplete training found - auto-resume
|
|
700
|
+
args.resume = best_ckpt_path
|
|
701
|
+
if accelerator.is_main_process:
|
|
702
|
+
logger.info(f"🔄 Auto-resuming from: {best_ckpt_path}")
|
|
703
|
+
|
|
704
|
+
if args.resume:
|
|
705
|
+
if os.path.exists(args.resume):
|
|
706
|
+
logger.info(f"🔄 Loading checkpoint from: {args.resume}")
|
|
707
|
+
accelerator.load_state(args.resume)
|
|
708
|
+
|
|
709
|
+
# Restore training metadata
|
|
710
|
+
meta_path = os.path.join(args.resume, "training_meta.pkl")
|
|
711
|
+
if os.path.exists(meta_path):
|
|
712
|
+
with open(meta_path, "rb") as f:
|
|
713
|
+
meta = pickle.load(f)
|
|
714
|
+
start_epoch = meta.get("epoch", 0)
|
|
715
|
+
best_val_loss = meta.get("best_val_loss", float("inf"))
|
|
716
|
+
patience_ctr = meta.get("patience_ctr", 0)
|
|
717
|
+
logger.info(
|
|
718
|
+
f" ✅ Restored: Epoch {start_epoch}, Best Loss: {best_val_loss:.6f}"
|
|
719
|
+
)
|
|
720
|
+
else:
|
|
721
|
+
logger.warning(
|
|
722
|
+
" ⚠️ training_meta.pkl not found, starting from epoch 0"
|
|
723
|
+
)
|
|
724
|
+
|
|
725
|
+
# Restore history
|
|
726
|
+
history_path = os.path.join(args.output_dir, "training_history.csv")
|
|
727
|
+
if os.path.exists(history_path):
|
|
728
|
+
history = pd.read_csv(history_path).to_dict("records")
|
|
729
|
+
logger.info(f" ✅ Loaded {len(history)} epochs from history")
|
|
730
|
+
else:
|
|
731
|
+
raise FileNotFoundError(f"Checkpoint not found: {args.resume}")
|
|
732
|
+
|
|
733
|
+
# ==========================================================================
|
|
734
|
+
# 4. PHYSICAL METRIC SETUP
|
|
735
|
+
# ==========================================================================
|
|
736
|
+
# Physical MAE = normalized MAE * scaler.scale_
|
|
737
|
+
phys_scale = torch.tensor(
|
|
738
|
+
scaler.scale_, device=accelerator.device, dtype=torch.float32
|
|
739
|
+
)
|
|
740
|
+
|
|
741
|
+
# ==========================================================================
|
|
742
|
+
# 5. TRAINING LOOP
|
|
743
|
+
# ==========================================================================
|
|
744
|
+
# Dynamic console header
|
|
745
|
+
if accelerator.is_main_process:
|
|
746
|
+
base_cols = ["Ep", "TrnLoss", "ValLoss", "R2", "PCC", "GradN", "LR", "MAE_Avg"]
|
|
747
|
+
param_cols = [f"MAE_P{i}" for i in range(out_dim)]
|
|
748
|
+
header = "{:<4} | {:<8} | {:<8} | {:<6} | {:<6} | {:<6} | {:<8} | {:<8}".format(
|
|
749
|
+
*base_cols
|
|
750
|
+
)
|
|
751
|
+
header += " | " + " | ".join([f"{c:<8}" for c in param_cols])
|
|
752
|
+
logger.info("=" * len(header))
|
|
753
|
+
logger.info(header)
|
|
754
|
+
logger.info("=" * len(header))
|
|
755
|
+
|
|
756
|
+
try:
|
|
757
|
+
time.time()
|
|
758
|
+
total_training_time = 0.0
|
|
759
|
+
|
|
760
|
+
for epoch in range(start_epoch, args.epochs):
|
|
761
|
+
epoch_start_time = time.time()
|
|
762
|
+
|
|
763
|
+
# ==================== TRAINING PHASE ====================
|
|
764
|
+
model.train()
|
|
765
|
+
# Use GPU tensor for loss accumulation to avoid .item() sync per batch
|
|
766
|
+
train_loss_sum = torch.tensor(0.0, device=accelerator.device)
|
|
767
|
+
train_samples = 0
|
|
768
|
+
grad_norm_tracker = MetricTracker()
|
|
769
|
+
|
|
770
|
+
pbar = tqdm(
|
|
771
|
+
train_dl,
|
|
772
|
+
disable=not accelerator.is_main_process,
|
|
773
|
+
leave=False,
|
|
774
|
+
desc=f"Epoch {epoch + 1}",
|
|
775
|
+
)
|
|
776
|
+
|
|
777
|
+
for x, y in pbar:
|
|
778
|
+
with accelerator.accumulate(model):
|
|
779
|
+
pred = model(x)
|
|
780
|
+
loss = criterion(pred, y)
|
|
781
|
+
|
|
782
|
+
accelerator.backward(loss)
|
|
783
|
+
|
|
784
|
+
if accelerator.sync_gradients:
|
|
785
|
+
grad_norm = accelerator.clip_grad_norm_(
|
|
786
|
+
model.parameters(), args.grad_clip
|
|
787
|
+
)
|
|
788
|
+
if grad_norm is not None:
|
|
789
|
+
grad_norm_tracker.update(grad_norm.item())
|
|
790
|
+
|
|
791
|
+
optimizer.step()
|
|
792
|
+
optimizer.zero_grad(set_to_none=True) # Faster than zero_grad()
|
|
793
|
+
|
|
794
|
+
# Per-batch LR scheduling (e.g., OneCycleLR)
|
|
795
|
+
if scheduler_step_per_batch:
|
|
796
|
+
scheduler.step()
|
|
797
|
+
|
|
798
|
+
# Accumulate as tensors to avoid .item() sync per batch
|
|
799
|
+
train_loss_sum += loss.detach() * x.size(0)
|
|
800
|
+
train_samples += x.size(0)
|
|
801
|
+
|
|
802
|
+
# Single .item() call at end of epoch (reduces GPU sync overhead)
|
|
803
|
+
train_loss_scalar = train_loss_sum.item()
|
|
804
|
+
|
|
805
|
+
# Synchronize training metrics across GPUs
|
|
806
|
+
global_loss = accelerator.reduce(
|
|
807
|
+
torch.tensor([train_loss_scalar], device=accelerator.device),
|
|
808
|
+
reduction="sum",
|
|
809
|
+
).item()
|
|
810
|
+
global_samples = accelerator.reduce(
|
|
811
|
+
torch.tensor([train_samples], device=accelerator.device),
|
|
812
|
+
reduction="sum",
|
|
813
|
+
).item()
|
|
814
|
+
avg_train_loss = global_loss / global_samples
|
|
815
|
+
|
|
816
|
+
# ==================== VALIDATION PHASE ====================
|
|
817
|
+
model.eval()
|
|
818
|
+
# Use GPU tensor for loss accumulation (consistent with training phase)
|
|
819
|
+
val_loss_sum = torch.tensor(0.0, device=accelerator.device)
|
|
820
|
+
val_mae_sum = torch.zeros(out_dim, device=accelerator.device)
|
|
821
|
+
val_samples = 0
|
|
822
|
+
|
|
823
|
+
# Accumulate predictions locally, gather ONCE at end (reduces sync overhead)
|
|
824
|
+
local_preds = []
|
|
825
|
+
local_targets = []
|
|
826
|
+
|
|
827
|
+
with torch.inference_mode():
|
|
828
|
+
for x, y in val_dl:
|
|
829
|
+
pred = model(x)
|
|
830
|
+
loss = criterion(pred, y)
|
|
831
|
+
|
|
832
|
+
val_loss_sum += loss.detach() * x.size(0)
|
|
833
|
+
val_samples += x.size(0)
|
|
834
|
+
|
|
835
|
+
# Physical MAE
|
|
836
|
+
mae_batch = torch.abs((pred - y) * phys_scale).sum(dim=0)
|
|
837
|
+
val_mae_sum += mae_batch
|
|
838
|
+
|
|
839
|
+
# Store locally (no GPU sync per batch)
|
|
840
|
+
local_preds.append(pred)
|
|
841
|
+
local_targets.append(y)
|
|
842
|
+
|
|
843
|
+
# Single gather at end of validation (2 syncs instead of 2×num_batches)
|
|
844
|
+
all_local_preds = torch.cat(local_preds)
|
|
845
|
+
all_local_targets = torch.cat(local_targets)
|
|
846
|
+
all_preds = accelerator.gather_for_metrics(all_local_preds)
|
|
847
|
+
all_targets = accelerator.gather_for_metrics(all_local_targets)
|
|
848
|
+
|
|
849
|
+
# Synchronize validation metrics
|
|
850
|
+
val_loss_scalar = val_loss_sum.item()
|
|
851
|
+
val_metrics = torch.cat(
|
|
852
|
+
[
|
|
853
|
+
torch.tensor([val_loss_scalar], device=accelerator.device),
|
|
854
|
+
val_mae_sum,
|
|
855
|
+
]
|
|
856
|
+
)
|
|
857
|
+
val_metrics_sync = accelerator.reduce(val_metrics, reduction="sum")
|
|
858
|
+
|
|
859
|
+
total_val_samples = accelerator.reduce(
|
|
860
|
+
torch.tensor([val_samples], device=accelerator.device), reduction="sum"
|
|
861
|
+
).item()
|
|
862
|
+
|
|
863
|
+
avg_val_loss = val_metrics_sync[0].item() / total_val_samples
|
|
864
|
+
# Cast to float32 before numpy (bf16 tensors can't convert directly)
|
|
865
|
+
avg_mae_per_param = (
|
|
866
|
+
(val_metrics_sync[1:] / total_val_samples).float().cpu().numpy()
|
|
867
|
+
)
|
|
868
|
+
avg_mae = avg_mae_per_param.mean()
|
|
869
|
+
|
|
870
|
+
# ==================== LOGGING & CHECKPOINTING ====================
|
|
871
|
+
if accelerator.is_main_process:
|
|
872
|
+
# Scientific metrics - cast to float32 before numpy (bf16 can't convert)
|
|
873
|
+
y_pred = all_preds.float().cpu().numpy()
|
|
874
|
+
y_true = all_targets.float().cpu().numpy()
|
|
875
|
+
|
|
876
|
+
# Trim DDP padding
|
|
877
|
+
real_len = len(val_dl.dataset)
|
|
878
|
+
if len(y_pred) > real_len:
|
|
879
|
+
y_pred = y_pred[:real_len]
|
|
880
|
+
y_true = y_true[:real_len]
|
|
881
|
+
|
|
882
|
+
# Guard against tiny validation sets (R² undefined for <2 samples)
|
|
883
|
+
if len(y_true) >= 2:
|
|
884
|
+
r2 = r2_score(y_true, y_pred)
|
|
885
|
+
else:
|
|
886
|
+
r2 = float("nan")
|
|
887
|
+
pcc = calc_pearson(y_true, y_pred)
|
|
888
|
+
current_lr = get_lr(optimizer)
|
|
889
|
+
|
|
890
|
+
# Update history
|
|
891
|
+
epoch_end_time = time.time()
|
|
892
|
+
epoch_time = epoch_end_time - epoch_start_time
|
|
893
|
+
total_training_time += epoch_time
|
|
894
|
+
|
|
895
|
+
epoch_stats = {
|
|
896
|
+
"epoch": epoch + 1,
|
|
897
|
+
"train_loss": avg_train_loss,
|
|
898
|
+
"val_loss": avg_val_loss,
|
|
899
|
+
"val_r2": r2,
|
|
900
|
+
"val_pearson": pcc,
|
|
901
|
+
"val_mae_avg": avg_mae,
|
|
902
|
+
"grad_norm": grad_norm_tracker.avg,
|
|
903
|
+
"lr": current_lr,
|
|
904
|
+
"epoch_time": round(epoch_time, 2),
|
|
905
|
+
"total_time": round(total_training_time, 2),
|
|
906
|
+
}
|
|
907
|
+
for i, mae in enumerate(avg_mae_per_param):
|
|
908
|
+
epoch_stats[f"MAE_Phys_P{i}"] = mae
|
|
909
|
+
|
|
910
|
+
history.append(epoch_stats)
|
|
911
|
+
|
|
912
|
+
# Console display
|
|
913
|
+
base_str = f"{epoch + 1:<4} | {avg_train_loss:<8.4f} | {avg_val_loss:<8.4f} | {r2:<6.4f} | {pcc:<6.4f} | {grad_norm_tracker.avg:<6.4f} | {current_lr:<8.2e} | {avg_mae:<8.4f}"
|
|
914
|
+
param_str = " | ".join([f"{m:<8.4f}" for m in avg_mae_per_param])
|
|
915
|
+
logger.info(f"{base_str} | {param_str}")
|
|
916
|
+
|
|
917
|
+
# WandB logging
|
|
918
|
+
if args.wandb and WANDB_AVAILABLE:
|
|
919
|
+
log_dict = {
|
|
920
|
+
"main/train_loss": avg_train_loss,
|
|
921
|
+
"main/val_loss": avg_val_loss,
|
|
922
|
+
"metrics/r2_score": r2,
|
|
923
|
+
"metrics/pearson_corr": pcc,
|
|
924
|
+
"metrics/mae_avg": avg_mae,
|
|
925
|
+
"system/grad_norm": grad_norm_tracker.avg,
|
|
926
|
+
"hyper/lr": current_lr,
|
|
927
|
+
}
|
|
928
|
+
for i, mae in enumerate(avg_mae_per_param):
|
|
929
|
+
log_dict[f"mae_detailed/P{i}"] = mae
|
|
930
|
+
|
|
931
|
+
# Periodic scatter plots
|
|
932
|
+
if (epoch % 5 == 0) or (avg_val_loss < best_val_loss):
|
|
933
|
+
real_true = scaler.inverse_transform(y_true)
|
|
934
|
+
real_pred = scaler.inverse_transform(y_pred)
|
|
935
|
+
fig = plot_scientific_scatter(real_true, real_pred)
|
|
936
|
+
log_dict["plots/scatter_analysis"] = wandb.Image(fig)
|
|
937
|
+
plt.close(fig)
|
|
938
|
+
|
|
939
|
+
accelerator.log(log_dict)
|
|
940
|
+
|
|
941
|
+
# ==========================================================================
|
|
942
|
+
# DDP-SAFE CHECKPOINT LOGIC
|
|
943
|
+
# ==========================================================================
|
|
944
|
+
# Step 1: Determine if this is the best epoch (BEFORE updating best_val_loss)
|
|
945
|
+
is_best_epoch = False
|
|
946
|
+
if accelerator.is_main_process:
|
|
947
|
+
if avg_val_loss < best_val_loss:
|
|
948
|
+
is_best_epoch = True
|
|
949
|
+
|
|
950
|
+
# Step 2: Broadcast decision to all ranks (required for save_state)
|
|
951
|
+
is_best_epoch = broadcast_early_stop(is_best_epoch, accelerator)
|
|
952
|
+
|
|
953
|
+
# Step 3: Save checkpoint with all ranks participating
|
|
954
|
+
if is_best_epoch:
|
|
955
|
+
ckpt_dir = os.path.join(args.output_dir, "best_checkpoint")
|
|
956
|
+
accelerator.save_state(ckpt_dir) # All ranks must call this
|
|
957
|
+
|
|
958
|
+
# Step 4: Rank 0 handles metadata and updates tracking variables
|
|
959
|
+
if accelerator.is_main_process:
|
|
960
|
+
best_val_loss = avg_val_loss # Update AFTER checkpoint saved
|
|
961
|
+
patience_ctr = 0
|
|
962
|
+
|
|
963
|
+
with open(os.path.join(ckpt_dir, "training_meta.pkl"), "wb") as f:
|
|
964
|
+
pickle.dump(
|
|
965
|
+
{
|
|
966
|
+
"epoch": epoch + 1,
|
|
967
|
+
"best_val_loss": best_val_loss,
|
|
968
|
+
"patience_ctr": patience_ctr,
|
|
969
|
+
# Model info for auto-detection during inference
|
|
970
|
+
"model_name": args.model,
|
|
971
|
+
"in_shape": in_shape,
|
|
972
|
+
"out_dim": out_dim,
|
|
973
|
+
},
|
|
974
|
+
f,
|
|
975
|
+
)
|
|
976
|
+
|
|
977
|
+
unwrapped = accelerator.unwrap_model(model)
|
|
978
|
+
torch.save(
|
|
979
|
+
unwrapped.state_dict(),
|
|
980
|
+
os.path.join(args.output_dir, "best_model_weights.pth"),
|
|
981
|
+
)
|
|
982
|
+
|
|
983
|
+
# Copy scaler to checkpoint for portability
|
|
984
|
+
scaler_src = os.path.join(args.output_dir, "scaler.pkl")
|
|
985
|
+
scaler_dst = os.path.join(ckpt_dir, "scaler.pkl")
|
|
986
|
+
if os.path.exists(scaler_src) and not os.path.exists(scaler_dst):
|
|
987
|
+
shutil.copy2(scaler_src, scaler_dst)
|
|
988
|
+
|
|
989
|
+
logger.info(
|
|
990
|
+
f" 💾 Best model saved (val_loss: {best_val_loss:.6f})"
|
|
991
|
+
)
|
|
992
|
+
|
|
993
|
+
# Also save CSV on best model (ensures progress is saved)
|
|
994
|
+
pd.DataFrame(history).to_csv(
|
|
995
|
+
os.path.join(args.output_dir, "training_history.csv"),
|
|
996
|
+
index=False,
|
|
997
|
+
)
|
|
998
|
+
else:
|
|
999
|
+
if accelerator.is_main_process:
|
|
1000
|
+
patience_ctr += 1
|
|
1001
|
+
|
|
1002
|
+
# Periodic checkpoint (all ranks participate in save_state)
|
|
1003
|
+
periodic_checkpoint_needed = (
|
|
1004
|
+
args.save_every > 0 and (epoch + 1) % args.save_every == 0
|
|
1005
|
+
)
|
|
1006
|
+
if periodic_checkpoint_needed:
|
|
1007
|
+
ckpt_name = f"epoch_{epoch + 1}_checkpoint"
|
|
1008
|
+
ckpt_dir = os.path.join(args.output_dir, ckpt_name)
|
|
1009
|
+
accelerator.save_state(ckpt_dir) # All ranks participate
|
|
1010
|
+
|
|
1011
|
+
if accelerator.is_main_process:
|
|
1012
|
+
with open(os.path.join(ckpt_dir, "training_meta.pkl"), "wb") as f:
|
|
1013
|
+
pickle.dump(
|
|
1014
|
+
{
|
|
1015
|
+
"epoch": epoch + 1,
|
|
1016
|
+
"best_val_loss": best_val_loss,
|
|
1017
|
+
"patience_ctr": patience_ctr,
|
|
1018
|
+
# Model info for auto-detection during inference
|
|
1019
|
+
"model_name": args.model,
|
|
1020
|
+
"in_shape": in_shape,
|
|
1021
|
+
"out_dim": out_dim,
|
|
1022
|
+
},
|
|
1023
|
+
f,
|
|
1024
|
+
)
|
|
1025
|
+
logger.info(f" 📁 Periodic checkpoint: {ckpt_name}")
|
|
1026
|
+
|
|
1027
|
+
# Save CSV with each checkpoint (keeps logs in sync with model state)
|
|
1028
|
+
pd.DataFrame(history).to_csv(
|
|
1029
|
+
os.path.join(args.output_dir, "training_history.csv"),
|
|
1030
|
+
index=False,
|
|
1031
|
+
)
|
|
1032
|
+
|
|
1033
|
+
# Learning rate scheduling (epoch-based schedulers only)
|
|
1034
|
+
if not scheduler_step_per_batch:
|
|
1035
|
+
if args.scheduler == "plateau":
|
|
1036
|
+
scheduler.step(avg_val_loss)
|
|
1037
|
+
else:
|
|
1038
|
+
scheduler.step()
|
|
1039
|
+
|
|
1040
|
+
# DDP-safe early stopping
|
|
1041
|
+
should_stop = (
|
|
1042
|
+
patience_ctr >= args.patience if accelerator.is_main_process else False
|
|
1043
|
+
)
|
|
1044
|
+
if broadcast_early_stop(should_stop, accelerator):
|
|
1045
|
+
if accelerator.is_main_process:
|
|
1046
|
+
logger.info(
|
|
1047
|
+
f"🛑 Early stopping at epoch {epoch + 1} (patience={args.patience})"
|
|
1048
|
+
)
|
|
1049
|
+
# Create completion flag to prevent auto-resume
|
|
1050
|
+
with open(
|
|
1051
|
+
os.path.join(args.output_dir, "training_complete.flag"), "w"
|
|
1052
|
+
) as f:
|
|
1053
|
+
f.write(
|
|
1054
|
+
f"Training completed via early stopping at epoch {epoch + 1}\n"
|
|
1055
|
+
)
|
|
1056
|
+
break
|
|
1057
|
+
|
|
1058
|
+
except KeyboardInterrupt:
|
|
1059
|
+
logger.warning("Training interrupted. Saving emergency checkpoint...")
|
|
1060
|
+
accelerator.save_state(os.path.join(args.output_dir, "interrupted_checkpoint"))
|
|
1061
|
+
|
|
1062
|
+
except Exception as e:
|
|
1063
|
+
logger.error(f"Critical error: {e}", exc_info=True)
|
|
1064
|
+
raise
|
|
1065
|
+
|
|
1066
|
+
else:
|
|
1067
|
+
# Training completed normally (reached max epochs without early stopping)
|
|
1068
|
+
# Create completion flag to prevent auto-resume on re-run
|
|
1069
|
+
if accelerator.is_main_process:
|
|
1070
|
+
if not os.path.exists(complete_flag_path):
|
|
1071
|
+
with open(complete_flag_path, "w") as f:
|
|
1072
|
+
f.write(f"Training completed normally after {args.epochs} epochs\n")
|
|
1073
|
+
logger.info(f"✅ Training completed after {args.epochs} epochs")
|
|
1074
|
+
|
|
1075
|
+
finally:
|
|
1076
|
+
# Final CSV write to capture all epochs (handles non-multiple-of-10 endings)
|
|
1077
|
+
if accelerator.is_main_process and len(history) > 0:
|
|
1078
|
+
pd.DataFrame(history).to_csv(
|
|
1079
|
+
os.path.join(args.output_dir, "training_history.csv"),
|
|
1080
|
+
index=False,
|
|
1081
|
+
)
|
|
1082
|
+
|
|
1083
|
+
# Generate training curves plot (PNG + SVG)
|
|
1084
|
+
if accelerator.is_main_process and len(history) > 0:
|
|
1085
|
+
try:
|
|
1086
|
+
fig = create_training_curves(history, show_lr=True)
|
|
1087
|
+
for fmt in ["png", "svg"]:
|
|
1088
|
+
fig.savefig(
|
|
1089
|
+
os.path.join(args.output_dir, f"training_curves.{fmt}"),
|
|
1090
|
+
dpi=FIGURE_DPI,
|
|
1091
|
+
bbox_inches="tight",
|
|
1092
|
+
)
|
|
1093
|
+
plt.close(fig)
|
|
1094
|
+
logger.info("✔ Saved: training_curves.png, training_curves.svg")
|
|
1095
|
+
except Exception as e:
|
|
1096
|
+
logger.warning(f"Could not generate training curves: {e}")
|
|
1097
|
+
|
|
1098
|
+
if args.wandb and WANDB_AVAILABLE:
|
|
1099
|
+
accelerator.end_training()
|
|
1100
|
+
|
|
1101
|
+
# Clean up distributed process group to prevent resource leak warning
|
|
1102
|
+
if torch.distributed.is_initialized():
|
|
1103
|
+
torch.distributed.destroy_process_group()
|
|
1104
|
+
|
|
1105
|
+
logger.info("Training completed.")
|
|
1106
|
+
|
|
1107
|
+
|
|
1108
|
+
if __name__ == "__main__":
|
|
1109
|
+
try:
|
|
1110
|
+
torch.multiprocessing.set_start_method("spawn")
|
|
1111
|
+
except RuntimeError:
|
|
1112
|
+
pass
|
|
1113
|
+
main()
|