ins-pricing 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ins_pricing/README.md +60 -0
- ins_pricing/__init__.py +102 -0
- ins_pricing/governance/README.md +18 -0
- ins_pricing/governance/__init__.py +20 -0
- ins_pricing/governance/approval.py +93 -0
- ins_pricing/governance/audit.py +37 -0
- ins_pricing/governance/registry.py +99 -0
- ins_pricing/governance/release.py +159 -0
- ins_pricing/modelling/BayesOpt.py +146 -0
- ins_pricing/modelling/BayesOpt_USAGE.md +925 -0
- ins_pricing/modelling/BayesOpt_entry.py +575 -0
- ins_pricing/modelling/BayesOpt_incremental.py +731 -0
- ins_pricing/modelling/Explain_Run.py +36 -0
- ins_pricing/modelling/Explain_entry.py +539 -0
- ins_pricing/modelling/Pricing_Run.py +36 -0
- ins_pricing/modelling/README.md +33 -0
- ins_pricing/modelling/__init__.py +44 -0
- ins_pricing/modelling/bayesopt/__init__.py +98 -0
- ins_pricing/modelling/bayesopt/config_preprocess.py +303 -0
- ins_pricing/modelling/bayesopt/core.py +1476 -0
- ins_pricing/modelling/bayesopt/models.py +2196 -0
- ins_pricing/modelling/bayesopt/trainers.py +2446 -0
- ins_pricing/modelling/bayesopt/utils.py +1021 -0
- ins_pricing/modelling/cli_common.py +136 -0
- ins_pricing/modelling/explain/__init__.py +55 -0
- ins_pricing/modelling/explain/gradients.py +334 -0
- ins_pricing/modelling/explain/metrics.py +176 -0
- ins_pricing/modelling/explain/permutation.py +155 -0
- ins_pricing/modelling/explain/shap_utils.py +146 -0
- ins_pricing/modelling/notebook_utils.py +284 -0
- ins_pricing/modelling/plotting/__init__.py +45 -0
- ins_pricing/modelling/plotting/common.py +63 -0
- ins_pricing/modelling/plotting/curves.py +572 -0
- ins_pricing/modelling/plotting/diagnostics.py +139 -0
- ins_pricing/modelling/plotting/geo.py +362 -0
- ins_pricing/modelling/plotting/importance.py +121 -0
- ins_pricing/modelling/run_logging.py +133 -0
- ins_pricing/modelling/tests/conftest.py +8 -0
- ins_pricing/modelling/tests/test_cross_val_generic.py +66 -0
- ins_pricing/modelling/tests/test_distributed_utils.py +18 -0
- ins_pricing/modelling/tests/test_explain.py +56 -0
- ins_pricing/modelling/tests/test_geo_tokens_split.py +49 -0
- ins_pricing/modelling/tests/test_graph_cache.py +33 -0
- ins_pricing/modelling/tests/test_plotting.py +63 -0
- ins_pricing/modelling/tests/test_plotting_library.py +150 -0
- ins_pricing/modelling/tests/test_preprocessor.py +48 -0
- ins_pricing/modelling/watchdog_run.py +211 -0
- ins_pricing/pricing/README.md +44 -0
- ins_pricing/pricing/__init__.py +27 -0
- ins_pricing/pricing/calibration.py +39 -0
- ins_pricing/pricing/data_quality.py +117 -0
- ins_pricing/pricing/exposure.py +85 -0
- ins_pricing/pricing/factors.py +91 -0
- ins_pricing/pricing/monitoring.py +99 -0
- ins_pricing/pricing/rate_table.py +78 -0
- ins_pricing/production/__init__.py +21 -0
- ins_pricing/production/drift.py +30 -0
- ins_pricing/production/monitoring.py +143 -0
- ins_pricing/production/scoring.py +40 -0
- ins_pricing/reporting/README.md +20 -0
- ins_pricing/reporting/__init__.py +11 -0
- ins_pricing/reporting/report_builder.py +72 -0
- ins_pricing/reporting/scheduler.py +45 -0
- ins_pricing/setup.py +41 -0
- ins_pricing v2/__init__.py +23 -0
- ins_pricing v2/governance/__init__.py +20 -0
- ins_pricing v2/governance/approval.py +93 -0
- ins_pricing v2/governance/audit.py +37 -0
- ins_pricing v2/governance/registry.py +99 -0
- ins_pricing v2/governance/release.py +159 -0
- ins_pricing v2/modelling/Explain_Run.py +36 -0
- ins_pricing v2/modelling/Pricing_Run.py +36 -0
- ins_pricing v2/modelling/__init__.py +151 -0
- ins_pricing v2/modelling/cli_common.py +141 -0
- ins_pricing v2/modelling/config.py +249 -0
- ins_pricing v2/modelling/config_preprocess.py +254 -0
- ins_pricing v2/modelling/core.py +741 -0
- ins_pricing v2/modelling/data_container.py +42 -0
- ins_pricing v2/modelling/explain/__init__.py +55 -0
- ins_pricing v2/modelling/explain/gradients.py +334 -0
- ins_pricing v2/modelling/explain/metrics.py +176 -0
- ins_pricing v2/modelling/explain/permutation.py +155 -0
- ins_pricing v2/modelling/explain/shap_utils.py +146 -0
- ins_pricing v2/modelling/features.py +215 -0
- ins_pricing v2/modelling/model_manager.py +148 -0
- ins_pricing v2/modelling/model_plotting.py +463 -0
- ins_pricing v2/modelling/models.py +2203 -0
- ins_pricing v2/modelling/notebook_utils.py +294 -0
- ins_pricing v2/modelling/plotting/__init__.py +45 -0
- ins_pricing v2/modelling/plotting/common.py +63 -0
- ins_pricing v2/modelling/plotting/curves.py +572 -0
- ins_pricing v2/modelling/plotting/diagnostics.py +139 -0
- ins_pricing v2/modelling/plotting/geo.py +362 -0
- ins_pricing v2/modelling/plotting/importance.py +121 -0
- ins_pricing v2/modelling/run_logging.py +133 -0
- ins_pricing v2/modelling/tests/conftest.py +8 -0
- ins_pricing v2/modelling/tests/test_cross_val_generic.py +66 -0
- ins_pricing v2/modelling/tests/test_distributed_utils.py +18 -0
- ins_pricing v2/modelling/tests/test_explain.py +56 -0
- ins_pricing v2/modelling/tests/test_geo_tokens_split.py +49 -0
- ins_pricing v2/modelling/tests/test_graph_cache.py +33 -0
- ins_pricing v2/modelling/tests/test_plotting.py +63 -0
- ins_pricing v2/modelling/tests/test_plotting_library.py +150 -0
- ins_pricing v2/modelling/tests/test_preprocessor.py +48 -0
- ins_pricing v2/modelling/trainers.py +2447 -0
- ins_pricing v2/modelling/utils.py +1020 -0
- ins_pricing v2/modelling/watchdog_run.py +211 -0
- ins_pricing v2/pricing/__init__.py +27 -0
- ins_pricing v2/pricing/calibration.py +39 -0
- ins_pricing v2/pricing/data_quality.py +117 -0
- ins_pricing v2/pricing/exposure.py +85 -0
- ins_pricing v2/pricing/factors.py +91 -0
- ins_pricing v2/pricing/monitoring.py +99 -0
- ins_pricing v2/pricing/rate_table.py +78 -0
- ins_pricing v2/production/__init__.py +21 -0
- ins_pricing v2/production/drift.py +30 -0
- ins_pricing v2/production/monitoring.py +143 -0
- ins_pricing v2/production/scoring.py +40 -0
- ins_pricing v2/reporting/__init__.py +11 -0
- ins_pricing v2/reporting/report_builder.py +72 -0
- ins_pricing v2/reporting/scheduler.py +45 -0
- ins_pricing v2/scripts/BayesOpt_incremental.py +722 -0
- ins_pricing v2/scripts/Explain_entry.py +545 -0
- ins_pricing v2/scripts/__init__.py +1 -0
- ins_pricing v2/scripts/train.py +568 -0
- ins_pricing v2/setup.py +55 -0
- ins_pricing v2/smoke_test.py +28 -0
- ins_pricing-0.1.6.dist-info/METADATA +78 -0
- ins_pricing-0.1.6.dist-info/RECORD +169 -0
- ins_pricing-0.1.6.dist-info/WHEEL +5 -0
- ins_pricing-0.1.6.dist-info/top_level.txt +4 -0
- user_packages/__init__.py +105 -0
- user_packages legacy/BayesOpt.py +5659 -0
- user_packages legacy/BayesOpt_entry.py +513 -0
- user_packages legacy/BayesOpt_incremental.py +685 -0
- user_packages legacy/Pricing_Run.py +36 -0
- user_packages legacy/Try/BayesOpt Legacy251213.py +3719 -0
- user_packages legacy/Try/BayesOpt Legacy251215.py +3758 -0
- user_packages legacy/Try/BayesOpt lagecy251201.py +3506 -0
- user_packages legacy/Try/BayesOpt lagecy251218.py +3992 -0
- user_packages legacy/Try/BayesOpt legacy.py +3280 -0
- user_packages legacy/Try/BayesOpt.py +838 -0
- user_packages legacy/Try/BayesOptAll.py +1569 -0
- user_packages legacy/Try/BayesOptAllPlatform.py +909 -0
- user_packages legacy/Try/BayesOptCPUGPU.py +1877 -0
- user_packages legacy/Try/BayesOptSearch.py +830 -0
- user_packages legacy/Try/BayesOptSearchOrigin.py +829 -0
- user_packages legacy/Try/BayesOptV1.py +1911 -0
- user_packages legacy/Try/BayesOptV10.py +2973 -0
- user_packages legacy/Try/BayesOptV11.py +3001 -0
- user_packages legacy/Try/BayesOptV12.py +3001 -0
- user_packages legacy/Try/BayesOptV2.py +2065 -0
- user_packages legacy/Try/BayesOptV3.py +2209 -0
- user_packages legacy/Try/BayesOptV4.py +2342 -0
- user_packages legacy/Try/BayesOptV5.py +2372 -0
- user_packages legacy/Try/BayesOptV6.py +2759 -0
- user_packages legacy/Try/BayesOptV7.py +2832 -0
- user_packages legacy/Try/BayesOptV8Codex.py +2731 -0
- user_packages legacy/Try/BayesOptV8Gemini.py +2614 -0
- user_packages legacy/Try/BayesOptV9.py +2927 -0
- user_packages legacy/Try/BayesOpt_entry legacy.py +313 -0
- user_packages legacy/Try/ModelBayesOptSearch.py +359 -0
- user_packages legacy/Try/ResNetBayesOptSearch.py +249 -0
- user_packages legacy/Try/XgbBayesOptSearch.py +121 -0
- user_packages legacy/Try/xgbbayesopt.py +523 -0
- user_packages legacy/__init__.py +19 -0
- user_packages legacy/cli_common.py +124 -0
- user_packages legacy/notebook_utils.py +228 -0
- user_packages legacy/watchdog_run.py +202 -0
|
@@ -0,0 +1,1476 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import asdict
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
import os
|
|
6
|
+
from typing import Any, Dict, List, Optional
|
|
7
|
+
|
|
8
|
+
try: # matplotlib is optional; avoid hard import failures in headless/minimal envs
|
|
9
|
+
import matplotlib
|
|
10
|
+
if os.name != "nt" and not os.environ.get("DISPLAY") and not os.environ.get("MPLBACKEND"):
|
|
11
|
+
matplotlib.use("Agg")
|
|
12
|
+
import matplotlib.pyplot as plt
|
|
13
|
+
_MPL_IMPORT_ERROR: Optional[BaseException] = None
|
|
14
|
+
except Exception as exc: # pragma: no cover - optional dependency
|
|
15
|
+
plt = None # type: ignore[assignment]
|
|
16
|
+
_MPL_IMPORT_ERROR = exc
|
|
17
|
+
import numpy as np
|
|
18
|
+
import pandas as pd
|
|
19
|
+
import torch
|
|
20
|
+
import statsmodels.api as sm
|
|
21
|
+
from sklearn.model_selection import ShuffleSplit
|
|
22
|
+
from sklearn.preprocessing import StandardScaler
|
|
23
|
+
|
|
24
|
+
from .config_preprocess import BayesOptConfig, DatasetPreprocessor, OutputManager, VersionManager
|
|
25
|
+
from .models import GraphNeuralNetSklearn
|
|
26
|
+
from .trainers import FTTrainer, GLMTrainer, GNNTrainer, ResNetTrainer, XGBTrainer
|
|
27
|
+
from .utils import EPS, PlotUtils, infer_factor_and_cate_list, set_global_seed
|
|
28
|
+
try:
|
|
29
|
+
from ..plotting import curves as plot_curves
|
|
30
|
+
from ..plotting import diagnostics as plot_diagnostics
|
|
31
|
+
from ..plotting.common import PlotStyle, finalize_figure
|
|
32
|
+
from ..explain import gradients as explain_gradients
|
|
33
|
+
from ..explain import permutation as explain_permutation
|
|
34
|
+
from ..explain import shap_utils as explain_shap
|
|
35
|
+
except Exception: # pragma: no cover - optional for legacy imports
|
|
36
|
+
try: # best-effort for non-package imports
|
|
37
|
+
from ins_pricing.plotting import curves as plot_curves
|
|
38
|
+
from ins_pricing.plotting import diagnostics as plot_diagnostics
|
|
39
|
+
from ins_pricing.plotting.common import PlotStyle, finalize_figure
|
|
40
|
+
from ins_pricing.explain import gradients as explain_gradients
|
|
41
|
+
from ins_pricing.explain import permutation as explain_permutation
|
|
42
|
+
from ins_pricing.explain import shap_utils as explain_shap
|
|
43
|
+
except Exception: # pragma: no cover
|
|
44
|
+
plot_curves = None
|
|
45
|
+
plot_diagnostics = None
|
|
46
|
+
PlotStyle = None
|
|
47
|
+
finalize_figure = None
|
|
48
|
+
explain_gradients = None
|
|
49
|
+
explain_permutation = None
|
|
50
|
+
explain_shap = None
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _plot_skip(label: str) -> None:
|
|
54
|
+
if _MPL_IMPORT_ERROR is not None:
|
|
55
|
+
print(f"[Plot] Skip {label}: matplotlib unavailable ({_MPL_IMPORT_ERROR}).", flush=True)
|
|
56
|
+
else:
|
|
57
|
+
print(f"[Plot] Skip {label}: matplotlib unavailable.", flush=True)
|
|
58
|
+
|
|
59
|
+
# BayesOpt orchestration and SHAP utilities
|
|
60
|
+
# =============================================================================
|
|
61
|
+
class BayesOptModel:
|
|
62
|
+
def __init__(self, train_data, test_data,
|
|
63
|
+
model_nme, resp_nme, weight_nme, factor_nmes: Optional[List[str]] = None, task_type='regression',
|
|
64
|
+
binary_resp_nme=None,
|
|
65
|
+
cate_list=None, prop_test=0.25, rand_seed=None,
|
|
66
|
+
epochs=100, use_gpu=True,
|
|
67
|
+
use_resn_data_parallel: bool = False, use_ft_data_parallel: bool = False,
|
|
68
|
+
use_gnn_data_parallel: bool = False,
|
|
69
|
+
use_resn_ddp: bool = False, use_ft_ddp: bool = False,
|
|
70
|
+
use_gnn_ddp: bool = False,
|
|
71
|
+
output_dir: Optional[str] = None,
|
|
72
|
+
gnn_use_approx_knn: bool = True,
|
|
73
|
+
gnn_approx_knn_threshold: int = 50000,
|
|
74
|
+
gnn_graph_cache: Optional[str] = None,
|
|
75
|
+
gnn_max_gpu_knn_nodes: Optional[int] = 200000,
|
|
76
|
+
gnn_knn_gpu_mem_ratio: float = 0.9,
|
|
77
|
+
gnn_knn_gpu_mem_overhead: float = 2.0,
|
|
78
|
+
ft_role: str = "model",
|
|
79
|
+
ft_feature_prefix: str = "ft_emb",
|
|
80
|
+
ft_num_numeric_tokens: Optional[int] = None,
|
|
81
|
+
infer_categorical_max_unique: int = 50,
|
|
82
|
+
infer_categorical_max_ratio: float = 0.05,
|
|
83
|
+
reuse_best_params: bool = False,
|
|
84
|
+
xgb_max_depth_max: int = 25,
|
|
85
|
+
xgb_n_estimators_max: int = 500,
|
|
86
|
+
resn_weight_decay: Optional[float] = None,
|
|
87
|
+
final_ensemble: bool = False,
|
|
88
|
+
final_ensemble_k: int = 3,
|
|
89
|
+
final_refit: bool = True,
|
|
90
|
+
optuna_storage: Optional[str] = None,
|
|
91
|
+
optuna_study_prefix: Optional[str] = None,
|
|
92
|
+
best_params_files: Optional[Dict[str, str]] = None):
|
|
93
|
+
"""Orchestrate BayesOpt training across multiple trainers.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
train_data: Training DataFrame.
|
|
97
|
+
test_data: Test DataFrame.
|
|
98
|
+
model_nme: Model name prefix used in outputs.
|
|
99
|
+
resp_nme: Target column name.
|
|
100
|
+
weight_nme: Sample weight column name.
|
|
101
|
+
factor_nmes: Feature column list.
|
|
102
|
+
task_type: "regression" or "classification".
|
|
103
|
+
binary_resp_nme: Optional binary target for lift curves.
|
|
104
|
+
cate_list: Categorical feature list.
|
|
105
|
+
prop_test: Validation split ratio in CV.
|
|
106
|
+
rand_seed: Random seed.
|
|
107
|
+
epochs: NN training epochs.
|
|
108
|
+
use_gpu: Prefer GPU when available.
|
|
109
|
+
use_resn_data_parallel: Enable DataParallel for ResNet.
|
|
110
|
+
use_ft_data_parallel: Enable DataParallel for FTTransformer.
|
|
111
|
+
use_gnn_data_parallel: Enable DataParallel for GNN.
|
|
112
|
+
use_resn_ddp: Enable DDP for ResNet.
|
|
113
|
+
use_ft_ddp: Enable DDP for FTTransformer.
|
|
114
|
+
use_gnn_ddp: Enable DDP for GNN.
|
|
115
|
+
output_dir: Output root for models/results/plots.
|
|
116
|
+
gnn_use_approx_knn: Use approximate kNN when available.
|
|
117
|
+
gnn_approx_knn_threshold: Row threshold to switch to approximate kNN.
|
|
118
|
+
gnn_graph_cache: Optional adjacency cache path.
|
|
119
|
+
gnn_max_gpu_knn_nodes: Force CPU kNN above this node count to avoid OOM.
|
|
120
|
+
gnn_knn_gpu_mem_ratio: Fraction of free GPU memory for kNN.
|
|
121
|
+
gnn_knn_gpu_mem_overhead: Temporary memory multiplier for GPU kNN.
|
|
122
|
+
ft_num_numeric_tokens: Number of numeric tokens for FT (None = auto).
|
|
123
|
+
final_ensemble: Enable k-fold model averaging at the final stage.
|
|
124
|
+
final_ensemble_k: Number of folds for averaging.
|
|
125
|
+
final_refit: Refit on full data using best stopping point.
|
|
126
|
+
"""
|
|
127
|
+
inferred_factors, inferred_cats = infer_factor_and_cate_list(
|
|
128
|
+
train_df=train_data,
|
|
129
|
+
test_df=test_data,
|
|
130
|
+
resp_nme=resp_nme,
|
|
131
|
+
weight_nme=weight_nme,
|
|
132
|
+
binary_resp_nme=binary_resp_nme,
|
|
133
|
+
factor_nmes=factor_nmes,
|
|
134
|
+
cate_list=cate_list,
|
|
135
|
+
infer_categorical_max_unique=int(infer_categorical_max_unique),
|
|
136
|
+
infer_categorical_max_ratio=float(infer_categorical_max_ratio),
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
cfg = BayesOptConfig(
|
|
140
|
+
model_nme=model_nme,
|
|
141
|
+
task_type=task_type,
|
|
142
|
+
resp_nme=resp_nme,
|
|
143
|
+
weight_nme=weight_nme,
|
|
144
|
+
factor_nmes=list(inferred_factors),
|
|
145
|
+
binary_resp_nme=binary_resp_nme,
|
|
146
|
+
cate_list=list(inferred_cats) if inferred_cats else None,
|
|
147
|
+
prop_test=prop_test,
|
|
148
|
+
rand_seed=rand_seed,
|
|
149
|
+
epochs=epochs,
|
|
150
|
+
use_gpu=use_gpu,
|
|
151
|
+
xgb_max_depth_max=int(xgb_max_depth_max),
|
|
152
|
+
xgb_n_estimators_max=int(xgb_n_estimators_max),
|
|
153
|
+
use_resn_data_parallel=use_resn_data_parallel,
|
|
154
|
+
use_ft_data_parallel=use_ft_data_parallel,
|
|
155
|
+
use_resn_ddp=use_resn_ddp,
|
|
156
|
+
use_gnn_data_parallel=use_gnn_data_parallel,
|
|
157
|
+
use_ft_ddp=use_ft_ddp,
|
|
158
|
+
use_gnn_ddp=use_gnn_ddp,
|
|
159
|
+
gnn_use_approx_knn=gnn_use_approx_knn,
|
|
160
|
+
gnn_approx_knn_threshold=gnn_approx_knn_threshold,
|
|
161
|
+
gnn_graph_cache=gnn_graph_cache,
|
|
162
|
+
gnn_max_gpu_knn_nodes=gnn_max_gpu_knn_nodes,
|
|
163
|
+
gnn_knn_gpu_mem_ratio=gnn_knn_gpu_mem_ratio,
|
|
164
|
+
gnn_knn_gpu_mem_overhead=gnn_knn_gpu_mem_overhead,
|
|
165
|
+
output_dir=output_dir,
|
|
166
|
+
optuna_storage=optuna_storage,
|
|
167
|
+
optuna_study_prefix=optuna_study_prefix,
|
|
168
|
+
best_params_files=best_params_files,
|
|
169
|
+
ft_role=str(ft_role or "model"),
|
|
170
|
+
ft_feature_prefix=str(ft_feature_prefix or "ft_emb"),
|
|
171
|
+
ft_num_numeric_tokens=ft_num_numeric_tokens,
|
|
172
|
+
reuse_best_params=bool(reuse_best_params),
|
|
173
|
+
resn_weight_decay=float(resn_weight_decay)
|
|
174
|
+
if resn_weight_decay is not None
|
|
175
|
+
else 1e-4,
|
|
176
|
+
final_ensemble=bool(final_ensemble),
|
|
177
|
+
final_ensemble_k=int(final_ensemble_k),
|
|
178
|
+
final_refit=bool(final_refit),
|
|
179
|
+
)
|
|
180
|
+
self.config = cfg
|
|
181
|
+
self.model_nme = cfg.model_nme
|
|
182
|
+
self.task_type = cfg.task_type
|
|
183
|
+
self.resp_nme = cfg.resp_nme
|
|
184
|
+
self.weight_nme = cfg.weight_nme
|
|
185
|
+
self.factor_nmes = cfg.factor_nmes
|
|
186
|
+
self.binary_resp_nme = cfg.binary_resp_nme
|
|
187
|
+
self.cate_list = list(cfg.cate_list or [])
|
|
188
|
+
self.prop_test = cfg.prop_test
|
|
189
|
+
self.epochs = cfg.epochs
|
|
190
|
+
self.rand_seed = cfg.rand_seed if cfg.rand_seed is not None else np.random.randint(
|
|
191
|
+
1, 10000)
|
|
192
|
+
set_global_seed(int(self.rand_seed))
|
|
193
|
+
self.use_gpu = bool(cfg.use_gpu and torch.cuda.is_available())
|
|
194
|
+
self.output_manager = OutputManager(
|
|
195
|
+
cfg.output_dir or os.getcwd(), self.model_nme)
|
|
196
|
+
|
|
197
|
+
preprocessor = DatasetPreprocessor(train_data, test_data, cfg).run()
|
|
198
|
+
self.train_data = preprocessor.train_data
|
|
199
|
+
self.test_data = preprocessor.test_data
|
|
200
|
+
self.train_oht_data = preprocessor.train_oht_data
|
|
201
|
+
self.test_oht_data = preprocessor.test_oht_data
|
|
202
|
+
self.train_oht_scl_data = preprocessor.train_oht_scl_data
|
|
203
|
+
self.test_oht_scl_data = preprocessor.test_oht_scl_data
|
|
204
|
+
self.var_nmes = preprocessor.var_nmes
|
|
205
|
+
self.num_features = preprocessor.num_features
|
|
206
|
+
self.cat_categories_for_shap = preprocessor.cat_categories_for_shap
|
|
207
|
+
self.geo_token_cols: List[str] = []
|
|
208
|
+
self.train_geo_tokens: Optional[pd.DataFrame] = None
|
|
209
|
+
self.test_geo_tokens: Optional[pd.DataFrame] = None
|
|
210
|
+
self.geo_gnn_model: Optional[GraphNeuralNetSklearn] = None
|
|
211
|
+
self._add_region_effect()
|
|
212
|
+
|
|
213
|
+
self.cv = ShuffleSplit(n_splits=int(1/self.prop_test),
|
|
214
|
+
test_size=self.prop_test,
|
|
215
|
+
random_state=self.rand_seed)
|
|
216
|
+
if self.task_type == 'classification':
|
|
217
|
+
self.obj = 'binary:logistic'
|
|
218
|
+
else: # regression task
|
|
219
|
+
if 'f' in self.model_nme:
|
|
220
|
+
self.obj = 'count:poisson'
|
|
221
|
+
elif 's' in self.model_nme:
|
|
222
|
+
self.obj = 'reg:gamma'
|
|
223
|
+
elif 'bc' in self.model_nme:
|
|
224
|
+
self.obj = 'reg:tweedie'
|
|
225
|
+
else:
|
|
226
|
+
self.obj = 'reg:tweedie'
|
|
227
|
+
self.fit_params = {
|
|
228
|
+
'sample_weight': self.train_data[self.weight_nme].values
|
|
229
|
+
}
|
|
230
|
+
self.model_label: List[str] = []
|
|
231
|
+
self.optuna_storage = cfg.optuna_storage
|
|
232
|
+
self.optuna_study_prefix = cfg.optuna_study_prefix or "bayesopt"
|
|
233
|
+
|
|
234
|
+
# Keep trainers in a dict for unified access and easy extension.
|
|
235
|
+
self.trainers: Dict[str, TrainerBase] = {
|
|
236
|
+
'glm': GLMTrainer(self),
|
|
237
|
+
'xgb': XGBTrainer(self),
|
|
238
|
+
'resn': ResNetTrainer(self),
|
|
239
|
+
'ft': FTTrainer(self),
|
|
240
|
+
'gnn': GNNTrainer(self),
|
|
241
|
+
}
|
|
242
|
+
self._prepare_geo_tokens()
|
|
243
|
+
self.xgb_best = None
|
|
244
|
+
self.resn_best = None
|
|
245
|
+
self.gnn_best = None
|
|
246
|
+
self.glm_best = None
|
|
247
|
+
self.ft_best = None
|
|
248
|
+
self.best_xgb_params = None
|
|
249
|
+
self.best_resn_params = None
|
|
250
|
+
self.best_gnn_params = None
|
|
251
|
+
self.best_ft_params = None
|
|
252
|
+
self.best_xgb_trial = None
|
|
253
|
+
self.best_resn_trial = None
|
|
254
|
+
self.best_gnn_trial = None
|
|
255
|
+
self.best_ft_trial = None
|
|
256
|
+
self.best_glm_params = None
|
|
257
|
+
self.best_glm_trial = None
|
|
258
|
+
self.xgb_load = None
|
|
259
|
+
self.resn_load = None
|
|
260
|
+
self.gnn_load = None
|
|
261
|
+
self.ft_load = None
|
|
262
|
+
self.version_manager = VersionManager(self.output_manager)
|
|
263
|
+
|
|
264
|
+
def default_tweedie_power(self, obj: Optional[str] = None) -> Optional[float]:
|
|
265
|
+
if self.task_type == 'classification':
|
|
266
|
+
return None
|
|
267
|
+
objective = obj or getattr(self, "obj", None)
|
|
268
|
+
if objective == 'count:poisson':
|
|
269
|
+
return 1.0
|
|
270
|
+
if objective == 'reg:gamma':
|
|
271
|
+
return 2.0
|
|
272
|
+
return 1.5
|
|
273
|
+
|
|
274
|
+
def _build_geo_tokens(self, params_override: Optional[Dict[str, Any]] = None):
|
|
275
|
+
"""Internal builder; allows trial overrides and returns None on failure."""
|
|
276
|
+
geo_cols = list(self.config.geo_feature_nmes or [])
|
|
277
|
+
if not geo_cols:
|
|
278
|
+
return None
|
|
279
|
+
|
|
280
|
+
available = [c for c in geo_cols if c in self.train_data.columns]
|
|
281
|
+
if not available:
|
|
282
|
+
return None
|
|
283
|
+
|
|
284
|
+
# Preprocess text/numeric: fill numeric with median, label-encode text, map unknowns.
|
|
285
|
+
proc_train = {}
|
|
286
|
+
proc_test = {}
|
|
287
|
+
for col in available:
|
|
288
|
+
s_train = self.train_data[col]
|
|
289
|
+
s_test = self.test_data[col]
|
|
290
|
+
if pd.api.types.is_numeric_dtype(s_train):
|
|
291
|
+
tr = pd.to_numeric(s_train, errors="coerce")
|
|
292
|
+
te = pd.to_numeric(s_test, errors="coerce")
|
|
293
|
+
med = np.nanmedian(tr)
|
|
294
|
+
proc_train[col] = np.nan_to_num(tr, nan=med).astype(np.float32)
|
|
295
|
+
proc_test[col] = np.nan_to_num(te, nan=med).astype(np.float32)
|
|
296
|
+
else:
|
|
297
|
+
cats = pd.Categorical(s_train.astype(str))
|
|
298
|
+
tr_codes = cats.codes.astype(np.float32, copy=True)
|
|
299
|
+
tr_codes[tr_codes < 0] = len(cats.categories)
|
|
300
|
+
te_cats = pd.Categorical(
|
|
301
|
+
s_test.astype(str), categories=cats.categories)
|
|
302
|
+
te_codes = te_cats.codes.astype(np.float32, copy=True)
|
|
303
|
+
te_codes[te_codes < 0] = len(cats.categories)
|
|
304
|
+
proc_train[col] = tr_codes
|
|
305
|
+
proc_test[col] = te_codes
|
|
306
|
+
|
|
307
|
+
train_geo_raw = pd.DataFrame(proc_train, index=self.train_data.index)
|
|
308
|
+
test_geo_raw = pd.DataFrame(proc_test, index=self.test_data.index)
|
|
309
|
+
|
|
310
|
+
scaler = StandardScaler()
|
|
311
|
+
train_geo = pd.DataFrame(
|
|
312
|
+
scaler.fit_transform(train_geo_raw),
|
|
313
|
+
columns=available,
|
|
314
|
+
index=self.train_data.index
|
|
315
|
+
)
|
|
316
|
+
test_geo = pd.DataFrame(
|
|
317
|
+
scaler.transform(test_geo_raw),
|
|
318
|
+
columns=available,
|
|
319
|
+
index=self.test_data.index
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
tw_power = self.default_tweedie_power()
|
|
323
|
+
|
|
324
|
+
cfg = params_override or {}
|
|
325
|
+
try:
|
|
326
|
+
geo_gnn = GraphNeuralNetSklearn(
|
|
327
|
+
model_nme=f"{self.model_nme}_geo",
|
|
328
|
+
input_dim=len(available),
|
|
329
|
+
hidden_dim=cfg.get("geo_token_hidden_dim",
|
|
330
|
+
self.config.geo_token_hidden_dim),
|
|
331
|
+
num_layers=cfg.get("geo_token_layers",
|
|
332
|
+
self.config.geo_token_layers),
|
|
333
|
+
k_neighbors=cfg.get("geo_token_k_neighbors",
|
|
334
|
+
self.config.geo_token_k_neighbors),
|
|
335
|
+
dropout=cfg.get("geo_token_dropout",
|
|
336
|
+
self.config.geo_token_dropout),
|
|
337
|
+
learning_rate=cfg.get(
|
|
338
|
+
"geo_token_learning_rate", self.config.geo_token_learning_rate),
|
|
339
|
+
epochs=int(cfg.get("geo_token_epochs",
|
|
340
|
+
self.config.geo_token_epochs)),
|
|
341
|
+
patience=5,
|
|
342
|
+
task_type=self.task_type,
|
|
343
|
+
tweedie_power=tw_power,
|
|
344
|
+
use_data_parallel=False,
|
|
345
|
+
use_ddp=False,
|
|
346
|
+
use_approx_knn=self.config.gnn_use_approx_knn,
|
|
347
|
+
approx_knn_threshold=self.config.gnn_approx_knn_threshold,
|
|
348
|
+
graph_cache_path=None,
|
|
349
|
+
max_gpu_knn_nodes=self.config.gnn_max_gpu_knn_nodes,
|
|
350
|
+
knn_gpu_mem_ratio=self.config.gnn_knn_gpu_mem_ratio,
|
|
351
|
+
knn_gpu_mem_overhead=self.config.gnn_knn_gpu_mem_overhead
|
|
352
|
+
)
|
|
353
|
+
geo_gnn.fit(
|
|
354
|
+
train_geo,
|
|
355
|
+
self.train_data[self.resp_nme],
|
|
356
|
+
self.train_data[self.weight_nme]
|
|
357
|
+
)
|
|
358
|
+
train_embed = geo_gnn.encode(train_geo)
|
|
359
|
+
test_embed = geo_gnn.encode(test_geo)
|
|
360
|
+
cols = [f"geo_token_{i}" for i in range(train_embed.shape[1])]
|
|
361
|
+
train_tokens = pd.DataFrame(
|
|
362
|
+
train_embed, index=self.train_data.index, columns=cols)
|
|
363
|
+
test_tokens = pd.DataFrame(
|
|
364
|
+
test_embed, index=self.test_data.index, columns=cols)
|
|
365
|
+
return train_tokens, test_tokens, cols, geo_gnn
|
|
366
|
+
except Exception as exc:
|
|
367
|
+
print(f"[GeoToken] Generation failed: {exc}")
|
|
368
|
+
return None
|
|
369
|
+
|
|
370
|
+
def _prepare_geo_tokens(self) -> None:
|
|
371
|
+
"""Build and persist geo tokens with default config values."""
|
|
372
|
+
gnn_trainer = self.trainers.get("gnn")
|
|
373
|
+
if gnn_trainer is not None and hasattr(gnn_trainer, "prepare_geo_tokens"):
|
|
374
|
+
try:
|
|
375
|
+
gnn_trainer.prepare_geo_tokens(force=False) # type: ignore[attr-defined]
|
|
376
|
+
return
|
|
377
|
+
except Exception as exc:
|
|
378
|
+
print(f"[GeoToken] GNNTrainer generation failed: {exc}")
|
|
379
|
+
|
|
380
|
+
result = self._build_geo_tokens()
|
|
381
|
+
if result is None:
|
|
382
|
+
return
|
|
383
|
+
train_tokens, test_tokens, cols, geo_gnn = result
|
|
384
|
+
self.train_geo_tokens = train_tokens
|
|
385
|
+
self.test_geo_tokens = test_tokens
|
|
386
|
+
self.geo_token_cols = cols
|
|
387
|
+
self.geo_gnn_model = geo_gnn
|
|
388
|
+
print(f"[GeoToken] Generated {len(cols)}-dim geo tokens; injecting into FT.")
|
|
389
|
+
|
|
390
|
+
def _add_region_effect(self) -> None:
|
|
391
|
+
"""Partial pooling over province/city to create a smoothed region_effect feature."""
|
|
392
|
+
prov_col = self.config.region_province_col
|
|
393
|
+
city_col = self.config.region_city_col
|
|
394
|
+
if not prov_col or not city_col:
|
|
395
|
+
return
|
|
396
|
+
for col in [prov_col, city_col]:
|
|
397
|
+
if col not in self.train_data.columns:
|
|
398
|
+
print(f"[RegionEffect] Missing column {col}; skipped.")
|
|
399
|
+
return
|
|
400
|
+
|
|
401
|
+
def safe_mean(df: pd.DataFrame) -> float:
|
|
402
|
+
w = df[self.weight_nme]
|
|
403
|
+
y = df[self.resp_nme]
|
|
404
|
+
denom = max(float(w.sum()), EPS)
|
|
405
|
+
return float((y * w).sum() / denom)
|
|
406
|
+
|
|
407
|
+
global_mean = safe_mean(self.train_data)
|
|
408
|
+
alpha = max(float(self.config.region_effect_alpha), 0.0)
|
|
409
|
+
|
|
410
|
+
w_all = self.train_data[self.weight_nme]
|
|
411
|
+
y_all = self.train_data[self.resp_nme]
|
|
412
|
+
yw_all = y_all * w_all
|
|
413
|
+
|
|
414
|
+
prov_sumw = w_all.groupby(self.train_data[prov_col]).sum()
|
|
415
|
+
prov_sumyw = yw_all.groupby(self.train_data[prov_col]).sum()
|
|
416
|
+
prov_mean = (prov_sumyw / prov_sumw.clip(lower=EPS)).astype(float)
|
|
417
|
+
prov_mean = prov_mean.fillna(global_mean)
|
|
418
|
+
|
|
419
|
+
city_sumw = self.train_data.groupby([prov_col, city_col])[
|
|
420
|
+
self.weight_nme].sum()
|
|
421
|
+
city_sumyw = yw_all.groupby(
|
|
422
|
+
[self.train_data[prov_col], self.train_data[city_col]]).sum()
|
|
423
|
+
city_df = pd.DataFrame({
|
|
424
|
+
"sum_w": city_sumw,
|
|
425
|
+
"sum_yw": city_sumyw,
|
|
426
|
+
})
|
|
427
|
+
city_df["prior"] = city_df.index.get_level_values(0).map(
|
|
428
|
+
prov_mean).fillna(global_mean)
|
|
429
|
+
city_df["effect"] = (
|
|
430
|
+
city_df["sum_yw"] + alpha * city_df["prior"]
|
|
431
|
+
) / (city_df["sum_w"] + alpha).clip(lower=EPS)
|
|
432
|
+
city_effect = city_df["effect"]
|
|
433
|
+
|
|
434
|
+
def lookup_effect(df: pd.DataFrame) -> pd.Series:
|
|
435
|
+
idx = pd.MultiIndex.from_frame(df[[prov_col, city_col]])
|
|
436
|
+
effects = city_effect.reindex(idx).to_numpy(dtype=np.float64)
|
|
437
|
+
prov_fallback = df[prov_col].map(
|
|
438
|
+
prov_mean).fillna(global_mean).to_numpy(dtype=np.float64)
|
|
439
|
+
effects = np.where(np.isfinite(effects), effects, prov_fallback)
|
|
440
|
+
effects = np.where(np.isfinite(effects), effects, global_mean)
|
|
441
|
+
return pd.Series(effects, index=df.index, dtype=np.float32)
|
|
442
|
+
|
|
443
|
+
re_train = lookup_effect(self.train_data)
|
|
444
|
+
re_test = lookup_effect(self.test_data)
|
|
445
|
+
|
|
446
|
+
col_name = "region_effect"
|
|
447
|
+
self.train_data[col_name] = re_train
|
|
448
|
+
self.test_data[col_name] = re_test
|
|
449
|
+
|
|
450
|
+
# Sync into one-hot and scaled variants.
|
|
451
|
+
for df in [self.train_oht_data, self.test_oht_data]:
|
|
452
|
+
if df is not None:
|
|
453
|
+
df[col_name] = re_train if df is self.train_oht_data else re_test
|
|
454
|
+
|
|
455
|
+
# Standardize region_effect and propagate.
|
|
456
|
+
scaler = StandardScaler()
|
|
457
|
+
re_train_s = scaler.fit_transform(
|
|
458
|
+
re_train.values.reshape(-1, 1)).astype(np.float32).reshape(-1)
|
|
459
|
+
re_test_s = scaler.transform(
|
|
460
|
+
re_test.values.reshape(-1, 1)).astype(np.float32).reshape(-1)
|
|
461
|
+
for df in [self.train_oht_scl_data, self.test_oht_scl_data]:
|
|
462
|
+
if df is not None:
|
|
463
|
+
df[col_name] = re_train_s if df is self.train_oht_scl_data else re_test_s
|
|
464
|
+
|
|
465
|
+
# Update feature lists.
|
|
466
|
+
if col_name not in self.factor_nmes:
|
|
467
|
+
self.factor_nmes.append(col_name)
|
|
468
|
+
if col_name not in self.num_features:
|
|
469
|
+
self.num_features.append(col_name)
|
|
470
|
+
if self.train_oht_scl_data is not None:
|
|
471
|
+
excluded = {self.weight_nme, self.resp_nme}
|
|
472
|
+
self.var_nmes = [
|
|
473
|
+
col for col in self.train_oht_scl_data.columns if col not in excluded
|
|
474
|
+
]
|
|
475
|
+
|
|
476
|
+
# Single-factor plotting helper.
|
|
477
|
+
def plot_oneway(self, n_bins=10):
|
|
478
|
+
if plt is None and plot_diagnostics is None:
|
|
479
|
+
_plot_skip("oneway plot")
|
|
480
|
+
return
|
|
481
|
+
if plot_diagnostics is None:
|
|
482
|
+
for c in self.factor_nmes:
|
|
483
|
+
fig = plt.figure(figsize=(7, 5))
|
|
484
|
+
if c in self.cate_list:
|
|
485
|
+
group_col = c
|
|
486
|
+
plot_source = self.train_data
|
|
487
|
+
else:
|
|
488
|
+
group_col = f'{c}_bins'
|
|
489
|
+
bins = pd.qcut(
|
|
490
|
+
self.train_data[c],
|
|
491
|
+
n_bins,
|
|
492
|
+
duplicates='drop' # Drop duplicate quantiles to avoid errors.
|
|
493
|
+
)
|
|
494
|
+
plot_source = self.train_data.assign(**{group_col: bins})
|
|
495
|
+
plot_data = plot_source.groupby(
|
|
496
|
+
[group_col], observed=True).sum(numeric_only=True)
|
|
497
|
+
plot_data.reset_index(inplace=True)
|
|
498
|
+
plot_data['act_v'] = plot_data['w_act'] / \
|
|
499
|
+
plot_data[self.weight_nme]
|
|
500
|
+
ax = fig.add_subplot(111)
|
|
501
|
+
ax.plot(plot_data.index, plot_data['act_v'],
|
|
502
|
+
label='Actual', color='red')
|
|
503
|
+
ax.set_title(
|
|
504
|
+
'Analysis of %s : Train Data' % group_col,
|
|
505
|
+
fontsize=8)
|
|
506
|
+
plt.xticks(plot_data.index,
|
|
507
|
+
list(plot_data[group_col].astype(str)),
|
|
508
|
+
rotation=90)
|
|
509
|
+
if len(list(plot_data[group_col].astype(str))) > 50:
|
|
510
|
+
plt.xticks(fontsize=3)
|
|
511
|
+
else:
|
|
512
|
+
plt.xticks(fontsize=6)
|
|
513
|
+
plt.yticks(fontsize=6)
|
|
514
|
+
ax2 = ax.twinx()
|
|
515
|
+
ax2.bar(plot_data.index,
|
|
516
|
+
plot_data[self.weight_nme],
|
|
517
|
+
alpha=0.5, color='seagreen')
|
|
518
|
+
plt.yticks(fontsize=6)
|
|
519
|
+
plt.margins(0.05)
|
|
520
|
+
plt.subplots_adjust(wspace=0.3)
|
|
521
|
+
save_path = self.output_manager.plot_path(
|
|
522
|
+
f'00_{self.model_nme}_{group_col}_oneway.png')
|
|
523
|
+
plt.savefig(save_path, dpi=300)
|
|
524
|
+
plt.close(fig)
|
|
525
|
+
return
|
|
526
|
+
|
|
527
|
+
if "w_act" not in self.train_data.columns:
|
|
528
|
+
print("[Oneway] Missing w_act column; skip plotting.", flush=True)
|
|
529
|
+
return
|
|
530
|
+
|
|
531
|
+
for c in self.factor_nmes:
|
|
532
|
+
is_cat = c in (self.cate_list or [])
|
|
533
|
+
group_col = c if is_cat else f"{c}_bins"
|
|
534
|
+
title = f"Analysis of {group_col} : Train Data"
|
|
535
|
+
save_path = self.output_manager.plot_path(
|
|
536
|
+
f"00_{self.model_nme}_{group_col}_oneway.png"
|
|
537
|
+
)
|
|
538
|
+
plot_diagnostics.plot_oneway(
|
|
539
|
+
self.train_data,
|
|
540
|
+
feature=c,
|
|
541
|
+
weight_col=self.weight_nme,
|
|
542
|
+
target_col="w_act",
|
|
543
|
+
n_bins=n_bins,
|
|
544
|
+
is_categorical=is_cat,
|
|
545
|
+
title=title,
|
|
546
|
+
save_path=save_path,
|
|
547
|
+
show=False,
|
|
548
|
+
)
|
|
549
|
+
|
|
550
|
+
def _require_trainer(self, model_key: str) -> "TrainerBase":
|
|
551
|
+
trainer = self.trainers.get(model_key)
|
|
552
|
+
if trainer is None:
|
|
553
|
+
raise KeyError(f"Unknown model key: {model_key}")
|
|
554
|
+
return trainer
|
|
555
|
+
|
|
556
|
+
def _pred_vector_columns(self, pred_prefix: str) -> List[str]:
|
|
557
|
+
"""Return vector feature columns like pred_<prefix>_0.. sorted by suffix."""
|
|
558
|
+
col_prefix = f"pred_{pred_prefix}_"
|
|
559
|
+
cols = [c for c in self.train_data.columns if c.startswith(col_prefix)]
|
|
560
|
+
|
|
561
|
+
def sort_key(name: str):
|
|
562
|
+
tail = name.rsplit("_", 1)[-1]
|
|
563
|
+
try:
|
|
564
|
+
return (0, int(tail))
|
|
565
|
+
except Exception:
|
|
566
|
+
return (1, tail)
|
|
567
|
+
|
|
568
|
+
cols.sort(key=sort_key)
|
|
569
|
+
return cols
|
|
570
|
+
|
|
571
|
+
def _inject_pred_features(self, pred_prefix: str) -> List[str]:
|
|
572
|
+
"""Inject pred_<prefix> or pred_<prefix>_i columns into features and return names."""
|
|
573
|
+
cols = self._pred_vector_columns(pred_prefix)
|
|
574
|
+
if cols:
|
|
575
|
+
self.add_numeric_features_from_columns(cols)
|
|
576
|
+
return cols
|
|
577
|
+
scalar_col = f"pred_{pred_prefix}"
|
|
578
|
+
if scalar_col in self.train_data.columns:
|
|
579
|
+
self.add_numeric_feature_from_column(scalar_col)
|
|
580
|
+
return [scalar_col]
|
|
581
|
+
return []
|
|
582
|
+
|
|
583
|
+
def _maybe_load_best_params(self, model_key: str, trainer: "TrainerBase") -> None:
|
|
584
|
+
# 1) If best_params_files is specified, load and skip tuning.
|
|
585
|
+
best_params_files = getattr(self.config, "best_params_files", None) or {}
|
|
586
|
+
best_params_file = best_params_files.get(model_key)
|
|
587
|
+
if best_params_file and not trainer.best_params:
|
|
588
|
+
trainer.best_params = IOUtils.load_params_file(best_params_file)
|
|
589
|
+
trainer.best_trial = None
|
|
590
|
+
print(
|
|
591
|
+
f"[Optuna][{trainer.label}] Loaded best_params from {best_params_file}; skip tuning."
|
|
592
|
+
)
|
|
593
|
+
|
|
594
|
+
# 2) If reuse_best_params is enabled, prefer version snapshots; else load legacy CSV.
|
|
595
|
+
reuse_params = bool(getattr(self.config, "reuse_best_params", False))
|
|
596
|
+
if reuse_params and not trainer.best_params:
|
|
597
|
+
payload = self.version_manager.load_latest(f"{model_key}_best")
|
|
598
|
+
best_params = None if payload is None else payload.get("best_params")
|
|
599
|
+
if best_params:
|
|
600
|
+
trainer.best_params = best_params
|
|
601
|
+
trainer.best_trial = None
|
|
602
|
+
trainer.study_name = payload.get(
|
|
603
|
+
"study_name") if isinstance(payload, dict) else None
|
|
604
|
+
print(
|
|
605
|
+
f"[Optuna][{trainer.label}] Reusing best_params from versions snapshot.")
|
|
606
|
+
return
|
|
607
|
+
|
|
608
|
+
params_path = self.output_manager.result_path(
|
|
609
|
+
f'{self.model_nme}_bestparams_{trainer.label.lower()}.csv'
|
|
610
|
+
)
|
|
611
|
+
if os.path.exists(params_path):
|
|
612
|
+
try:
|
|
613
|
+
trainer.best_params = IOUtils.load_params_file(params_path)
|
|
614
|
+
trainer.best_trial = None
|
|
615
|
+
print(
|
|
616
|
+
f"[Optuna][{trainer.label}] Reusing best_params from {params_path}.")
|
|
617
|
+
except ValueError:
|
|
618
|
+
# Legacy compatibility: ignore empty files and continue tuning.
|
|
619
|
+
pass
|
|
620
|
+
|
|
621
|
+
# Generic optimization entry point.
|
|
622
|
+
def optimize_model(self, model_key: str, max_evals: int = 100):
|
|
623
|
+
if model_key not in self.trainers:
|
|
624
|
+
print(f"Warning: Unknown model key: {model_key}")
|
|
625
|
+
return
|
|
626
|
+
|
|
627
|
+
trainer = self._require_trainer(model_key)
|
|
628
|
+
self._maybe_load_best_params(model_key, trainer)
|
|
629
|
+
|
|
630
|
+
should_tune = not trainer.best_params
|
|
631
|
+
if should_tune:
|
|
632
|
+
if model_key == "ft" and str(self.config.ft_role) == "unsupervised_embedding":
|
|
633
|
+
if hasattr(trainer, "cross_val_unsupervised"):
|
|
634
|
+
trainer.tune(
|
|
635
|
+
max_evals,
|
|
636
|
+
objective_fn=getattr(trainer, "cross_val_unsupervised")
|
|
637
|
+
)
|
|
638
|
+
else:
|
|
639
|
+
raise RuntimeError(
|
|
640
|
+
"FT trainer does not support unsupervised Optuna objective.")
|
|
641
|
+
else:
|
|
642
|
+
trainer.tune(max_evals)
|
|
643
|
+
|
|
644
|
+
if model_key == "ft" and str(self.config.ft_role) != "model":
|
|
645
|
+
prefix = str(self.config.ft_feature_prefix or "ft_emb")
|
|
646
|
+
role = str(self.config.ft_role)
|
|
647
|
+
if role == "embedding":
|
|
648
|
+
trainer.train_as_feature(
|
|
649
|
+
pred_prefix=prefix, feature_mode="embedding")
|
|
650
|
+
elif role == "unsupervised_embedding":
|
|
651
|
+
trainer.pretrain_unsupervised_as_feature(
|
|
652
|
+
pred_prefix=prefix,
|
|
653
|
+
params=trainer.best_params
|
|
654
|
+
)
|
|
655
|
+
else:
|
|
656
|
+
raise ValueError(
|
|
657
|
+
f"Unsupported ft_role='{role}', expected 'model'/'embedding'/'unsupervised_embedding'.")
|
|
658
|
+
|
|
659
|
+
# Inject generated prediction/embedding columns as features (scalar or vector).
|
|
660
|
+
self._inject_pred_features(prefix)
|
|
661
|
+
# Do not add FT as a standalone model label; downstream models handle evaluation.
|
|
662
|
+
else:
|
|
663
|
+
trainer.train()
|
|
664
|
+
|
|
665
|
+
if bool(getattr(self.config, "final_ensemble", False)):
|
|
666
|
+
k = int(getattr(self.config, "final_ensemble_k", 3) or 3)
|
|
667
|
+
if k > 1:
|
|
668
|
+
if model_key == "ft" and str(self.config.ft_role) != "model":
|
|
669
|
+
pass
|
|
670
|
+
elif hasattr(trainer, "ensemble_predict"):
|
|
671
|
+
trainer.ensemble_predict(k)
|
|
672
|
+
else:
|
|
673
|
+
print(
|
|
674
|
+
f"[Ensemble] Trainer '{model_key}' does not support ensemble prediction.",
|
|
675
|
+
flush=True,
|
|
676
|
+
)
|
|
677
|
+
|
|
678
|
+
# Update context fields for backward compatibility.
|
|
679
|
+
setattr(self, f"{model_key}_best", trainer.model)
|
|
680
|
+
setattr(self, f"best_{model_key}_params", trainer.best_params)
|
|
681
|
+
setattr(self, f"best_{model_key}_trial", trainer.best_trial)
|
|
682
|
+
# Save a snapshot for traceability.
|
|
683
|
+
study_name = getattr(trainer, "study_name", None)
|
|
684
|
+
if study_name is None and trainer.best_trial is not None:
|
|
685
|
+
study_obj = getattr(trainer.best_trial, "study", None)
|
|
686
|
+
study_name = getattr(study_obj, "study_name", None)
|
|
687
|
+
snapshot = {
|
|
688
|
+
"model_key": model_key,
|
|
689
|
+
"timestamp": datetime.now().isoformat(),
|
|
690
|
+
"best_params": trainer.best_params,
|
|
691
|
+
"study_name": study_name,
|
|
692
|
+
"config": asdict(self.config),
|
|
693
|
+
}
|
|
694
|
+
self.version_manager.save(f"{model_key}_best", snapshot)
|
|
695
|
+
|
|
696
|
+
def add_numeric_feature_from_column(self, col_name: str) -> None:
|
|
697
|
+
"""Add an existing column as a feature and sync one-hot/scaled tables."""
|
|
698
|
+
if col_name not in self.train_data.columns or col_name not in self.test_data.columns:
|
|
699
|
+
raise KeyError(
|
|
700
|
+
f"Column '{col_name}' must exist in both train_data and test_data.")
|
|
701
|
+
|
|
702
|
+
if col_name not in self.factor_nmes:
|
|
703
|
+
self.factor_nmes.append(col_name)
|
|
704
|
+
if col_name not in self.config.factor_nmes:
|
|
705
|
+
self.config.factor_nmes.append(col_name)
|
|
706
|
+
|
|
707
|
+
if col_name not in self.cate_list and col_name not in self.num_features:
|
|
708
|
+
self.num_features.append(col_name)
|
|
709
|
+
|
|
710
|
+
if self.train_oht_data is not None and self.test_oht_data is not None:
|
|
711
|
+
self.train_oht_data[col_name] = self.train_data[col_name].values
|
|
712
|
+
self.test_oht_data[col_name] = self.test_data[col_name].values
|
|
713
|
+
if self.train_oht_scl_data is not None and self.test_oht_scl_data is not None:
|
|
714
|
+
scaler = StandardScaler()
|
|
715
|
+
tr = self.train_data[col_name].to_numpy(
|
|
716
|
+
dtype=np.float32, copy=False).reshape(-1, 1)
|
|
717
|
+
te = self.test_data[col_name].to_numpy(
|
|
718
|
+
dtype=np.float32, copy=False).reshape(-1, 1)
|
|
719
|
+
self.train_oht_scl_data[col_name] = scaler.fit_transform(
|
|
720
|
+
tr).reshape(-1)
|
|
721
|
+
self.test_oht_scl_data[col_name] = scaler.transform(te).reshape(-1)
|
|
722
|
+
|
|
723
|
+
if col_name not in self.var_nmes:
|
|
724
|
+
self.var_nmes.append(col_name)
|
|
725
|
+
|
|
726
|
+
def add_numeric_features_from_columns(self, col_names: List[str]) -> None:
|
|
727
|
+
if not col_names:
|
|
728
|
+
return
|
|
729
|
+
|
|
730
|
+
missing = [
|
|
731
|
+
col for col in col_names
|
|
732
|
+
if col not in self.train_data.columns or col not in self.test_data.columns
|
|
733
|
+
]
|
|
734
|
+
if missing:
|
|
735
|
+
raise KeyError(
|
|
736
|
+
f"Column(s) {missing} must exist in both train_data and test_data."
|
|
737
|
+
)
|
|
738
|
+
|
|
739
|
+
for col_name in col_names:
|
|
740
|
+
if col_name not in self.factor_nmes:
|
|
741
|
+
self.factor_nmes.append(col_name)
|
|
742
|
+
if col_name not in self.config.factor_nmes:
|
|
743
|
+
self.config.factor_nmes.append(col_name)
|
|
744
|
+
if col_name not in self.cate_list and col_name not in self.num_features:
|
|
745
|
+
self.num_features.append(col_name)
|
|
746
|
+
if col_name not in self.var_nmes:
|
|
747
|
+
self.var_nmes.append(col_name)
|
|
748
|
+
|
|
749
|
+
if self.train_oht_data is not None and self.test_oht_data is not None:
|
|
750
|
+
self.train_oht_data.loc[:, col_names] = self.train_data[col_names].to_numpy(copy=False)
|
|
751
|
+
self.test_oht_data.loc[:, col_names] = self.test_data[col_names].to_numpy(copy=False)
|
|
752
|
+
|
|
753
|
+
if self.train_oht_scl_data is not None and self.test_oht_scl_data is not None:
|
|
754
|
+
scaler = StandardScaler()
|
|
755
|
+
tr = self.train_data[col_names].to_numpy(dtype=np.float32, copy=False)
|
|
756
|
+
te = self.test_data[col_names].to_numpy(dtype=np.float32, copy=False)
|
|
757
|
+
self.train_oht_scl_data.loc[:, col_names] = scaler.fit_transform(tr)
|
|
758
|
+
self.test_oht_scl_data.loc[:, col_names] = scaler.transform(te)
|
|
759
|
+
|
|
760
|
+
def prepare_ft_as_feature(self, max_evals: int = 50, pred_prefix: str = "ft_feat") -> str:
|
|
761
|
+
"""Train FT as a feature generator and return the downstream column name."""
|
|
762
|
+
ft_trainer = self._require_trainer("ft")
|
|
763
|
+
ft_trainer.tune(max_evals=max_evals)
|
|
764
|
+
if hasattr(ft_trainer, "train_as_feature"):
|
|
765
|
+
ft_trainer.train_as_feature(pred_prefix=pred_prefix)
|
|
766
|
+
else:
|
|
767
|
+
ft_trainer.train()
|
|
768
|
+
feature_col = f"pred_{pred_prefix}"
|
|
769
|
+
self.add_numeric_feature_from_column(feature_col)
|
|
770
|
+
return feature_col
|
|
771
|
+
|
|
772
|
+
def prepare_ft_embedding_as_features(self, max_evals: int = 50, pred_prefix: str = "ft_emb") -> List[str]:
|
|
773
|
+
"""Train FT and inject pooled embeddings as vector features pred_<prefix>_0.. ."""
|
|
774
|
+
ft_trainer = self._require_trainer("ft")
|
|
775
|
+
ft_trainer.tune(max_evals=max_evals)
|
|
776
|
+
if hasattr(ft_trainer, "train_as_feature"):
|
|
777
|
+
ft_trainer.train_as_feature(
|
|
778
|
+
pred_prefix=pred_prefix, feature_mode="embedding")
|
|
779
|
+
else:
|
|
780
|
+
raise RuntimeError(
|
|
781
|
+
"FT trainer does not support embedding feature mode.")
|
|
782
|
+
cols = self._pred_vector_columns(pred_prefix)
|
|
783
|
+
if not cols:
|
|
784
|
+
raise RuntimeError(
|
|
785
|
+
f"No embedding columns were generated for prefix '{pred_prefix}'.")
|
|
786
|
+
self.add_numeric_features_from_columns(cols)
|
|
787
|
+
return cols
|
|
788
|
+
|
|
789
|
+
def prepare_ft_unsupervised_embedding_as_features(self,
|
|
790
|
+
pred_prefix: str = "ft_uemb",
|
|
791
|
+
params: Optional[Dict[str,
|
|
792
|
+
Any]] = None,
|
|
793
|
+
mask_prob_num: float = 0.15,
|
|
794
|
+
mask_prob_cat: float = 0.15,
|
|
795
|
+
num_loss_weight: float = 1.0,
|
|
796
|
+
cat_loss_weight: float = 1.0) -> List[str]:
|
|
797
|
+
"""Export embeddings after FT self-supervised masked reconstruction pretraining."""
|
|
798
|
+
ft_trainer = self._require_trainer("ft")
|
|
799
|
+
if not hasattr(ft_trainer, "pretrain_unsupervised_as_feature"):
|
|
800
|
+
raise RuntimeError(
|
|
801
|
+
"FT trainer does not support unsupervised pretraining.")
|
|
802
|
+
ft_trainer.pretrain_unsupervised_as_feature(
|
|
803
|
+
pred_prefix=pred_prefix,
|
|
804
|
+
params=params,
|
|
805
|
+
mask_prob_num=mask_prob_num,
|
|
806
|
+
mask_prob_cat=mask_prob_cat,
|
|
807
|
+
num_loss_weight=num_loss_weight,
|
|
808
|
+
cat_loss_weight=cat_loss_weight
|
|
809
|
+
)
|
|
810
|
+
cols = self._pred_vector_columns(pred_prefix)
|
|
811
|
+
if not cols:
|
|
812
|
+
raise RuntimeError(
|
|
813
|
+
f"No embedding columns were generated for prefix '{pred_prefix}'.")
|
|
814
|
+
self.add_numeric_features_from_columns(cols)
|
|
815
|
+
return cols
|
|
816
|
+
|
|
817
|
+
# GLM Bayesian optimization wrapper.
|
|
818
|
+
def bayesopt_glm(self, max_evals=50):
|
|
819
|
+
self.optimize_model('glm', max_evals)
|
|
820
|
+
|
|
821
|
+
# XGBoost Bayesian optimization wrapper.
|
|
822
|
+
def bayesopt_xgb(self, max_evals=100):
|
|
823
|
+
self.optimize_model('xgb', max_evals)
|
|
824
|
+
|
|
825
|
+
# ResNet Bayesian optimization wrapper.
|
|
826
|
+
def bayesopt_resnet(self, max_evals=100):
|
|
827
|
+
self.optimize_model('resn', max_evals)
|
|
828
|
+
|
|
829
|
+
# GNN Bayesian optimization wrapper.
|
|
830
|
+
def bayesopt_gnn(self, max_evals=50):
|
|
831
|
+
self.optimize_model('gnn', max_evals)
|
|
832
|
+
|
|
833
|
+
# FT-Transformer Bayesian optimization wrapper.
|
|
834
|
+
def bayesopt_ft(self, max_evals=50):
|
|
835
|
+
self.optimize_model('ft', max_evals)
|
|
836
|
+
|
|
837
|
+
# Lift curve plotting.
|
|
838
|
+
def plot_lift(self, model_label, pred_nme, n_bins=10):
|
|
839
|
+
if plt is None:
|
|
840
|
+
_plot_skip("lift plot")
|
|
841
|
+
return
|
|
842
|
+
model_map = {
|
|
843
|
+
'Xgboost': 'pred_xgb',
|
|
844
|
+
'ResNet': 'pred_resn',
|
|
845
|
+
'ResNetClassifier': 'pred_resn',
|
|
846
|
+
'GLM': 'pred_glm',
|
|
847
|
+
'GNN': 'pred_gnn',
|
|
848
|
+
}
|
|
849
|
+
if str(self.config.ft_role) == "model":
|
|
850
|
+
model_map.update({
|
|
851
|
+
'FTTransformer': 'pred_ft',
|
|
852
|
+
'FTTransformerClassifier': 'pred_ft',
|
|
853
|
+
})
|
|
854
|
+
for k, v in model_map.items():
|
|
855
|
+
if model_label.startswith(k):
|
|
856
|
+
pred_nme = v
|
|
857
|
+
break
|
|
858
|
+
|
|
859
|
+
datasets = []
|
|
860
|
+
for title, data in [
|
|
861
|
+
('Lift Chart on Train Data', self.train_data),
|
|
862
|
+
('Lift Chart on Test Data', self.test_data),
|
|
863
|
+
]:
|
|
864
|
+
if 'w_act' not in data.columns or data['w_act'].isna().all():
|
|
865
|
+
print(
|
|
866
|
+
f"[Lift] Missing labels for {title}; skip.",
|
|
867
|
+
flush=True,
|
|
868
|
+
)
|
|
869
|
+
continue
|
|
870
|
+
datasets.append((title, data))
|
|
871
|
+
|
|
872
|
+
if not datasets:
|
|
873
|
+
print("[Lift] No labeled data available; skip plotting.", flush=True)
|
|
874
|
+
return
|
|
875
|
+
|
|
876
|
+
if plot_curves is None:
|
|
877
|
+
fig = plt.figure(figsize=(11, 5))
|
|
878
|
+
positions = [111] if len(datasets) == 1 else [121, 122]
|
|
879
|
+
for pos, (title, data) in zip(positions, datasets):
|
|
880
|
+
if pred_nme not in data.columns or f'w_{pred_nme}' not in data.columns:
|
|
881
|
+
print(
|
|
882
|
+
f"[Lift] Missing prediction columns in {title}; skip.",
|
|
883
|
+
flush=True,
|
|
884
|
+
)
|
|
885
|
+
continue
|
|
886
|
+
lift_df = pd.DataFrame({
|
|
887
|
+
'pred': data[pred_nme].values,
|
|
888
|
+
'w_pred': data[f'w_{pred_nme}'].values,
|
|
889
|
+
'act': data['w_act'].values,
|
|
890
|
+
'weight': data[self.weight_nme].values
|
|
891
|
+
})
|
|
892
|
+
plot_data = PlotUtils.split_data(lift_df, 'pred', 'weight', n_bins)
|
|
893
|
+
denom = np.maximum(plot_data['weight'], EPS)
|
|
894
|
+
plot_data['exp_v'] = plot_data['w_pred'] / denom
|
|
895
|
+
plot_data['act_v'] = plot_data['act'] / denom
|
|
896
|
+
plot_data = plot_data.reset_index()
|
|
897
|
+
|
|
898
|
+
ax = fig.add_subplot(pos)
|
|
899
|
+
PlotUtils.plot_lift_ax(ax, plot_data, title)
|
|
900
|
+
|
|
901
|
+
plt.subplots_adjust(wspace=0.3)
|
|
902
|
+
save_path = self.output_manager.plot_path(
|
|
903
|
+
f'01_{self.model_nme}_{model_label}_lift.png')
|
|
904
|
+
plt.savefig(save_path, dpi=300)
|
|
905
|
+
plt.show()
|
|
906
|
+
plt.close(fig)
|
|
907
|
+
return
|
|
908
|
+
|
|
909
|
+
style = PlotStyle() if PlotStyle else None
|
|
910
|
+
fig, axes = plt.subplots(1, len(datasets), figsize=(11, 5))
|
|
911
|
+
if len(datasets) == 1:
|
|
912
|
+
axes = [axes]
|
|
913
|
+
|
|
914
|
+
for ax, (title, data) in zip(axes, datasets):
|
|
915
|
+
pred_vals = None
|
|
916
|
+
if pred_nme in data.columns:
|
|
917
|
+
pred_vals = data[pred_nme].values
|
|
918
|
+
else:
|
|
919
|
+
w_pred_col = f"w_{pred_nme}"
|
|
920
|
+
if w_pred_col in data.columns:
|
|
921
|
+
denom = np.maximum(data[self.weight_nme].values, EPS)
|
|
922
|
+
pred_vals = data[w_pred_col].values / denom
|
|
923
|
+
if pred_vals is None:
|
|
924
|
+
print(
|
|
925
|
+
f"[Lift] Missing prediction columns in {title}; skip.",
|
|
926
|
+
flush=True,
|
|
927
|
+
)
|
|
928
|
+
continue
|
|
929
|
+
|
|
930
|
+
plot_curves.plot_lift_curve(
|
|
931
|
+
pred_vals,
|
|
932
|
+
data['w_act'].values,
|
|
933
|
+
data[self.weight_nme].values,
|
|
934
|
+
n_bins=n_bins,
|
|
935
|
+
title=title,
|
|
936
|
+
pred_label="Predicted",
|
|
937
|
+
act_label="Actual",
|
|
938
|
+
weight_label="Earned Exposure",
|
|
939
|
+
pred_weighted=False,
|
|
940
|
+
actual_weighted=True,
|
|
941
|
+
ax=ax,
|
|
942
|
+
show=False,
|
|
943
|
+
style=style,
|
|
944
|
+
)
|
|
945
|
+
|
|
946
|
+
plt.subplots_adjust(wspace=0.3)
|
|
947
|
+
save_path = self.output_manager.plot_path(
|
|
948
|
+
f'01_{self.model_nme}_{model_label}_lift.png')
|
|
949
|
+
if finalize_figure:
|
|
950
|
+
finalize_figure(fig, save_path=save_path, show=True, style=style)
|
|
951
|
+
else:
|
|
952
|
+
plt.savefig(save_path, dpi=300)
|
|
953
|
+
plt.show()
|
|
954
|
+
plt.close(fig)
|
|
955
|
+
|
|
956
|
+
# Double lift curve plot.
|
|
957
|
+
def plot_dlift(self, model_comp: List[str] = ['xgb', 'resn'], n_bins: int = 10) -> None:
|
|
958
|
+
# Compare two models across bins.
|
|
959
|
+
# Args:
|
|
960
|
+
# model_comp: model keys to compare (e.g., ['xgb', 'resn']).
|
|
961
|
+
# n_bins: number of bins for lift curves.
|
|
962
|
+
if plt is None:
|
|
963
|
+
_plot_skip("double lift plot")
|
|
964
|
+
return
|
|
965
|
+
if len(model_comp) != 2:
|
|
966
|
+
raise ValueError("`model_comp` must contain two models to compare.")
|
|
967
|
+
|
|
968
|
+
model_name_map = {
|
|
969
|
+
'xgb': 'Xgboost',
|
|
970
|
+
'resn': 'ResNet',
|
|
971
|
+
'glm': 'GLM',
|
|
972
|
+
'gnn': 'GNN',
|
|
973
|
+
}
|
|
974
|
+
if str(self.config.ft_role) == "model":
|
|
975
|
+
model_name_map['ft'] = 'FTTransformer'
|
|
976
|
+
|
|
977
|
+
name1, name2 = model_comp
|
|
978
|
+
if name1 not in model_name_map or name2 not in model_name_map:
|
|
979
|
+
raise ValueError(f"Unsupported model key. Choose from {list(model_name_map.keys())}.")
|
|
980
|
+
|
|
981
|
+
datasets = []
|
|
982
|
+
for data_name, data in [('Train Data', self.train_data),
|
|
983
|
+
('Test Data', self.test_data)]:
|
|
984
|
+
if 'w_act' not in data.columns or data['w_act'].isna().all():
|
|
985
|
+
print(
|
|
986
|
+
f"[Double Lift] Missing labels for {data_name}; skip.",
|
|
987
|
+
flush=True,
|
|
988
|
+
)
|
|
989
|
+
continue
|
|
990
|
+
datasets.append((data_name, data))
|
|
991
|
+
|
|
992
|
+
if not datasets:
|
|
993
|
+
print("[Double Lift] No labeled data available; skip plotting.", flush=True)
|
|
994
|
+
return
|
|
995
|
+
|
|
996
|
+
if plot_curves is None:
|
|
997
|
+
fig, axes = plt.subplots(1, len(datasets), figsize=(11, 5))
|
|
998
|
+
if len(datasets) == 1:
|
|
999
|
+
axes = [axes]
|
|
1000
|
+
|
|
1001
|
+
for ax, (data_name, data) in zip(axes, datasets):
|
|
1002
|
+
pred1_col = f'w_pred_{name1}'
|
|
1003
|
+
pred2_col = f'w_pred_{name2}'
|
|
1004
|
+
|
|
1005
|
+
if pred1_col not in data.columns or pred2_col not in data.columns:
|
|
1006
|
+
print(
|
|
1007
|
+
f"Warning: missing prediction columns {pred1_col} or {pred2_col} in {data_name}. Skip plot.")
|
|
1008
|
+
continue
|
|
1009
|
+
|
|
1010
|
+
lift_data = pd.DataFrame({
|
|
1011
|
+
'pred1': data[pred1_col].values,
|
|
1012
|
+
'pred2': data[pred2_col].values,
|
|
1013
|
+
'diff_ly': data[pred1_col].values / np.maximum(data[pred2_col].values, EPS),
|
|
1014
|
+
'act': data['w_act'].values,
|
|
1015
|
+
'weight': data[self.weight_nme].values
|
|
1016
|
+
})
|
|
1017
|
+
plot_data = PlotUtils.split_data(
|
|
1018
|
+
lift_data, 'diff_ly', 'weight', n_bins)
|
|
1019
|
+
denom = np.maximum(plot_data['act'], EPS)
|
|
1020
|
+
plot_data['exp_v1'] = plot_data['pred1'] / denom
|
|
1021
|
+
plot_data['exp_v2'] = plot_data['pred2'] / denom
|
|
1022
|
+
plot_data['act_v'] = plot_data['act'] / denom
|
|
1023
|
+
plot_data.reset_index(inplace=True)
|
|
1024
|
+
|
|
1025
|
+
label1 = model_name_map[name1]
|
|
1026
|
+
label2 = model_name_map[name2]
|
|
1027
|
+
|
|
1028
|
+
PlotUtils.plot_dlift_ax(
|
|
1029
|
+
ax, plot_data, f'Double Lift Chart on {data_name}', label1, label2)
|
|
1030
|
+
|
|
1031
|
+
plt.subplots_adjust(bottom=0.25, top=0.95, right=0.8, wspace=0.3)
|
|
1032
|
+
save_path = self.output_manager.plot_path(
|
|
1033
|
+
f'02_{self.model_nme}_dlift_{name1}_vs_{name2}.png')
|
|
1034
|
+
plt.savefig(save_path, dpi=300)
|
|
1035
|
+
plt.show()
|
|
1036
|
+
plt.close(fig)
|
|
1037
|
+
return
|
|
1038
|
+
|
|
1039
|
+
style = PlotStyle() if PlotStyle else None
|
|
1040
|
+
fig, axes = plt.subplots(1, len(datasets), figsize=(11, 5))
|
|
1041
|
+
if len(datasets) == 1:
|
|
1042
|
+
axes = [axes]
|
|
1043
|
+
|
|
1044
|
+
label1 = model_name_map[name1]
|
|
1045
|
+
label2 = model_name_map[name2]
|
|
1046
|
+
|
|
1047
|
+
for ax, (data_name, data) in zip(axes, datasets):
|
|
1048
|
+
weight_vals = data[self.weight_nme].values
|
|
1049
|
+
pred1 = None
|
|
1050
|
+
pred2 = None
|
|
1051
|
+
|
|
1052
|
+
pred1_col = f"pred_{name1}"
|
|
1053
|
+
pred2_col = f"pred_{name2}"
|
|
1054
|
+
if pred1_col in data.columns:
|
|
1055
|
+
pred1 = data[pred1_col].values
|
|
1056
|
+
else:
|
|
1057
|
+
w_pred1_col = f"w_pred_{name1}"
|
|
1058
|
+
if w_pred1_col in data.columns:
|
|
1059
|
+
pred1 = data[w_pred1_col].values / np.maximum(weight_vals, EPS)
|
|
1060
|
+
|
|
1061
|
+
if pred2_col in data.columns:
|
|
1062
|
+
pred2 = data[pred2_col].values
|
|
1063
|
+
else:
|
|
1064
|
+
w_pred2_col = f"w_pred_{name2}"
|
|
1065
|
+
if w_pred2_col in data.columns:
|
|
1066
|
+
pred2 = data[w_pred2_col].values / np.maximum(weight_vals, EPS)
|
|
1067
|
+
|
|
1068
|
+
if pred1 is None or pred2 is None:
|
|
1069
|
+
print(
|
|
1070
|
+
f"Warning: missing pred_{name1}/pred_{name2} or w_pred columns in {data_name}. Skip plot.")
|
|
1071
|
+
continue
|
|
1072
|
+
|
|
1073
|
+
plot_curves.plot_double_lift_curve(
|
|
1074
|
+
pred1,
|
|
1075
|
+
pred2,
|
|
1076
|
+
data['w_act'].values,
|
|
1077
|
+
weight_vals,
|
|
1078
|
+
n_bins=n_bins,
|
|
1079
|
+
title=f"Double Lift Chart on {data_name}",
|
|
1080
|
+
label1=label1,
|
|
1081
|
+
label2=label2,
|
|
1082
|
+
pred1_weighted=False,
|
|
1083
|
+
pred2_weighted=False,
|
|
1084
|
+
actual_weighted=True,
|
|
1085
|
+
ax=ax,
|
|
1086
|
+
show=False,
|
|
1087
|
+
style=style,
|
|
1088
|
+
)
|
|
1089
|
+
|
|
1090
|
+
plt.subplots_adjust(bottom=0.25, top=0.95, right=0.8, wspace=0.3)
|
|
1091
|
+
save_path = self.output_manager.plot_path(
|
|
1092
|
+
f'02_{self.model_nme}_dlift_{name1}_vs_{name2}.png')
|
|
1093
|
+
if finalize_figure:
|
|
1094
|
+
finalize_figure(fig, save_path=save_path, show=True, style=style)
|
|
1095
|
+
else:
|
|
1096
|
+
plt.savefig(save_path, dpi=300)
|
|
1097
|
+
plt.show()
|
|
1098
|
+
plt.close(fig)
|
|
1099
|
+
|
|
1100
|
+
# Conversion lift curve plot.
|
|
1101
|
+
def plot_conversion_lift(self, model_pred_col: str, n_bins: int = 20):
|
|
1102
|
+
if plt is None:
|
|
1103
|
+
_plot_skip("conversion lift plot")
|
|
1104
|
+
return
|
|
1105
|
+
if not self.binary_resp_nme:
|
|
1106
|
+
print("Error: `binary_resp_nme` not provided at BayesOptModel init; cannot plot conversion lift.")
|
|
1107
|
+
return
|
|
1108
|
+
|
|
1109
|
+
if plot_curves is None:
|
|
1110
|
+
fig, axes = plt.subplots(1, 2, figsize=(14, 6), sharey=True)
|
|
1111
|
+
datasets = {
|
|
1112
|
+
'Train Data': self.train_data,
|
|
1113
|
+
'Test Data': self.test_data
|
|
1114
|
+
}
|
|
1115
|
+
|
|
1116
|
+
for ax, (data_name, data) in zip(axes, datasets.items()):
|
|
1117
|
+
if model_pred_col not in data.columns:
|
|
1118
|
+
print(f"Warning: missing prediction column '{model_pred_col}' in {data_name}. Skip plot.")
|
|
1119
|
+
continue
|
|
1120
|
+
|
|
1121
|
+
# Sort by model prediction and compute bins.
|
|
1122
|
+
plot_data = data.sort_values(by=model_pred_col).copy()
|
|
1123
|
+
plot_data['cum_weight'] = plot_data[self.weight_nme].cumsum()
|
|
1124
|
+
total_weight = plot_data[self.weight_nme].sum()
|
|
1125
|
+
|
|
1126
|
+
if total_weight > EPS:
|
|
1127
|
+
plot_data['bin'] = pd.cut(
|
|
1128
|
+
plot_data['cum_weight'],
|
|
1129
|
+
bins=n_bins,
|
|
1130
|
+
labels=False,
|
|
1131
|
+
right=False
|
|
1132
|
+
)
|
|
1133
|
+
else:
|
|
1134
|
+
plot_data['bin'] = 0
|
|
1135
|
+
|
|
1136
|
+
# Aggregate by bins.
|
|
1137
|
+
lift_agg = plot_data.groupby('bin').agg(
|
|
1138
|
+
total_weight=(self.weight_nme, 'sum'),
|
|
1139
|
+
actual_conversions=(self.binary_resp_nme, 'sum'),
|
|
1140
|
+
weighted_conversions=('w_binary_act', 'sum'),
|
|
1141
|
+
avg_pred=(model_pred_col, 'mean')
|
|
1142
|
+
).reset_index()
|
|
1143
|
+
|
|
1144
|
+
# Compute conversion rate.
|
|
1145
|
+
lift_agg['conversion_rate'] = lift_agg['weighted_conversions'] / \
|
|
1146
|
+
lift_agg['total_weight']
|
|
1147
|
+
|
|
1148
|
+
# Compute overall average conversion rate.
|
|
1149
|
+
overall_conversion_rate = data['w_binary_act'].sum(
|
|
1150
|
+
) / data[self.weight_nme].sum()
|
|
1151
|
+
ax.axhline(y=overall_conversion_rate, color='gray', linestyle='--',
|
|
1152
|
+
label=f'Overall Avg Rate ({overall_conversion_rate:.2%})')
|
|
1153
|
+
|
|
1154
|
+
ax.plot(lift_agg['bin'], lift_agg['conversion_rate'],
|
|
1155
|
+
marker='o', linestyle='-', label='Actual Conversion Rate')
|
|
1156
|
+
ax.set_title(f'Conversion Rate Lift Chart on {data_name}')
|
|
1157
|
+
ax.set_xlabel(f'Model Score Decile (based on {model_pred_col})')
|
|
1158
|
+
ax.set_ylabel('Conversion Rate')
|
|
1159
|
+
ax.grid(True, linestyle='--', alpha=0.6)
|
|
1160
|
+
ax.legend()
|
|
1161
|
+
|
|
1162
|
+
plt.tight_layout()
|
|
1163
|
+
plt.show()
|
|
1164
|
+
return
|
|
1165
|
+
|
|
1166
|
+
fig, axes = plt.subplots(1, 2, figsize=(14, 6), sharey=True)
|
|
1167
|
+
datasets = {
|
|
1168
|
+
'Train Data': self.train_data,
|
|
1169
|
+
'Test Data': self.test_data
|
|
1170
|
+
}
|
|
1171
|
+
|
|
1172
|
+
for ax, (data_name, data) in zip(axes, datasets.items()):
|
|
1173
|
+
if model_pred_col not in data.columns:
|
|
1174
|
+
print(f"Warning: missing prediction column '{model_pred_col}' in {data_name}. Skip plot.")
|
|
1175
|
+
continue
|
|
1176
|
+
|
|
1177
|
+
plot_curves.plot_conversion_lift(
|
|
1178
|
+
data[model_pred_col].values,
|
|
1179
|
+
data[self.binary_resp_nme].values,
|
|
1180
|
+
data[self.weight_nme].values,
|
|
1181
|
+
n_bins=n_bins,
|
|
1182
|
+
title=f'Conversion Rate Lift Chart on {data_name}',
|
|
1183
|
+
ax=ax,
|
|
1184
|
+
show=False,
|
|
1185
|
+
)
|
|
1186
|
+
|
|
1187
|
+
plt.tight_layout()
|
|
1188
|
+
plt.show()
|
|
1189
|
+
|
|
1190
|
+
# ========= Lightweight explainability: Permutation Importance =========
|
|
1191
|
+
def compute_permutation_importance(self,
|
|
1192
|
+
model_key: str,
|
|
1193
|
+
on_train: bool = True,
|
|
1194
|
+
metric: Any = "auto",
|
|
1195
|
+
n_repeats: int = 5,
|
|
1196
|
+
max_rows: int = 5000,
|
|
1197
|
+
random_state: Optional[int] = None):
|
|
1198
|
+
if explain_permutation is None:
|
|
1199
|
+
raise RuntimeError("explain.permutation is not available.")
|
|
1200
|
+
|
|
1201
|
+
model_key = str(model_key)
|
|
1202
|
+
data = self.train_data if on_train else self.test_data
|
|
1203
|
+
if self.resp_nme not in data.columns:
|
|
1204
|
+
raise RuntimeError("Missing response column for permutation importance.")
|
|
1205
|
+
y = data[self.resp_nme]
|
|
1206
|
+
w = data[self.weight_nme] if self.weight_nme in data.columns else None
|
|
1207
|
+
|
|
1208
|
+
if model_key == "resn":
|
|
1209
|
+
if self.resn_best is None:
|
|
1210
|
+
raise RuntimeError("ResNet model not trained.")
|
|
1211
|
+
X = self.train_oht_scl_data if on_train else self.test_oht_scl_data
|
|
1212
|
+
if X is None:
|
|
1213
|
+
raise RuntimeError("Missing standardized features for ResNet.")
|
|
1214
|
+
X = X[self.var_nmes]
|
|
1215
|
+
predict_fn = lambda df: self.resn_best.predict(df)
|
|
1216
|
+
elif model_key == "ft":
|
|
1217
|
+
if self.ft_best is None:
|
|
1218
|
+
raise RuntimeError("FT model not trained.")
|
|
1219
|
+
if str(self.config.ft_role) != "model":
|
|
1220
|
+
raise RuntimeError("FT role is not 'model'; FT predictions unavailable.")
|
|
1221
|
+
X = data[self.factor_nmes]
|
|
1222
|
+
geo_tokens = self.train_geo_tokens if on_train else self.test_geo_tokens
|
|
1223
|
+
geo_np = None
|
|
1224
|
+
if geo_tokens is not None:
|
|
1225
|
+
geo_np = geo_tokens.to_numpy(dtype=np.float32, copy=False)
|
|
1226
|
+
predict_fn = lambda df, geo=geo_np: self.ft_best.predict(df, geo_tokens=geo)
|
|
1227
|
+
elif model_key == "xgb":
|
|
1228
|
+
if self.xgb_best is None:
|
|
1229
|
+
raise RuntimeError("XGB model not trained.")
|
|
1230
|
+
X = data[self.factor_nmes]
|
|
1231
|
+
predict_fn = lambda df: self.xgb_best.predict(df)
|
|
1232
|
+
else:
|
|
1233
|
+
raise ValueError("Unsupported model_key for permutation importance.")
|
|
1234
|
+
|
|
1235
|
+
return explain_permutation.permutation_importance(
|
|
1236
|
+
predict_fn,
|
|
1237
|
+
X,
|
|
1238
|
+
y,
|
|
1239
|
+
sample_weight=w,
|
|
1240
|
+
metric=metric,
|
|
1241
|
+
task_type=self.task_type,
|
|
1242
|
+
n_repeats=n_repeats,
|
|
1243
|
+
random_state=random_state,
|
|
1244
|
+
max_rows=max_rows,
|
|
1245
|
+
)
|
|
1246
|
+
|
|
1247
|
+
# ========= Deep explainability: Integrated Gradients =========
|
|
1248
|
+
def compute_integrated_gradients_resn(self,
|
|
1249
|
+
on_train: bool = True,
|
|
1250
|
+
baseline: Any = None,
|
|
1251
|
+
steps: int = 50,
|
|
1252
|
+
batch_size: int = 256,
|
|
1253
|
+
target: Optional[int] = None):
|
|
1254
|
+
if explain_gradients is None:
|
|
1255
|
+
raise RuntimeError("explain.gradients is not available.")
|
|
1256
|
+
if self.resn_best is None:
|
|
1257
|
+
raise RuntimeError("ResNet model not trained.")
|
|
1258
|
+
X = self.train_oht_scl_data if on_train else self.test_oht_scl_data
|
|
1259
|
+
if X is None:
|
|
1260
|
+
raise RuntimeError("Missing standardized features for ResNet.")
|
|
1261
|
+
X = X[self.var_nmes]
|
|
1262
|
+
return explain_gradients.resnet_integrated_gradients(
|
|
1263
|
+
self.resn_best,
|
|
1264
|
+
X,
|
|
1265
|
+
baseline=baseline,
|
|
1266
|
+
steps=steps,
|
|
1267
|
+
batch_size=batch_size,
|
|
1268
|
+
target=target,
|
|
1269
|
+
)
|
|
1270
|
+
|
|
1271
|
+
def compute_integrated_gradients_ft(self,
|
|
1272
|
+
on_train: bool = True,
|
|
1273
|
+
geo_tokens: Optional[np.ndarray] = None,
|
|
1274
|
+
baseline_num: Any = None,
|
|
1275
|
+
baseline_geo: Any = None,
|
|
1276
|
+
steps: int = 50,
|
|
1277
|
+
batch_size: int = 256,
|
|
1278
|
+
target: Optional[int] = None):
|
|
1279
|
+
if explain_gradients is None:
|
|
1280
|
+
raise RuntimeError("explain.gradients is not available.")
|
|
1281
|
+
if self.ft_best is None:
|
|
1282
|
+
raise RuntimeError("FT model not trained.")
|
|
1283
|
+
if str(self.config.ft_role) != "model":
|
|
1284
|
+
raise RuntimeError("FT role is not 'model'; FT explanations unavailable.")
|
|
1285
|
+
|
|
1286
|
+
data = self.train_data if on_train else self.test_data
|
|
1287
|
+
X = data[self.factor_nmes]
|
|
1288
|
+
|
|
1289
|
+
if geo_tokens is None and getattr(self.ft_best, "num_geo", 0) > 0:
|
|
1290
|
+
tokens_df = self.train_geo_tokens if on_train else self.test_geo_tokens
|
|
1291
|
+
if tokens_df is not None:
|
|
1292
|
+
geo_tokens = tokens_df.to_numpy(dtype=np.float32, copy=False)
|
|
1293
|
+
|
|
1294
|
+
return explain_gradients.ft_integrated_gradients(
|
|
1295
|
+
self.ft_best,
|
|
1296
|
+
X,
|
|
1297
|
+
geo_tokens=geo_tokens,
|
|
1298
|
+
baseline_num=baseline_num,
|
|
1299
|
+
baseline_geo=baseline_geo,
|
|
1300
|
+
steps=steps,
|
|
1301
|
+
batch_size=batch_size,
|
|
1302
|
+
target=target,
|
|
1303
|
+
)
|
|
1304
|
+
|
|
1305
|
+
# Save model
|
|
1306
|
+
def save_model(self, model_name=None):
|
|
1307
|
+
keys = [model_name] if model_name else self.trainers.keys()
|
|
1308
|
+
for key in keys:
|
|
1309
|
+
if key in self.trainers:
|
|
1310
|
+
self.trainers[key].save()
|
|
1311
|
+
else:
|
|
1312
|
+
if model_name: # Only warn when the user specifies a model name.
|
|
1313
|
+
print(f"[save_model] Warning: Unknown model key {key}")
|
|
1314
|
+
|
|
1315
|
+
def load_model(self, model_name=None):
|
|
1316
|
+
keys = [model_name] if model_name else self.trainers.keys()
|
|
1317
|
+
for key in keys:
|
|
1318
|
+
if key in self.trainers:
|
|
1319
|
+
self.trainers[key].load()
|
|
1320
|
+
# Sync context fields.
|
|
1321
|
+
trainer = self.trainers[key]
|
|
1322
|
+
if trainer.model is not None:
|
|
1323
|
+
setattr(self, f"{key}_best", trainer.model)
|
|
1324
|
+
# For legacy compatibility, also update xxx_load.
|
|
1325
|
+
# Old versions only tracked xgb_load/resn_load/ft_load (not glm_load/gnn_load).
|
|
1326
|
+
if key in ['xgb', 'resn', 'ft', 'gnn']:
|
|
1327
|
+
setattr(self, f"{key}_load", trainer.model)
|
|
1328
|
+
else:
|
|
1329
|
+
if model_name:
|
|
1330
|
+
print(f"[load_model] Warning: Unknown model key {key}")
|
|
1331
|
+
|
|
1332
|
+
def _sample_rows(self, data: pd.DataFrame, n: int) -> pd.DataFrame:
|
|
1333
|
+
if len(data) == 0:
|
|
1334
|
+
return data
|
|
1335
|
+
return data.sample(min(len(data), n), random_state=self.rand_seed)
|
|
1336
|
+
|
|
1337
|
+
@staticmethod
|
|
1338
|
+
def _shap_nsamples(arr: np.ndarray, max_nsamples: int = 300) -> int:
|
|
1339
|
+
min_needed = arr.shape[1] + 2
|
|
1340
|
+
return max(min_needed, min(max_nsamples, arr.shape[0] * arr.shape[1]))
|
|
1341
|
+
|
|
1342
|
+
def _build_ft_shap_matrix(self, data: pd.DataFrame) -> np.ndarray:
|
|
1343
|
+
matrices = []
|
|
1344
|
+
for col in self.factor_nmes:
|
|
1345
|
+
s = data[col]
|
|
1346
|
+
if col in self.cate_list:
|
|
1347
|
+
cats = pd.Categorical(
|
|
1348
|
+
s,
|
|
1349
|
+
categories=self.cat_categories_for_shap[col]
|
|
1350
|
+
)
|
|
1351
|
+
codes = np.asarray(cats.codes, dtype=np.float64).reshape(-1, 1)
|
|
1352
|
+
matrices.append(codes)
|
|
1353
|
+
else:
|
|
1354
|
+
vals = pd.to_numeric(s, errors="coerce")
|
|
1355
|
+
arr = vals.to_numpy(dtype=np.float64, copy=True).reshape(-1, 1)
|
|
1356
|
+
matrices.append(arr)
|
|
1357
|
+
X_mat = np.concatenate(matrices, axis=1) # Result shape (N, F)
|
|
1358
|
+
return X_mat
|
|
1359
|
+
|
|
1360
|
+
def _decode_ft_shap_matrix_to_df(self, X_mat: np.ndarray) -> pd.DataFrame:
|
|
1361
|
+
data_dict = {}
|
|
1362
|
+
for j, col in enumerate(self.factor_nmes):
|
|
1363
|
+
col_vals = X_mat[:, j]
|
|
1364
|
+
if col in self.cate_list:
|
|
1365
|
+
cats = self.cat_categories_for_shap[col]
|
|
1366
|
+
codes = np.round(col_vals).astype(int)
|
|
1367
|
+
codes = np.clip(codes, -1, len(cats) - 1)
|
|
1368
|
+
cat_series = pd.Categorical.from_codes(
|
|
1369
|
+
codes,
|
|
1370
|
+
categories=cats
|
|
1371
|
+
)
|
|
1372
|
+
data_dict[col] = cat_series
|
|
1373
|
+
else:
|
|
1374
|
+
data_dict[col] = col_vals.astype(float)
|
|
1375
|
+
|
|
1376
|
+
df = pd.DataFrame(data_dict, columns=self.factor_nmes)
|
|
1377
|
+
for col in self.cate_list:
|
|
1378
|
+
if col in df.columns:
|
|
1379
|
+
df[col] = df[col].astype("category")
|
|
1380
|
+
return df
|
|
1381
|
+
|
|
1382
|
+
def _build_glm_design(self, data: pd.DataFrame) -> pd.DataFrame:
|
|
1383
|
+
X = data[self.var_nmes]
|
|
1384
|
+
return sm.add_constant(X, has_constant='add')
|
|
1385
|
+
|
|
1386
|
+
def _compute_shap_core(self,
|
|
1387
|
+
model_key: str,
|
|
1388
|
+
n_background: int,
|
|
1389
|
+
n_samples: int,
|
|
1390
|
+
on_train: bool,
|
|
1391
|
+
X_df: pd.DataFrame,
|
|
1392
|
+
prep_fn,
|
|
1393
|
+
predict_fn,
|
|
1394
|
+
cleanup_fn=None):
|
|
1395
|
+
if explain_shap is None:
|
|
1396
|
+
raise RuntimeError("explain.shap_utils is not available.")
|
|
1397
|
+
return explain_shap.compute_shap_core(
|
|
1398
|
+
self,
|
|
1399
|
+
model_key,
|
|
1400
|
+
n_background,
|
|
1401
|
+
n_samples,
|
|
1402
|
+
on_train,
|
|
1403
|
+
X_df=X_df,
|
|
1404
|
+
prep_fn=prep_fn,
|
|
1405
|
+
predict_fn=predict_fn,
|
|
1406
|
+
cleanup_fn=cleanup_fn,
|
|
1407
|
+
)
|
|
1408
|
+
|
|
1409
|
+
# ========= GLM SHAP explainability =========
|
|
1410
|
+
def compute_shap_glm(self, n_background: int = 500,
|
|
1411
|
+
n_samples: int = 200,
|
|
1412
|
+
on_train: bool = True):
|
|
1413
|
+
if explain_shap is None:
|
|
1414
|
+
raise RuntimeError("explain.shap_utils is not available.")
|
|
1415
|
+
self.shap_glm = explain_shap.compute_shap_glm(
|
|
1416
|
+
self,
|
|
1417
|
+
n_background=n_background,
|
|
1418
|
+
n_samples=n_samples,
|
|
1419
|
+
on_train=on_train,
|
|
1420
|
+
)
|
|
1421
|
+
return self.shap_glm
|
|
1422
|
+
|
|
1423
|
+
# ========= XGBoost SHAP explainability =========
|
|
1424
|
+
def compute_shap_xgb(self, n_background: int = 500,
|
|
1425
|
+
n_samples: int = 200,
|
|
1426
|
+
on_train: bool = True):
|
|
1427
|
+
if explain_shap is None:
|
|
1428
|
+
raise RuntimeError("explain.shap_utils is not available.")
|
|
1429
|
+
self.shap_xgb = explain_shap.compute_shap_xgb(
|
|
1430
|
+
self,
|
|
1431
|
+
n_background=n_background,
|
|
1432
|
+
n_samples=n_samples,
|
|
1433
|
+
on_train=on_train,
|
|
1434
|
+
)
|
|
1435
|
+
return self.shap_xgb
|
|
1436
|
+
|
|
1437
|
+
# ========= ResNet SHAP explainability =========
|
|
1438
|
+
def _resn_predict_wrapper(self, X_np):
|
|
1439
|
+
model = self.resn_best.resnet.to("cpu")
|
|
1440
|
+
with torch.no_grad():
|
|
1441
|
+
X_tensor = torch.tensor(X_np, dtype=torch.float32)
|
|
1442
|
+
y_pred = model(X_tensor).cpu().numpy()
|
|
1443
|
+
y_pred = np.clip(y_pred, 1e-6, None)
|
|
1444
|
+
return y_pred.reshape(-1)
|
|
1445
|
+
|
|
1446
|
+
def compute_shap_resn(self, n_background: int = 500,
|
|
1447
|
+
n_samples: int = 200,
|
|
1448
|
+
on_train: bool = True):
|
|
1449
|
+
if explain_shap is None:
|
|
1450
|
+
raise RuntimeError("explain.shap_utils is not available.")
|
|
1451
|
+
self.shap_resn = explain_shap.compute_shap_resn(
|
|
1452
|
+
self,
|
|
1453
|
+
n_background=n_background,
|
|
1454
|
+
n_samples=n_samples,
|
|
1455
|
+
on_train=on_train,
|
|
1456
|
+
)
|
|
1457
|
+
return self.shap_resn
|
|
1458
|
+
|
|
1459
|
+
# ========= FT-Transformer SHAP explainability =========
|
|
1460
|
+
def _ft_shap_predict_wrapper(self, X_mat: np.ndarray) -> np.ndarray:
|
|
1461
|
+
df_input = self._decode_ft_shap_matrix_to_df(X_mat)
|
|
1462
|
+
y_pred = self.ft_best.predict(df_input)
|
|
1463
|
+
return np.asarray(y_pred, dtype=np.float64).reshape(-1)
|
|
1464
|
+
|
|
1465
|
+
def compute_shap_ft(self, n_background: int = 500,
|
|
1466
|
+
n_samples: int = 200,
|
|
1467
|
+
on_train: bool = True):
|
|
1468
|
+
if explain_shap is None:
|
|
1469
|
+
raise RuntimeError("explain.shap_utils is not available.")
|
|
1470
|
+
self.shap_ft = explain_shap.compute_shap_ft(
|
|
1471
|
+
self,
|
|
1472
|
+
n_background=n_background,
|
|
1473
|
+
n_samples=n_samples,
|
|
1474
|
+
on_train=on_train,
|
|
1475
|
+
)
|
|
1476
|
+
return self.shap_ft
|