ins-pricing 0.4.5__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ins_pricing/README.md +48 -22
- ins_pricing/__init__.py +142 -90
- ins_pricing/cli/BayesOpt_entry.py +52 -50
- ins_pricing/cli/BayesOpt_incremental.py +39 -105
- ins_pricing/cli/Explain_Run.py +31 -23
- ins_pricing/cli/Explain_entry.py +532 -579
- ins_pricing/cli/Pricing_Run.py +31 -23
- ins_pricing/cli/bayesopt_entry_runner.py +11 -9
- ins_pricing/cli/utils/cli_common.py +256 -256
- ins_pricing/cli/utils/cli_config.py +375 -375
- ins_pricing/cli/utils/import_resolver.py +382 -365
- ins_pricing/cli/utils/notebook_utils.py +340 -340
- ins_pricing/cli/watchdog_run.py +209 -201
- ins_pricing/frontend/__init__.py +10 -10
- ins_pricing/frontend/example_workflows.py +1 -1
- ins_pricing/governance/__init__.py +20 -20
- ins_pricing/governance/release.py +159 -159
- ins_pricing/modelling/__init__.py +147 -92
- ins_pricing/modelling/{core/bayesopt → bayesopt}/README.md +2 -2
- ins_pricing/modelling/{core/bayesopt → bayesopt}/__init__.py +64 -102
- ins_pricing/modelling/{core/bayesopt → bayesopt}/config_preprocess.py +562 -562
- ins_pricing/modelling/{core/bayesopt → bayesopt}/core.py +965 -964
- ins_pricing/modelling/{core/bayesopt → bayesopt}/model_explain_mixin.py +296 -296
- ins_pricing/modelling/{core/bayesopt → bayesopt}/model_plotting_mixin.py +482 -548
- ins_pricing/modelling/{core/bayesopt → bayesopt}/models/__init__.py +27 -27
- ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_trainer.py +915 -913
- ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_gnn.py +788 -785
- ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_resn.py +448 -446
- ins_pricing/modelling/bayesopt/trainers/__init__.py +19 -0
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_base.py +1308 -1308
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_ft.py +3 -3
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_glm.py +197 -198
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_gnn.py +344 -344
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_resn.py +283 -283
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_xgb.py +346 -347
- ins_pricing/modelling/bayesopt/utils/__init__.py +67 -0
- ins_pricing/modelling/bayesopt/utils/constants.py +21 -0
- ins_pricing/modelling/bayesopt/utils/io_utils.py +7 -0
- ins_pricing/modelling/bayesopt/utils/losses.py +27 -0
- ins_pricing/modelling/bayesopt/utils/metrics_and_devices.py +17 -0
- ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/torch_trainer_mixin.py +623 -623
- ins_pricing/modelling/{core/evaluation.py → evaluation.py} +113 -104
- ins_pricing/modelling/explain/__init__.py +55 -55
- ins_pricing/modelling/explain/metrics.py +27 -174
- ins_pricing/modelling/explain/permutation.py +237 -237
- ins_pricing/modelling/plotting/__init__.py +40 -36
- ins_pricing/modelling/plotting/compat.py +228 -0
- ins_pricing/modelling/plotting/curves.py +572 -572
- ins_pricing/modelling/plotting/diagnostics.py +163 -163
- ins_pricing/modelling/plotting/geo.py +362 -362
- ins_pricing/modelling/plotting/importance.py +121 -121
- ins_pricing/pricing/__init__.py +27 -27
- ins_pricing/production/__init__.py +35 -25
- ins_pricing/production/{predict.py → inference.py} +140 -57
- ins_pricing/production/monitoring.py +8 -21
- ins_pricing/reporting/__init__.py +11 -11
- ins_pricing/setup.py +1 -1
- ins_pricing/tests/production/test_inference.py +90 -0
- ins_pricing/utils/__init__.py +116 -83
- ins_pricing/utils/device.py +255 -255
- ins_pricing/utils/features.py +53 -0
- ins_pricing/utils/io.py +72 -0
- ins_pricing/{modelling/core/bayesopt/utils → utils}/losses.py +125 -129
- ins_pricing/utils/metrics.py +158 -24
- ins_pricing/utils/numerics.py +76 -0
- ins_pricing/utils/paths.py +9 -1
- {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.0.dist-info}/METADATA +182 -182
- ins_pricing-0.5.0.dist-info/RECORD +131 -0
- ins_pricing/modelling/core/BayesOpt.py +0 -146
- ins_pricing/modelling/core/__init__.py +0 -1
- ins_pricing/modelling/core/bayesopt/trainers/__init__.py +0 -19
- ins_pricing/modelling/core/bayesopt/utils/__init__.py +0 -86
- ins_pricing/modelling/core/bayesopt/utils/constants.py +0 -183
- ins_pricing/modelling/core/bayesopt/utils/io_utils.py +0 -126
- ins_pricing/modelling/core/bayesopt/utils/metrics_and_devices.py +0 -555
- ins_pricing/modelling/core/bayesopt/utils.py +0 -105
- ins_pricing/modelling/core/bayesopt/utils_backup.py +0 -1503
- ins_pricing/tests/production/test_predict.py +0 -233
- ins_pricing-0.4.5.dist-info/RECORD +0 -130
- /ins_pricing/modelling/{core/bayesopt → bayesopt}/config_components.py +0 -0
- /ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_components.py +0 -0
- /ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/distributed_utils.py +0 -0
- {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.0.dist-info}/WHEEL +0 -0
- {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.0.dist-info}/top_level.txt +0 -0
|
@@ -1,623 +1,623 @@
|
|
|
1
|
-
"""PyTorch training mixin with resource management and training loops.
|
|
2
|
-
|
|
3
|
-
This module provides the TorchTrainerMixin class which is used by
|
|
4
|
-
PyTorch-based trainers (ResNet, FT, GNN) for:
|
|
5
|
-
- Resource profiling and memory management
|
|
6
|
-
- Batch size computation and optimization
|
|
7
|
-
- DataLoader creation with DDP support
|
|
8
|
-
- Generic training and validation loops with AMP
|
|
9
|
-
- Early stopping and loss curve plotting
|
|
10
|
-
"""
|
|
11
|
-
|
|
12
|
-
from __future__ import annotations
|
|
13
|
-
|
|
14
|
-
import copy
|
|
15
|
-
import ctypes
|
|
16
|
-
import gc
|
|
17
|
-
import math
|
|
18
|
-
import os
|
|
19
|
-
import time
|
|
20
|
-
from contextlib import nullcontext
|
|
21
|
-
from typing import
|
|
22
|
-
|
|
23
|
-
import numpy as np
|
|
24
|
-
import optuna
|
|
25
|
-
import torch
|
|
26
|
-
import torch.nn as nn
|
|
27
|
-
import torch.nn.functional as F
|
|
28
|
-
import torch.distributed as dist
|
|
29
|
-
from torch.cuda.amp import autocast
|
|
30
|
-
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
31
|
-
from torch.utils.data import DataLoader, DistributedSampler
|
|
32
|
-
|
|
33
|
-
# Try to import plotting functions
|
|
34
|
-
try:
|
|
35
|
-
import matplotlib
|
|
36
|
-
if os.name != "nt" and not os.environ.get("DISPLAY") and not os.environ.get("MPLBACKEND"):
|
|
37
|
-
matplotlib.use("Agg")
|
|
38
|
-
import matplotlib.pyplot as plt
|
|
39
|
-
_MPL_IMPORT_ERROR: Optional[BaseException] = None
|
|
40
|
-
except Exception as exc:
|
|
41
|
-
matplotlib = None
|
|
42
|
-
plt = None
|
|
43
|
-
_MPL_IMPORT_ERROR = exc
|
|
44
|
-
|
|
45
|
-
try:
|
|
46
|
-
from
|
|
47
|
-
except Exception:
|
|
48
|
-
try:
|
|
49
|
-
from ins_pricing.plotting.diagnostics import plot_loss_curve as plot_loss_curve_common
|
|
50
|
-
except Exception:
|
|
51
|
-
plot_loss_curve_common = None
|
|
52
|
-
|
|
53
|
-
# Import from other utils modules
|
|
54
|
-
from .
|
|
55
|
-
from .losses import (
|
|
56
|
-
infer_loss_name_from_model_name,
|
|
57
|
-
loss_requires_positive,
|
|
58
|
-
normalize_loss_name,
|
|
59
|
-
resolve_tweedie_power,
|
|
60
|
-
)
|
|
61
|
-
from .distributed_utils import DistributedUtils
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
def _plot_skip(label: str) -> None:
|
|
65
|
-
"""Print message when plot is skipped due to missing matplotlib."""
|
|
66
|
-
if _MPL_IMPORT_ERROR is not None:
|
|
67
|
-
print(f"[Plot] Skip {label}: matplotlib unavailable ({_MPL_IMPORT_ERROR}).", flush=True)
|
|
68
|
-
else:
|
|
69
|
-
print(f"[Plot] Skip {label}: matplotlib unavailable.", flush=True)
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
class TorchTrainerMixin:
|
|
73
|
-
"""Shared helpers for PyTorch tabular trainers.
|
|
74
|
-
|
|
75
|
-
Provides resource profiling, memory management, batch size optimization,
|
|
76
|
-
and standardized training loops with mixed precision and DDP support.
|
|
77
|
-
|
|
78
|
-
This mixin is used by ResNetTrainer, FTTrainer, and GNNTrainer.
|
|
79
|
-
"""
|
|
80
|
-
|
|
81
|
-
def
|
|
82
|
-
"""
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
if profile
|
|
100
|
-
profile = "
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
)
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
("
|
|
136
|
-
("
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
return
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
if
|
|
171
|
-
return
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
sample_bytes =
|
|
190
|
-
if sample_bytes
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
budget_ratio = 0.
|
|
212
|
-
overhead =
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
if
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
)
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
base_bs =
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
self, "
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
shuffle =
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
return (y_pred
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
"""
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
prune_tensor = torch.zeros(1, device=prune_device)
|
|
546
|
-
if is_main_rank:
|
|
547
|
-
prune_tensor.fill_(1 if prune_flag else 0)
|
|
548
|
-
dist.broadcast(prune_tensor, src=0)
|
|
549
|
-
prune_flag = bool(prune_tensor.item())
|
|
550
|
-
|
|
551
|
-
if prune_flag:
|
|
552
|
-
raise optuna.TrialPruned()
|
|
553
|
-
|
|
554
|
-
if stop_training:
|
|
555
|
-
break
|
|
556
|
-
|
|
557
|
-
should_log_epoch = (not dist.is_initialized()
|
|
558
|
-
or DistributedUtils.is_main_process())
|
|
559
|
-
if should_log_epoch:
|
|
560
|
-
elapsed = int(time.time() - epoch_start_ts)
|
|
561
|
-
if val_weighted_loss is None:
|
|
562
|
-
print(
|
|
563
|
-
f"[Training] Epoch {epoch}/{getattr(self, 'epochs', 1)} "
|
|
564
|
-
f"train_loss={float(train_epoch_loss):.6f} elapsed={elapsed}s",
|
|
565
|
-
flush=True,
|
|
566
|
-
)
|
|
567
|
-
else:
|
|
568
|
-
print(
|
|
569
|
-
f"[Training] Epoch {epoch}/{getattr(self, 'epochs', 1)} "
|
|
570
|
-
f"train_loss={float(train_epoch_loss):.6f} "
|
|
571
|
-
f"val_loss={float(val_weighted_loss):.6f} elapsed={elapsed}s",
|
|
572
|
-
flush=True,
|
|
573
|
-
)
|
|
574
|
-
|
|
575
|
-
if epoch % 10 == 0:
|
|
576
|
-
if torch.cuda.is_available():
|
|
577
|
-
torch.cuda.empty_cache()
|
|
578
|
-
gc.collect()
|
|
579
|
-
|
|
580
|
-
history = {"train": train_history, "val": val_history}
|
|
581
|
-
self._plot_loss_curve(history, loss_curve_path)
|
|
582
|
-
return best_state, history
|
|
583
|
-
|
|
584
|
-
def _plot_loss_curve(self, history: Dict[str, List[float]], save_path: Optional[str]) -> None:
|
|
585
|
-
"""Plot training and validation loss curves."""
|
|
586
|
-
if not save_path:
|
|
587
|
-
return
|
|
588
|
-
if dist.is_initialized() and not DistributedUtils.is_main_process():
|
|
589
|
-
return
|
|
590
|
-
train_hist = history.get("train", []) if history else []
|
|
591
|
-
val_hist = history.get("val", []) if history else []
|
|
592
|
-
if not train_hist and not val_hist:
|
|
593
|
-
return
|
|
594
|
-
if plot_loss_curve_common is not None:
|
|
595
|
-
plot_loss_curve_common(
|
|
596
|
-
history=history,
|
|
597
|
-
title="Loss vs. Epoch",
|
|
598
|
-
save_path=save_path,
|
|
599
|
-
show=False,
|
|
600
|
-
)
|
|
601
|
-
else:
|
|
602
|
-
if plt is None:
|
|
603
|
-
_plot_skip("loss curve")
|
|
604
|
-
return
|
|
605
|
-
ensure_parent_dir(save_path)
|
|
606
|
-
epochs = range(1, max(len(train_hist), len(val_hist)) + 1)
|
|
607
|
-
fig = plt.figure(figsize=(8, 4))
|
|
608
|
-
ax = fig.add_subplot(111)
|
|
609
|
-
if train_hist:
|
|
610
|
-
ax.plot(range(1, len(train_hist) + 1), train_hist,
|
|
611
|
-
label='Train Loss', color='tab:blue')
|
|
612
|
-
if val_hist:
|
|
613
|
-
ax.plot(range(1, len(val_hist) + 1), val_hist,
|
|
614
|
-
label='Validation Loss', color='tab:orange')
|
|
615
|
-
ax.set_xlabel('Epoch')
|
|
616
|
-
ax.set_ylabel('Weighted Loss')
|
|
617
|
-
ax.set_title('Loss vs. Epoch')
|
|
618
|
-
ax.grid(True, linestyle='--', alpha=0.3)
|
|
619
|
-
ax.legend()
|
|
620
|
-
plt.tight_layout()
|
|
621
|
-
plt.savefig(save_path, dpi=300)
|
|
622
|
-
plt.close(fig)
|
|
623
|
-
print(f"[Training] Loss curve saved to {save_path}")
|
|
1
|
+
"""PyTorch training mixin with resource management and training loops.
|
|
2
|
+
|
|
3
|
+
This module provides the TorchTrainerMixin class which is used by
|
|
4
|
+
PyTorch-based trainers (ResNet, FT, GNN) for:
|
|
5
|
+
- Resource profiling and memory management
|
|
6
|
+
- Batch size computation and optimization
|
|
7
|
+
- DataLoader creation with DDP support
|
|
8
|
+
- Generic training and validation loops with AMP
|
|
9
|
+
- Early stopping and loss curve plotting
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
import copy
|
|
15
|
+
import ctypes
|
|
16
|
+
import gc
|
|
17
|
+
import math
|
|
18
|
+
import os
|
|
19
|
+
import time
|
|
20
|
+
from contextlib import nullcontext
|
|
21
|
+
from typing import Dict, List, Optional
|
|
22
|
+
|
|
23
|
+
import numpy as np
|
|
24
|
+
import optuna
|
|
25
|
+
import torch
|
|
26
|
+
import torch.nn as nn
|
|
27
|
+
import torch.nn.functional as F
|
|
28
|
+
import torch.distributed as dist
|
|
29
|
+
from torch.cuda.amp import autocast
|
|
30
|
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
31
|
+
from torch.utils.data import DataLoader, DistributedSampler
|
|
32
|
+
|
|
33
|
+
# Try to import plotting functions
|
|
34
|
+
try:
|
|
35
|
+
import matplotlib
|
|
36
|
+
if os.name != "nt" and not os.environ.get("DISPLAY") and not os.environ.get("MPLBACKEND"):
|
|
37
|
+
matplotlib.use("Agg")
|
|
38
|
+
import matplotlib.pyplot as plt
|
|
39
|
+
_MPL_IMPORT_ERROR: Optional[BaseException] = None
|
|
40
|
+
except Exception as exc:
|
|
41
|
+
matplotlib = None
|
|
42
|
+
plt = None
|
|
43
|
+
_MPL_IMPORT_ERROR = exc
|
|
44
|
+
|
|
45
|
+
try:
|
|
46
|
+
from ins_pricing.modelling.plotting.diagnostics import plot_loss_curve as plot_loss_curve_common
|
|
47
|
+
except Exception:
|
|
48
|
+
try:
|
|
49
|
+
from ins_pricing.plotting.diagnostics import plot_loss_curve as plot_loss_curve_common
|
|
50
|
+
except Exception:
|
|
51
|
+
plot_loss_curve_common = None
|
|
52
|
+
|
|
53
|
+
# Import from other utils modules
|
|
54
|
+
from ins_pricing.utils import EPS, compute_batch_size, tweedie_loss, ensure_parent_dir
|
|
55
|
+
from ins_pricing.utils.losses import (
|
|
56
|
+
infer_loss_name_from_model_name,
|
|
57
|
+
loss_requires_positive,
|
|
58
|
+
normalize_loss_name,
|
|
59
|
+
resolve_tweedie_power,
|
|
60
|
+
)
|
|
61
|
+
from ins_pricing.modelling.bayesopt.utils.distributed_utils import DistributedUtils
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _plot_skip(label: str) -> None:
|
|
65
|
+
"""Print message when plot is skipped due to missing matplotlib."""
|
|
66
|
+
if _MPL_IMPORT_ERROR is not None:
|
|
67
|
+
print(f"[Plot] Skip {label}: matplotlib unavailable ({_MPL_IMPORT_ERROR}).", flush=True)
|
|
68
|
+
else:
|
|
69
|
+
print(f"[Plot] Skip {label}: matplotlib unavailable.", flush=True)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class TorchTrainerMixin:
|
|
73
|
+
"""Shared helpers for PyTorch tabular trainers.
|
|
74
|
+
|
|
75
|
+
Provides resource profiling, memory management, batch size optimization,
|
|
76
|
+
and standardized training loops with mixed precision and DDP support.
|
|
77
|
+
|
|
78
|
+
This mixin is used by ResNetTrainer, FTTrainer, and GNNTrainer.
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
def _resolve_device(self) -> torch.device:
|
|
82
|
+
"""Resolve device to a torch.device instance."""
|
|
83
|
+
device = getattr(self, "device", None)
|
|
84
|
+
if device is None:
|
|
85
|
+
return torch.device("cpu")
|
|
86
|
+
return device if isinstance(device, torch.device) else torch.device(device)
|
|
87
|
+
|
|
88
|
+
def _device_type(self) -> str:
|
|
89
|
+
"""Get device type (cpu/cuda/mps)."""
|
|
90
|
+
return self._resolve_device().type
|
|
91
|
+
|
|
92
|
+
def _resolve_resource_profile(self) -> str:
|
|
93
|
+
"""Determine resource usage profile.
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
One of: 'throughput', 'memory_saving', or 'auto'
|
|
97
|
+
"""
|
|
98
|
+
profile = getattr(self, "resource_profile", None)
|
|
99
|
+
if not profile:
|
|
100
|
+
profile = os.environ.get("BAYESOPT_RESOURCE_PROFILE", "auto")
|
|
101
|
+
profile = str(profile).strip().lower()
|
|
102
|
+
if profile in {"cpu", "mps", "cuda"}:
|
|
103
|
+
profile = "auto"
|
|
104
|
+
if profile not in {"auto", "throughput", "memory_saving"}:
|
|
105
|
+
profile = "auto"
|
|
106
|
+
if profile == "auto" and self._device_type() == "cuda":
|
|
107
|
+
profile = "throughput"
|
|
108
|
+
return profile
|
|
109
|
+
|
|
110
|
+
def _log_resource_summary_once(self, profile: str) -> None:
|
|
111
|
+
"""Log resource configuration summary once."""
|
|
112
|
+
if getattr(self, "_resource_summary_logged", False):
|
|
113
|
+
return
|
|
114
|
+
if dist.is_initialized() and not DistributedUtils.is_main_process():
|
|
115
|
+
return
|
|
116
|
+
self._resource_summary_logged = True
|
|
117
|
+
device = self._resolve_device()
|
|
118
|
+
device_type = self._device_type()
|
|
119
|
+
cpu_count = os.cpu_count() or 1
|
|
120
|
+
cuda_count = torch.cuda.device_count() if torch.cuda.is_available() else 0
|
|
121
|
+
mps_available = bool(getattr(torch.backends, "mps", None) and torch.backends.mps.is_available())
|
|
122
|
+
ddp_enabled = bool(getattr(self, "is_ddp_enabled", False))
|
|
123
|
+
data_parallel = bool(getattr(self, "use_data_parallel", False))
|
|
124
|
+
print(
|
|
125
|
+
f">>> Resource summary: device={device}, device_type={device_type}, "
|
|
126
|
+
f"cpu_count={cpu_count}, cuda_count={cuda_count}, mps={mps_available}, "
|
|
127
|
+
f"ddp={ddp_enabled}, data_parallel={data_parallel}, profile={profile}"
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
def _available_system_memory(self) -> Optional[int]:
|
|
131
|
+
"""Get available system RAM in bytes."""
|
|
132
|
+
if os.name == "nt":
|
|
133
|
+
class _MemStatus(ctypes.Structure):
|
|
134
|
+
_fields_ = [
|
|
135
|
+
("dwLength", ctypes.c_ulong),
|
|
136
|
+
("dwMemoryLoad", ctypes.c_ulong),
|
|
137
|
+
("ullTotalPhys", ctypes.c_ulonglong),
|
|
138
|
+
("ullAvailPhys", ctypes.c_ulonglong),
|
|
139
|
+
("ullTotalPageFile", ctypes.c_ulonglong),
|
|
140
|
+
("ullAvailPageFile", ctypes.c_ulonglong),
|
|
141
|
+
("ullTotalVirtual", ctypes.c_ulonglong),
|
|
142
|
+
("ullAvailVirtual", ctypes.c_ulonglong),
|
|
143
|
+
("sullAvailExtendedVirtual", ctypes.c_ulonglong),
|
|
144
|
+
]
|
|
145
|
+
status = _MemStatus()
|
|
146
|
+
status.dwLength = ctypes.sizeof(_MemStatus)
|
|
147
|
+
if ctypes.windll.kernel32.GlobalMemoryStatusEx(ctypes.byref(status)):
|
|
148
|
+
return int(status.ullAvailPhys)
|
|
149
|
+
return None
|
|
150
|
+
try:
|
|
151
|
+
pages = os.sysconf("SC_AVPHYS_PAGES")
|
|
152
|
+
page_size = os.sysconf("SC_PAGE_SIZE")
|
|
153
|
+
return int(pages * page_size)
|
|
154
|
+
except Exception:
|
|
155
|
+
return None
|
|
156
|
+
|
|
157
|
+
def _available_cuda_memory(self) -> Optional[int]:
|
|
158
|
+
"""Get available CUDA memory in bytes."""
|
|
159
|
+
if not torch.cuda.is_available():
|
|
160
|
+
return None
|
|
161
|
+
try:
|
|
162
|
+
free_mem, _total_mem = torch.cuda.mem_get_info()
|
|
163
|
+
except Exception:
|
|
164
|
+
return None
|
|
165
|
+
return int(free_mem)
|
|
166
|
+
|
|
167
|
+
def _estimate_sample_bytes(self, dataset) -> Optional[int]:
|
|
168
|
+
"""Estimate memory per sample in bytes."""
|
|
169
|
+
try:
|
|
170
|
+
if len(dataset) == 0:
|
|
171
|
+
return None
|
|
172
|
+
sample = dataset[0]
|
|
173
|
+
except Exception:
|
|
174
|
+
return None
|
|
175
|
+
|
|
176
|
+
def _bytes(obj) -> int:
|
|
177
|
+
if obj is None:
|
|
178
|
+
return 0
|
|
179
|
+
if torch.is_tensor(obj):
|
|
180
|
+
return int(obj.element_size() * obj.nelement())
|
|
181
|
+
if isinstance(obj, np.ndarray):
|
|
182
|
+
return int(obj.nbytes)
|
|
183
|
+
if isinstance(obj, (list, tuple)):
|
|
184
|
+
return int(sum(_bytes(item) for item in obj))
|
|
185
|
+
if isinstance(obj, dict):
|
|
186
|
+
return int(sum(_bytes(item) for item in obj.values()))
|
|
187
|
+
return 0
|
|
188
|
+
|
|
189
|
+
sample_bytes = _bytes(sample)
|
|
190
|
+
return int(sample_bytes) if sample_bytes > 0 else None
|
|
191
|
+
|
|
192
|
+
def _cap_batch_size_by_memory(self, dataset, batch_size: int, profile: str) -> int:
|
|
193
|
+
"""Cap batch size based on available memory."""
|
|
194
|
+
if batch_size <= 1:
|
|
195
|
+
return batch_size
|
|
196
|
+
sample_bytes = self._estimate_sample_bytes(dataset)
|
|
197
|
+
if sample_bytes is None:
|
|
198
|
+
return batch_size
|
|
199
|
+
device_type = self._device_type()
|
|
200
|
+
if device_type == "cuda":
|
|
201
|
+
available = self._available_cuda_memory()
|
|
202
|
+
if available is None:
|
|
203
|
+
return batch_size
|
|
204
|
+
if profile == "throughput":
|
|
205
|
+
budget_ratio = 0.8
|
|
206
|
+
overhead = 8.0
|
|
207
|
+
elif profile == "memory_saving":
|
|
208
|
+
budget_ratio = 0.5
|
|
209
|
+
overhead = 14.0
|
|
210
|
+
else:
|
|
211
|
+
budget_ratio = 0.6
|
|
212
|
+
overhead = 12.0
|
|
213
|
+
else:
|
|
214
|
+
available = self._available_system_memory()
|
|
215
|
+
if available is None:
|
|
216
|
+
return batch_size
|
|
217
|
+
if profile == "throughput":
|
|
218
|
+
budget_ratio = 0.4
|
|
219
|
+
overhead = 1.8
|
|
220
|
+
elif profile == "memory_saving":
|
|
221
|
+
budget_ratio = 0.25
|
|
222
|
+
overhead = 3.0
|
|
223
|
+
else:
|
|
224
|
+
budget_ratio = 0.3
|
|
225
|
+
overhead = 2.6
|
|
226
|
+
budget = int(available * budget_ratio)
|
|
227
|
+
per_sample = int(sample_bytes * overhead)
|
|
228
|
+
if per_sample <= 0:
|
|
229
|
+
return batch_size
|
|
230
|
+
max_batch = max(1, int(budget // per_sample))
|
|
231
|
+
if max_batch < batch_size:
|
|
232
|
+
print(
|
|
233
|
+
f">>> Memory cap: batch_size {batch_size} -> {max_batch} "
|
|
234
|
+
f"(per_sample~{sample_bytes}B, budget~{budget // (1024**2)}MB)"
|
|
235
|
+
)
|
|
236
|
+
return min(batch_size, max_batch)
|
|
237
|
+
|
|
238
|
+
def _resolve_num_workers(self, max_workers: int, profile: Optional[str] = None) -> int:
|
|
239
|
+
"""Determine number of DataLoader workers."""
|
|
240
|
+
if os.name == 'nt':
|
|
241
|
+
return 0
|
|
242
|
+
override = getattr(self, "dataloader_workers", None)
|
|
243
|
+
if override is None:
|
|
244
|
+
override = os.environ.get("BAYESOPT_DATALOADER_WORKERS")
|
|
245
|
+
if override is not None:
|
|
246
|
+
try:
|
|
247
|
+
return max(0, int(override))
|
|
248
|
+
except (TypeError, ValueError):
|
|
249
|
+
pass
|
|
250
|
+
if getattr(self, "is_ddp_enabled", False):
|
|
251
|
+
return 0
|
|
252
|
+
profile = profile or self._resolve_resource_profile()
|
|
253
|
+
if profile == "memory_saving":
|
|
254
|
+
return 0
|
|
255
|
+
worker_cap = min(int(max_workers), os.cpu_count() or 1)
|
|
256
|
+
if self._device_type() == "mps":
|
|
257
|
+
worker_cap = min(worker_cap, 2)
|
|
258
|
+
return worker_cap
|
|
259
|
+
|
|
260
|
+
def _build_dataloader(self,
|
|
261
|
+
dataset,
|
|
262
|
+
N: int,
|
|
263
|
+
base_bs_gpu: tuple,
|
|
264
|
+
base_bs_cpu: tuple,
|
|
265
|
+
min_bs: int = 64,
|
|
266
|
+
target_effective_cuda: int = 1024,
|
|
267
|
+
target_effective_cpu: int = 512,
|
|
268
|
+
large_threshold: int = 200_000,
|
|
269
|
+
mid_threshold: int = 50_000):
|
|
270
|
+
"""Build DataLoader with adaptive batch size and worker configuration.
|
|
271
|
+
|
|
272
|
+
Returns:
|
|
273
|
+
Tuple of (dataloader, accum_steps)
|
|
274
|
+
"""
|
|
275
|
+
profile = self._resolve_resource_profile()
|
|
276
|
+
self._log_resource_summary_once(profile)
|
|
277
|
+
data_size = int(N) if N is not None else len(dataset)
|
|
278
|
+
gpu_large, gpu_mid, gpu_small = base_bs_gpu
|
|
279
|
+
cpu_mid, cpu_small = base_bs_cpu
|
|
280
|
+
|
|
281
|
+
device_type = self._device_type()
|
|
282
|
+
is_ddp = bool(getattr(self, "is_ddp_enabled", False))
|
|
283
|
+
if device_type == 'cuda':
|
|
284
|
+
# Only scale batch size by GPU count when DDP is enabled.
|
|
285
|
+
# In single-process (non-DDP) mode, large multi-GPU nodes can
|
|
286
|
+
# still OOM on RAM/VRAM if we scale by device_count.
|
|
287
|
+
device_count = 1
|
|
288
|
+
if is_ddp:
|
|
289
|
+
device_count = torch.cuda.device_count()
|
|
290
|
+
if device_count > 1:
|
|
291
|
+
min_bs = min_bs * device_count
|
|
292
|
+
print(
|
|
293
|
+
f">>> Multi-GPU detected: {device_count} devices. Adjusted min_bs to {min_bs}.")
|
|
294
|
+
|
|
295
|
+
if data_size > large_threshold:
|
|
296
|
+
base_bs = gpu_large * device_count
|
|
297
|
+
elif data_size > mid_threshold:
|
|
298
|
+
base_bs = gpu_mid * device_count
|
|
299
|
+
else:
|
|
300
|
+
base_bs = gpu_small * device_count
|
|
301
|
+
else:
|
|
302
|
+
base_bs = cpu_mid if data_size > mid_threshold else cpu_small
|
|
303
|
+
|
|
304
|
+
batch_size = compute_batch_size(
|
|
305
|
+
data_size=data_size,
|
|
306
|
+
learning_rate=self.learning_rate,
|
|
307
|
+
batch_num=self.batch_num,
|
|
308
|
+
minimum=min_bs
|
|
309
|
+
)
|
|
310
|
+
batch_size = min(batch_size, base_bs, data_size)
|
|
311
|
+
batch_size = self._cap_batch_size_by_memory(
|
|
312
|
+
dataset, batch_size, profile)
|
|
313
|
+
|
|
314
|
+
target_effective_bs = target_effective_cuda if device_type == 'cuda' else target_effective_cpu
|
|
315
|
+
world_size = 1
|
|
316
|
+
if is_ddp:
|
|
317
|
+
world_size = getattr(self, "world_size", None)
|
|
318
|
+
world_size = max(1, world_size or DistributedUtils.world_size())
|
|
319
|
+
target_effective_bs = max(1, target_effective_bs // world_size)
|
|
320
|
+
samples_per_rank = math.ceil(
|
|
321
|
+
data_size / max(1, world_size)) if world_size > 1 else data_size
|
|
322
|
+
steps_per_epoch = max(
|
|
323
|
+
1, math.ceil(samples_per_rank / max(1, batch_size)))
|
|
324
|
+
desired_accum = max(1, target_effective_bs // max(1, batch_size))
|
|
325
|
+
accum_steps = max(1, min(desired_accum, steps_per_epoch))
|
|
326
|
+
|
|
327
|
+
workers = self._resolve_num_workers(8, profile=profile)
|
|
328
|
+
prefetch_factor = None
|
|
329
|
+
if workers > 0:
|
|
330
|
+
prefetch_factor = 4 if profile == "throughput" else 2
|
|
331
|
+
persistent = workers > 0 and profile != "memory_saving"
|
|
332
|
+
print(
|
|
333
|
+
f">>> DataLoader config: Batch Size={batch_size}, Accum Steps={accum_steps}, "
|
|
334
|
+
f"Workers={workers}, Prefetch={prefetch_factor or 'off'}, Profile={profile}")
|
|
335
|
+
sampler = None
|
|
336
|
+
use_distributed_sampler = bool(
|
|
337
|
+
dist.is_initialized() and getattr(self, "is_ddp_enabled", False)
|
|
338
|
+
)
|
|
339
|
+
if use_distributed_sampler:
|
|
340
|
+
sampler = DistributedSampler(dataset, shuffle=True)
|
|
341
|
+
shuffle = False
|
|
342
|
+
else:
|
|
343
|
+
shuffle = True
|
|
344
|
+
|
|
345
|
+
dataloader = DataLoader(
|
|
346
|
+
dataset,
|
|
347
|
+
batch_size=batch_size,
|
|
348
|
+
shuffle=shuffle,
|
|
349
|
+
sampler=sampler,
|
|
350
|
+
num_workers=workers,
|
|
351
|
+
pin_memory=(device_type == 'cuda'),
|
|
352
|
+
persistent_workers=persistent,
|
|
353
|
+
**({"prefetch_factor": prefetch_factor} if prefetch_factor is not None else {}),
|
|
354
|
+
)
|
|
355
|
+
self.dataloader_sampler = sampler
|
|
356
|
+
return dataloader, accum_steps
|
|
357
|
+
|
|
358
|
+
def _build_val_dataloader(self, dataset, train_dataloader, accum_steps):
|
|
359
|
+
"""Build validation DataLoader."""
|
|
360
|
+
profile = self._resolve_resource_profile()
|
|
361
|
+
val_bs = accum_steps * train_dataloader.batch_size
|
|
362
|
+
val_workers = self._resolve_num_workers(4, profile=profile)
|
|
363
|
+
prefetch_factor = None
|
|
364
|
+
if val_workers > 0:
|
|
365
|
+
prefetch_factor = 2
|
|
366
|
+
return DataLoader(
|
|
367
|
+
dataset,
|
|
368
|
+
batch_size=val_bs,
|
|
369
|
+
shuffle=False,
|
|
370
|
+
num_workers=val_workers,
|
|
371
|
+
pin_memory=(self._device_type() == 'cuda'),
|
|
372
|
+
persistent_workers=(val_workers > 0 and profile != "memory_saving"),
|
|
373
|
+
**({"prefetch_factor": prefetch_factor} if prefetch_factor is not None else {}),
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
def _compute_losses(self, y_pred, y_true, apply_softplus: bool = False):
|
|
377
|
+
"""Compute per-sample losses based on task type."""
|
|
378
|
+
task = getattr(self, "task_type", "regression")
|
|
379
|
+
if task == 'classification':
|
|
380
|
+
loss_fn = nn.BCEWithLogitsLoss(reduction='none')
|
|
381
|
+
return loss_fn(y_pred, y_true).view(-1)
|
|
382
|
+
loss_name = normalize_loss_name(
|
|
383
|
+
getattr(self, "loss_name", None), task_type="regression"
|
|
384
|
+
)
|
|
385
|
+
if loss_name == "auto":
|
|
386
|
+
model_name = getattr(self, "model_name", None) or getattr(self, "model_nme", "")
|
|
387
|
+
loss_name = infer_loss_name_from_model_name(model_name)
|
|
388
|
+
if apply_softplus:
|
|
389
|
+
y_pred = F.softplus(y_pred)
|
|
390
|
+
if loss_requires_positive(loss_name):
|
|
391
|
+
y_pred = torch.clamp(y_pred, min=1e-6)
|
|
392
|
+
power = resolve_tweedie_power(
|
|
393
|
+
loss_name, default=float(getattr(self, "tw_power", 1.5) or 1.5)
|
|
394
|
+
)
|
|
395
|
+
if power is None:
|
|
396
|
+
power = float(getattr(self, "tw_power", 1.5) or 1.5)
|
|
397
|
+
return tweedie_loss(y_pred, y_true, p=power).view(-1)
|
|
398
|
+
if loss_name == "mse":
|
|
399
|
+
return (y_pred - y_true).pow(2).view(-1)
|
|
400
|
+
if loss_name == "mae":
|
|
401
|
+
return (y_pred - y_true).abs().view(-1)
|
|
402
|
+
raise ValueError(f"Unsupported loss_name '{loss_name}' for regression.")
|
|
403
|
+
|
|
404
|
+
def _compute_weighted_loss(self, y_pred, y_true, weights, apply_softplus: bool = False):
|
|
405
|
+
"""Compute weighted loss."""
|
|
406
|
+
losses = self._compute_losses(
|
|
407
|
+
y_pred, y_true, apply_softplus=apply_softplus)
|
|
408
|
+
weighted_loss = (losses * weights.view(-1)).sum() / \
|
|
409
|
+
torch.clamp(weights.sum(), min=EPS)
|
|
410
|
+
return weighted_loss
|
|
411
|
+
|
|
412
|
+
def _early_stop_update(self, val_loss, best_loss, best_state, patience_counter, model,
|
|
413
|
+
ignore_keys: Optional[List[str]] = None):
|
|
414
|
+
"""Update early stopping state."""
|
|
415
|
+
if val_loss < best_loss:
|
|
416
|
+
ignore_keys = ignore_keys or []
|
|
417
|
+
base_module = model.module if hasattr(model, "module") else model
|
|
418
|
+
state_dict = {
|
|
419
|
+
k: (v.clone() if isinstance(v, torch.Tensor) else copy.deepcopy(v))
|
|
420
|
+
for k, v in base_module.state_dict().items()
|
|
421
|
+
if not any(k.startswith(ignore_key) for ignore_key in ignore_keys)
|
|
422
|
+
}
|
|
423
|
+
return val_loss, state_dict, 0, False
|
|
424
|
+
patience_counter += 1
|
|
425
|
+
should_stop = best_state is not None and patience_counter >= getattr(
|
|
426
|
+
self, "patience", 0)
|
|
427
|
+
return best_loss, best_state, patience_counter, should_stop
|
|
428
|
+
|
|
429
|
+
def _train_model(self,
|
|
430
|
+
model,
|
|
431
|
+
dataloader,
|
|
432
|
+
accum_steps,
|
|
433
|
+
optimizer,
|
|
434
|
+
scaler,
|
|
435
|
+
forward_fn,
|
|
436
|
+
val_forward_fn=None,
|
|
437
|
+
apply_softplus: bool = False,
|
|
438
|
+
clip_fn=None,
|
|
439
|
+
trial: Optional[optuna.trial.Trial] = None,
|
|
440
|
+
loss_curve_path: Optional[str] = None):
|
|
441
|
+
"""Generic training loop with AMP, DDP, and early stopping support.
|
|
442
|
+
|
|
443
|
+
Returns:
|
|
444
|
+
Tuple of (best_state_dict, history)
|
|
445
|
+
"""
|
|
446
|
+
device_type = self._device_type()
|
|
447
|
+
best_loss = float('inf')
|
|
448
|
+
best_state = None
|
|
449
|
+
patience_counter = 0
|
|
450
|
+
stop_training = False
|
|
451
|
+
train_history: List[float] = []
|
|
452
|
+
val_history: List[float] = []
|
|
453
|
+
|
|
454
|
+
is_ddp_model = isinstance(model, DDP)
|
|
455
|
+
use_collectives = dist.is_initialized() and is_ddp_model
|
|
456
|
+
|
|
457
|
+
for epoch in range(1, getattr(self, "epochs", 1) + 1):
|
|
458
|
+
epoch_start_ts = time.time()
|
|
459
|
+
val_weighted_loss = None
|
|
460
|
+
if hasattr(self, 'dataloader_sampler') and self.dataloader_sampler is not None:
|
|
461
|
+
self.dataloader_sampler.set_epoch(epoch)
|
|
462
|
+
|
|
463
|
+
model.train()
|
|
464
|
+
optimizer.zero_grad()
|
|
465
|
+
|
|
466
|
+
epoch_loss_sum = None
|
|
467
|
+
epoch_weight_sum = None
|
|
468
|
+
for step, batch in enumerate(dataloader):
|
|
469
|
+
is_update_step = ((step + 1) % accum_steps == 0) or \
|
|
470
|
+
((step + 1) == len(dataloader))
|
|
471
|
+
sync_cm = model.no_sync if (
|
|
472
|
+
is_ddp_model and not is_update_step) else nullcontext
|
|
473
|
+
|
|
474
|
+
with sync_cm():
|
|
475
|
+
with autocast(enabled=(device_type == 'cuda')):
|
|
476
|
+
y_pred, y_true, w = forward_fn(batch)
|
|
477
|
+
weighted_loss = self._compute_weighted_loss(
|
|
478
|
+
y_pred, y_true, w, apply_softplus=apply_softplus)
|
|
479
|
+
loss_for_backward = weighted_loss / accum_steps
|
|
480
|
+
|
|
481
|
+
batch_weight = torch.clamp(
|
|
482
|
+
w.detach().sum(), min=EPS).to(dtype=torch.float32)
|
|
483
|
+
loss_val = weighted_loss.detach().to(dtype=torch.float32)
|
|
484
|
+
if epoch_loss_sum is None:
|
|
485
|
+
epoch_loss_sum = torch.zeros(
|
|
486
|
+
(), device=batch_weight.device, dtype=torch.float32)
|
|
487
|
+
epoch_weight_sum = torch.zeros(
|
|
488
|
+
(), device=batch_weight.device, dtype=torch.float32)
|
|
489
|
+
epoch_loss_sum = epoch_loss_sum + loss_val * batch_weight
|
|
490
|
+
epoch_weight_sum = epoch_weight_sum + batch_weight
|
|
491
|
+
scaler.scale(loss_for_backward).backward()
|
|
492
|
+
|
|
493
|
+
if is_update_step:
|
|
494
|
+
if clip_fn is not None:
|
|
495
|
+
clip_fn()
|
|
496
|
+
scaler.step(optimizer)
|
|
497
|
+
scaler.update()
|
|
498
|
+
optimizer.zero_grad()
|
|
499
|
+
|
|
500
|
+
if epoch_loss_sum is None or epoch_weight_sum is None:
|
|
501
|
+
train_epoch_loss = 0.0
|
|
502
|
+
else:
|
|
503
|
+
train_epoch_loss = (
|
|
504
|
+
epoch_loss_sum / torch.clamp(epoch_weight_sum, min=EPS)
|
|
505
|
+
).item()
|
|
506
|
+
train_history.append(float(train_epoch_loss))
|
|
507
|
+
|
|
508
|
+
if val_forward_fn is not None:
|
|
509
|
+
should_compute_val = (not dist.is_initialized()
|
|
510
|
+
or DistributedUtils.is_main_process())
|
|
511
|
+
val_device = self._resolve_device()
|
|
512
|
+
loss_tensor_device = val_device if device_type == 'cuda' else torch.device(
|
|
513
|
+
"cpu")
|
|
514
|
+
val_loss_tensor = torch.zeros(1, device=loss_tensor_device)
|
|
515
|
+
|
|
516
|
+
if should_compute_val:
|
|
517
|
+
model.eval()
|
|
518
|
+
with torch.no_grad(), autocast(enabled=(device_type == 'cuda')):
|
|
519
|
+
val_result = val_forward_fn()
|
|
520
|
+
if isinstance(val_result, tuple) and len(val_result) == 3:
|
|
521
|
+
y_val_pred, y_val_true, w_val = val_result
|
|
522
|
+
val_weighted_loss = self._compute_weighted_loss(
|
|
523
|
+
y_val_pred, y_val_true, w_val, apply_softplus=apply_softplus)
|
|
524
|
+
else:
|
|
525
|
+
val_weighted_loss = val_result
|
|
526
|
+
val_loss_tensor[0] = float(val_weighted_loss)
|
|
527
|
+
|
|
528
|
+
if use_collectives:
|
|
529
|
+
dist.broadcast(val_loss_tensor, src=0)
|
|
530
|
+
val_weighted_loss = float(val_loss_tensor.item())
|
|
531
|
+
|
|
532
|
+
val_history.append(val_weighted_loss)
|
|
533
|
+
|
|
534
|
+
best_loss, best_state, patience_counter, stop_training = self._early_stop_update(
|
|
535
|
+
val_weighted_loss, best_loss, best_state, patience_counter, model)
|
|
536
|
+
|
|
537
|
+
prune_flag = False
|
|
538
|
+
is_main_rank = DistributedUtils.is_main_process()
|
|
539
|
+
if trial is not None and is_main_rank:
|
|
540
|
+
trial.report(val_weighted_loss, epoch)
|
|
541
|
+
prune_flag = trial.should_prune()
|
|
542
|
+
|
|
543
|
+
if use_collectives:
|
|
544
|
+
prune_device = self._resolve_device()
|
|
545
|
+
prune_tensor = torch.zeros(1, device=prune_device)
|
|
546
|
+
if is_main_rank:
|
|
547
|
+
prune_tensor.fill_(1 if prune_flag else 0)
|
|
548
|
+
dist.broadcast(prune_tensor, src=0)
|
|
549
|
+
prune_flag = bool(prune_tensor.item())
|
|
550
|
+
|
|
551
|
+
if prune_flag:
|
|
552
|
+
raise optuna.TrialPruned()
|
|
553
|
+
|
|
554
|
+
if stop_training:
|
|
555
|
+
break
|
|
556
|
+
|
|
557
|
+
should_log_epoch = (not dist.is_initialized()
|
|
558
|
+
or DistributedUtils.is_main_process())
|
|
559
|
+
if should_log_epoch:
|
|
560
|
+
elapsed = int(time.time() - epoch_start_ts)
|
|
561
|
+
if val_weighted_loss is None:
|
|
562
|
+
print(
|
|
563
|
+
f"[Training] Epoch {epoch}/{getattr(self, 'epochs', 1)} "
|
|
564
|
+
f"train_loss={float(train_epoch_loss):.6f} elapsed={elapsed}s",
|
|
565
|
+
flush=True,
|
|
566
|
+
)
|
|
567
|
+
else:
|
|
568
|
+
print(
|
|
569
|
+
f"[Training] Epoch {epoch}/{getattr(self, 'epochs', 1)} "
|
|
570
|
+
f"train_loss={float(train_epoch_loss):.6f} "
|
|
571
|
+
f"val_loss={float(val_weighted_loss):.6f} elapsed={elapsed}s",
|
|
572
|
+
flush=True,
|
|
573
|
+
)
|
|
574
|
+
|
|
575
|
+
if epoch % 10 == 0:
|
|
576
|
+
if torch.cuda.is_available():
|
|
577
|
+
torch.cuda.empty_cache()
|
|
578
|
+
gc.collect()
|
|
579
|
+
|
|
580
|
+
history = {"train": train_history, "val": val_history}
|
|
581
|
+
self._plot_loss_curve(history, loss_curve_path)
|
|
582
|
+
return best_state, history
|
|
583
|
+
|
|
584
|
+
def _plot_loss_curve(self, history: Dict[str, List[float]], save_path: Optional[str]) -> None:
|
|
585
|
+
"""Plot training and validation loss curves."""
|
|
586
|
+
if not save_path:
|
|
587
|
+
return
|
|
588
|
+
if dist.is_initialized() and not DistributedUtils.is_main_process():
|
|
589
|
+
return
|
|
590
|
+
train_hist = history.get("train", []) if history else []
|
|
591
|
+
val_hist = history.get("val", []) if history else []
|
|
592
|
+
if not train_hist and not val_hist:
|
|
593
|
+
return
|
|
594
|
+
if plot_loss_curve_common is not None:
|
|
595
|
+
plot_loss_curve_common(
|
|
596
|
+
history=history,
|
|
597
|
+
title="Loss vs. Epoch",
|
|
598
|
+
save_path=save_path,
|
|
599
|
+
show=False,
|
|
600
|
+
)
|
|
601
|
+
else:
|
|
602
|
+
if plt is None:
|
|
603
|
+
_plot_skip("loss curve")
|
|
604
|
+
return
|
|
605
|
+
ensure_parent_dir(save_path)
|
|
606
|
+
epochs = range(1, max(len(train_hist), len(val_hist)) + 1)
|
|
607
|
+
fig = plt.figure(figsize=(8, 4))
|
|
608
|
+
ax = fig.add_subplot(111)
|
|
609
|
+
if train_hist:
|
|
610
|
+
ax.plot(range(1, len(train_hist) + 1), train_hist,
|
|
611
|
+
label='Train Loss', color='tab:blue')
|
|
612
|
+
if val_hist:
|
|
613
|
+
ax.plot(range(1, len(val_hist) + 1), val_hist,
|
|
614
|
+
label='Validation Loss', color='tab:orange')
|
|
615
|
+
ax.set_xlabel('Epoch')
|
|
616
|
+
ax.set_ylabel('Weighted Loss')
|
|
617
|
+
ax.set_title('Loss vs. Epoch')
|
|
618
|
+
ax.grid(True, linestyle='--', alpha=0.3)
|
|
619
|
+
ax.legend()
|
|
620
|
+
plt.tight_layout()
|
|
621
|
+
plt.savefig(save_path, dpi=300)
|
|
622
|
+
plt.close(fig)
|
|
623
|
+
print(f"[Training] Loss curve saved to {save_path}")
|