ins-pricing 0.4.4__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ins_pricing/README.md +74 -56
- ins_pricing/__init__.py +142 -90
- ins_pricing/cli/BayesOpt_entry.py +52 -50
- ins_pricing/cli/BayesOpt_incremental.py +832 -898
- ins_pricing/cli/Explain_Run.py +31 -23
- ins_pricing/cli/Explain_entry.py +532 -579
- ins_pricing/cli/Pricing_Run.py +31 -23
- ins_pricing/cli/bayesopt_entry_runner.py +1440 -1438
- ins_pricing/cli/utils/cli_common.py +256 -256
- ins_pricing/cli/utils/cli_config.py +375 -375
- ins_pricing/cli/utils/import_resolver.py +382 -365
- ins_pricing/cli/utils/notebook_utils.py +340 -340
- ins_pricing/cli/watchdog_run.py +209 -201
- ins_pricing/frontend/README.md +573 -419
- ins_pricing/frontend/__init__.py +10 -10
- ins_pricing/frontend/config_builder.py +1 -0
- ins_pricing/frontend/example_workflows.py +1 -1
- ins_pricing/governance/__init__.py +20 -20
- ins_pricing/governance/release.py +159 -159
- ins_pricing/modelling/README.md +67 -0
- ins_pricing/modelling/__init__.py +147 -92
- ins_pricing/modelling/bayesopt/README.md +59 -0
- ins_pricing/modelling/{core/bayesopt → bayesopt}/__init__.py +64 -102
- ins_pricing/modelling/{core/bayesopt → bayesopt}/config_preprocess.py +562 -550
- ins_pricing/modelling/{core/bayesopt → bayesopt}/core.py +965 -962
- ins_pricing/modelling/{core/bayesopt → bayesopt}/model_explain_mixin.py +296 -296
- ins_pricing/modelling/{core/bayesopt → bayesopt}/model_plotting_mixin.py +482 -548
- ins_pricing/modelling/{core/bayesopt → bayesopt}/models/__init__.py +27 -27
- ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_trainer.py +915 -913
- ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_gnn.py +788 -785
- ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_resn.py +448 -446
- ins_pricing/modelling/bayesopt/trainers/__init__.py +19 -0
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_base.py +1308 -1308
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_ft.py +3 -3
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_glm.py +197 -198
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_gnn.py +344 -344
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_resn.py +283 -283
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_xgb.py +346 -347
- ins_pricing/modelling/bayesopt/utils/__init__.py +67 -0
- ins_pricing/modelling/bayesopt/utils/constants.py +21 -0
- ins_pricing/modelling/bayesopt/utils/io_utils.py +7 -0
- ins_pricing/modelling/bayesopt/utils/losses.py +27 -0
- ins_pricing/modelling/bayesopt/utils/metrics_and_devices.py +17 -0
- ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/torch_trainer_mixin.py +623 -623
- ins_pricing/modelling/{core/evaluation.py → evaluation.py} +113 -104
- ins_pricing/modelling/explain/__init__.py +55 -55
- ins_pricing/modelling/explain/metrics.py +27 -174
- ins_pricing/modelling/explain/permutation.py +237 -237
- ins_pricing/modelling/plotting/__init__.py +40 -36
- ins_pricing/modelling/plotting/compat.py +228 -0
- ins_pricing/modelling/plotting/curves.py +572 -572
- ins_pricing/modelling/plotting/diagnostics.py +163 -163
- ins_pricing/modelling/plotting/geo.py +362 -362
- ins_pricing/modelling/plotting/importance.py +121 -121
- ins_pricing/pricing/__init__.py +27 -27
- ins_pricing/production/__init__.py +35 -25
- ins_pricing/production/{predict.py → inference.py} +140 -57
- ins_pricing/production/monitoring.py +8 -21
- ins_pricing/reporting/__init__.py +11 -11
- ins_pricing/setup.py +1 -1
- ins_pricing/tests/production/test_inference.py +90 -0
- ins_pricing/utils/__init__.py +116 -83
- ins_pricing/utils/device.py +255 -255
- ins_pricing/utils/features.py +53 -0
- ins_pricing/utils/io.py +72 -0
- ins_pricing/{modelling/core/bayesopt/utils → utils}/losses.py +125 -129
- ins_pricing/utils/metrics.py +158 -24
- ins_pricing/utils/numerics.py +76 -0
- ins_pricing/utils/paths.py +9 -1
- {ins_pricing-0.4.4.dist-info → ins_pricing-0.5.0.dist-info}/METADATA +55 -35
- ins_pricing-0.5.0.dist-info/RECORD +131 -0
- ins_pricing/CHANGELOG.md +0 -272
- ins_pricing/RELEASE_NOTES_0.2.8.md +0 -344
- ins_pricing/docs/LOSS_FUNCTIONS.md +0 -78
- ins_pricing/docs/modelling/BayesOpt_USAGE.md +0 -945
- ins_pricing/docs/modelling/README.md +0 -34
- ins_pricing/frontend/QUICKSTART.md +0 -152
- ins_pricing/modelling/core/BayesOpt.py +0 -146
- ins_pricing/modelling/core/__init__.py +0 -1
- ins_pricing/modelling/core/bayesopt/PHASE2_REFACTORING_SUMMARY.md +0 -449
- ins_pricing/modelling/core/bayesopt/PHASE3_REFACTORING_SUMMARY.md +0 -406
- ins_pricing/modelling/core/bayesopt/REFACTORING_SUMMARY.md +0 -247
- ins_pricing/modelling/core/bayesopt/trainers/__init__.py +0 -19
- ins_pricing/modelling/core/bayesopt/utils/__init__.py +0 -86
- ins_pricing/modelling/core/bayesopt/utils/constants.py +0 -183
- ins_pricing/modelling/core/bayesopt/utils/io_utils.py +0 -126
- ins_pricing/modelling/core/bayesopt/utils/metrics_and_devices.py +0 -555
- ins_pricing/modelling/core/bayesopt/utils.py +0 -105
- ins_pricing/modelling/core/bayesopt/utils_backup.py +0 -1503
- ins_pricing/tests/production/test_predict.py +0 -233
- ins_pricing-0.4.4.dist-info/RECORD +0 -137
- /ins_pricing/modelling/{core/bayesopt → bayesopt}/config_components.py +0 -0
- /ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_components.py +0 -0
- /ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/distributed_utils.py +0 -0
- {ins_pricing-0.4.4.dist-info → ins_pricing-0.5.0.dist-info}/WHEEL +0 -0
- {ins_pricing-0.4.4.dist-info → ins_pricing-0.5.0.dist-info}/top_level.txt +0 -0
|
@@ -1,962 +1,965 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
from dataclasses import asdict
|
|
4
|
-
from datetime import datetime
|
|
5
|
-
import os
|
|
6
|
-
from pathlib import Path
|
|
7
|
-
from typing import Any, Dict, List, Optional
|
|
8
|
-
import numpy as np
|
|
9
|
-
import pandas as pd
|
|
10
|
-
import torch
|
|
11
|
-
from sklearn.model_selection import GroupKFold, ShuffleSplit, TimeSeriesSplit
|
|
12
|
-
from sklearn.preprocessing import StandardScaler
|
|
13
|
-
|
|
14
|
-
from .config_preprocess import BayesOptConfig, DatasetPreprocessor, OutputManager, VersionManager
|
|
15
|
-
from .model_explain_mixin import BayesOptExplainMixin
|
|
16
|
-
from .model_plotting_mixin import BayesOptPlottingMixin
|
|
17
|
-
from .models import GraphNeuralNetSklearn
|
|
18
|
-
from .trainers import FTTrainer, GLMTrainer, GNNTrainer, ResNetTrainer, XGBTrainer
|
|
19
|
-
from .utils import EPS, infer_factor_and_cate_list, set_global_seed
|
|
20
|
-
from .utils.
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
self.
|
|
40
|
-
self.
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
#
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
#
|
|
184
|
-
|
|
185
|
-
warnings
|
|
186
|
-
|
|
187
|
-
"
|
|
188
|
-
"
|
|
189
|
-
"
|
|
190
|
-
"
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
if
|
|
279
|
-
else
|
|
280
|
-
|
|
281
|
-
if
|
|
282
|
-
else
|
|
283
|
-
|
|
284
|
-
if
|
|
285
|
-
else
|
|
286
|
-
|
|
287
|
-
if
|
|
288
|
-
else
|
|
289
|
-
|
|
290
|
-
if
|
|
291
|
-
else
|
|
292
|
-
|
|
293
|
-
if
|
|
294
|
-
else
|
|
295
|
-
|
|
296
|
-
if
|
|
297
|
-
else
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
self.
|
|
314
|
-
self.
|
|
315
|
-
self.
|
|
316
|
-
self.
|
|
317
|
-
self.
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
self.
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
self.
|
|
329
|
-
self.
|
|
330
|
-
self.
|
|
331
|
-
self.
|
|
332
|
-
self.
|
|
333
|
-
self.
|
|
334
|
-
self.
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
self.
|
|
350
|
-
self.
|
|
351
|
-
|
|
352
|
-
self.
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
self.
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
self.
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
'
|
|
370
|
-
'
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
self.
|
|
376
|
-
self.
|
|
377
|
-
self.
|
|
378
|
-
self.
|
|
379
|
-
self.
|
|
380
|
-
self.
|
|
381
|
-
self.
|
|
382
|
-
self.
|
|
383
|
-
self.
|
|
384
|
-
self.
|
|
385
|
-
self.
|
|
386
|
-
self.
|
|
387
|
-
self.
|
|
388
|
-
self.
|
|
389
|
-
self.
|
|
390
|
-
self.
|
|
391
|
-
self.
|
|
392
|
-
self.
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
if
|
|
401
|
-
|
|
402
|
-
cv_splits =
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
if
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
return
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
self.
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
"
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
effects =
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
self.
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
trainer.
|
|
753
|
-
pred_prefix=prefix,
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
"
|
|
790
|
-
"
|
|
791
|
-
|
|
792
|
-
"
|
|
793
|
-
"
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
""
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
if col_name not in self.
|
|
809
|
-
self.
|
|
810
|
-
|
|
811
|
-
if
|
|
812
|
-
self.
|
|
813
|
-
|
|
814
|
-
if self.
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
self.
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
self.
|
|
860
|
-
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
"
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
|
|
897
|
-
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
|
|
902
|
-
|
|
903
|
-
ft_trainer
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
|
|
926
|
-
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
def
|
|
939
|
-
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
|
|
951
|
-
|
|
952
|
-
|
|
953
|
-
|
|
954
|
-
|
|
955
|
-
|
|
956
|
-
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
|
|
960
|
-
|
|
961
|
-
|
|
962
|
-
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import asdict
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
import os
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, Dict, List, Optional
|
|
8
|
+
import numpy as np
|
|
9
|
+
import pandas as pd
|
|
10
|
+
import torch
|
|
11
|
+
from sklearn.model_selection import GroupKFold, ShuffleSplit, TimeSeriesSplit
|
|
12
|
+
from sklearn.preprocessing import StandardScaler
|
|
13
|
+
|
|
14
|
+
from ins_pricing.modelling.bayesopt.config_preprocess import BayesOptConfig, DatasetPreprocessor, OutputManager, VersionManager
|
|
15
|
+
from ins_pricing.modelling.bayesopt.model_explain_mixin import BayesOptExplainMixin
|
|
16
|
+
from ins_pricing.modelling.bayesopt.model_plotting_mixin import BayesOptPlottingMixin
|
|
17
|
+
from ins_pricing.modelling.bayesopt.models import GraphNeuralNetSklearn
|
|
18
|
+
from ins_pricing.modelling.bayesopt.trainers import FTTrainer, GLMTrainer, GNNTrainer, ResNetTrainer, XGBTrainer
|
|
19
|
+
from ins_pricing.utils import EPS, infer_factor_and_cate_list, set_global_seed
|
|
20
|
+
from ins_pricing.utils.io import IOUtils
|
|
21
|
+
from ins_pricing.utils.losses import (
|
|
22
|
+
infer_loss_name_from_model_name,
|
|
23
|
+
normalize_loss_name,
|
|
24
|
+
resolve_tweedie_power,
|
|
25
|
+
resolve_xgb_objective,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class _CVSplitter:
|
|
30
|
+
"""Wrapper to carry optional groups or time order for CV splits."""
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
splitter,
|
|
35
|
+
*,
|
|
36
|
+
groups: Optional[pd.Series] = None,
|
|
37
|
+
order: Optional[np.ndarray] = None,
|
|
38
|
+
) -> None:
|
|
39
|
+
self._splitter = splitter
|
|
40
|
+
self._groups = groups
|
|
41
|
+
self._order = order
|
|
42
|
+
|
|
43
|
+
def split(self, X, y=None, groups=None):
|
|
44
|
+
if self._order is not None:
|
|
45
|
+
order = np.asarray(self._order)
|
|
46
|
+
X_ord = X.iloc[order] if hasattr(X, "iloc") else X[order]
|
|
47
|
+
for tr_idx, val_idx in self._splitter.split(X_ord, y=y):
|
|
48
|
+
yield order[tr_idx], order[val_idx]
|
|
49
|
+
return
|
|
50
|
+
use_groups = groups if groups is not None else self._groups
|
|
51
|
+
for tr_idx, val_idx in self._splitter.split(X, y=y, groups=use_groups):
|
|
52
|
+
yield tr_idx, val_idx
|
|
53
|
+
|
|
54
|
+
# BayesOpt orchestration and SHAP utilities
|
|
55
|
+
# =============================================================================
|
|
56
|
+
class BayesOptModel(BayesOptPlottingMixin, BayesOptExplainMixin):
|
|
57
|
+
def __init__(self, train_data, test_data,
|
|
58
|
+
config: Optional[BayesOptConfig] = None,
|
|
59
|
+
# Backward compatibility: individual parameters (DEPRECATED)
|
|
60
|
+
model_nme=None, resp_nme=None, weight_nme=None,
|
|
61
|
+
factor_nmes: Optional[List[str]] = None, task_type='regression',
|
|
62
|
+
binary_resp_nme=None,
|
|
63
|
+
cate_list=None, prop_test=0.25, rand_seed=None,
|
|
64
|
+
epochs=100, use_gpu=True,
|
|
65
|
+
use_resn_data_parallel: bool = False, use_ft_data_parallel: bool = False,
|
|
66
|
+
use_gnn_data_parallel: bool = False,
|
|
67
|
+
use_resn_ddp: bool = False, use_ft_ddp: bool = False,
|
|
68
|
+
use_gnn_ddp: bool = False,
|
|
69
|
+
output_dir: Optional[str] = None,
|
|
70
|
+
gnn_use_approx_knn: bool = True,
|
|
71
|
+
gnn_approx_knn_threshold: int = 50000,
|
|
72
|
+
gnn_graph_cache: Optional[str] = None,
|
|
73
|
+
gnn_max_gpu_knn_nodes: Optional[int] = 200000,
|
|
74
|
+
gnn_knn_gpu_mem_ratio: float = 0.9,
|
|
75
|
+
gnn_knn_gpu_mem_overhead: float = 2.0,
|
|
76
|
+
ft_role: str = "model",
|
|
77
|
+
ft_feature_prefix: str = "ft_emb",
|
|
78
|
+
ft_num_numeric_tokens: Optional[int] = None,
|
|
79
|
+
infer_categorical_max_unique: int = 50,
|
|
80
|
+
infer_categorical_max_ratio: float = 0.05,
|
|
81
|
+
reuse_best_params: bool = False,
|
|
82
|
+
xgb_max_depth_max: int = 25,
|
|
83
|
+
xgb_n_estimators_max: int = 500,
|
|
84
|
+
resn_weight_decay: Optional[float] = None,
|
|
85
|
+
final_ensemble: bool = False,
|
|
86
|
+
final_ensemble_k: int = 3,
|
|
87
|
+
final_refit: bool = True,
|
|
88
|
+
optuna_storage: Optional[str] = None,
|
|
89
|
+
optuna_study_prefix: Optional[str] = None,
|
|
90
|
+
best_params_files: Optional[Dict[str, str]] = None,
|
|
91
|
+
cv_strategy: Optional[str] = None,
|
|
92
|
+
cv_splits: Optional[int] = None,
|
|
93
|
+
cv_group_col: Optional[str] = None,
|
|
94
|
+
cv_time_col: Optional[str] = None,
|
|
95
|
+
cv_time_ascending: bool = True,
|
|
96
|
+
ft_oof_folds: Optional[int] = None,
|
|
97
|
+
ft_oof_strategy: Optional[str] = None,
|
|
98
|
+
ft_oof_shuffle: bool = True,
|
|
99
|
+
save_preprocess: bool = False,
|
|
100
|
+
preprocess_artifact_path: Optional[str] = None,
|
|
101
|
+
plot_path_style: Optional[str] = None,
|
|
102
|
+
bo_sample_limit: Optional[int] = None,
|
|
103
|
+
cache_predictions: bool = False,
|
|
104
|
+
prediction_cache_dir: Optional[str] = None,
|
|
105
|
+
prediction_cache_format: Optional[str] = None,
|
|
106
|
+
region_province_col: Optional[str] = None,
|
|
107
|
+
region_city_col: Optional[str] = None,
|
|
108
|
+
region_effect_alpha: Optional[float] = None,
|
|
109
|
+
geo_feature_nmes: Optional[List[str]] = None,
|
|
110
|
+
geo_token_hidden_dim: Optional[int] = None,
|
|
111
|
+
geo_token_layers: Optional[int] = None,
|
|
112
|
+
geo_token_dropout: Optional[float] = None,
|
|
113
|
+
geo_token_k_neighbors: Optional[int] = None,
|
|
114
|
+
geo_token_learning_rate: Optional[float] = None,
|
|
115
|
+
geo_token_epochs: Optional[int] = None):
|
|
116
|
+
"""Orchestrate BayesOpt training across multiple trainers.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
train_data: Training DataFrame.
|
|
120
|
+
test_data: Test DataFrame.
|
|
121
|
+
config: BayesOptConfig instance with all configuration (RECOMMENDED).
|
|
122
|
+
If provided, all other parameters are ignored.
|
|
123
|
+
|
|
124
|
+
# DEPRECATED: Individual parameters (use config instead)
|
|
125
|
+
model_nme: Model name prefix used in outputs.
|
|
126
|
+
resp_nme: Target column name.
|
|
127
|
+
weight_nme: Sample weight column name.
|
|
128
|
+
factor_nmes: Feature column list.
|
|
129
|
+
task_type: "regression" or "classification".
|
|
130
|
+
binary_resp_nme: Optional binary target for lift curves.
|
|
131
|
+
cate_list: Categorical feature list.
|
|
132
|
+
prop_test: Validation split ratio in CV.
|
|
133
|
+
rand_seed: Random seed.
|
|
134
|
+
epochs: NN training epochs.
|
|
135
|
+
use_gpu: Prefer GPU when available.
|
|
136
|
+
use_resn_data_parallel: Enable DataParallel for ResNet.
|
|
137
|
+
use_ft_data_parallel: Enable DataParallel for FTTransformer.
|
|
138
|
+
use_gnn_data_parallel: Enable DataParallel for GNN.
|
|
139
|
+
use_resn_ddp: Enable DDP for ResNet.
|
|
140
|
+
use_ft_ddp: Enable DDP for FTTransformer.
|
|
141
|
+
use_gnn_ddp: Enable DDP for GNN.
|
|
142
|
+
output_dir: Output root for models/results/plots.
|
|
143
|
+
gnn_use_approx_knn: Use approximate kNN when available.
|
|
144
|
+
gnn_approx_knn_threshold: Row threshold to switch to approximate kNN.
|
|
145
|
+
gnn_graph_cache: Optional adjacency cache path.
|
|
146
|
+
gnn_max_gpu_knn_nodes: Force CPU kNN above this node count to avoid OOM.
|
|
147
|
+
gnn_knn_gpu_mem_ratio: Fraction of free GPU memory for kNN.
|
|
148
|
+
gnn_knn_gpu_mem_overhead: Temporary memory multiplier for GPU kNN.
|
|
149
|
+
ft_num_numeric_tokens: Number of numeric tokens for FT (None = auto).
|
|
150
|
+
final_ensemble: Enable k-fold model averaging at the final stage.
|
|
151
|
+
final_ensemble_k: Number of folds for averaging.
|
|
152
|
+
final_refit: Refit on full data using best stopping point.
|
|
153
|
+
|
|
154
|
+
Examples:
|
|
155
|
+
# New style (recommended):
|
|
156
|
+
config = BayesOptConfig(
|
|
157
|
+
model_nme="my_model",
|
|
158
|
+
resp_nme="target",
|
|
159
|
+
weight_nme="weight",
|
|
160
|
+
factor_nmes=["feat1", "feat2"]
|
|
161
|
+
)
|
|
162
|
+
model = BayesOptModel(train_df, test_df, config=config)
|
|
163
|
+
|
|
164
|
+
# Old style (deprecated, for backward compatibility):
|
|
165
|
+
model = BayesOptModel(
|
|
166
|
+
train_df, test_df,
|
|
167
|
+
model_nme="my_model",
|
|
168
|
+
resp_nme="target",
|
|
169
|
+
weight_nme="weight",
|
|
170
|
+
factor_nmes=["feat1", "feat2"]
|
|
171
|
+
)
|
|
172
|
+
"""
|
|
173
|
+
# Detect which API is being used
|
|
174
|
+
if config is not None:
|
|
175
|
+
# New API: config object provided
|
|
176
|
+
if isinstance(config, BayesOptConfig):
|
|
177
|
+
cfg = config
|
|
178
|
+
else:
|
|
179
|
+
raise TypeError(
|
|
180
|
+
f"config must be a BayesOptConfig instance, got {type(config).__name__}"
|
|
181
|
+
)
|
|
182
|
+
else:
|
|
183
|
+
# Old API: individual parameters (backward compatibility)
|
|
184
|
+
# Show deprecation warning
|
|
185
|
+
import warnings
|
|
186
|
+
warnings.warn(
|
|
187
|
+
"Passing individual parameters to BayesOptModel.__init__ is deprecated. "
|
|
188
|
+
"Use the 'config' parameter with a BayesOptConfig instance instead:\n"
|
|
189
|
+
" config = BayesOptConfig(model_nme=..., resp_nme=..., ...)\n"
|
|
190
|
+
" model = BayesOptModel(train_data, test_data, config=config)\n"
|
|
191
|
+
"Individual parameters will be removed in v0.4.0.",
|
|
192
|
+
DeprecationWarning,
|
|
193
|
+
stacklevel=2
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
# Validate required parameters
|
|
197
|
+
if model_nme is None:
|
|
198
|
+
raise ValueError("model_nme is required when not using config parameter")
|
|
199
|
+
if resp_nme is None:
|
|
200
|
+
raise ValueError("resp_nme is required when not using config parameter")
|
|
201
|
+
if weight_nme is None:
|
|
202
|
+
raise ValueError("weight_nme is required when not using config parameter")
|
|
203
|
+
|
|
204
|
+
# Infer categorical features if needed
|
|
205
|
+
# Only use user-specified categorical list for one-hot; do not auto-infer.
|
|
206
|
+
user_cate_list = [] if cate_list is None else list(cate_list)
|
|
207
|
+
inferred_factors, inferred_cats = infer_factor_and_cate_list(
|
|
208
|
+
train_df=train_data,
|
|
209
|
+
test_df=test_data,
|
|
210
|
+
resp_nme=resp_nme,
|
|
211
|
+
weight_nme=weight_nme,
|
|
212
|
+
binary_resp_nme=binary_resp_nme,
|
|
213
|
+
factor_nmes=factor_nmes,
|
|
214
|
+
cate_list=user_cate_list,
|
|
215
|
+
infer_categorical_max_unique=int(infer_categorical_max_unique),
|
|
216
|
+
infer_categorical_max_ratio=float(infer_categorical_max_ratio),
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
# Construct config from individual parameters
|
|
220
|
+
cfg = BayesOptConfig(
|
|
221
|
+
model_nme=model_nme,
|
|
222
|
+
task_type=task_type,
|
|
223
|
+
resp_nme=resp_nme,
|
|
224
|
+
weight_nme=weight_nme,
|
|
225
|
+
factor_nmes=list(inferred_factors),
|
|
226
|
+
binary_resp_nme=binary_resp_nme,
|
|
227
|
+
cate_list=list(inferred_cats) if inferred_cats else None,
|
|
228
|
+
prop_test=prop_test,
|
|
229
|
+
rand_seed=rand_seed,
|
|
230
|
+
epochs=epochs,
|
|
231
|
+
use_gpu=use_gpu,
|
|
232
|
+
xgb_max_depth_max=int(xgb_max_depth_max),
|
|
233
|
+
xgb_n_estimators_max=int(xgb_n_estimators_max),
|
|
234
|
+
use_resn_data_parallel=use_resn_data_parallel,
|
|
235
|
+
use_ft_data_parallel=use_ft_data_parallel,
|
|
236
|
+
use_resn_ddp=use_resn_ddp,
|
|
237
|
+
use_gnn_data_parallel=use_gnn_data_parallel,
|
|
238
|
+
use_ft_ddp=use_ft_ddp,
|
|
239
|
+
use_gnn_ddp=use_gnn_ddp,
|
|
240
|
+
gnn_use_approx_knn=gnn_use_approx_knn,
|
|
241
|
+
gnn_approx_knn_threshold=gnn_approx_knn_threshold,
|
|
242
|
+
gnn_graph_cache=gnn_graph_cache,
|
|
243
|
+
gnn_max_gpu_knn_nodes=gnn_max_gpu_knn_nodes,
|
|
244
|
+
gnn_knn_gpu_mem_ratio=gnn_knn_gpu_mem_ratio,
|
|
245
|
+
gnn_knn_gpu_mem_overhead=gnn_knn_gpu_mem_overhead,
|
|
246
|
+
output_dir=output_dir,
|
|
247
|
+
optuna_storage=optuna_storage,
|
|
248
|
+
optuna_study_prefix=optuna_study_prefix,
|
|
249
|
+
best_params_files=best_params_files,
|
|
250
|
+
ft_role=str(ft_role or "model"),
|
|
251
|
+
ft_feature_prefix=str(ft_feature_prefix or "ft_emb"),
|
|
252
|
+
ft_num_numeric_tokens=ft_num_numeric_tokens,
|
|
253
|
+
reuse_best_params=bool(reuse_best_params),
|
|
254
|
+
resn_weight_decay=float(resn_weight_decay)
|
|
255
|
+
if resn_weight_decay is not None
|
|
256
|
+
else 1e-4,
|
|
257
|
+
final_ensemble=bool(final_ensemble),
|
|
258
|
+
final_ensemble_k=int(final_ensemble_k),
|
|
259
|
+
final_refit=bool(final_refit),
|
|
260
|
+
cv_strategy=str(cv_strategy or "random"),
|
|
261
|
+
cv_splits=cv_splits,
|
|
262
|
+
cv_group_col=cv_group_col,
|
|
263
|
+
cv_time_col=cv_time_col,
|
|
264
|
+
cv_time_ascending=bool(cv_time_ascending),
|
|
265
|
+
ft_oof_folds=ft_oof_folds,
|
|
266
|
+
ft_oof_strategy=ft_oof_strategy,
|
|
267
|
+
ft_oof_shuffle=bool(ft_oof_shuffle),
|
|
268
|
+
save_preprocess=bool(save_preprocess),
|
|
269
|
+
preprocess_artifact_path=preprocess_artifact_path,
|
|
270
|
+
plot_path_style=str(plot_path_style or "nested"),
|
|
271
|
+
bo_sample_limit=bo_sample_limit,
|
|
272
|
+
cache_predictions=bool(cache_predictions),
|
|
273
|
+
prediction_cache_dir=prediction_cache_dir,
|
|
274
|
+
prediction_cache_format=str(prediction_cache_format or "parquet"),
|
|
275
|
+
region_province_col=region_province_col,
|
|
276
|
+
region_city_col=region_city_col,
|
|
277
|
+
region_effect_alpha=float(region_effect_alpha)
|
|
278
|
+
if region_effect_alpha is not None
|
|
279
|
+
else 50.0,
|
|
280
|
+
geo_feature_nmes=list(geo_feature_nmes)
|
|
281
|
+
if geo_feature_nmes is not None
|
|
282
|
+
else None,
|
|
283
|
+
geo_token_hidden_dim=int(geo_token_hidden_dim)
|
|
284
|
+
if geo_token_hidden_dim is not None
|
|
285
|
+
else 32,
|
|
286
|
+
geo_token_layers=int(geo_token_layers)
|
|
287
|
+
if geo_token_layers is not None
|
|
288
|
+
else 2,
|
|
289
|
+
geo_token_dropout=float(geo_token_dropout)
|
|
290
|
+
if geo_token_dropout is not None
|
|
291
|
+
else 0.1,
|
|
292
|
+
geo_token_k_neighbors=int(geo_token_k_neighbors)
|
|
293
|
+
if geo_token_k_neighbors is not None
|
|
294
|
+
else 10,
|
|
295
|
+
geo_token_learning_rate=float(geo_token_learning_rate)
|
|
296
|
+
if geo_token_learning_rate is not None
|
|
297
|
+
else 1e-3,
|
|
298
|
+
geo_token_epochs=int(geo_token_epochs)
|
|
299
|
+
if geo_token_epochs is not None
|
|
300
|
+
else 50,
|
|
301
|
+
)
|
|
302
|
+
self.config = cfg
|
|
303
|
+
self.model_nme = cfg.model_nme
|
|
304
|
+
self.task_type = cfg.task_type
|
|
305
|
+
normalized_loss = normalize_loss_name(getattr(cfg, "loss_name", None), self.task_type)
|
|
306
|
+
if self.task_type == "classification":
|
|
307
|
+
self.loss_name = "logloss" if normalized_loss == "auto" else normalized_loss
|
|
308
|
+
else:
|
|
309
|
+
if normalized_loss == "auto":
|
|
310
|
+
self.loss_name = infer_loss_name_from_model_name(self.model_nme)
|
|
311
|
+
else:
|
|
312
|
+
self.loss_name = normalized_loss
|
|
313
|
+
self.resp_nme = cfg.resp_nme
|
|
314
|
+
self.weight_nme = cfg.weight_nme
|
|
315
|
+
self.factor_nmes = cfg.factor_nmes
|
|
316
|
+
self.binary_resp_nme = cfg.binary_resp_nme
|
|
317
|
+
self.cate_list = list(cfg.cate_list or [])
|
|
318
|
+
self.prop_test = cfg.prop_test
|
|
319
|
+
self.epochs = cfg.epochs
|
|
320
|
+
self.rand_seed = cfg.rand_seed if cfg.rand_seed is not None else np.random.randint(
|
|
321
|
+
1, 10000)
|
|
322
|
+
set_global_seed(int(self.rand_seed))
|
|
323
|
+
self.use_gpu = bool(cfg.use_gpu and torch.cuda.is_available())
|
|
324
|
+
self.output_manager = OutputManager(
|
|
325
|
+
cfg.output_dir or os.getcwd(), self.model_nme)
|
|
326
|
+
|
|
327
|
+
preprocessor = DatasetPreprocessor(train_data, test_data, cfg).run()
|
|
328
|
+
self.train_data = preprocessor.train_data
|
|
329
|
+
self.test_data = preprocessor.test_data
|
|
330
|
+
self.train_oht_data = preprocessor.train_oht_data
|
|
331
|
+
self.test_oht_data = preprocessor.test_oht_data
|
|
332
|
+
self.train_oht_scl_data = preprocessor.train_oht_scl_data
|
|
333
|
+
self.test_oht_scl_data = preprocessor.test_oht_scl_data
|
|
334
|
+
self.var_nmes = preprocessor.var_nmes
|
|
335
|
+
self.num_features = preprocessor.num_features
|
|
336
|
+
self.cat_categories_for_shap = preprocessor.cat_categories_for_shap
|
|
337
|
+
self.numeric_scalers = preprocessor.numeric_scalers
|
|
338
|
+
if getattr(self.config, "save_preprocess", False):
|
|
339
|
+
artifact_path = getattr(self.config, "preprocess_artifact_path", None)
|
|
340
|
+
if artifact_path:
|
|
341
|
+
target = Path(str(artifact_path))
|
|
342
|
+
if not target.is_absolute():
|
|
343
|
+
target = Path(self.output_manager.result_dir) / target
|
|
344
|
+
else:
|
|
345
|
+
target = Path(self.output_manager.result_path(
|
|
346
|
+
f"{self.model_nme}_preprocess.json"
|
|
347
|
+
))
|
|
348
|
+
preprocessor.save_artifacts(target)
|
|
349
|
+
self.geo_token_cols: List[str] = []
|
|
350
|
+
self.train_geo_tokens: Optional[pd.DataFrame] = None
|
|
351
|
+
self.test_geo_tokens: Optional[pd.DataFrame] = None
|
|
352
|
+
self.geo_gnn_model: Optional[GraphNeuralNetSklearn] = None
|
|
353
|
+
self._add_region_effect()
|
|
354
|
+
|
|
355
|
+
self.cv = self._build_cv_splitter()
|
|
356
|
+
if self.task_type == 'classification':
|
|
357
|
+
self.obj = 'binary:logistic'
|
|
358
|
+
else: # regression task
|
|
359
|
+
self.obj = resolve_xgb_objective(self.loss_name)
|
|
360
|
+
self.fit_params = {
|
|
361
|
+
'sample_weight': self.train_data[self.weight_nme].values
|
|
362
|
+
}
|
|
363
|
+
self.model_label: List[str] = []
|
|
364
|
+
self.optuna_storage = cfg.optuna_storage
|
|
365
|
+
self.optuna_study_prefix = cfg.optuna_study_prefix or "bayesopt"
|
|
366
|
+
|
|
367
|
+
# Keep trainers in a dict for unified access and easy extension.
|
|
368
|
+
self.trainers: Dict[str, TrainerBase] = {
|
|
369
|
+
'glm': GLMTrainer(self),
|
|
370
|
+
'xgb': XGBTrainer(self),
|
|
371
|
+
'resn': ResNetTrainer(self),
|
|
372
|
+
'ft': FTTrainer(self),
|
|
373
|
+
'gnn': GNNTrainer(self),
|
|
374
|
+
}
|
|
375
|
+
self._prepare_geo_tokens()
|
|
376
|
+
self.xgb_best = None
|
|
377
|
+
self.resn_best = None
|
|
378
|
+
self.gnn_best = None
|
|
379
|
+
self.glm_best = None
|
|
380
|
+
self.ft_best = None
|
|
381
|
+
self.best_xgb_params = None
|
|
382
|
+
self.best_resn_params = None
|
|
383
|
+
self.best_gnn_params = None
|
|
384
|
+
self.best_ft_params = None
|
|
385
|
+
self.best_xgb_trial = None
|
|
386
|
+
self.best_resn_trial = None
|
|
387
|
+
self.best_gnn_trial = None
|
|
388
|
+
self.best_ft_trial = None
|
|
389
|
+
self.best_glm_params = None
|
|
390
|
+
self.best_glm_trial = None
|
|
391
|
+
self.xgb_load = None
|
|
392
|
+
self.resn_load = None
|
|
393
|
+
self.gnn_load = None
|
|
394
|
+
self.ft_load = None
|
|
395
|
+
self.version_manager = VersionManager(self.output_manager)
|
|
396
|
+
|
|
397
|
+
def _build_cv_splitter(self) -> _CVSplitter:
|
|
398
|
+
strategy = str(getattr(self.config, "cv_strategy", "random") or "random").strip().lower()
|
|
399
|
+
val_ratio = float(self.prop_test) if self.prop_test is not None else 0.25
|
|
400
|
+
if not (0.0 < val_ratio < 1.0):
|
|
401
|
+
val_ratio = 0.25
|
|
402
|
+
cv_splits = getattr(self.config, "cv_splits", None)
|
|
403
|
+
if cv_splits is None:
|
|
404
|
+
cv_splits = max(2, int(round(1 / val_ratio)))
|
|
405
|
+
cv_splits = max(2, int(cv_splits))
|
|
406
|
+
|
|
407
|
+
if strategy in {"group", "grouped"}:
|
|
408
|
+
group_col = getattr(self.config, "cv_group_col", None)
|
|
409
|
+
if not group_col:
|
|
410
|
+
raise ValueError("cv_group_col is required for group cv_strategy.")
|
|
411
|
+
if group_col not in self.train_data.columns:
|
|
412
|
+
raise KeyError(f"cv_group_col '{group_col}' not in train_data.")
|
|
413
|
+
groups = self.train_data[group_col]
|
|
414
|
+
splitter = GroupKFold(n_splits=cv_splits)
|
|
415
|
+
return _CVSplitter(splitter, groups=groups)
|
|
416
|
+
|
|
417
|
+
if strategy in {"time", "timeseries", "temporal"}:
|
|
418
|
+
time_col = getattr(self.config, "cv_time_col", None)
|
|
419
|
+
if not time_col:
|
|
420
|
+
raise ValueError("cv_time_col is required for time cv_strategy.")
|
|
421
|
+
if time_col not in self.train_data.columns:
|
|
422
|
+
raise KeyError(f"cv_time_col '{time_col}' not in train_data.")
|
|
423
|
+
ascending = bool(getattr(self.config, "cv_time_ascending", True))
|
|
424
|
+
order_index = self.train_data[time_col].sort_values(ascending=ascending).index
|
|
425
|
+
order = self.train_data.index.get_indexer(order_index)
|
|
426
|
+
splitter = TimeSeriesSplit(n_splits=cv_splits)
|
|
427
|
+
return _CVSplitter(splitter, order=order)
|
|
428
|
+
|
|
429
|
+
splitter = ShuffleSplit(
|
|
430
|
+
n_splits=cv_splits,
|
|
431
|
+
test_size=val_ratio,
|
|
432
|
+
random_state=self.rand_seed,
|
|
433
|
+
)
|
|
434
|
+
return _CVSplitter(splitter)
|
|
435
|
+
|
|
436
|
+
def default_tweedie_power(self, obj: Optional[str] = None) -> Optional[float]:
|
|
437
|
+
if self.task_type == 'classification':
|
|
438
|
+
return None
|
|
439
|
+
loss_name = getattr(self, "loss_name", None)
|
|
440
|
+
if loss_name:
|
|
441
|
+
resolved = resolve_tweedie_power(str(loss_name), default=1.5)
|
|
442
|
+
if resolved is not None:
|
|
443
|
+
return resolved
|
|
444
|
+
objective = obj or getattr(self, "obj", None)
|
|
445
|
+
if objective == 'count:poisson':
|
|
446
|
+
return 1.0
|
|
447
|
+
if objective == 'reg:gamma':
|
|
448
|
+
return 2.0
|
|
449
|
+
return 1.5
|
|
450
|
+
|
|
451
|
+
def _build_geo_tokens(self, params_override: Optional[Dict[str, Any]] = None):
|
|
452
|
+
"""Internal builder; allows trial overrides and returns None on failure."""
|
|
453
|
+
geo_cols = list(self.config.geo_feature_nmes or [])
|
|
454
|
+
if not geo_cols:
|
|
455
|
+
return None
|
|
456
|
+
|
|
457
|
+
available = [c for c in geo_cols if c in self.train_data.columns]
|
|
458
|
+
if not available:
|
|
459
|
+
return None
|
|
460
|
+
|
|
461
|
+
# Preprocess text/numeric: fill numeric with median, label-encode text, map unknowns.
|
|
462
|
+
proc_train = {}
|
|
463
|
+
proc_test = {}
|
|
464
|
+
for col in available:
|
|
465
|
+
s_train = self.train_data[col]
|
|
466
|
+
s_test = self.test_data[col]
|
|
467
|
+
if pd.api.types.is_numeric_dtype(s_train):
|
|
468
|
+
tr = pd.to_numeric(s_train, errors="coerce")
|
|
469
|
+
te = pd.to_numeric(s_test, errors="coerce")
|
|
470
|
+
med = np.nanmedian(tr)
|
|
471
|
+
proc_train[col] = np.nan_to_num(tr, nan=med).astype(np.float32)
|
|
472
|
+
proc_test[col] = np.nan_to_num(te, nan=med).astype(np.float32)
|
|
473
|
+
else:
|
|
474
|
+
cats = pd.Categorical(s_train.astype(str))
|
|
475
|
+
tr_codes = cats.codes.astype(np.float32, copy=True)
|
|
476
|
+
tr_codes[tr_codes < 0] = len(cats.categories)
|
|
477
|
+
te_cats = pd.Categorical(
|
|
478
|
+
s_test.astype(str), categories=cats.categories)
|
|
479
|
+
te_codes = te_cats.codes.astype(np.float32, copy=True)
|
|
480
|
+
te_codes[te_codes < 0] = len(cats.categories)
|
|
481
|
+
proc_train[col] = tr_codes
|
|
482
|
+
proc_test[col] = te_codes
|
|
483
|
+
|
|
484
|
+
train_geo_raw = pd.DataFrame(proc_train, index=self.train_data.index)
|
|
485
|
+
test_geo_raw = pd.DataFrame(proc_test, index=self.test_data.index)
|
|
486
|
+
|
|
487
|
+
scaler = StandardScaler()
|
|
488
|
+
train_geo = pd.DataFrame(
|
|
489
|
+
scaler.fit_transform(train_geo_raw),
|
|
490
|
+
columns=available,
|
|
491
|
+
index=self.train_data.index
|
|
492
|
+
)
|
|
493
|
+
test_geo = pd.DataFrame(
|
|
494
|
+
scaler.transform(test_geo_raw),
|
|
495
|
+
columns=available,
|
|
496
|
+
index=self.test_data.index
|
|
497
|
+
)
|
|
498
|
+
|
|
499
|
+
tw_power = self.default_tweedie_power()
|
|
500
|
+
|
|
501
|
+
cfg = params_override or {}
|
|
502
|
+
try:
|
|
503
|
+
geo_gnn = GraphNeuralNetSklearn(
|
|
504
|
+
model_nme=f"{self.model_nme}_geo",
|
|
505
|
+
input_dim=len(available),
|
|
506
|
+
hidden_dim=cfg.get("geo_token_hidden_dim",
|
|
507
|
+
self.config.geo_token_hidden_dim),
|
|
508
|
+
num_layers=cfg.get("geo_token_layers",
|
|
509
|
+
self.config.geo_token_layers),
|
|
510
|
+
k_neighbors=cfg.get("geo_token_k_neighbors",
|
|
511
|
+
self.config.geo_token_k_neighbors),
|
|
512
|
+
dropout=cfg.get("geo_token_dropout",
|
|
513
|
+
self.config.geo_token_dropout),
|
|
514
|
+
learning_rate=cfg.get(
|
|
515
|
+
"geo_token_learning_rate", self.config.geo_token_learning_rate),
|
|
516
|
+
epochs=int(cfg.get("geo_token_epochs",
|
|
517
|
+
self.config.geo_token_epochs)),
|
|
518
|
+
patience=5,
|
|
519
|
+
task_type=self.task_type,
|
|
520
|
+
tweedie_power=tw_power,
|
|
521
|
+
loss_name=self.loss_name,
|
|
522
|
+
use_data_parallel=False,
|
|
523
|
+
use_ddp=False,
|
|
524
|
+
use_approx_knn=self.config.gnn_use_approx_knn,
|
|
525
|
+
approx_knn_threshold=self.config.gnn_approx_knn_threshold,
|
|
526
|
+
graph_cache_path=None,
|
|
527
|
+
max_gpu_knn_nodes=self.config.gnn_max_gpu_knn_nodes,
|
|
528
|
+
knn_gpu_mem_ratio=self.config.gnn_knn_gpu_mem_ratio,
|
|
529
|
+
knn_gpu_mem_overhead=self.config.gnn_knn_gpu_mem_overhead
|
|
530
|
+
)
|
|
531
|
+
geo_gnn.fit(
|
|
532
|
+
train_geo,
|
|
533
|
+
self.train_data[self.resp_nme],
|
|
534
|
+
self.train_data[self.weight_nme]
|
|
535
|
+
)
|
|
536
|
+
train_embed = geo_gnn.encode(train_geo)
|
|
537
|
+
test_embed = geo_gnn.encode(test_geo)
|
|
538
|
+
cols = [f"geo_token_{i}" for i in range(train_embed.shape[1])]
|
|
539
|
+
train_tokens = pd.DataFrame(
|
|
540
|
+
train_embed, index=self.train_data.index, columns=cols)
|
|
541
|
+
test_tokens = pd.DataFrame(
|
|
542
|
+
test_embed, index=self.test_data.index, columns=cols)
|
|
543
|
+
return train_tokens, test_tokens, cols, geo_gnn
|
|
544
|
+
except Exception as exc:
|
|
545
|
+
print(f"[GeoToken] Generation failed: {exc}")
|
|
546
|
+
return None
|
|
547
|
+
|
|
548
|
+
def _prepare_geo_tokens(self) -> None:
|
|
549
|
+
"""Build and persist geo tokens with default config values."""
|
|
550
|
+
gnn_trainer = self.trainers.get("gnn")
|
|
551
|
+
if gnn_trainer is not None and hasattr(gnn_trainer, "prepare_geo_tokens"):
|
|
552
|
+
try:
|
|
553
|
+
gnn_trainer.prepare_geo_tokens(force=False) # type: ignore[attr-defined]
|
|
554
|
+
return
|
|
555
|
+
except Exception as exc:
|
|
556
|
+
print(f"[GeoToken] GNNTrainer generation failed: {exc}")
|
|
557
|
+
|
|
558
|
+
result = self._build_geo_tokens()
|
|
559
|
+
if result is None:
|
|
560
|
+
return
|
|
561
|
+
train_tokens, test_tokens, cols, geo_gnn = result
|
|
562
|
+
self.train_geo_tokens = train_tokens
|
|
563
|
+
self.test_geo_tokens = test_tokens
|
|
564
|
+
self.geo_token_cols = cols
|
|
565
|
+
self.geo_gnn_model = geo_gnn
|
|
566
|
+
print(f"[GeoToken] Generated {len(cols)}-dim geo tokens; injecting into FT.")
|
|
567
|
+
|
|
568
|
+
def _add_region_effect(self) -> None:
|
|
569
|
+
"""Partial pooling over province/city to create a smoothed region_effect feature."""
|
|
570
|
+
prov_col = self.config.region_province_col
|
|
571
|
+
city_col = self.config.region_city_col
|
|
572
|
+
if not prov_col or not city_col:
|
|
573
|
+
return
|
|
574
|
+
for col in [prov_col, city_col]:
|
|
575
|
+
if col not in self.train_data.columns:
|
|
576
|
+
print(f"[RegionEffect] Missing column {col}; skipped.")
|
|
577
|
+
return
|
|
578
|
+
|
|
579
|
+
def safe_mean(df: pd.DataFrame) -> float:
|
|
580
|
+
w = df[self.weight_nme]
|
|
581
|
+
y = df[self.resp_nme]
|
|
582
|
+
denom = max(float(w.sum()), EPS)
|
|
583
|
+
return float((y * w).sum() / denom)
|
|
584
|
+
|
|
585
|
+
global_mean = safe_mean(self.train_data)
|
|
586
|
+
alpha = max(float(self.config.region_effect_alpha), 0.0)
|
|
587
|
+
|
|
588
|
+
w_all = self.train_data[self.weight_nme]
|
|
589
|
+
y_all = self.train_data[self.resp_nme]
|
|
590
|
+
yw_all = y_all * w_all
|
|
591
|
+
|
|
592
|
+
prov_sumw = w_all.groupby(self.train_data[prov_col]).sum()
|
|
593
|
+
prov_sumyw = yw_all.groupby(self.train_data[prov_col]).sum()
|
|
594
|
+
prov_mean = (prov_sumyw / prov_sumw.clip(lower=EPS)).astype(float)
|
|
595
|
+
prov_mean = prov_mean.fillna(global_mean)
|
|
596
|
+
|
|
597
|
+
city_sumw = self.train_data.groupby([prov_col, city_col])[
|
|
598
|
+
self.weight_nme].sum()
|
|
599
|
+
city_sumyw = yw_all.groupby(
|
|
600
|
+
[self.train_data[prov_col], self.train_data[city_col]]).sum()
|
|
601
|
+
city_df = pd.DataFrame({
|
|
602
|
+
"sum_w": city_sumw,
|
|
603
|
+
"sum_yw": city_sumyw,
|
|
604
|
+
})
|
|
605
|
+
city_df["prior"] = city_df.index.get_level_values(0).map(
|
|
606
|
+
prov_mean).fillna(global_mean)
|
|
607
|
+
city_df["effect"] = (
|
|
608
|
+
city_df["sum_yw"] + alpha * city_df["prior"]
|
|
609
|
+
) / (city_df["sum_w"] + alpha).clip(lower=EPS)
|
|
610
|
+
city_effect = city_df["effect"]
|
|
611
|
+
|
|
612
|
+
def lookup_effect(df: pd.DataFrame) -> pd.Series:
|
|
613
|
+
idx = pd.MultiIndex.from_frame(df[[prov_col, city_col]])
|
|
614
|
+
effects = city_effect.reindex(idx).to_numpy(dtype=np.float64)
|
|
615
|
+
prov_fallback = df[prov_col].map(
|
|
616
|
+
prov_mean).fillna(global_mean).to_numpy(dtype=np.float64)
|
|
617
|
+
effects = np.where(np.isfinite(effects), effects, prov_fallback)
|
|
618
|
+
effects = np.where(np.isfinite(effects), effects, global_mean)
|
|
619
|
+
return pd.Series(effects, index=df.index, dtype=np.float32)
|
|
620
|
+
|
|
621
|
+
re_train = lookup_effect(self.train_data)
|
|
622
|
+
re_test = lookup_effect(self.test_data)
|
|
623
|
+
|
|
624
|
+
col_name = "region_effect"
|
|
625
|
+
self.train_data[col_name] = re_train
|
|
626
|
+
self.test_data[col_name] = re_test
|
|
627
|
+
|
|
628
|
+
# Sync into one-hot and scaled variants.
|
|
629
|
+
for df in [self.train_oht_data, self.test_oht_data]:
|
|
630
|
+
if df is not None:
|
|
631
|
+
df[col_name] = re_train if df is self.train_oht_data else re_test
|
|
632
|
+
|
|
633
|
+
# Standardize region_effect and propagate.
|
|
634
|
+
scaler = StandardScaler()
|
|
635
|
+
re_train_s = scaler.fit_transform(
|
|
636
|
+
re_train.values.reshape(-1, 1)).astype(np.float32).reshape(-1)
|
|
637
|
+
re_test_s = scaler.transform(
|
|
638
|
+
re_test.values.reshape(-1, 1)).astype(np.float32).reshape(-1)
|
|
639
|
+
for df in [self.train_oht_scl_data, self.test_oht_scl_data]:
|
|
640
|
+
if df is not None:
|
|
641
|
+
df[col_name] = re_train_s if df is self.train_oht_scl_data else re_test_s
|
|
642
|
+
|
|
643
|
+
# Update feature lists.
|
|
644
|
+
if col_name not in self.factor_nmes:
|
|
645
|
+
self.factor_nmes.append(col_name)
|
|
646
|
+
if col_name not in self.num_features:
|
|
647
|
+
self.num_features.append(col_name)
|
|
648
|
+
if self.train_oht_scl_data is not None:
|
|
649
|
+
excluded = {self.weight_nme, self.resp_nme}
|
|
650
|
+
self.var_nmes = [
|
|
651
|
+
col for col in self.train_oht_scl_data.columns if col not in excluded
|
|
652
|
+
]
|
|
653
|
+
|
|
654
|
+
def _require_trainer(self, model_key: str) -> "TrainerBase":
|
|
655
|
+
trainer = self.trainers.get(model_key)
|
|
656
|
+
if trainer is None:
|
|
657
|
+
raise KeyError(f"Unknown model key: {model_key}")
|
|
658
|
+
return trainer
|
|
659
|
+
|
|
660
|
+
def _pred_vector_columns(self, pred_prefix: str) -> List[str]:
|
|
661
|
+
"""Return vector feature columns like pred_<prefix>_0.. sorted by suffix."""
|
|
662
|
+
col_prefix = f"pred_{pred_prefix}_"
|
|
663
|
+
cols = [c for c in self.train_data.columns if c.startswith(col_prefix)]
|
|
664
|
+
|
|
665
|
+
def sort_key(name: str):
|
|
666
|
+
tail = name.rsplit("_", 1)[-1]
|
|
667
|
+
try:
|
|
668
|
+
return (0, int(tail))
|
|
669
|
+
except Exception:
|
|
670
|
+
return (1, tail)
|
|
671
|
+
|
|
672
|
+
cols.sort(key=sort_key)
|
|
673
|
+
return cols
|
|
674
|
+
|
|
675
|
+
def _inject_pred_features(self, pred_prefix: str) -> List[str]:
|
|
676
|
+
"""Inject pred_<prefix> or pred_<prefix>_i columns into features and return names."""
|
|
677
|
+
cols = self._pred_vector_columns(pred_prefix)
|
|
678
|
+
if cols:
|
|
679
|
+
self.add_numeric_features_from_columns(cols)
|
|
680
|
+
return cols
|
|
681
|
+
scalar_col = f"pred_{pred_prefix}"
|
|
682
|
+
if scalar_col in self.train_data.columns:
|
|
683
|
+
self.add_numeric_feature_from_column(scalar_col)
|
|
684
|
+
return [scalar_col]
|
|
685
|
+
return []
|
|
686
|
+
|
|
687
|
+
def _maybe_load_best_params(self, model_key: str, trainer: "TrainerBase") -> None:
|
|
688
|
+
# 1) If best_params_files is specified, load and skip tuning.
|
|
689
|
+
best_params_files = getattr(self.config, "best_params_files", None) or {}
|
|
690
|
+
best_params_file = best_params_files.get(model_key)
|
|
691
|
+
if best_params_file and not trainer.best_params:
|
|
692
|
+
trainer.best_params = IOUtils.load_params_file(best_params_file)
|
|
693
|
+
trainer.best_trial = None
|
|
694
|
+
print(
|
|
695
|
+
f"[Optuna][{trainer.label}] Loaded best_params from {best_params_file}; skip tuning."
|
|
696
|
+
)
|
|
697
|
+
|
|
698
|
+
# 2) If reuse_best_params is enabled, prefer version snapshots; else load legacy CSV.
|
|
699
|
+
reuse_params = bool(getattr(self.config, "reuse_best_params", False))
|
|
700
|
+
if reuse_params and not trainer.best_params:
|
|
701
|
+
payload = self.version_manager.load_latest(f"{model_key}_best")
|
|
702
|
+
best_params = None if payload is None else payload.get("best_params")
|
|
703
|
+
if best_params:
|
|
704
|
+
trainer.best_params = best_params
|
|
705
|
+
trainer.best_trial = None
|
|
706
|
+
trainer.study_name = payload.get(
|
|
707
|
+
"study_name") if isinstance(payload, dict) else None
|
|
708
|
+
print(
|
|
709
|
+
f"[Optuna][{trainer.label}] Reusing best_params from versions snapshot.")
|
|
710
|
+
return
|
|
711
|
+
|
|
712
|
+
params_path = self.output_manager.result_path(
|
|
713
|
+
f'{self.model_nme}_bestparams_{trainer.label.lower()}.csv'
|
|
714
|
+
)
|
|
715
|
+
if os.path.exists(params_path):
|
|
716
|
+
try:
|
|
717
|
+
trainer.best_params = IOUtils.load_params_file(params_path)
|
|
718
|
+
trainer.best_trial = None
|
|
719
|
+
print(
|
|
720
|
+
f"[Optuna][{trainer.label}] Reusing best_params from {params_path}.")
|
|
721
|
+
except ValueError:
|
|
722
|
+
# Legacy compatibility: ignore empty files and continue tuning.
|
|
723
|
+
pass
|
|
724
|
+
|
|
725
|
+
# Generic optimization entry point.
|
|
726
|
+
def optimize_model(self, model_key: str, max_evals: int = 100):
|
|
727
|
+
if model_key not in self.trainers:
|
|
728
|
+
print(f"Warning: Unknown model key: {model_key}")
|
|
729
|
+
return
|
|
730
|
+
|
|
731
|
+
trainer = self._require_trainer(model_key)
|
|
732
|
+
self._maybe_load_best_params(model_key, trainer)
|
|
733
|
+
|
|
734
|
+
should_tune = not trainer.best_params
|
|
735
|
+
if should_tune:
|
|
736
|
+
if model_key == "ft" and str(self.config.ft_role) == "unsupervised_embedding":
|
|
737
|
+
if hasattr(trainer, "cross_val_unsupervised"):
|
|
738
|
+
trainer.tune(
|
|
739
|
+
max_evals,
|
|
740
|
+
objective_fn=getattr(trainer, "cross_val_unsupervised")
|
|
741
|
+
)
|
|
742
|
+
else:
|
|
743
|
+
raise RuntimeError(
|
|
744
|
+
"FT trainer does not support unsupervised Optuna objective.")
|
|
745
|
+
else:
|
|
746
|
+
trainer.tune(max_evals)
|
|
747
|
+
|
|
748
|
+
if model_key == "ft" and str(self.config.ft_role) != "model":
|
|
749
|
+
prefix = str(self.config.ft_feature_prefix or "ft_emb")
|
|
750
|
+
role = str(self.config.ft_role)
|
|
751
|
+
if role == "embedding":
|
|
752
|
+
trainer.train_as_feature(
|
|
753
|
+
pred_prefix=prefix, feature_mode="embedding")
|
|
754
|
+
elif role == "unsupervised_embedding":
|
|
755
|
+
trainer.pretrain_unsupervised_as_feature(
|
|
756
|
+
pred_prefix=prefix,
|
|
757
|
+
params=trainer.best_params
|
|
758
|
+
)
|
|
759
|
+
else:
|
|
760
|
+
raise ValueError(
|
|
761
|
+
f"Unsupported ft_role='{role}', expected 'model'/'embedding'/'unsupervised_embedding'.")
|
|
762
|
+
|
|
763
|
+
# Inject generated prediction/embedding columns as features (scalar or vector).
|
|
764
|
+
self._inject_pred_features(prefix)
|
|
765
|
+
# Do not add FT as a standalone model label; downstream models handle evaluation.
|
|
766
|
+
else:
|
|
767
|
+
trainer.train()
|
|
768
|
+
|
|
769
|
+
if bool(getattr(self.config, "final_ensemble", False)):
|
|
770
|
+
k = int(getattr(self.config, "final_ensemble_k", 3) or 3)
|
|
771
|
+
if k > 1:
|
|
772
|
+
if model_key == "ft" and str(self.config.ft_role) != "model":
|
|
773
|
+
pass
|
|
774
|
+
elif hasattr(trainer, "ensemble_predict"):
|
|
775
|
+
trainer.ensemble_predict(k)
|
|
776
|
+
else:
|
|
777
|
+
print(
|
|
778
|
+
f"[Ensemble] Trainer '{model_key}' does not support ensemble prediction.",
|
|
779
|
+
flush=True,
|
|
780
|
+
)
|
|
781
|
+
|
|
782
|
+
# Update context fields for backward compatibility.
|
|
783
|
+
setattr(self, f"{model_key}_best", trainer.model)
|
|
784
|
+
setattr(self, f"best_{model_key}_params", trainer.best_params)
|
|
785
|
+
setattr(self, f"best_{model_key}_trial", trainer.best_trial)
|
|
786
|
+
# Save a snapshot for traceability.
|
|
787
|
+
study_name = getattr(trainer, "study_name", None)
|
|
788
|
+
if study_name is None and trainer.best_trial is not None:
|
|
789
|
+
study_obj = getattr(trainer.best_trial, "study", None)
|
|
790
|
+
study_name = getattr(study_obj, "study_name", None)
|
|
791
|
+
snapshot = {
|
|
792
|
+
"model_key": model_key,
|
|
793
|
+
"timestamp": datetime.now().isoformat(),
|
|
794
|
+
"best_params": trainer.best_params,
|
|
795
|
+
"study_name": study_name,
|
|
796
|
+
"config": asdict(self.config),
|
|
797
|
+
}
|
|
798
|
+
self.version_manager.save(f"{model_key}_best", snapshot)
|
|
799
|
+
|
|
800
|
+
def add_numeric_feature_from_column(self, col_name: str) -> None:
|
|
801
|
+
"""Add an existing column as a feature and sync one-hot/scaled tables."""
|
|
802
|
+
if col_name not in self.train_data.columns or col_name not in self.test_data.columns:
|
|
803
|
+
raise KeyError(
|
|
804
|
+
f"Column '{col_name}' must exist in both train_data and test_data.")
|
|
805
|
+
|
|
806
|
+
if col_name not in self.factor_nmes:
|
|
807
|
+
self.factor_nmes.append(col_name)
|
|
808
|
+
if col_name not in self.config.factor_nmes:
|
|
809
|
+
self.config.factor_nmes.append(col_name)
|
|
810
|
+
|
|
811
|
+
if col_name not in self.cate_list and col_name not in self.num_features:
|
|
812
|
+
self.num_features.append(col_name)
|
|
813
|
+
|
|
814
|
+
if self.train_oht_data is not None and self.test_oht_data is not None:
|
|
815
|
+
self.train_oht_data[col_name] = self.train_data[col_name].values
|
|
816
|
+
self.test_oht_data[col_name] = self.test_data[col_name].values
|
|
817
|
+
if self.train_oht_scl_data is not None and self.test_oht_scl_data is not None:
|
|
818
|
+
scaler = StandardScaler()
|
|
819
|
+
tr = self.train_data[col_name].to_numpy(
|
|
820
|
+
dtype=np.float32, copy=False).reshape(-1, 1)
|
|
821
|
+
te = self.test_data[col_name].to_numpy(
|
|
822
|
+
dtype=np.float32, copy=False).reshape(-1, 1)
|
|
823
|
+
self.train_oht_scl_data[col_name] = scaler.fit_transform(
|
|
824
|
+
tr).reshape(-1)
|
|
825
|
+
self.test_oht_scl_data[col_name] = scaler.transform(te).reshape(-1)
|
|
826
|
+
|
|
827
|
+
if col_name not in self.var_nmes:
|
|
828
|
+
self.var_nmes.append(col_name)
|
|
829
|
+
|
|
830
|
+
def add_numeric_features_from_columns(self, col_names: List[str]) -> None:
|
|
831
|
+
if not col_names:
|
|
832
|
+
return
|
|
833
|
+
|
|
834
|
+
missing = [
|
|
835
|
+
col for col in col_names
|
|
836
|
+
if col not in self.train_data.columns or col not in self.test_data.columns
|
|
837
|
+
]
|
|
838
|
+
if missing:
|
|
839
|
+
raise KeyError(
|
|
840
|
+
f"Column(s) {missing} must exist in both train_data and test_data."
|
|
841
|
+
)
|
|
842
|
+
|
|
843
|
+
for col_name in col_names:
|
|
844
|
+
if col_name not in self.factor_nmes:
|
|
845
|
+
self.factor_nmes.append(col_name)
|
|
846
|
+
if col_name not in self.config.factor_nmes:
|
|
847
|
+
self.config.factor_nmes.append(col_name)
|
|
848
|
+
if col_name not in self.cate_list and col_name not in self.num_features:
|
|
849
|
+
self.num_features.append(col_name)
|
|
850
|
+
if col_name not in self.var_nmes:
|
|
851
|
+
self.var_nmes.append(col_name)
|
|
852
|
+
|
|
853
|
+
if self.train_oht_data is not None and self.test_oht_data is not None:
|
|
854
|
+
self.train_oht_data[col_names] = self.train_data[col_names].to_numpy(copy=False)
|
|
855
|
+
self.test_oht_data[col_names] = self.test_data[col_names].to_numpy(copy=False)
|
|
856
|
+
|
|
857
|
+
if self.train_oht_scl_data is not None and self.test_oht_scl_data is not None:
|
|
858
|
+
scaler = StandardScaler()
|
|
859
|
+
tr = self.train_data[col_names].to_numpy(dtype=np.float32, copy=False)
|
|
860
|
+
te = self.test_data[col_names].to_numpy(dtype=np.float32, copy=False)
|
|
861
|
+
self.train_oht_scl_data[col_names] = scaler.fit_transform(tr)
|
|
862
|
+
self.test_oht_scl_data[col_names] = scaler.transform(te)
|
|
863
|
+
|
|
864
|
+
def prepare_ft_as_feature(self, max_evals: int = 50, pred_prefix: str = "ft_feat") -> str:
|
|
865
|
+
"""Train FT as a feature generator and return the downstream column name."""
|
|
866
|
+
ft_trainer = self._require_trainer("ft")
|
|
867
|
+
ft_trainer.tune(max_evals=max_evals)
|
|
868
|
+
if hasattr(ft_trainer, "train_as_feature"):
|
|
869
|
+
ft_trainer.train_as_feature(pred_prefix=pred_prefix)
|
|
870
|
+
else:
|
|
871
|
+
ft_trainer.train()
|
|
872
|
+
feature_col = f"pred_{pred_prefix}"
|
|
873
|
+
self.add_numeric_feature_from_column(feature_col)
|
|
874
|
+
return feature_col
|
|
875
|
+
|
|
876
|
+
def prepare_ft_embedding_as_features(self, max_evals: int = 50, pred_prefix: str = "ft_emb") -> List[str]:
|
|
877
|
+
"""Train FT and inject pooled embeddings as vector features pred_<prefix>_0.. ."""
|
|
878
|
+
ft_trainer = self._require_trainer("ft")
|
|
879
|
+
ft_trainer.tune(max_evals=max_evals)
|
|
880
|
+
if hasattr(ft_trainer, "train_as_feature"):
|
|
881
|
+
ft_trainer.train_as_feature(
|
|
882
|
+
pred_prefix=pred_prefix, feature_mode="embedding")
|
|
883
|
+
else:
|
|
884
|
+
raise RuntimeError(
|
|
885
|
+
"FT trainer does not support embedding feature mode.")
|
|
886
|
+
cols = self._pred_vector_columns(pred_prefix)
|
|
887
|
+
if not cols:
|
|
888
|
+
raise RuntimeError(
|
|
889
|
+
f"No embedding columns were generated for prefix '{pred_prefix}'.")
|
|
890
|
+
self.add_numeric_features_from_columns(cols)
|
|
891
|
+
return cols
|
|
892
|
+
|
|
893
|
+
def prepare_ft_unsupervised_embedding_as_features(self,
|
|
894
|
+
pred_prefix: str = "ft_uemb",
|
|
895
|
+
params: Optional[Dict[str,
|
|
896
|
+
Any]] = None,
|
|
897
|
+
mask_prob_num: float = 0.15,
|
|
898
|
+
mask_prob_cat: float = 0.15,
|
|
899
|
+
num_loss_weight: float = 1.0,
|
|
900
|
+
cat_loss_weight: float = 1.0) -> List[str]:
|
|
901
|
+
"""Export embeddings after FT self-supervised masked reconstruction pretraining."""
|
|
902
|
+
ft_trainer = self._require_trainer("ft")
|
|
903
|
+
if not hasattr(ft_trainer, "pretrain_unsupervised_as_feature"):
|
|
904
|
+
raise RuntimeError(
|
|
905
|
+
"FT trainer does not support unsupervised pretraining.")
|
|
906
|
+
ft_trainer.pretrain_unsupervised_as_feature(
|
|
907
|
+
pred_prefix=pred_prefix,
|
|
908
|
+
params=params,
|
|
909
|
+
mask_prob_num=mask_prob_num,
|
|
910
|
+
mask_prob_cat=mask_prob_cat,
|
|
911
|
+
num_loss_weight=num_loss_weight,
|
|
912
|
+
cat_loss_weight=cat_loss_weight
|
|
913
|
+
)
|
|
914
|
+
cols = self._pred_vector_columns(pred_prefix)
|
|
915
|
+
if not cols:
|
|
916
|
+
raise RuntimeError(
|
|
917
|
+
f"No embedding columns were generated for prefix '{pred_prefix}'.")
|
|
918
|
+
self.add_numeric_features_from_columns(cols)
|
|
919
|
+
return cols
|
|
920
|
+
|
|
921
|
+
# GLM Bayesian optimization wrapper.
|
|
922
|
+
def bayesopt_glm(self, max_evals=50):
|
|
923
|
+
self.optimize_model('glm', max_evals)
|
|
924
|
+
|
|
925
|
+
# XGBoost Bayesian optimization wrapper.
|
|
926
|
+
def bayesopt_xgb(self, max_evals=100):
|
|
927
|
+
self.optimize_model('xgb', max_evals)
|
|
928
|
+
|
|
929
|
+
# ResNet Bayesian optimization wrapper.
|
|
930
|
+
def bayesopt_resnet(self, max_evals=100):
|
|
931
|
+
self.optimize_model('resn', max_evals)
|
|
932
|
+
|
|
933
|
+
# GNN Bayesian optimization wrapper.
|
|
934
|
+
def bayesopt_gnn(self, max_evals=50):
|
|
935
|
+
self.optimize_model('gnn', max_evals)
|
|
936
|
+
|
|
937
|
+
# FT-Transformer Bayesian optimization wrapper.
|
|
938
|
+
def bayesopt_ft(self, max_evals=50):
|
|
939
|
+
self.optimize_model('ft', max_evals)
|
|
940
|
+
|
|
941
|
+
def save_model(self, model_name=None):
|
|
942
|
+
keys = [model_name] if model_name else self.trainers.keys()
|
|
943
|
+
for key in keys:
|
|
944
|
+
if key in self.trainers:
|
|
945
|
+
self.trainers[key].save()
|
|
946
|
+
else:
|
|
947
|
+
if model_name: # Only warn when the user specifies a model name.
|
|
948
|
+
print(f"[save_model] Warning: Unknown model key {key}")
|
|
949
|
+
|
|
950
|
+
def load_model(self, model_name=None):
|
|
951
|
+
keys = [model_name] if model_name else self.trainers.keys()
|
|
952
|
+
for key in keys:
|
|
953
|
+
if key in self.trainers:
|
|
954
|
+
self.trainers[key].load()
|
|
955
|
+
# Sync context fields.
|
|
956
|
+
trainer = self.trainers[key]
|
|
957
|
+
if trainer.model is not None:
|
|
958
|
+
setattr(self, f"{key}_best", trainer.model)
|
|
959
|
+
# For legacy compatibility, also update xxx_load.
|
|
960
|
+
# Old versions only tracked xgb_load/resn_load/ft_load (not glm_load/gnn_load).
|
|
961
|
+
if key in ['xgb', 'resn', 'ft', 'gnn']:
|
|
962
|
+
setattr(self, f"{key}_load", trainer.model)
|
|
963
|
+
else:
|
|
964
|
+
if model_name:
|
|
965
|
+
print(f"[load_model] Warning: Unknown model key {key}")
|