ins-pricing 0.4.5__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ins_pricing/README.md +48 -22
- ins_pricing/__init__.py +142 -90
- ins_pricing/cli/BayesOpt_entry.py +58 -46
- ins_pricing/cli/BayesOpt_incremental.py +77 -110
- ins_pricing/cli/Explain_Run.py +42 -23
- ins_pricing/cli/Explain_entry.py +551 -577
- ins_pricing/cli/Pricing_Run.py +42 -23
- ins_pricing/cli/bayesopt_entry_runner.py +51 -16
- ins_pricing/cli/utils/bootstrap.py +23 -0
- ins_pricing/cli/utils/cli_common.py +256 -256
- ins_pricing/cli/utils/cli_config.py +379 -360
- ins_pricing/cli/utils/import_resolver.py +375 -358
- ins_pricing/cli/utils/notebook_utils.py +256 -242
- ins_pricing/cli/watchdog_run.py +216 -198
- ins_pricing/frontend/__init__.py +10 -10
- ins_pricing/frontend/app.py +132 -61
- ins_pricing/frontend/config_builder.py +33 -0
- ins_pricing/frontend/example_config.json +11 -0
- ins_pricing/frontend/example_workflows.py +1 -1
- ins_pricing/frontend/runner.py +340 -388
- ins_pricing/governance/__init__.py +20 -20
- ins_pricing/governance/release.py +159 -159
- ins_pricing/modelling/README.md +1 -1
- ins_pricing/modelling/__init__.py +147 -92
- ins_pricing/modelling/{core/bayesopt → bayesopt}/README.md +31 -13
- ins_pricing/modelling/{core/bayesopt → bayesopt}/__init__.py +64 -102
- ins_pricing/modelling/{core/bayesopt → bayesopt}/config_components.py +12 -0
- ins_pricing/modelling/{core/bayesopt → bayesopt}/config_preprocess.py +589 -552
- ins_pricing/modelling/{core/bayesopt → bayesopt}/core.py +987 -958
- ins_pricing/modelling/{core/bayesopt → bayesopt}/model_explain_mixin.py +296 -296
- ins_pricing/modelling/{core/bayesopt → bayesopt}/model_plotting_mixin.py +488 -548
- ins_pricing/modelling/{core/bayesopt → bayesopt}/models/__init__.py +27 -27
- ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_components.py +349 -342
- ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_trainer.py +921 -913
- ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_gnn.py +794 -785
- ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_resn.py +454 -446
- ins_pricing/modelling/bayesopt/trainers/__init__.py +19 -0
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_base.py +1294 -1282
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_ft.py +64 -56
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_glm.py +203 -198
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_gnn.py +333 -325
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_resn.py +279 -267
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_xgb.py +515 -313
- ins_pricing/modelling/bayesopt/utils/__init__.py +67 -0
- ins_pricing/modelling/bayesopt/utils/constants.py +21 -0
- ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/distributed_utils.py +193 -186
- ins_pricing/modelling/bayesopt/utils/io_utils.py +7 -0
- ins_pricing/modelling/bayesopt/utils/losses.py +27 -0
- ins_pricing/modelling/bayesopt/utils/metrics_and_devices.py +17 -0
- ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/torch_trainer_mixin.py +636 -623
- ins_pricing/modelling/{core/evaluation.py → evaluation.py} +113 -104
- ins_pricing/modelling/explain/__init__.py +55 -55
- ins_pricing/modelling/explain/metrics.py +27 -174
- ins_pricing/modelling/explain/permutation.py +237 -237
- ins_pricing/modelling/plotting/__init__.py +40 -36
- ins_pricing/modelling/plotting/compat.py +228 -0
- ins_pricing/modelling/plotting/curves.py +572 -572
- ins_pricing/modelling/plotting/diagnostics.py +163 -163
- ins_pricing/modelling/plotting/geo.py +362 -362
- ins_pricing/modelling/plotting/importance.py +121 -121
- ins_pricing/pricing/__init__.py +27 -27
- ins_pricing/pricing/factors.py +67 -56
- ins_pricing/production/__init__.py +35 -25
- ins_pricing/production/{predict.py → inference.py} +140 -57
- ins_pricing/production/monitoring.py +8 -21
- ins_pricing/reporting/__init__.py +11 -11
- ins_pricing/setup.py +1 -1
- ins_pricing/tests/production/test_inference.py +90 -0
- ins_pricing/utils/__init__.py +112 -78
- ins_pricing/utils/device.py +258 -237
- ins_pricing/utils/features.py +53 -0
- ins_pricing/utils/io.py +72 -0
- ins_pricing/utils/logging.py +34 -1
- ins_pricing/{modelling/core/bayesopt/utils → utils}/losses.py +125 -129
- ins_pricing/utils/metrics.py +158 -24
- ins_pricing/utils/numerics.py +76 -0
- ins_pricing/utils/paths.py +9 -1
- ins_pricing/utils/profiling.py +8 -4
- {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/METADATA +1 -1
- ins_pricing-0.5.1.dist-info/RECORD +132 -0
- ins_pricing/modelling/core/BayesOpt.py +0 -146
- ins_pricing/modelling/core/__init__.py +0 -1
- ins_pricing/modelling/core/bayesopt/trainers/__init__.py +0 -19
- ins_pricing/modelling/core/bayesopt/utils/__init__.py +0 -86
- ins_pricing/modelling/core/bayesopt/utils/constants.py +0 -183
- ins_pricing/modelling/core/bayesopt/utils/io_utils.py +0 -126
- ins_pricing/modelling/core/bayesopt/utils/metrics_and_devices.py +0 -555
- ins_pricing/modelling/core/bayesopt/utils.py +0 -105
- ins_pricing/modelling/core/bayesopt/utils_backup.py +0 -1503
- ins_pricing/tests/production/test_predict.py +0 -233
- ins_pricing-0.4.5.dist-info/RECORD +0 -130
- {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/WHEEL +0 -0
- {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/top_level.txt +0 -0
|
@@ -1,913 +1,921 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
import copy
|
|
4
|
-
from contextlib import nullcontext
|
|
5
|
-
from typing import Any, Dict, List, Optional
|
|
6
|
-
|
|
7
|
-
import numpy as np
|
|
8
|
-
import optuna
|
|
9
|
-
import pandas as pd
|
|
10
|
-
import torch
|
|
11
|
-
import torch.distributed as dist
|
|
12
|
-
import torch.nn as nn
|
|
13
|
-
import torch.nn.functional as F
|
|
14
|
-
from torch.cuda.amp import autocast, GradScaler
|
|
15
|
-
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
16
|
-
from torch.nn.utils import clip_grad_norm_
|
|
17
|
-
|
|
18
|
-
from
|
|
19
|
-
from
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
self.
|
|
181
|
-
self.
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
self.
|
|
187
|
-
self.
|
|
188
|
-
self.
|
|
189
|
-
self.
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
self.
|
|
195
|
-
self.
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
self.
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
self.
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
self.
|
|
226
|
-
self.
|
|
227
|
-
self.
|
|
228
|
-
self.
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
self.
|
|
254
|
-
self.
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
self.
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
core =
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
self.use_data_parallel =
|
|
288
|
-
|
|
289
|
-
self.
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
if unmapped.any():
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
raise ValueError(
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
if
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
self.
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
if
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
if
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
geo_np
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
else
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
scaler,
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
)
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
)
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
self.device, non_blocking=True)
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
self.device, non_blocking=True)
|
|
768
|
-
|
|
769
|
-
self.device, non_blocking=True)
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
device=
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
if val_loss_value
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
)
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
|
|
897
|
-
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import copy
|
|
4
|
+
from contextlib import nullcontext
|
|
5
|
+
from typing import Any, Dict, List, Optional
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import optuna
|
|
9
|
+
import pandas as pd
|
|
10
|
+
import torch
|
|
11
|
+
import torch.distributed as dist
|
|
12
|
+
import torch.nn as nn
|
|
13
|
+
import torch.nn.functional as F
|
|
14
|
+
from torch.cuda.amp import autocast, GradScaler
|
|
15
|
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
16
|
+
from torch.nn.utils import clip_grad_norm_
|
|
17
|
+
|
|
18
|
+
from ins_pricing.modelling.bayesopt.utils.distributed_utils import DistributedUtils
|
|
19
|
+
from ins_pricing.modelling.bayesopt.utils.torch_trainer_mixin import TorchTrainerMixin
|
|
20
|
+
from ins_pricing.utils import EPS, get_logger, log_print
|
|
21
|
+
from ins_pricing.utils.losses import (
|
|
22
|
+
infer_loss_name_from_model_name,
|
|
23
|
+
normalize_loss_name,
|
|
24
|
+
resolve_tweedie_power,
|
|
25
|
+
)
|
|
26
|
+
from ins_pricing.modelling.bayesopt.models.model_ft_components import FTTransformerCore, MaskedTabularDataset, TabularDataset
|
|
27
|
+
|
|
28
|
+
_logger = get_logger("ins_pricing.modelling.bayesopt.models.model_ft_trainer")
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _log(*args, **kwargs) -> None:
|
|
32
|
+
log_print(_logger, *args, **kwargs)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
# --- Helper functions for reconstruction loss computation ---
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _compute_numeric_reconstruction_loss(
|
|
39
|
+
num_pred: Optional[torch.Tensor],
|
|
40
|
+
num_true: Optional[torch.Tensor],
|
|
41
|
+
num_mask: Optional[torch.Tensor],
|
|
42
|
+
loss_weight: float,
|
|
43
|
+
device: torch.device,
|
|
44
|
+
) -> torch.Tensor:
|
|
45
|
+
"""Compute MSE loss for numeric feature reconstruction.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
num_pred: Predicted numeric values (N, num_features)
|
|
49
|
+
num_true: Ground truth numeric values (N, num_features)
|
|
50
|
+
num_mask: Boolean mask indicating which values were masked (N, num_features)
|
|
51
|
+
loss_weight: Weight to apply to the loss
|
|
52
|
+
device: Target device for computation
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
Weighted MSE loss for masked numeric features
|
|
56
|
+
"""
|
|
57
|
+
if num_pred is None or num_true is None or num_mask is None:
|
|
58
|
+
return torch.zeros((), device=device, dtype=torch.float32)
|
|
59
|
+
|
|
60
|
+
num_mask = num_mask.to(dtype=torch.bool)
|
|
61
|
+
if not num_mask.any():
|
|
62
|
+
return torch.zeros((), device=device, dtype=torch.float32)
|
|
63
|
+
|
|
64
|
+
diff = num_pred - num_true
|
|
65
|
+
mse = diff * diff
|
|
66
|
+
return float(loss_weight) * mse[num_mask].mean()
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _compute_categorical_reconstruction_loss(
|
|
70
|
+
cat_logits: Optional[List[torch.Tensor]],
|
|
71
|
+
cat_true: Optional[torch.Tensor],
|
|
72
|
+
cat_mask: Optional[torch.Tensor],
|
|
73
|
+
loss_weight: float,
|
|
74
|
+
device: torch.device,
|
|
75
|
+
) -> torch.Tensor:
|
|
76
|
+
"""Compute cross-entropy loss for categorical feature reconstruction.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
cat_logits: List of logits for each categorical feature
|
|
80
|
+
cat_true: Ground truth categorical indices (N, num_cat_features)
|
|
81
|
+
cat_mask: Boolean mask indicating which values were masked (N, num_cat_features)
|
|
82
|
+
loss_weight: Weight to apply to the loss
|
|
83
|
+
device: Target device for computation
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
Weighted cross-entropy loss for masked categorical features
|
|
87
|
+
"""
|
|
88
|
+
if not cat_logits or cat_true is None or cat_mask is None:
|
|
89
|
+
return torch.zeros((), device=device, dtype=torch.float32)
|
|
90
|
+
|
|
91
|
+
cat_mask = cat_mask.to(dtype=torch.bool)
|
|
92
|
+
cat_losses: List[torch.Tensor] = []
|
|
93
|
+
|
|
94
|
+
for j, logits in enumerate(cat_logits):
|
|
95
|
+
mask_j = cat_mask[:, j]
|
|
96
|
+
if not mask_j.any():
|
|
97
|
+
continue
|
|
98
|
+
targets = cat_true[:, j]
|
|
99
|
+
cat_losses.append(
|
|
100
|
+
F.cross_entropy(logits, targets, reduction='none')[mask_j].mean()
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
if not cat_losses:
|
|
104
|
+
return torch.zeros((), device=device, dtype=torch.float32)
|
|
105
|
+
|
|
106
|
+
return float(loss_weight) * torch.stack(cat_losses).mean()
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def _compute_reconstruction_loss(
|
|
110
|
+
num_pred: Optional[torch.Tensor],
|
|
111
|
+
cat_logits: Optional[List[torch.Tensor]],
|
|
112
|
+
num_true: Optional[torch.Tensor],
|
|
113
|
+
num_mask: Optional[torch.Tensor],
|
|
114
|
+
cat_true: Optional[torch.Tensor],
|
|
115
|
+
cat_mask: Optional[torch.Tensor],
|
|
116
|
+
num_loss_weight: float,
|
|
117
|
+
cat_loss_weight: float,
|
|
118
|
+
device: torch.device,
|
|
119
|
+
) -> torch.Tensor:
|
|
120
|
+
"""Compute combined reconstruction loss for masked tabular data.
|
|
121
|
+
|
|
122
|
+
This combines numeric (MSE) and categorical (cross-entropy) reconstruction losses.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
num_pred: Predicted numeric values
|
|
126
|
+
cat_logits: List of logits for categorical features
|
|
127
|
+
num_true: Ground truth numeric values
|
|
128
|
+
num_mask: Mask for numeric features
|
|
129
|
+
cat_true: Ground truth categorical indices
|
|
130
|
+
cat_mask: Mask for categorical features
|
|
131
|
+
num_loss_weight: Weight for numeric loss
|
|
132
|
+
cat_loss_weight: Weight for categorical loss
|
|
133
|
+
device: Target device for computation
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
Combined weighted reconstruction loss
|
|
137
|
+
"""
|
|
138
|
+
num_loss = _compute_numeric_reconstruction_loss(
|
|
139
|
+
num_pred, num_true, num_mask, num_loss_weight, device
|
|
140
|
+
)
|
|
141
|
+
cat_loss = _compute_categorical_reconstruction_loss(
|
|
142
|
+
cat_logits, cat_true, cat_mask, cat_loss_weight, device
|
|
143
|
+
)
|
|
144
|
+
return num_loss + cat_loss
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
# Scikit-Learn style wrapper for FTTransformer.
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
class FTTransformerSklearn(TorchTrainerMixin, nn.Module):
|
|
151
|
+
|
|
152
|
+
# sklearn-style wrapper:
|
|
153
|
+
# - num_cols: numeric feature column names
|
|
154
|
+
# - cat_cols: categorical feature column names (label-encoded to [0, n_classes-1])
|
|
155
|
+
|
|
156
|
+
@staticmethod
|
|
157
|
+
def resolve_numeric_token_count(num_cols, cat_cols, requested: Optional[int]) -> int:
|
|
158
|
+
num_cols_count = len(num_cols or [])
|
|
159
|
+
if num_cols_count == 0:
|
|
160
|
+
return 0
|
|
161
|
+
if requested is not None:
|
|
162
|
+
count = int(requested)
|
|
163
|
+
if count <= 0:
|
|
164
|
+
raise ValueError("num_numeric_tokens must be >= 1 when numeric features exist.")
|
|
165
|
+
return count
|
|
166
|
+
return max(1, num_cols_count)
|
|
167
|
+
|
|
168
|
+
def __init__(self, model_nme: str, num_cols, cat_cols, d_model: int = 64, n_heads: int = 8,
|
|
169
|
+
n_layers: int = 4, dropout: float = 0.1, batch_num: int = 100, epochs: int = 100,
|
|
170
|
+
task_type: str = 'regression',
|
|
171
|
+
tweedie_power: float = 1.5, learning_rate: float = 1e-3, patience: int = 10,
|
|
172
|
+
weight_decay: float = 0.0,
|
|
173
|
+
use_data_parallel: bool = True,
|
|
174
|
+
use_ddp: bool = False,
|
|
175
|
+
num_numeric_tokens: Optional[int] = None,
|
|
176
|
+
loss_name: Optional[str] = None
|
|
177
|
+
):
|
|
178
|
+
super().__init__()
|
|
179
|
+
|
|
180
|
+
self.use_ddp = use_ddp
|
|
181
|
+
self.is_ddp_enabled, self.local_rank, self.rank, self.world_size = (
|
|
182
|
+
False, 0, 0, 1)
|
|
183
|
+
if self.use_ddp:
|
|
184
|
+
self.is_ddp_enabled, self.local_rank, self.rank, self.world_size = DistributedUtils.setup_ddp()
|
|
185
|
+
|
|
186
|
+
self.model_nme = model_nme
|
|
187
|
+
self.num_cols = list(num_cols)
|
|
188
|
+
self.cat_cols = list(cat_cols)
|
|
189
|
+
self.num_numeric_tokens = self.resolve_numeric_token_count(
|
|
190
|
+
self.num_cols,
|
|
191
|
+
self.cat_cols,
|
|
192
|
+
num_numeric_tokens,
|
|
193
|
+
)
|
|
194
|
+
self.d_model = d_model
|
|
195
|
+
self.n_heads = n_heads
|
|
196
|
+
self.n_layers = n_layers
|
|
197
|
+
self.dropout = dropout
|
|
198
|
+
self.batch_num = batch_num
|
|
199
|
+
self.epochs = epochs
|
|
200
|
+
self.learning_rate = learning_rate
|
|
201
|
+
self.weight_decay = weight_decay
|
|
202
|
+
self.task_type = task_type
|
|
203
|
+
self.patience = patience
|
|
204
|
+
resolved_loss = normalize_loss_name(loss_name, self.task_type)
|
|
205
|
+
if self.task_type == 'classification':
|
|
206
|
+
self.loss_name = "logloss"
|
|
207
|
+
self.tw_power = None # No Tweedie power for classification.
|
|
208
|
+
else:
|
|
209
|
+
if resolved_loss == "auto":
|
|
210
|
+
resolved_loss = infer_loss_name_from_model_name(self.model_nme)
|
|
211
|
+
self.loss_name = resolved_loss
|
|
212
|
+
if self.loss_name == "tweedie":
|
|
213
|
+
self.tw_power = float(tweedie_power) if tweedie_power is not None else 1.5
|
|
214
|
+
else:
|
|
215
|
+
self.tw_power = resolve_tweedie_power(self.loss_name, default=1.5)
|
|
216
|
+
|
|
217
|
+
if self.is_ddp_enabled:
|
|
218
|
+
self.device = torch.device(f"cuda:{self.local_rank}")
|
|
219
|
+
elif torch.cuda.is_available():
|
|
220
|
+
self.device = torch.device("cuda")
|
|
221
|
+
elif torch.backends.mps.is_available():
|
|
222
|
+
self.device = torch.device("mps")
|
|
223
|
+
else:
|
|
224
|
+
self.device = torch.device("cpu")
|
|
225
|
+
self.cat_cardinalities = None
|
|
226
|
+
self.cat_categories = {}
|
|
227
|
+
self.cat_maps: Dict[str, Dict[Any, int]] = {}
|
|
228
|
+
self.cat_str_maps: Dict[str, Dict[str, int]] = {}
|
|
229
|
+
self._num_mean = None
|
|
230
|
+
self._num_std = None
|
|
231
|
+
self.ft = None
|
|
232
|
+
self.use_data_parallel = bool(use_data_parallel)
|
|
233
|
+
self.num_geo = 0
|
|
234
|
+
self._geo_params: Dict[str, Any] = {}
|
|
235
|
+
self.loss_curve_path: Optional[str] = None
|
|
236
|
+
self.training_history: Dict[str, List[float]] = {
|
|
237
|
+
"train": [], "val": []}
|
|
238
|
+
|
|
239
|
+
def _build_model(self, X_train):
|
|
240
|
+
num_numeric = len(self.num_cols)
|
|
241
|
+
cat_cardinalities = []
|
|
242
|
+
|
|
243
|
+
if num_numeric > 0:
|
|
244
|
+
num_arr = X_train[self.num_cols].to_numpy(
|
|
245
|
+
dtype=np.float32, copy=False)
|
|
246
|
+
num_arr = np.nan_to_num(num_arr, nan=0.0, posinf=0.0, neginf=0.0)
|
|
247
|
+
mean = num_arr.mean(axis=0).astype(np.float32, copy=False)
|
|
248
|
+
std = num_arr.std(axis=0).astype(np.float32, copy=False)
|
|
249
|
+
std = np.where(std < 1e-6, 1.0, std).astype(np.float32, copy=False)
|
|
250
|
+
self._num_mean = mean
|
|
251
|
+
self._num_std = std
|
|
252
|
+
else:
|
|
253
|
+
self._num_mean = None
|
|
254
|
+
self._num_std = None
|
|
255
|
+
|
|
256
|
+
self.cat_maps = {}
|
|
257
|
+
self.cat_str_maps = {}
|
|
258
|
+
for col in self.cat_cols:
|
|
259
|
+
cats = X_train[col].astype('category')
|
|
260
|
+
categories = cats.cat.categories
|
|
261
|
+
self.cat_categories[col] = categories # Store full category list from training.
|
|
262
|
+
self.cat_maps[col] = {cat: i for i, cat in enumerate(categories)}
|
|
263
|
+
if categories.dtype == object or pd.api.types.is_string_dtype(categories.dtype):
|
|
264
|
+
self.cat_str_maps[col] = {str(cat): i for i, cat in enumerate(categories)}
|
|
265
|
+
|
|
266
|
+
card = len(categories) + 1 # Reserve one extra class for unknown/missing.
|
|
267
|
+
cat_cardinalities.append(card)
|
|
268
|
+
|
|
269
|
+
self.cat_cardinalities = cat_cardinalities
|
|
270
|
+
|
|
271
|
+
core = FTTransformerCore(
|
|
272
|
+
num_numeric=num_numeric,
|
|
273
|
+
cat_cardinalities=cat_cardinalities,
|
|
274
|
+
d_model=self.d_model,
|
|
275
|
+
n_heads=self.n_heads,
|
|
276
|
+
n_layers=self.n_layers,
|
|
277
|
+
dropout=self.dropout,
|
|
278
|
+
task_type=self.task_type,
|
|
279
|
+
num_geo=self.num_geo,
|
|
280
|
+
num_numeric_tokens=self.num_numeric_tokens
|
|
281
|
+
)
|
|
282
|
+
use_dp = self.use_data_parallel and (self.device.type == "cuda") and (torch.cuda.device_count() > 1)
|
|
283
|
+
if self.is_ddp_enabled:
|
|
284
|
+
core = core.to(self.device)
|
|
285
|
+
core = DDP(core, device_ids=[
|
|
286
|
+
self.local_rank], output_device=self.local_rank, find_unused_parameters=True)
|
|
287
|
+
self.use_data_parallel = False
|
|
288
|
+
elif use_dp:
|
|
289
|
+
if self.use_ddp and not self.is_ddp_enabled:
|
|
290
|
+
_log(
|
|
291
|
+
">>> DDP requested but not initialized; falling back to DataParallel.")
|
|
292
|
+
core = nn.DataParallel(core, device_ids=list(
|
|
293
|
+
range(torch.cuda.device_count())))
|
|
294
|
+
self.device = torch.device("cuda")
|
|
295
|
+
self.use_data_parallel = True
|
|
296
|
+
else:
|
|
297
|
+
self.use_data_parallel = False
|
|
298
|
+
self.ft = core.to(self.device)
|
|
299
|
+
|
|
300
|
+
def _encode_cats(self, X):
|
|
301
|
+
# Input DataFrame must include all categorical feature columns.
|
|
302
|
+
# Return int64 array with shape (N, num_categorical_features).
|
|
303
|
+
|
|
304
|
+
if not self.cat_cols:
|
|
305
|
+
return np.zeros((len(X), 0), dtype='int64')
|
|
306
|
+
|
|
307
|
+
n_rows = len(X)
|
|
308
|
+
n_cols = len(self.cat_cols)
|
|
309
|
+
X_cat_np = np.empty((n_rows, n_cols), dtype='int64')
|
|
310
|
+
for idx, col in enumerate(self.cat_cols):
|
|
311
|
+
categories = self.cat_categories[col]
|
|
312
|
+
mapping = self.cat_maps.get(col)
|
|
313
|
+
if mapping is None:
|
|
314
|
+
mapping = {cat: i for i, cat in enumerate(categories)}
|
|
315
|
+
self.cat_maps[col] = mapping
|
|
316
|
+
unknown_idx = len(categories)
|
|
317
|
+
series = X[col]
|
|
318
|
+
codes = series.map(mapping)
|
|
319
|
+
unmapped = series.notna() & codes.isna()
|
|
320
|
+
if unmapped.any():
|
|
321
|
+
try:
|
|
322
|
+
series_cast = series.astype(categories.dtype)
|
|
323
|
+
except Exception:
|
|
324
|
+
series_cast = None
|
|
325
|
+
if series_cast is not None:
|
|
326
|
+
codes = series_cast.map(mapping)
|
|
327
|
+
unmapped = series_cast.notna() & codes.isna()
|
|
328
|
+
if unmapped.any():
|
|
329
|
+
str_map = self.cat_str_maps.get(col)
|
|
330
|
+
if str_map is None:
|
|
331
|
+
str_map = {str(cat): i for i, cat in enumerate(categories)}
|
|
332
|
+
self.cat_str_maps[col] = str_map
|
|
333
|
+
codes = series.astype(str).map(str_map)
|
|
334
|
+
if pd.api.types.is_categorical_dtype(codes):
|
|
335
|
+
codes = codes.astype("float")
|
|
336
|
+
codes = codes.fillna(unknown_idx).astype(
|
|
337
|
+
"int64", copy=False).to_numpy()
|
|
338
|
+
X_cat_np[:, idx] = codes
|
|
339
|
+
return X_cat_np
|
|
340
|
+
|
|
341
|
+
def _build_train_tensors(self, X_train, y_train, w_train, geo_train=None):
|
|
342
|
+
return self._tensorize_split(X_train, y_train, w_train, geo_tokens=geo_train)
|
|
343
|
+
|
|
344
|
+
def _build_val_tensors(self, X_val, y_val, w_val, geo_val=None):
|
|
345
|
+
return self._tensorize_split(X_val, y_val, w_val, geo_tokens=geo_val, allow_none=True)
|
|
346
|
+
|
|
347
|
+
@staticmethod
|
|
348
|
+
def _validate_vector(arr, name: str, n_rows: int) -> None:
|
|
349
|
+
if arr is None:
|
|
350
|
+
return
|
|
351
|
+
if isinstance(arr, pd.DataFrame):
|
|
352
|
+
if arr.shape[1] != 1:
|
|
353
|
+
raise ValueError(f"{name} must be 1d (single column).")
|
|
354
|
+
length = len(arr)
|
|
355
|
+
else:
|
|
356
|
+
arr_np = np.asarray(arr)
|
|
357
|
+
if arr_np.ndim == 0:
|
|
358
|
+
raise ValueError(f"{name} must be 1d.")
|
|
359
|
+
if arr_np.ndim > 2 or (arr_np.ndim == 2 and arr_np.shape[1] != 1):
|
|
360
|
+
raise ValueError(f"{name} must be 1d or Nx1.")
|
|
361
|
+
length = arr_np.shape[0]
|
|
362
|
+
if length != n_rows:
|
|
363
|
+
raise ValueError(
|
|
364
|
+
f"{name} length {length} does not match X length {n_rows}."
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
def _tensorize_split(self, X, y, w, geo_tokens=None, allow_none: bool = False):
|
|
368
|
+
if X is None:
|
|
369
|
+
if allow_none:
|
|
370
|
+
return None, None, None, None, None, False
|
|
371
|
+
raise ValueError("Input features X must not be None.")
|
|
372
|
+
if not isinstance(X, pd.DataFrame):
|
|
373
|
+
raise ValueError("X must be a pandas DataFrame.")
|
|
374
|
+
missing_cols = [
|
|
375
|
+
col for col in (self.num_cols + self.cat_cols) if col not in X.columns
|
|
376
|
+
]
|
|
377
|
+
if missing_cols:
|
|
378
|
+
raise ValueError(f"X is missing required columns: {missing_cols}")
|
|
379
|
+
n_rows = len(X)
|
|
380
|
+
if y is not None:
|
|
381
|
+
self._validate_vector(y, "y", n_rows)
|
|
382
|
+
if w is not None:
|
|
383
|
+
self._validate_vector(w, "w", n_rows)
|
|
384
|
+
|
|
385
|
+
num_np = X[self.num_cols].to_numpy(dtype=np.float32, copy=False)
|
|
386
|
+
if not num_np.flags["OWNDATA"]:
|
|
387
|
+
num_np = num_np.copy()
|
|
388
|
+
num_np = np.nan_to_num(num_np, nan=0.0,
|
|
389
|
+
posinf=0.0, neginf=0.0, copy=False)
|
|
390
|
+
if self._num_mean is not None and self._num_std is not None and num_np.size:
|
|
391
|
+
num_np = (num_np - self._num_mean) / self._num_std
|
|
392
|
+
X_num = torch.as_tensor(num_np)
|
|
393
|
+
if self.cat_cols:
|
|
394
|
+
X_cat = torch.as_tensor(self._encode_cats(X), dtype=torch.long)
|
|
395
|
+
else:
|
|
396
|
+
X_cat = torch.zeros((X_num.shape[0], 0), dtype=torch.long)
|
|
397
|
+
|
|
398
|
+
if geo_tokens is not None:
|
|
399
|
+
geo_np = np.asarray(geo_tokens, dtype=np.float32)
|
|
400
|
+
if geo_np.shape[0] != n_rows:
|
|
401
|
+
raise ValueError(
|
|
402
|
+
"geo_tokens length does not match X rows.")
|
|
403
|
+
if geo_np.ndim == 1:
|
|
404
|
+
geo_np = geo_np.reshape(-1, 1)
|
|
405
|
+
elif self.num_geo > 0:
|
|
406
|
+
raise RuntimeError("geo_tokens must not be empty; prepare geo tokens first.")
|
|
407
|
+
else:
|
|
408
|
+
geo_np = np.zeros((X_num.shape[0], 0), dtype=np.float32)
|
|
409
|
+
X_geo = torch.as_tensor(geo_np)
|
|
410
|
+
|
|
411
|
+
y_tensor = torch.as_tensor(
|
|
412
|
+
y.to_numpy(dtype=np.float32, copy=False) if hasattr(
|
|
413
|
+
y, "to_numpy") else np.asarray(y, dtype=np.float32)
|
|
414
|
+
).view(-1, 1) if y is not None else None
|
|
415
|
+
if y_tensor is None:
|
|
416
|
+
w_tensor = None
|
|
417
|
+
elif w is not None:
|
|
418
|
+
w_tensor = torch.as_tensor(
|
|
419
|
+
w.to_numpy(dtype=np.float32, copy=False) if hasattr(
|
|
420
|
+
w, "to_numpy") else np.asarray(w, dtype=np.float32)
|
|
421
|
+
).view(-1, 1)
|
|
422
|
+
else:
|
|
423
|
+
w_tensor = torch.ones_like(y_tensor)
|
|
424
|
+
return X_num, X_cat, X_geo, y_tensor, w_tensor, y is not None
|
|
425
|
+
|
|
426
|
+
def fit(self, X_train, y_train, w_train=None,
|
|
427
|
+
X_val=None, y_val=None, w_val=None, trial=None,
|
|
428
|
+
geo_train=None, geo_val=None):
|
|
429
|
+
|
|
430
|
+
# Build the underlying model on first fit.
|
|
431
|
+
self.num_geo = geo_train.shape[1] if geo_train is not None else 0
|
|
432
|
+
if self.ft is None:
|
|
433
|
+
self._build_model(X_train)
|
|
434
|
+
|
|
435
|
+
X_num_train, X_cat_train, X_geo_train, y_tensor, w_tensor, _ = self._build_train_tensors(
|
|
436
|
+
X_train, y_train, w_train, geo_train=geo_train)
|
|
437
|
+
X_num_val, X_cat_val, X_geo_val, y_val_tensor, w_val_tensor, has_val = self._build_val_tensors(
|
|
438
|
+
X_val, y_val, w_val, geo_val=geo_val)
|
|
439
|
+
|
|
440
|
+
# --- Build DataLoader ---
|
|
441
|
+
dataset = TabularDataset(
|
|
442
|
+
X_num_train, X_cat_train, X_geo_train, y_tensor, w_tensor
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
dataloader, accum_steps = self._build_dataloader(
|
|
446
|
+
dataset,
|
|
447
|
+
N=X_num_train.shape[0],
|
|
448
|
+
base_bs_gpu=(2048, 1024, 512),
|
|
449
|
+
base_bs_cpu=(256, 128),
|
|
450
|
+
min_bs=64,
|
|
451
|
+
target_effective_cuda=2048,
|
|
452
|
+
target_effective_cpu=1024
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
if self.is_ddp_enabled and hasattr(dataloader.sampler, 'set_epoch'):
|
|
456
|
+
self.dataloader_sampler = dataloader.sampler
|
|
457
|
+
else:
|
|
458
|
+
self.dataloader_sampler = None
|
|
459
|
+
|
|
460
|
+
optimizer = torch.optim.Adam(
|
|
461
|
+
self.ft.parameters(),
|
|
462
|
+
lr=self.learning_rate,
|
|
463
|
+
weight_decay=float(getattr(self, "weight_decay", 0.0)),
|
|
464
|
+
)
|
|
465
|
+
scaler = GradScaler(enabled=(self.device.type == 'cuda'))
|
|
466
|
+
|
|
467
|
+
X_num_val_dev = X_cat_val_dev = y_val_dev = w_val_dev = None
|
|
468
|
+
val_dataloader = None
|
|
469
|
+
if has_val:
|
|
470
|
+
val_dataset = TabularDataset(
|
|
471
|
+
X_num_val, X_cat_val, X_geo_val, y_val_tensor, w_val_tensor
|
|
472
|
+
)
|
|
473
|
+
val_dataloader = self._build_val_dataloader(
|
|
474
|
+
val_dataset, dataloader, accum_steps)
|
|
475
|
+
|
|
476
|
+
# Check for both DataParallel and DDP wrappers
|
|
477
|
+
is_data_parallel = isinstance(self.ft, (nn.DataParallel, DDP))
|
|
478
|
+
|
|
479
|
+
def forward_fn(batch):
|
|
480
|
+
X_num_b, X_cat_b, X_geo_b, y_b, w_b = batch
|
|
481
|
+
|
|
482
|
+
# For DataParallel, inputs are automatically scattered; for DDP, move to local device
|
|
483
|
+
if not isinstance(self.ft, nn.DataParallel):
|
|
484
|
+
X_num_b = X_num_b.to(self.device, non_blocking=True)
|
|
485
|
+
X_cat_b = X_cat_b.to(self.device, non_blocking=True)
|
|
486
|
+
X_geo_b = X_geo_b.to(self.device, non_blocking=True)
|
|
487
|
+
y_b = y_b.to(self.device, non_blocking=True)
|
|
488
|
+
w_b = w_b.to(self.device, non_blocking=True)
|
|
489
|
+
|
|
490
|
+
y_pred = self.ft(X_num_b, X_cat_b, X_geo_b)
|
|
491
|
+
return y_pred, y_b, w_b
|
|
492
|
+
|
|
493
|
+
def val_forward_fn():
|
|
494
|
+
total_loss = 0.0
|
|
495
|
+
total_weight = 0.0
|
|
496
|
+
for batch in val_dataloader:
|
|
497
|
+
X_num_b, X_cat_b, X_geo_b, y_b, w_b = batch
|
|
498
|
+
if not isinstance(self.ft, nn.DataParallel):
|
|
499
|
+
X_num_b = X_num_b.to(self.device, non_blocking=True)
|
|
500
|
+
X_cat_b = X_cat_b.to(self.device, non_blocking=True)
|
|
501
|
+
X_geo_b = X_geo_b.to(self.device, non_blocking=True)
|
|
502
|
+
y_b = y_b.to(self.device, non_blocking=True)
|
|
503
|
+
w_b = w_b.to(self.device, non_blocking=True)
|
|
504
|
+
|
|
505
|
+
y_pred = self.ft(X_num_b, X_cat_b, X_geo_b)
|
|
506
|
+
|
|
507
|
+
# Manually compute validation loss.
|
|
508
|
+
losses = self._compute_losses(
|
|
509
|
+
y_pred, y_b, apply_softplus=False)
|
|
510
|
+
|
|
511
|
+
batch_weight_sum = torch.clamp(w_b.sum(), min=EPS)
|
|
512
|
+
batch_weighted_loss_sum = (losses * w_b.view(-1)).sum()
|
|
513
|
+
|
|
514
|
+
total_loss += batch_weighted_loss_sum.item()
|
|
515
|
+
total_weight += batch_weight_sum.item()
|
|
516
|
+
|
|
517
|
+
return total_loss / max(total_weight, EPS)
|
|
518
|
+
|
|
519
|
+
clip_fn = None
|
|
520
|
+
if self.device.type == 'cuda':
|
|
521
|
+
def clip_fn(): return (scaler.unscale_(optimizer),
|
|
522
|
+
clip_grad_norm_(self.ft.parameters(), max_norm=1.0))
|
|
523
|
+
|
|
524
|
+
best_state, history = self._train_model(
|
|
525
|
+
self.ft,
|
|
526
|
+
dataloader,
|
|
527
|
+
accum_steps,
|
|
528
|
+
optimizer,
|
|
529
|
+
scaler,
|
|
530
|
+
forward_fn,
|
|
531
|
+
val_forward_fn if has_val else None,
|
|
532
|
+
apply_softplus=False,
|
|
533
|
+
clip_fn=clip_fn,
|
|
534
|
+
trial=trial,
|
|
535
|
+
loss_curve_path=getattr(self, "loss_curve_path", None)
|
|
536
|
+
)
|
|
537
|
+
|
|
538
|
+
if has_val and best_state is not None:
|
|
539
|
+
# Load state into unwrapped module to match how it was saved
|
|
540
|
+
base_module = self.ft.module if hasattr(self.ft, "module") else self.ft
|
|
541
|
+
base_module.load_state_dict(best_state)
|
|
542
|
+
self.training_history = history
|
|
543
|
+
|
|
544
|
+
def fit_unsupervised(self,
|
|
545
|
+
X_train,
|
|
546
|
+
X_val=None,
|
|
547
|
+
trial: Optional[optuna.trial.Trial] = None,
|
|
548
|
+
geo_train=None,
|
|
549
|
+
geo_val=None,
|
|
550
|
+
mask_prob_num: float = 0.15,
|
|
551
|
+
mask_prob_cat: float = 0.15,
|
|
552
|
+
num_loss_weight: float = 1.0,
|
|
553
|
+
cat_loss_weight: float = 1.0) -> float:
|
|
554
|
+
"""Self-supervised pretraining via masked reconstruction (supports raw string categories)."""
|
|
555
|
+
self.num_geo = geo_train.shape[1] if geo_train is not None else 0
|
|
556
|
+
if self.ft is None:
|
|
557
|
+
self._build_model(X_train)
|
|
558
|
+
|
|
559
|
+
X_num, X_cat, X_geo, _, _, _ = self._tensorize_split(
|
|
560
|
+
X_train, None, None, geo_tokens=geo_train, allow_none=True)
|
|
561
|
+
has_val = X_val is not None
|
|
562
|
+
if has_val:
|
|
563
|
+
X_num_val, X_cat_val, X_geo_val, _, _, _ = self._tensorize_split(
|
|
564
|
+
X_val, None, None, geo_tokens=geo_val, allow_none=True)
|
|
565
|
+
else:
|
|
566
|
+
X_num_val = X_cat_val = X_geo_val = None
|
|
567
|
+
|
|
568
|
+
N = int(X_num.shape[0])
|
|
569
|
+
num_dim = int(X_num.shape[1])
|
|
570
|
+
cat_dim = int(X_cat.shape[1])
|
|
571
|
+
device_type = self._device_type()
|
|
572
|
+
|
|
573
|
+
gen = torch.Generator()
|
|
574
|
+
gen.manual_seed(13 + int(getattr(self, "rank", 0)))
|
|
575
|
+
|
|
576
|
+
base_model = self.ft.module if hasattr(self.ft, "module") else self.ft
|
|
577
|
+
cardinals = getattr(base_model, "cat_cardinalities", None) or []
|
|
578
|
+
unknown_idx = torch.tensor(
|
|
579
|
+
[int(c) - 1 for c in cardinals], dtype=torch.long).view(1, -1)
|
|
580
|
+
|
|
581
|
+
means = None
|
|
582
|
+
if num_dim > 0:
|
|
583
|
+
# Keep masked fill values on the same scale as model inputs (may be normalized in _tensorize_split).
|
|
584
|
+
means = X_num.to(dtype=torch.float32).mean(dim=0, keepdim=True)
|
|
585
|
+
|
|
586
|
+
def _mask_inputs(X_num_in: torch.Tensor,
|
|
587
|
+
X_cat_in: torch.Tensor,
|
|
588
|
+
generator: torch.Generator):
|
|
589
|
+
n_rows = int(X_num_in.shape[0])
|
|
590
|
+
num_mask_local = None
|
|
591
|
+
cat_mask_local = None
|
|
592
|
+
X_num_masked_local = X_num_in
|
|
593
|
+
X_cat_masked_local = X_cat_in
|
|
594
|
+
if num_dim > 0:
|
|
595
|
+
num_mask_local = (torch.rand(
|
|
596
|
+
(n_rows, num_dim), generator=generator) < float(mask_prob_num))
|
|
597
|
+
X_num_masked_local = X_num_in.clone()
|
|
598
|
+
if num_mask_local.any():
|
|
599
|
+
X_num_masked_local[num_mask_local] = means.expand_as(
|
|
600
|
+
X_num_masked_local)[num_mask_local]
|
|
601
|
+
if cat_dim > 0:
|
|
602
|
+
cat_mask_local = (torch.rand(
|
|
603
|
+
(n_rows, cat_dim), generator=generator) < float(mask_prob_cat))
|
|
604
|
+
X_cat_masked_local = X_cat_in.clone()
|
|
605
|
+
if cat_mask_local.any():
|
|
606
|
+
X_cat_masked_local[cat_mask_local] = unknown_idx.expand_as(
|
|
607
|
+
X_cat_masked_local)[cat_mask_local]
|
|
608
|
+
return X_num_masked_local, X_cat_masked_local, num_mask_local, cat_mask_local
|
|
609
|
+
|
|
610
|
+
X_num_true = X_num if num_dim > 0 else None
|
|
611
|
+
X_cat_true = X_cat if cat_dim > 0 else None
|
|
612
|
+
X_num_masked, X_cat_masked, num_mask, cat_mask = _mask_inputs(
|
|
613
|
+
X_num, X_cat, gen)
|
|
614
|
+
|
|
615
|
+
dataset = MaskedTabularDataset(
|
|
616
|
+
X_num_masked, X_cat_masked, X_geo,
|
|
617
|
+
X_num_true, num_mask,
|
|
618
|
+
X_cat_true, cat_mask
|
|
619
|
+
)
|
|
620
|
+
dataloader, accum_steps = self._build_dataloader(
|
|
621
|
+
dataset,
|
|
622
|
+
N=N,
|
|
623
|
+
base_bs_gpu=(2048, 1024, 512),
|
|
624
|
+
base_bs_cpu=(256, 128),
|
|
625
|
+
min_bs=64,
|
|
626
|
+
target_effective_cuda=2048,
|
|
627
|
+
target_effective_cpu=1024
|
|
628
|
+
)
|
|
629
|
+
if self.is_ddp_enabled and hasattr(dataloader.sampler, 'set_epoch'):
|
|
630
|
+
self.dataloader_sampler = dataloader.sampler
|
|
631
|
+
else:
|
|
632
|
+
self.dataloader_sampler = None
|
|
633
|
+
|
|
634
|
+
optimizer = torch.optim.Adam(
|
|
635
|
+
self.ft.parameters(),
|
|
636
|
+
lr=self.learning_rate,
|
|
637
|
+
weight_decay=float(getattr(self, "weight_decay", 0.0)),
|
|
638
|
+
)
|
|
639
|
+
scaler = GradScaler(enabled=(device_type == 'cuda'))
|
|
640
|
+
|
|
641
|
+
train_history: List[float] = []
|
|
642
|
+
val_history: List[float] = []
|
|
643
|
+
best_loss = float("inf")
|
|
644
|
+
best_state = None
|
|
645
|
+
patience_counter = 0
|
|
646
|
+
is_ddp_model = isinstance(self.ft, DDP)
|
|
647
|
+
use_collectives = dist.is_initialized() and is_ddp_model
|
|
648
|
+
|
|
649
|
+
clip_fn = None
|
|
650
|
+
if self.device.type == 'cuda':
|
|
651
|
+
def clip_fn(): return (scaler.unscale_(optimizer),
|
|
652
|
+
clip_grad_norm_(self.ft.parameters(), max_norm=1.0))
|
|
653
|
+
|
|
654
|
+
for epoch in range(1, int(self.epochs) + 1):
|
|
655
|
+
if self.dataloader_sampler is not None:
|
|
656
|
+
self.dataloader_sampler.set_epoch(epoch)
|
|
657
|
+
|
|
658
|
+
self.ft.train()
|
|
659
|
+
optimizer.zero_grad()
|
|
660
|
+
epoch_loss_sum = 0.0
|
|
661
|
+
epoch_count = 0.0
|
|
662
|
+
|
|
663
|
+
for step, batch in enumerate(dataloader):
|
|
664
|
+
is_update_step = ((step + 1) % accum_steps == 0) or \
|
|
665
|
+
((step + 1) == len(dataloader))
|
|
666
|
+
sync_cm = self.ft.no_sync if (
|
|
667
|
+
is_ddp_model and not is_update_step) else nullcontext
|
|
668
|
+
with sync_cm():
|
|
669
|
+
with autocast(enabled=(device_type == 'cuda')):
|
|
670
|
+
X_num_b, X_cat_b, X_geo_b, num_true_b, num_mask_b, cat_true_b, cat_mask_b = batch
|
|
671
|
+
X_num_b = X_num_b.to(self.device, non_blocking=True)
|
|
672
|
+
X_cat_b = X_cat_b.to(self.device, non_blocking=True)
|
|
673
|
+
X_geo_b = X_geo_b.to(self.device, non_blocking=True)
|
|
674
|
+
num_true_b = None if num_true_b is None else num_true_b.to(
|
|
675
|
+
self.device, non_blocking=True)
|
|
676
|
+
num_mask_b = None if num_mask_b is None else num_mask_b.to(
|
|
677
|
+
self.device, non_blocking=True)
|
|
678
|
+
cat_true_b = None if cat_true_b is None else cat_true_b.to(
|
|
679
|
+
self.device, non_blocking=True)
|
|
680
|
+
cat_mask_b = None if cat_mask_b is None else cat_mask_b.to(
|
|
681
|
+
self.device, non_blocking=True)
|
|
682
|
+
|
|
683
|
+
num_pred, cat_logits = self.ft(
|
|
684
|
+
X_num_b, X_cat_b, X_geo_b, return_reconstruction=True)
|
|
685
|
+
batch_loss = _compute_reconstruction_loss(
|
|
686
|
+
num_pred, cat_logits, num_true_b, num_mask_b,
|
|
687
|
+
cat_true_b, cat_mask_b, num_loss_weight, cat_loss_weight,
|
|
688
|
+
device=X_num_b.device)
|
|
689
|
+
local_bad = 0 if bool(torch.isfinite(batch_loss)) else 1
|
|
690
|
+
global_bad = local_bad
|
|
691
|
+
if use_collectives:
|
|
692
|
+
bad = torch.tensor(
|
|
693
|
+
[local_bad],
|
|
694
|
+
device=batch_loss.device,
|
|
695
|
+
dtype=torch.int32,
|
|
696
|
+
)
|
|
697
|
+
dist.all_reduce(bad, op=dist.ReduceOp.MAX)
|
|
698
|
+
global_bad = int(bad.item())
|
|
699
|
+
|
|
700
|
+
if global_bad:
|
|
701
|
+
msg = (
|
|
702
|
+
f"[FTTransformerSklearn.fit_unsupervised] non-finite loss "
|
|
703
|
+
f"(epoch={epoch}, step={step}, loss={batch_loss.detach().item()})"
|
|
704
|
+
)
|
|
705
|
+
should_log = (not dist.is_initialized()
|
|
706
|
+
or DistributedUtils.is_main_process())
|
|
707
|
+
if should_log:
|
|
708
|
+
_log(msg, flush=True)
|
|
709
|
+
_log(
|
|
710
|
+
f" X_num: finite={bool(torch.isfinite(X_num_b).all())} "
|
|
711
|
+
f"min={float(X_num_b.min().detach().cpu()) if X_num_b.numel() else 0.0:.3g} "
|
|
712
|
+
f"max={float(X_num_b.max().detach().cpu()) if X_num_b.numel() else 0.0:.3g}",
|
|
713
|
+
flush=True,
|
|
714
|
+
)
|
|
715
|
+
if X_geo_b is not None:
|
|
716
|
+
_log(
|
|
717
|
+
f" X_geo: finite={bool(torch.isfinite(X_geo_b).all())} "
|
|
718
|
+
f"min={float(X_geo_b.min().detach().cpu()) if X_geo_b.numel() else 0.0:.3g} "
|
|
719
|
+
f"max={float(X_geo_b.max().detach().cpu()) if X_geo_b.numel() else 0.0:.3g}",
|
|
720
|
+
flush=True,
|
|
721
|
+
)
|
|
722
|
+
if trial is not None:
|
|
723
|
+
raise optuna.TrialPruned(msg)
|
|
724
|
+
raise RuntimeError(msg)
|
|
725
|
+
loss_for_backward = batch_loss / float(accum_steps)
|
|
726
|
+
scaler.scale(loss_for_backward).backward()
|
|
727
|
+
|
|
728
|
+
if is_update_step:
|
|
729
|
+
if clip_fn is not None:
|
|
730
|
+
clip_fn()
|
|
731
|
+
scaler.step(optimizer)
|
|
732
|
+
scaler.update()
|
|
733
|
+
optimizer.zero_grad()
|
|
734
|
+
|
|
735
|
+
epoch_loss_sum += float(batch_loss.detach().item()) * \
|
|
736
|
+
float(X_num_b.shape[0])
|
|
737
|
+
epoch_count += float(X_num_b.shape[0])
|
|
738
|
+
|
|
739
|
+
train_history.append(epoch_loss_sum / max(epoch_count, 1.0))
|
|
740
|
+
|
|
741
|
+
if has_val and X_num_val is not None and X_cat_val is not None and X_geo_val is not None:
|
|
742
|
+
should_compute_val = (not dist.is_initialized()
|
|
743
|
+
or DistributedUtils.is_main_process())
|
|
744
|
+
loss_tensor_device = self.device if device_type == 'cuda' else torch.device(
|
|
745
|
+
"cpu")
|
|
746
|
+
val_loss_tensor = torch.zeros(1, device=loss_tensor_device)
|
|
747
|
+
|
|
748
|
+
if should_compute_val:
|
|
749
|
+
self.ft.eval()
|
|
750
|
+
with torch.no_grad(), autocast(enabled=(device_type == 'cuda')):
|
|
751
|
+
val_bs = min(
|
|
752
|
+
int(dataloader.batch_size * max(1, accum_steps)), int(X_num_val.shape[0]))
|
|
753
|
+
total_val = 0.0
|
|
754
|
+
total_n = 0.0
|
|
755
|
+
for start in range(0, int(X_num_val.shape[0]), max(1, val_bs)):
|
|
756
|
+
end = min(
|
|
757
|
+
int(X_num_val.shape[0]), start + max(1, val_bs))
|
|
758
|
+
X_num_v_true_cpu = X_num_val[start:end]
|
|
759
|
+
X_cat_v_true_cpu = X_cat_val[start:end]
|
|
760
|
+
X_geo_v = X_geo_val[start:end].to(
|
|
761
|
+
self.device, non_blocking=True)
|
|
762
|
+
gen_val = torch.Generator()
|
|
763
|
+
gen_val.manual_seed(10_000 + epoch + start)
|
|
764
|
+
X_num_v_cpu, X_cat_v_cpu, val_num_mask, val_cat_mask = _mask_inputs(
|
|
765
|
+
X_num_v_true_cpu, X_cat_v_true_cpu, gen_val)
|
|
766
|
+
X_num_v_true = X_num_v_true_cpu.to(
|
|
767
|
+
self.device, non_blocking=True)
|
|
768
|
+
X_cat_v_true = X_cat_v_true_cpu.to(
|
|
769
|
+
self.device, non_blocking=True)
|
|
770
|
+
X_num_v = X_num_v_cpu.to(
|
|
771
|
+
self.device, non_blocking=True)
|
|
772
|
+
X_cat_v = X_cat_v_cpu.to(
|
|
773
|
+
self.device, non_blocking=True)
|
|
774
|
+
val_num_mask = None if val_num_mask is None else val_num_mask.to(
|
|
775
|
+
self.device, non_blocking=True)
|
|
776
|
+
val_cat_mask = None if val_cat_mask is None else val_cat_mask.to(
|
|
777
|
+
self.device, non_blocking=True)
|
|
778
|
+
num_pred_v, cat_logits_v = self.ft(
|
|
779
|
+
X_num_v, X_cat_v, X_geo_v, return_reconstruction=True)
|
|
780
|
+
loss_v = _compute_reconstruction_loss(
|
|
781
|
+
num_pred_v, cat_logits_v,
|
|
782
|
+
X_num_v_true if X_num_v_true.numel() else None, val_num_mask,
|
|
783
|
+
X_cat_v_true if X_cat_v_true.numel() else None, val_cat_mask,
|
|
784
|
+
num_loss_weight, cat_loss_weight,
|
|
785
|
+
device=X_num_v.device
|
|
786
|
+
)
|
|
787
|
+
if not torch.isfinite(loss_v):
|
|
788
|
+
total_val = float("inf")
|
|
789
|
+
total_n = 1.0
|
|
790
|
+
break
|
|
791
|
+
total_val += float(loss_v.detach().item()
|
|
792
|
+
) * float(end - start)
|
|
793
|
+
total_n += float(end - start)
|
|
794
|
+
val_loss_tensor[0] = total_val / max(total_n, 1.0)
|
|
795
|
+
|
|
796
|
+
if use_collectives:
|
|
797
|
+
dist.broadcast(val_loss_tensor, src=0)
|
|
798
|
+
val_loss_value = float(val_loss_tensor.item())
|
|
799
|
+
prune_now = False
|
|
800
|
+
prune_msg = None
|
|
801
|
+
if not np.isfinite(val_loss_value):
|
|
802
|
+
prune_now = True
|
|
803
|
+
prune_msg = (
|
|
804
|
+
f"[FTTransformerSklearn.fit_unsupervised] non-finite val loss "
|
|
805
|
+
f"(epoch={epoch}, val_loss={val_loss_value})"
|
|
806
|
+
)
|
|
807
|
+
val_history.append(val_loss_value)
|
|
808
|
+
|
|
809
|
+
if val_loss_value < best_loss:
|
|
810
|
+
best_loss = val_loss_value
|
|
811
|
+
# Efficiently clone state_dict - only clone tensor data, not DDP metadata
|
|
812
|
+
base_module = self.ft.module if hasattr(self.ft, "module") else self.ft
|
|
813
|
+
best_state = {
|
|
814
|
+
k: v.detach().clone().cpu() if isinstance(v, torch.Tensor) else copy.deepcopy(v)
|
|
815
|
+
for k, v in base_module.state_dict().items()
|
|
816
|
+
}
|
|
817
|
+
patience_counter = 0
|
|
818
|
+
else:
|
|
819
|
+
patience_counter += 1
|
|
820
|
+
if best_state is not None and patience_counter >= int(self.patience):
|
|
821
|
+
break
|
|
822
|
+
|
|
823
|
+
if trial is not None and (not dist.is_initialized() or DistributedUtils.is_main_process()):
|
|
824
|
+
trial.report(val_loss_value, epoch)
|
|
825
|
+
if trial.should_prune():
|
|
826
|
+
prune_now = True
|
|
827
|
+
|
|
828
|
+
if use_collectives:
|
|
829
|
+
flag = torch.tensor(
|
|
830
|
+
[1 if prune_now else 0],
|
|
831
|
+
device=loss_tensor_device,
|
|
832
|
+
dtype=torch.int32,
|
|
833
|
+
)
|
|
834
|
+
dist.broadcast(flag, src=0)
|
|
835
|
+
prune_now = bool(flag.item())
|
|
836
|
+
|
|
837
|
+
if prune_now:
|
|
838
|
+
if prune_msg:
|
|
839
|
+
raise optuna.TrialPruned(prune_msg)
|
|
840
|
+
raise optuna.TrialPruned()
|
|
841
|
+
|
|
842
|
+
self.training_history = {"train": train_history, "val": val_history}
|
|
843
|
+
self._plot_loss_curve(self.training_history, getattr(
|
|
844
|
+
self, "loss_curve_path", None))
|
|
845
|
+
if has_val and best_state is not None:
|
|
846
|
+
# Load state into unwrapped module to match how it was saved
|
|
847
|
+
base_module = self.ft.module if hasattr(self.ft, "module") else self.ft
|
|
848
|
+
base_module.load_state_dict(best_state)
|
|
849
|
+
return float(best_loss if has_val else (train_history[-1] if train_history else 0.0))
|
|
850
|
+
|
|
851
|
+
def predict(self, X_test, geo_tokens=None, batch_size: Optional[int] = None, return_embedding: bool = False):
|
|
852
|
+
# X_test must include all numeric/categorical columns; geo_tokens is optional.
|
|
853
|
+
|
|
854
|
+
self.ft.eval()
|
|
855
|
+
X_num, X_cat, X_geo, _, _, _ = self._tensorize_split(
|
|
856
|
+
X_test, None, None, geo_tokens=geo_tokens, allow_none=True)
|
|
857
|
+
|
|
858
|
+
num_rows = X_num.shape[0]
|
|
859
|
+
if num_rows == 0:
|
|
860
|
+
return np.empty(0, dtype=np.float32)
|
|
861
|
+
|
|
862
|
+
device = self.device if isinstance(
|
|
863
|
+
self.device, torch.device) else torch.device(self.device)
|
|
864
|
+
|
|
865
|
+
def resolve_batch_size(n_rows: int) -> int:
|
|
866
|
+
if batch_size is not None:
|
|
867
|
+
return max(1, min(int(batch_size), n_rows))
|
|
868
|
+
# Estimate a safe batch size based on model size to avoid attention OOM.
|
|
869
|
+
token_cnt = self.num_numeric_tokens + len(self.cat_cols)
|
|
870
|
+
if self.num_geo > 0:
|
|
871
|
+
token_cnt += 1
|
|
872
|
+
approx_units = max(1, token_cnt * max(1, self.d_model))
|
|
873
|
+
if device.type == 'cuda':
|
|
874
|
+
if approx_units >= 8192:
|
|
875
|
+
base = 512
|
|
876
|
+
elif approx_units >= 4096:
|
|
877
|
+
base = 1024
|
|
878
|
+
else:
|
|
879
|
+
base = 2048
|
|
880
|
+
else:
|
|
881
|
+
base = 512
|
|
882
|
+
return max(1, min(base, n_rows))
|
|
883
|
+
|
|
884
|
+
eff_batch = resolve_batch_size(num_rows)
|
|
885
|
+
preds: List[torch.Tensor] = []
|
|
886
|
+
|
|
887
|
+
inference_cm = getattr(torch, "inference_mode", torch.no_grad)
|
|
888
|
+
with inference_cm():
|
|
889
|
+
for start in range(0, num_rows, eff_batch):
|
|
890
|
+
end = min(num_rows, start + eff_batch)
|
|
891
|
+
X_num_b = X_num[start:end].to(device, non_blocking=True)
|
|
892
|
+
X_cat_b = X_cat[start:end].to(device, non_blocking=True)
|
|
893
|
+
X_geo_b = X_geo[start:end].to(device, non_blocking=True)
|
|
894
|
+
pred_chunk = self.ft(
|
|
895
|
+
X_num_b, X_cat_b, X_geo_b, return_embedding=return_embedding)
|
|
896
|
+
preds.append(pred_chunk.cpu())
|
|
897
|
+
|
|
898
|
+
y_pred = torch.cat(preds, dim=0).numpy()
|
|
899
|
+
|
|
900
|
+
if return_embedding:
|
|
901
|
+
return y_pred
|
|
902
|
+
|
|
903
|
+
if self.task_type == 'classification':
|
|
904
|
+
# Convert logits to probabilities.
|
|
905
|
+
y_pred = 1 / (1 + np.exp(-y_pred))
|
|
906
|
+
else:
|
|
907
|
+
# Model already has softplus; optionally apply log-exp smoothing: y_pred = log(1 + exp(y_pred)).
|
|
908
|
+
y_pred = np.clip(y_pred, 1e-6, None)
|
|
909
|
+
return y_pred.ravel()
|
|
910
|
+
|
|
911
|
+
def set_params(self, params: dict):
|
|
912
|
+
|
|
913
|
+
# Keep sklearn-style behavior.
|
|
914
|
+
# Note: changing structural params (e.g., d_model/n_heads) requires refit to take effect.
|
|
915
|
+
|
|
916
|
+
for key, value in params.items():
|
|
917
|
+
if hasattr(self, key):
|
|
918
|
+
setattr(self, key, value)
|
|
919
|
+
else:
|
|
920
|
+
raise ValueError(f"Parameter {key} not found in model.")
|
|
921
|
+
return self
|