ins-pricing 0.3.4__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ins_pricing/docs/LOSS_FUNCTIONS.md +78 -0
- ins_pricing/frontend/QUICKSTART.md +152 -0
- ins_pricing/frontend/README.md +388 -0
- ins_pricing/frontend/__init__.py +10 -0
- ins_pricing/frontend/app.py +903 -0
- ins_pricing/frontend/config_builder.py +352 -0
- ins_pricing/frontend/example_config.json +36 -0
- ins_pricing/frontend/example_workflows.py +979 -0
- ins_pricing/frontend/ft_workflow.py +316 -0
- ins_pricing/frontend/runner.py +388 -0
- ins_pricing/production/predict.py +693 -664
- ins_pricing/setup.py +1 -1
- {ins_pricing-0.3.4.dist-info → ins_pricing-0.4.0.dist-info}/METADATA +1 -1
- {ins_pricing-0.3.4.dist-info → ins_pricing-0.4.0.dist-info}/RECORD +16 -6
- {ins_pricing-0.3.4.dist-info → ins_pricing-0.4.0.dist-info}/WHEEL +1 -1
- {ins_pricing-0.3.4.dist-info → ins_pricing-0.4.0.dist-info}/top_level.txt +0 -0
|
@@ -1,664 +1,693 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
import json
|
|
4
|
-
from pathlib import Path
|
|
5
|
-
from typing import Any, Dict, Iterable, List, Optional, Sequence
|
|
6
|
-
|
|
7
|
-
import joblib
|
|
8
|
-
import numpy as np
|
|
9
|
-
import pandas as pd
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
from .
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
)
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
if
|
|
61
|
-
return
|
|
62
|
-
if
|
|
63
|
-
return
|
|
64
|
-
return
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
def
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
)
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
return "
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
def
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
"""
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
return pd.
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
model_name
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
)
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
)
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
"
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
cfg
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any, Dict, Iterable, List, Optional, Sequence, TYPE_CHECKING
|
|
6
|
+
|
|
7
|
+
import joblib
|
|
8
|
+
import numpy as np
|
|
9
|
+
import pandas as pd
|
|
10
|
+
try: # statsmodels is optional when GLM inference is not used
|
|
11
|
+
import statsmodels.api as sm
|
|
12
|
+
_SM_IMPORT_ERROR: Optional[BaseException] = None
|
|
13
|
+
except Exception as exc: # pragma: no cover - optional dependency
|
|
14
|
+
sm = None # type: ignore[assignment]
|
|
15
|
+
_SM_IMPORT_ERROR = exc
|
|
16
|
+
|
|
17
|
+
from .preprocess import (
|
|
18
|
+
apply_preprocess_artifacts,
|
|
19
|
+
load_preprocess_artifacts,
|
|
20
|
+
prepare_raw_features,
|
|
21
|
+
)
|
|
22
|
+
from .scoring import batch_score
|
|
23
|
+
from ..modelling.core.bayesopt.utils.losses import (
|
|
24
|
+
infer_loss_name_from_model_name,
|
|
25
|
+
normalize_loss_name,
|
|
26
|
+
resolve_tweedie_power,
|
|
27
|
+
)
|
|
28
|
+
from ins_pricing.utils.logging import get_logger
|
|
29
|
+
|
|
30
|
+
_logger = get_logger("ins_pricing.production.predict")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
if TYPE_CHECKING:
|
|
34
|
+
from ..modelling.core.bayesopt.models.model_gnn import GraphNeuralNetSklearn
|
|
35
|
+
from ..modelling.core.bayesopt.models.model_resn import ResNetSklearn
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _torch_load(*args, **kwargs):
|
|
39
|
+
from ins_pricing.utils.torch_compat import torch_load
|
|
40
|
+
return torch_load(*args, **kwargs)
|
|
41
|
+
|
|
42
|
+
def _get_device_manager():
|
|
43
|
+
from ins_pricing.utils.device import DeviceManager
|
|
44
|
+
return DeviceManager
|
|
45
|
+
|
|
46
|
+
MODEL_PREFIX = {
|
|
47
|
+
"xgb": "Xgboost",
|
|
48
|
+
"glm": "GLM",
|
|
49
|
+
"resn": "ResNet",
|
|
50
|
+
"ft": "FTTransformer",
|
|
51
|
+
"gnn": "GNN",
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
OHT_MODELS = {"resn", "gnn", "glm"}
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _default_tweedie_power(model_name: str, task_type: str) -> Optional[float]:
|
|
58
|
+
if task_type == "classification":
|
|
59
|
+
return None
|
|
60
|
+
if "f" in model_name:
|
|
61
|
+
return 1.0
|
|
62
|
+
if "s" in model_name:
|
|
63
|
+
return 2.0
|
|
64
|
+
return 1.5
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _resolve_loss_name(cfg: Dict[str, Any], model_name: str, task_type: str) -> str:
|
|
68
|
+
normalized = normalize_loss_name(cfg.get("loss_name"), task_type)
|
|
69
|
+
if task_type == "classification":
|
|
70
|
+
return "logloss" if normalized == "auto" else normalized
|
|
71
|
+
if normalized == "auto":
|
|
72
|
+
return infer_loss_name_from_model_name(model_name)
|
|
73
|
+
return normalized
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _resolve_value(
|
|
77
|
+
value: Any,
|
|
78
|
+
*,
|
|
79
|
+
model_name: str,
|
|
80
|
+
base_dir: Path,
|
|
81
|
+
) -> Optional[Path]:
|
|
82
|
+
if value is None:
|
|
83
|
+
return None
|
|
84
|
+
if isinstance(value, dict):
|
|
85
|
+
value = value.get(model_name)
|
|
86
|
+
if value is None:
|
|
87
|
+
return None
|
|
88
|
+
path_str = str(value)
|
|
89
|
+
try:
|
|
90
|
+
path_str = path_str.format(model_name=model_name)
|
|
91
|
+
except Exception:
|
|
92
|
+
pass
|
|
93
|
+
candidate = Path(path_str)
|
|
94
|
+
if candidate.is_absolute():
|
|
95
|
+
return candidate
|
|
96
|
+
return (base_dir / candidate).resolve()
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def _load_json(path: Path) -> Dict[str, Any]:
|
|
100
|
+
return json.loads(path.read_text(encoding="utf-8"))
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def _infer_format(path: Path) -> str:
|
|
104
|
+
suffix = path.suffix.lower()
|
|
105
|
+
if suffix in {".parquet", ".pq"}:
|
|
106
|
+
return "parquet"
|
|
107
|
+
if suffix in {".feather", ".ft"}:
|
|
108
|
+
return "feather"
|
|
109
|
+
return "csv"
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _load_dataset(path: Path, chunksize: Optional[int] = None) -> pd.DataFrame:
|
|
113
|
+
"""Load dataset with optional chunked reading for large CSV files.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
path: Path to the dataset file
|
|
117
|
+
chunksize: If specified for CSV files, reads in chunks and concatenates.
|
|
118
|
+
Useful for large files that may not fit in memory at once.
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
DataFrame containing the full dataset
|
|
122
|
+
"""
|
|
123
|
+
fmt = _infer_format(path)
|
|
124
|
+
if fmt == "parquet":
|
|
125
|
+
return pd.read_parquet(path)
|
|
126
|
+
if fmt == "feather":
|
|
127
|
+
return pd.read_feather(path)
|
|
128
|
+
|
|
129
|
+
# For CSV, support chunked reading for large files
|
|
130
|
+
if chunksize is not None:
|
|
131
|
+
chunks = []
|
|
132
|
+
for chunk in pd.read_csv(path, low_memory=False, chunksize=chunksize):
|
|
133
|
+
chunks.append(chunk)
|
|
134
|
+
return pd.concat(chunks, ignore_index=True)
|
|
135
|
+
return pd.read_csv(path, low_memory=False)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def _model_file_path(output_dir: Path, model_name: str, model_key: str) -> Path:
|
|
139
|
+
prefix = MODEL_PREFIX.get(model_key)
|
|
140
|
+
if prefix is None:
|
|
141
|
+
raise ValueError(f"Unsupported model key: {model_key}")
|
|
142
|
+
ext = "pkl" if model_key in {"xgb", "glm"} else "pth"
|
|
143
|
+
return output_dir / "model" / f"01_{model_name}_{prefix}.{ext}"
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def _load_preprocess_from_model_file(
|
|
147
|
+
output_dir: Path,
|
|
148
|
+
model_name: str,
|
|
149
|
+
model_key: str,
|
|
150
|
+
) -> Optional[Dict[str, Any]]:
|
|
151
|
+
model_path = _model_file_path(output_dir, model_name, model_key)
|
|
152
|
+
if not model_path.exists():
|
|
153
|
+
return None
|
|
154
|
+
if model_key in {"xgb", "glm"}:
|
|
155
|
+
payload = joblib.load(model_path)
|
|
156
|
+
else:
|
|
157
|
+
payload = _torch_load(model_path, map_location="cpu")
|
|
158
|
+
if isinstance(payload, dict):
|
|
159
|
+
return payload.get("preprocess_artifacts")
|
|
160
|
+
return None
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def _move_to_device(model_obj: Any) -> None:
|
|
164
|
+
"""Move model to best available device using shared DeviceManager."""
|
|
165
|
+
DeviceManager = _get_device_manager()
|
|
166
|
+
DeviceManager.move_to_device(model_obj)
|
|
167
|
+
if hasattr(model_obj, "eval"):
|
|
168
|
+
model_obj.eval()
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def load_best_params(
|
|
172
|
+
output_dir: str | Path,
|
|
173
|
+
model_name: str,
|
|
174
|
+
model_key: str,
|
|
175
|
+
) -> Optional[Dict[str, Any]]:
|
|
176
|
+
output_path = Path(output_dir)
|
|
177
|
+
versions_dir = output_path / "Results" / "versions"
|
|
178
|
+
if versions_dir.exists():
|
|
179
|
+
candidates = sorted(versions_dir.glob(f"*_{model_key}_best.json"))
|
|
180
|
+
if candidates:
|
|
181
|
+
payload = _load_json(candidates[-1])
|
|
182
|
+
params = payload.get("best_params")
|
|
183
|
+
if params:
|
|
184
|
+
return params
|
|
185
|
+
|
|
186
|
+
label_map = {
|
|
187
|
+
"xgb": "xgboost",
|
|
188
|
+
"resn": "resnet",
|
|
189
|
+
"ft": "fttransformer",
|
|
190
|
+
"glm": "glm",
|
|
191
|
+
"gnn": "gnn",
|
|
192
|
+
}
|
|
193
|
+
label = label_map.get(model_key, model_key)
|
|
194
|
+
csv_path = output_path / "Results" / f"{model_name}_bestparams_{label}.csv"
|
|
195
|
+
if csv_path.exists():
|
|
196
|
+
df = pd.read_csv(csv_path)
|
|
197
|
+
if not df.empty:
|
|
198
|
+
return df.iloc[0].to_dict()
|
|
199
|
+
return None
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def _build_resn_model(
|
|
203
|
+
*,
|
|
204
|
+
model_name: str,
|
|
205
|
+
input_dim: int,
|
|
206
|
+
task_type: str,
|
|
207
|
+
epochs: int,
|
|
208
|
+
resn_weight_decay: float,
|
|
209
|
+
loss_name: str,
|
|
210
|
+
params: Dict[str, Any],
|
|
211
|
+
) -> ResNetSklearn:
|
|
212
|
+
from ..modelling.core.bayesopt.models.model_resn import ResNetSklearn
|
|
213
|
+
if loss_name == "tweedie":
|
|
214
|
+
power = params.get(
|
|
215
|
+
"tw_power", _default_tweedie_power(model_name, task_type))
|
|
216
|
+
power = float(power) if power is not None else None
|
|
217
|
+
else:
|
|
218
|
+
power = resolve_tweedie_power(loss_name, default=1.5)
|
|
219
|
+
weight_decay = float(params.get("weight_decay", resn_weight_decay))
|
|
220
|
+
return ResNetSklearn(
|
|
221
|
+
model_nme=model_name,
|
|
222
|
+
input_dim=input_dim,
|
|
223
|
+
hidden_dim=int(params.get("hidden_dim", 64)),
|
|
224
|
+
block_num=int(params.get("block_num", 2)),
|
|
225
|
+
task_type=task_type,
|
|
226
|
+
epochs=int(epochs),
|
|
227
|
+
tweedie_power=power,
|
|
228
|
+
learning_rate=float(params.get("learning_rate", 0.01)),
|
|
229
|
+
patience=int(params.get("patience", 10)),
|
|
230
|
+
use_layernorm=True,
|
|
231
|
+
dropout=float(params.get("dropout", 0.1)),
|
|
232
|
+
residual_scale=float(params.get("residual_scale", 0.1)),
|
|
233
|
+
stochastic_depth=float(params.get("stochastic_depth", 0.0)),
|
|
234
|
+
weight_decay=weight_decay,
|
|
235
|
+
use_data_parallel=False,
|
|
236
|
+
use_ddp=False,
|
|
237
|
+
loss_name=loss_name,
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def _build_gnn_model(
|
|
242
|
+
*,
|
|
243
|
+
model_name: str,
|
|
244
|
+
input_dim: int,
|
|
245
|
+
task_type: str,
|
|
246
|
+
epochs: int,
|
|
247
|
+
cfg: Dict[str, Any],
|
|
248
|
+
loss_name: str,
|
|
249
|
+
params: Dict[str, Any],
|
|
250
|
+
) -> GraphNeuralNetSklearn:
|
|
251
|
+
from ..modelling.core.bayesopt.models.model_gnn import GraphNeuralNetSklearn
|
|
252
|
+
base_tw = _default_tweedie_power(model_name, task_type)
|
|
253
|
+
if loss_name == "tweedie":
|
|
254
|
+
tw_power = params.get("tw_power", base_tw)
|
|
255
|
+
tw_power = float(tw_power) if tw_power is not None else None
|
|
256
|
+
else:
|
|
257
|
+
tw_power = resolve_tweedie_power(loss_name, default=1.5)
|
|
258
|
+
return GraphNeuralNetSklearn(
|
|
259
|
+
model_nme=f"{model_name}_gnn",
|
|
260
|
+
input_dim=input_dim,
|
|
261
|
+
hidden_dim=int(params.get("hidden_dim", 64)),
|
|
262
|
+
num_layers=int(params.get("num_layers", 2)),
|
|
263
|
+
k_neighbors=int(params.get("k_neighbors", 10)),
|
|
264
|
+
dropout=float(params.get("dropout", 0.1)),
|
|
265
|
+
learning_rate=float(params.get("learning_rate", 1e-3)),
|
|
266
|
+
epochs=int(params.get("epochs", epochs)),
|
|
267
|
+
patience=int(params.get("patience", 5)),
|
|
268
|
+
task_type=task_type,
|
|
269
|
+
tweedie_power=tw_power,
|
|
270
|
+
weight_decay=float(params.get("weight_decay", 0.0)),
|
|
271
|
+
use_data_parallel=False,
|
|
272
|
+
use_ddp=False,
|
|
273
|
+
use_approx_knn=bool(cfg.get("gnn_use_approx_knn", True)),
|
|
274
|
+
approx_knn_threshold=int(cfg.get("gnn_approx_knn_threshold", 50000)),
|
|
275
|
+
graph_cache_path=cfg.get("gnn_graph_cache"),
|
|
276
|
+
max_gpu_knn_nodes=cfg.get("gnn_max_gpu_knn_nodes"),
|
|
277
|
+
knn_gpu_mem_ratio=cfg.get("gnn_knn_gpu_mem_ratio", 0.9),
|
|
278
|
+
knn_gpu_mem_overhead=cfg.get("gnn_knn_gpu_mem_overhead", 2.0),
|
|
279
|
+
loss_name=loss_name,
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
def load_saved_model(
|
|
284
|
+
*,
|
|
285
|
+
output_dir: str | Path,
|
|
286
|
+
model_name: str,
|
|
287
|
+
model_key: str,
|
|
288
|
+
task_type: str,
|
|
289
|
+
input_dim: Optional[int],
|
|
290
|
+
cfg: Dict[str, Any],
|
|
291
|
+
) -> Any:
|
|
292
|
+
model_path = _model_file_path(Path(output_dir), model_name, model_key)
|
|
293
|
+
if not model_path.exists():
|
|
294
|
+
raise FileNotFoundError(f"Model file not found: {model_path}")
|
|
295
|
+
|
|
296
|
+
if model_key in {"xgb", "glm"}:
|
|
297
|
+
payload = joblib.load(model_path)
|
|
298
|
+
if isinstance(payload, dict) and "model" in payload:
|
|
299
|
+
return payload.get("model")
|
|
300
|
+
return payload
|
|
301
|
+
|
|
302
|
+
if model_key == "ft":
|
|
303
|
+
payload = _torch_load(
|
|
304
|
+
model_path, map_location="cpu", weights_only=False)
|
|
305
|
+
if isinstance(payload, dict):
|
|
306
|
+
if "state_dict" in payload and "model_config" in payload:
|
|
307
|
+
# New format: state_dict + model_config (DDP-safe)
|
|
308
|
+
state_dict = payload.get("state_dict")
|
|
309
|
+
model_config = payload.get("model_config", {})
|
|
310
|
+
|
|
311
|
+
from ..modelling.core.bayesopt.models import FTTransformerSklearn
|
|
312
|
+
from ..modelling.core.bayesopt.models.model_ft_components import FTTransformerCore
|
|
313
|
+
|
|
314
|
+
# Reconstruct model from config
|
|
315
|
+
resolved_loss = model_config.get("loss_name")
|
|
316
|
+
if not resolved_loss:
|
|
317
|
+
resolved_loss = _resolve_loss_name(
|
|
318
|
+
cfg, model_name, task_type)
|
|
319
|
+
model = FTTransformerSklearn(
|
|
320
|
+
model_nme=model_config.get("model_nme", ""),
|
|
321
|
+
num_cols=model_config.get("num_cols", []),
|
|
322
|
+
cat_cols=model_config.get("cat_cols", []),
|
|
323
|
+
d_model=model_config.get("d_model", 64),
|
|
324
|
+
n_heads=model_config.get("n_heads", 8),
|
|
325
|
+
n_layers=model_config.get("n_layers", 4),
|
|
326
|
+
dropout=model_config.get("dropout", 0.1),
|
|
327
|
+
task_type=model_config.get("task_type", "regression"),
|
|
328
|
+
loss_name=resolved_loss,
|
|
329
|
+
tweedie_power=model_config.get("tw_power", 1.5),
|
|
330
|
+
num_numeric_tokens=model_config.get("num_numeric_tokens"),
|
|
331
|
+
use_data_parallel=False,
|
|
332
|
+
use_ddp=False,
|
|
333
|
+
)
|
|
334
|
+
# Restore internal state
|
|
335
|
+
model.num_geo = model_config.get("num_geo", 0)
|
|
336
|
+
model.cat_cardinalities = model_config.get("cat_cardinalities")
|
|
337
|
+
model.cat_categories = {k: pd.Index(
|
|
338
|
+
v) for k, v in model_config.get("cat_categories", {}).items()}
|
|
339
|
+
if model_config.get("_num_mean") is not None:
|
|
340
|
+
model._num_mean = np.array(
|
|
341
|
+
model_config["_num_mean"], dtype=np.float32)
|
|
342
|
+
if model_config.get("_num_std") is not None:
|
|
343
|
+
model._num_std = np.array(
|
|
344
|
+
model_config["_num_std"], dtype=np.float32)
|
|
345
|
+
|
|
346
|
+
# Build the model architecture and load weights
|
|
347
|
+
if model.cat_cardinalities is not None:
|
|
348
|
+
core = FTTransformerCore(
|
|
349
|
+
num_numeric=len(model.num_cols),
|
|
350
|
+
cat_cardinalities=model.cat_cardinalities,
|
|
351
|
+
d_model=model.d_model,
|
|
352
|
+
n_heads=model.n_heads,
|
|
353
|
+
n_layers=model.n_layers,
|
|
354
|
+
dropout=model.dropout,
|
|
355
|
+
task_type=model.task_type,
|
|
356
|
+
num_geo=model.num_geo,
|
|
357
|
+
num_numeric_tokens=model.num_numeric_tokens,
|
|
358
|
+
)
|
|
359
|
+
model.ft = core
|
|
360
|
+
model.ft.load_state_dict(state_dict)
|
|
361
|
+
|
|
362
|
+
_move_to_device(model)
|
|
363
|
+
return model
|
|
364
|
+
elif "model" in payload:
|
|
365
|
+
# Legacy format: full model object
|
|
366
|
+
model = payload.get("model")
|
|
367
|
+
_move_to_device(model)
|
|
368
|
+
return model
|
|
369
|
+
# Very old format: direct model object
|
|
370
|
+
_move_to_device(payload)
|
|
371
|
+
return payload
|
|
372
|
+
|
|
373
|
+
if model_key == "resn":
|
|
374
|
+
if input_dim is None:
|
|
375
|
+
raise ValueError("input_dim is required for ResNet loading")
|
|
376
|
+
payload = _torch_load(model_path, map_location="cpu")
|
|
377
|
+
if isinstance(payload, dict) and "state_dict" in payload:
|
|
378
|
+
state_dict = payload.get("state_dict")
|
|
379
|
+
params = payload.get("best_params") or load_best_params(
|
|
380
|
+
output_dir, model_name, model_key
|
|
381
|
+
)
|
|
382
|
+
else:
|
|
383
|
+
state_dict = payload
|
|
384
|
+
params = load_best_params(output_dir, model_name, model_key)
|
|
385
|
+
if params is None:
|
|
386
|
+
raise RuntimeError("Best params not found for resn")
|
|
387
|
+
loss_name = _resolve_loss_name(cfg, model_name, task_type)
|
|
388
|
+
model = _build_resn_model(
|
|
389
|
+
model_name=model_name,
|
|
390
|
+
input_dim=input_dim,
|
|
391
|
+
task_type=task_type,
|
|
392
|
+
epochs=int(cfg.get("epochs", 50)),
|
|
393
|
+
resn_weight_decay=float(cfg.get("resn_weight_decay", 1e-4)),
|
|
394
|
+
loss_name=loss_name,
|
|
395
|
+
params=params,
|
|
396
|
+
)
|
|
397
|
+
model.resnet.load_state_dict(state_dict)
|
|
398
|
+
_move_to_device(model)
|
|
399
|
+
return model
|
|
400
|
+
|
|
401
|
+
if model_key == "gnn":
|
|
402
|
+
if input_dim is None:
|
|
403
|
+
raise ValueError("input_dim is required for GNN loading")
|
|
404
|
+
payload = _torch_load(model_path, map_location="cpu")
|
|
405
|
+
if not isinstance(payload, dict):
|
|
406
|
+
raise ValueError(f"Invalid GNN checkpoint: {model_path}")
|
|
407
|
+
params = payload.get("best_params") or {}
|
|
408
|
+
state_dict = payload.get("state_dict")
|
|
409
|
+
loss_name = _resolve_loss_name(cfg, model_name, task_type)
|
|
410
|
+
model = _build_gnn_model(
|
|
411
|
+
model_name=model_name,
|
|
412
|
+
input_dim=input_dim,
|
|
413
|
+
task_type=task_type,
|
|
414
|
+
epochs=int(cfg.get("epochs", 50)),
|
|
415
|
+
cfg=cfg,
|
|
416
|
+
loss_name=loss_name,
|
|
417
|
+
params=params,
|
|
418
|
+
)
|
|
419
|
+
model.set_params(dict(params))
|
|
420
|
+
base_gnn = getattr(model, "_unwrap_gnn", lambda: None)()
|
|
421
|
+
if base_gnn is not None and state_dict is not None:
|
|
422
|
+
base_gnn.load_state_dict(state_dict, strict=False)
|
|
423
|
+
_move_to_device(model)
|
|
424
|
+
return model
|
|
425
|
+
|
|
426
|
+
raise ValueError(f"Unsupported model key: {model_key}")
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
def _build_artifacts_from_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
|
|
430
|
+
factor_nmes = list(cfg.get("feature_list") or [])
|
|
431
|
+
cate_list = list(cfg.get("categorical_features") or [])
|
|
432
|
+
num_features = [c for c in factor_nmes if c not in cate_list]
|
|
433
|
+
return {
|
|
434
|
+
"factor_nmes": factor_nmes,
|
|
435
|
+
"cate_list": cate_list,
|
|
436
|
+
"num_features": num_features,
|
|
437
|
+
"cat_categories": {},
|
|
438
|
+
"var_nmes": [],
|
|
439
|
+
"numeric_scalers": {},
|
|
440
|
+
"drop_first": True,
|
|
441
|
+
}
|
|
442
|
+
|
|
443
|
+
|
|
444
|
+
def _prepare_features(
|
|
445
|
+
df: pd.DataFrame,
|
|
446
|
+
*,
|
|
447
|
+
model_key: str,
|
|
448
|
+
cfg: Dict[str, Any],
|
|
449
|
+
artifacts: Optional[Dict[str, Any]],
|
|
450
|
+
) -> pd.DataFrame:
|
|
451
|
+
if model_key in OHT_MODELS:
|
|
452
|
+
if artifacts is None:
|
|
453
|
+
raise RuntimeError(
|
|
454
|
+
f"Preprocess artifacts are required for {model_key} inference. "
|
|
455
|
+
"Enable save_preprocess during training or provide preprocess_artifact_path."
|
|
456
|
+
)
|
|
457
|
+
return apply_preprocess_artifacts(df, artifacts)
|
|
458
|
+
|
|
459
|
+
if artifacts is None:
|
|
460
|
+
artifacts = _build_artifacts_from_config(cfg)
|
|
461
|
+
return prepare_raw_features(df, artifacts)
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
def _predict_with_model(
|
|
465
|
+
*,
|
|
466
|
+
model: Any,
|
|
467
|
+
model_key: str,
|
|
468
|
+
task_type: str,
|
|
469
|
+
features: pd.DataFrame,
|
|
470
|
+
) -> np.ndarray:
|
|
471
|
+
if model_key == "xgb":
|
|
472
|
+
if task_type == "classification" and hasattr(model, "predict_proba"):
|
|
473
|
+
return model.predict_proba(features)[:, 1]
|
|
474
|
+
return model.predict(features)
|
|
475
|
+
|
|
476
|
+
if model_key == "glm":
|
|
477
|
+
if sm is None:
|
|
478
|
+
raise RuntimeError(
|
|
479
|
+
f"statsmodels is required for GLM inference ({_SM_IMPORT_ERROR})."
|
|
480
|
+
)
|
|
481
|
+
design = sm.add_constant(features, has_constant="add")
|
|
482
|
+
return model.predict(design)
|
|
483
|
+
|
|
484
|
+
return model.predict(features)
|
|
485
|
+
|
|
486
|
+
|
|
487
|
+
class SavedModelPredictor:
|
|
488
|
+
def __init__(
|
|
489
|
+
self,
|
|
490
|
+
*,
|
|
491
|
+
model_key: str,
|
|
492
|
+
model_name: str,
|
|
493
|
+
task_type: str,
|
|
494
|
+
cfg: Dict[str, Any],
|
|
495
|
+
output_dir: Path,
|
|
496
|
+
artifacts: Optional[Dict[str, Any]],
|
|
497
|
+
) -> None:
|
|
498
|
+
self.model_key = model_key
|
|
499
|
+
self.model_name = model_name
|
|
500
|
+
self.task_type = task_type
|
|
501
|
+
self.cfg = cfg
|
|
502
|
+
self.output_dir = output_dir
|
|
503
|
+
self.artifacts = artifacts
|
|
504
|
+
|
|
505
|
+
if model_key == "ft" and str(cfg.get("ft_role", "model")) != "model":
|
|
506
|
+
raise ValueError("FT predictions require ft_role == 'model'.")
|
|
507
|
+
if model_key == "ft" and cfg.get("geo_feature_nmes"):
|
|
508
|
+
raise ValueError(
|
|
509
|
+
"FT inference with geo tokens is not supported in this helper.")
|
|
510
|
+
|
|
511
|
+
input_dim = None
|
|
512
|
+
if model_key in OHT_MODELS and artifacts is not None:
|
|
513
|
+
var_nmes = list(artifacts.get("var_nmes") or [])
|
|
514
|
+
input_dim = len(var_nmes) if var_nmes else None
|
|
515
|
+
|
|
516
|
+
self.model = load_saved_model(
|
|
517
|
+
output_dir=output_dir,
|
|
518
|
+
model_name=model_name,
|
|
519
|
+
model_key=model_key,
|
|
520
|
+
task_type=task_type,
|
|
521
|
+
input_dim=input_dim,
|
|
522
|
+
cfg=cfg,
|
|
523
|
+
)
|
|
524
|
+
|
|
525
|
+
def predict(self, df: pd.DataFrame) -> np.ndarray:
|
|
526
|
+
features = _prepare_features(
|
|
527
|
+
df,
|
|
528
|
+
model_key=self.model_key,
|
|
529
|
+
cfg=self.cfg,
|
|
530
|
+
artifacts=self.artifacts,
|
|
531
|
+
)
|
|
532
|
+
return _predict_with_model(
|
|
533
|
+
model=self.model,
|
|
534
|
+
model_key=self.model_key,
|
|
535
|
+
task_type=self.task_type,
|
|
536
|
+
features=features,
|
|
537
|
+
)
|
|
538
|
+
|
|
539
|
+
|
|
540
|
+
def load_predictor_from_config(
|
|
541
|
+
config_path: str | Path,
|
|
542
|
+
model_key: str,
|
|
543
|
+
*,
|
|
544
|
+
model_name: Optional[str] = None,
|
|
545
|
+
output_dir: Optional[str | Path] = None,
|
|
546
|
+
preprocess_artifact_path: Optional[str | Path] = None,
|
|
547
|
+
) -> SavedModelPredictor:
|
|
548
|
+
config_path = Path(config_path).resolve()
|
|
549
|
+
cfg = _load_json(config_path)
|
|
550
|
+
base_dir = config_path.parent
|
|
551
|
+
|
|
552
|
+
if model_name is None:
|
|
553
|
+
model_list = list(cfg.get("model_list") or [])
|
|
554
|
+
model_categories = list(cfg.get("model_categories") or [])
|
|
555
|
+
if len(model_list) != 1 or len(model_categories) != 1:
|
|
556
|
+
raise ValueError(
|
|
557
|
+
"Provide model_name when config has multiple models.")
|
|
558
|
+
model_name = f"{model_list[0]}_{model_categories[0]}"
|
|
559
|
+
|
|
560
|
+
resolved_output = (
|
|
561
|
+
Path(output_dir).resolve()
|
|
562
|
+
if output_dir is not None
|
|
563
|
+
else _resolve_value(cfg.get("output_dir"), model_name=model_name, base_dir=base_dir)
|
|
564
|
+
)
|
|
565
|
+
if resolved_output is None:
|
|
566
|
+
raise ValueError("output_dir is required to locate saved models.")
|
|
567
|
+
|
|
568
|
+
resolved_artifact = None
|
|
569
|
+
if preprocess_artifact_path is not None:
|
|
570
|
+
resolved_artifact = Path(preprocess_artifact_path).resolve()
|
|
571
|
+
else:
|
|
572
|
+
resolved_artifact = _resolve_value(
|
|
573
|
+
cfg.get("preprocess_artifact_path"),
|
|
574
|
+
model_name=model_name,
|
|
575
|
+
base_dir=base_dir,
|
|
576
|
+
)
|
|
577
|
+
|
|
578
|
+
if resolved_artifact is None:
|
|
579
|
+
candidate = resolved_output / "Results" / \
|
|
580
|
+
f"{model_name}_preprocess.json"
|
|
581
|
+
if candidate.exists():
|
|
582
|
+
resolved_artifact = candidate
|
|
583
|
+
|
|
584
|
+
artifacts = None
|
|
585
|
+
if resolved_artifact is not None and resolved_artifact.exists():
|
|
586
|
+
artifacts = load_preprocess_artifacts(resolved_artifact)
|
|
587
|
+
if artifacts is None:
|
|
588
|
+
artifacts = _load_preprocess_from_model_file(
|
|
589
|
+
resolved_output, model_name, model_key
|
|
590
|
+
)
|
|
591
|
+
|
|
592
|
+
predictor = SavedModelPredictor(
|
|
593
|
+
model_key=model_key,
|
|
594
|
+
model_name=model_name,
|
|
595
|
+
task_type=str(cfg.get("task_type", "regression")),
|
|
596
|
+
cfg=cfg,
|
|
597
|
+
output_dir=resolved_output,
|
|
598
|
+
artifacts=artifacts,
|
|
599
|
+
)
|
|
600
|
+
return predictor
|
|
601
|
+
|
|
602
|
+
|
|
603
|
+
def predict_from_config(
|
|
604
|
+
config_path: str | Path,
|
|
605
|
+
input_path: str | Path,
|
|
606
|
+
*,
|
|
607
|
+
model_keys: Sequence[str],
|
|
608
|
+
model_name: Optional[str] = None,
|
|
609
|
+
output_path: Optional[str | Path] = None,
|
|
610
|
+
output_col_prefix: str = "pred_",
|
|
611
|
+
batch_size: int = 10000,
|
|
612
|
+
chunksize: Optional[int] = None,
|
|
613
|
+
parallel_load: bool = False,
|
|
614
|
+
n_jobs: int = -1,
|
|
615
|
+
) -> pd.DataFrame:
|
|
616
|
+
"""Predict from multiple models with optional parallel loading.
|
|
617
|
+
|
|
618
|
+
Args:
|
|
619
|
+
config_path: Path to configuration file
|
|
620
|
+
input_path: Path to input data
|
|
621
|
+
model_keys: List of model keys to use for prediction
|
|
622
|
+
model_name: Optional model name override
|
|
623
|
+
output_path: Optional path to save results
|
|
624
|
+
output_col_prefix: Prefix for output columns
|
|
625
|
+
batch_size: Batch size for scoring
|
|
626
|
+
chunksize: Optional chunk size for CSV reading
|
|
627
|
+
parallel_load: If True, load models in parallel (faster for multiple models)
|
|
628
|
+
n_jobs: Number of parallel jobs for model loading (-1 = all cores)
|
|
629
|
+
|
|
630
|
+
Returns:
|
|
631
|
+
DataFrame with predictions from all models
|
|
632
|
+
"""
|
|
633
|
+
input_path = Path(input_path).resolve()
|
|
634
|
+
data = _load_dataset(input_path, chunksize=chunksize)
|
|
635
|
+
|
|
636
|
+
result = data.copy()
|
|
637
|
+
|
|
638
|
+
# Option 1: Parallel model loading (faster when loading multiple models)
|
|
639
|
+
if parallel_load and len(model_keys) > 1:
|
|
640
|
+
from joblib import Parallel, delayed
|
|
641
|
+
|
|
642
|
+
def load_and_score(key):
|
|
643
|
+
predictor = load_predictor_from_config(
|
|
644
|
+
config_path,
|
|
645
|
+
key,
|
|
646
|
+
model_name=model_name,
|
|
647
|
+
)
|
|
648
|
+
output_col = f"{output_col_prefix}{key}"
|
|
649
|
+
scored = batch_score(
|
|
650
|
+
predictor.predict,
|
|
651
|
+
data,
|
|
652
|
+
output_col=output_col,
|
|
653
|
+
batch_size=batch_size,
|
|
654
|
+
keep_input=False,
|
|
655
|
+
)
|
|
656
|
+
return output_col, scored[output_col].values
|
|
657
|
+
|
|
658
|
+
results = Parallel(n_jobs=n_jobs, prefer="threads")(
|
|
659
|
+
delayed(load_and_score)(key) for key in model_keys
|
|
660
|
+
)
|
|
661
|
+
|
|
662
|
+
for output_col, predictions in results:
|
|
663
|
+
result[output_col] = predictions
|
|
664
|
+
else:
|
|
665
|
+
# Option 2: Sequential loading (original behavior)
|
|
666
|
+
for key in model_keys:
|
|
667
|
+
predictor = load_predictor_from_config(
|
|
668
|
+
config_path,
|
|
669
|
+
key,
|
|
670
|
+
model_name=model_name,
|
|
671
|
+
)
|
|
672
|
+
output_col = f"{output_col_prefix}{key}"
|
|
673
|
+
scored = batch_score(
|
|
674
|
+
predictor.predict,
|
|
675
|
+
data,
|
|
676
|
+
output_col=output_col,
|
|
677
|
+
batch_size=batch_size,
|
|
678
|
+
keep_input=False,
|
|
679
|
+
)
|
|
680
|
+
result[output_col] = scored[output_col].values
|
|
681
|
+
|
|
682
|
+
if output_path:
|
|
683
|
+
output_path = Path(output_path)
|
|
684
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
685
|
+
suffix = output_path.suffix.lower()
|
|
686
|
+
if suffix in {".parquet", ".pq"}:
|
|
687
|
+
result.to_parquet(output_path, index=False)
|
|
688
|
+
elif suffix in {".feather", ".ft"}:
|
|
689
|
+
result.to_feather(output_path)
|
|
690
|
+
else:
|
|
691
|
+
result.to_csv(output_path, index=False)
|
|
692
|
+
|
|
693
|
+
return result
|