ins-pricing 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. ins_pricing/README.md +60 -0
  2. ins_pricing/__init__.py +102 -0
  3. ins_pricing/governance/README.md +18 -0
  4. ins_pricing/governance/__init__.py +20 -0
  5. ins_pricing/governance/approval.py +93 -0
  6. ins_pricing/governance/audit.py +37 -0
  7. ins_pricing/governance/registry.py +99 -0
  8. ins_pricing/governance/release.py +159 -0
  9. ins_pricing/modelling/BayesOpt.py +146 -0
  10. ins_pricing/modelling/BayesOpt_USAGE.md +925 -0
  11. ins_pricing/modelling/BayesOpt_entry.py +575 -0
  12. ins_pricing/modelling/BayesOpt_incremental.py +731 -0
  13. ins_pricing/modelling/Explain_Run.py +36 -0
  14. ins_pricing/modelling/Explain_entry.py +539 -0
  15. ins_pricing/modelling/Pricing_Run.py +36 -0
  16. ins_pricing/modelling/README.md +33 -0
  17. ins_pricing/modelling/__init__.py +44 -0
  18. ins_pricing/modelling/bayesopt/__init__.py +98 -0
  19. ins_pricing/modelling/bayesopt/config_preprocess.py +303 -0
  20. ins_pricing/modelling/bayesopt/core.py +1476 -0
  21. ins_pricing/modelling/bayesopt/models.py +2196 -0
  22. ins_pricing/modelling/bayesopt/trainers.py +2446 -0
  23. ins_pricing/modelling/bayesopt/utils.py +1021 -0
  24. ins_pricing/modelling/cli_common.py +136 -0
  25. ins_pricing/modelling/explain/__init__.py +55 -0
  26. ins_pricing/modelling/explain/gradients.py +334 -0
  27. ins_pricing/modelling/explain/metrics.py +176 -0
  28. ins_pricing/modelling/explain/permutation.py +155 -0
  29. ins_pricing/modelling/explain/shap_utils.py +146 -0
  30. ins_pricing/modelling/notebook_utils.py +284 -0
  31. ins_pricing/modelling/plotting/__init__.py +45 -0
  32. ins_pricing/modelling/plotting/common.py +63 -0
  33. ins_pricing/modelling/plotting/curves.py +572 -0
  34. ins_pricing/modelling/plotting/diagnostics.py +139 -0
  35. ins_pricing/modelling/plotting/geo.py +362 -0
  36. ins_pricing/modelling/plotting/importance.py +121 -0
  37. ins_pricing/modelling/run_logging.py +133 -0
  38. ins_pricing/modelling/tests/conftest.py +8 -0
  39. ins_pricing/modelling/tests/test_cross_val_generic.py +66 -0
  40. ins_pricing/modelling/tests/test_distributed_utils.py +18 -0
  41. ins_pricing/modelling/tests/test_explain.py +56 -0
  42. ins_pricing/modelling/tests/test_geo_tokens_split.py +49 -0
  43. ins_pricing/modelling/tests/test_graph_cache.py +33 -0
  44. ins_pricing/modelling/tests/test_plotting.py +63 -0
  45. ins_pricing/modelling/tests/test_plotting_library.py +150 -0
  46. ins_pricing/modelling/tests/test_preprocessor.py +48 -0
  47. ins_pricing/modelling/watchdog_run.py +211 -0
  48. ins_pricing/pricing/README.md +44 -0
  49. ins_pricing/pricing/__init__.py +27 -0
  50. ins_pricing/pricing/calibration.py +39 -0
  51. ins_pricing/pricing/data_quality.py +117 -0
  52. ins_pricing/pricing/exposure.py +85 -0
  53. ins_pricing/pricing/factors.py +91 -0
  54. ins_pricing/pricing/monitoring.py +99 -0
  55. ins_pricing/pricing/rate_table.py +78 -0
  56. ins_pricing/production/__init__.py +21 -0
  57. ins_pricing/production/drift.py +30 -0
  58. ins_pricing/production/monitoring.py +143 -0
  59. ins_pricing/production/scoring.py +40 -0
  60. ins_pricing/reporting/README.md +20 -0
  61. ins_pricing/reporting/__init__.py +11 -0
  62. ins_pricing/reporting/report_builder.py +72 -0
  63. ins_pricing/reporting/scheduler.py +45 -0
  64. ins_pricing/setup.py +41 -0
  65. ins_pricing v2/__init__.py +23 -0
  66. ins_pricing v2/governance/__init__.py +20 -0
  67. ins_pricing v2/governance/approval.py +93 -0
  68. ins_pricing v2/governance/audit.py +37 -0
  69. ins_pricing v2/governance/registry.py +99 -0
  70. ins_pricing v2/governance/release.py +159 -0
  71. ins_pricing v2/modelling/Explain_Run.py +36 -0
  72. ins_pricing v2/modelling/Pricing_Run.py +36 -0
  73. ins_pricing v2/modelling/__init__.py +151 -0
  74. ins_pricing v2/modelling/cli_common.py +141 -0
  75. ins_pricing v2/modelling/config.py +249 -0
  76. ins_pricing v2/modelling/config_preprocess.py +254 -0
  77. ins_pricing v2/modelling/core.py +741 -0
  78. ins_pricing v2/modelling/data_container.py +42 -0
  79. ins_pricing v2/modelling/explain/__init__.py +55 -0
  80. ins_pricing v2/modelling/explain/gradients.py +334 -0
  81. ins_pricing v2/modelling/explain/metrics.py +176 -0
  82. ins_pricing v2/modelling/explain/permutation.py +155 -0
  83. ins_pricing v2/modelling/explain/shap_utils.py +146 -0
  84. ins_pricing v2/modelling/features.py +215 -0
  85. ins_pricing v2/modelling/model_manager.py +148 -0
  86. ins_pricing v2/modelling/model_plotting.py +463 -0
  87. ins_pricing v2/modelling/models.py +2203 -0
  88. ins_pricing v2/modelling/notebook_utils.py +294 -0
  89. ins_pricing v2/modelling/plotting/__init__.py +45 -0
  90. ins_pricing v2/modelling/plotting/common.py +63 -0
  91. ins_pricing v2/modelling/plotting/curves.py +572 -0
  92. ins_pricing v2/modelling/plotting/diagnostics.py +139 -0
  93. ins_pricing v2/modelling/plotting/geo.py +362 -0
  94. ins_pricing v2/modelling/plotting/importance.py +121 -0
  95. ins_pricing v2/modelling/run_logging.py +133 -0
  96. ins_pricing v2/modelling/tests/conftest.py +8 -0
  97. ins_pricing v2/modelling/tests/test_cross_val_generic.py +66 -0
  98. ins_pricing v2/modelling/tests/test_distributed_utils.py +18 -0
  99. ins_pricing v2/modelling/tests/test_explain.py +56 -0
  100. ins_pricing v2/modelling/tests/test_geo_tokens_split.py +49 -0
  101. ins_pricing v2/modelling/tests/test_graph_cache.py +33 -0
  102. ins_pricing v2/modelling/tests/test_plotting.py +63 -0
  103. ins_pricing v2/modelling/tests/test_plotting_library.py +150 -0
  104. ins_pricing v2/modelling/tests/test_preprocessor.py +48 -0
  105. ins_pricing v2/modelling/trainers.py +2447 -0
  106. ins_pricing v2/modelling/utils.py +1020 -0
  107. ins_pricing v2/modelling/watchdog_run.py +211 -0
  108. ins_pricing v2/pricing/__init__.py +27 -0
  109. ins_pricing v2/pricing/calibration.py +39 -0
  110. ins_pricing v2/pricing/data_quality.py +117 -0
  111. ins_pricing v2/pricing/exposure.py +85 -0
  112. ins_pricing v2/pricing/factors.py +91 -0
  113. ins_pricing v2/pricing/monitoring.py +99 -0
  114. ins_pricing v2/pricing/rate_table.py +78 -0
  115. ins_pricing v2/production/__init__.py +21 -0
  116. ins_pricing v2/production/drift.py +30 -0
  117. ins_pricing v2/production/monitoring.py +143 -0
  118. ins_pricing v2/production/scoring.py +40 -0
  119. ins_pricing v2/reporting/__init__.py +11 -0
  120. ins_pricing v2/reporting/report_builder.py +72 -0
  121. ins_pricing v2/reporting/scheduler.py +45 -0
  122. ins_pricing v2/scripts/BayesOpt_incremental.py +722 -0
  123. ins_pricing v2/scripts/Explain_entry.py +545 -0
  124. ins_pricing v2/scripts/__init__.py +1 -0
  125. ins_pricing v2/scripts/train.py +568 -0
  126. ins_pricing v2/setup.py +55 -0
  127. ins_pricing v2/smoke_test.py +28 -0
  128. ins_pricing-0.1.6.dist-info/METADATA +78 -0
  129. ins_pricing-0.1.6.dist-info/RECORD +169 -0
  130. ins_pricing-0.1.6.dist-info/WHEEL +5 -0
  131. ins_pricing-0.1.6.dist-info/top_level.txt +4 -0
  132. user_packages/__init__.py +105 -0
  133. user_packages legacy/BayesOpt.py +5659 -0
  134. user_packages legacy/BayesOpt_entry.py +513 -0
  135. user_packages legacy/BayesOpt_incremental.py +685 -0
  136. user_packages legacy/Pricing_Run.py +36 -0
  137. user_packages legacy/Try/BayesOpt Legacy251213.py +3719 -0
  138. user_packages legacy/Try/BayesOpt Legacy251215.py +3758 -0
  139. user_packages legacy/Try/BayesOpt lagecy251201.py +3506 -0
  140. user_packages legacy/Try/BayesOpt lagecy251218.py +3992 -0
  141. user_packages legacy/Try/BayesOpt legacy.py +3280 -0
  142. user_packages legacy/Try/BayesOpt.py +838 -0
  143. user_packages legacy/Try/BayesOptAll.py +1569 -0
  144. user_packages legacy/Try/BayesOptAllPlatform.py +909 -0
  145. user_packages legacy/Try/BayesOptCPUGPU.py +1877 -0
  146. user_packages legacy/Try/BayesOptSearch.py +830 -0
  147. user_packages legacy/Try/BayesOptSearchOrigin.py +829 -0
  148. user_packages legacy/Try/BayesOptV1.py +1911 -0
  149. user_packages legacy/Try/BayesOptV10.py +2973 -0
  150. user_packages legacy/Try/BayesOptV11.py +3001 -0
  151. user_packages legacy/Try/BayesOptV12.py +3001 -0
  152. user_packages legacy/Try/BayesOptV2.py +2065 -0
  153. user_packages legacy/Try/BayesOptV3.py +2209 -0
  154. user_packages legacy/Try/BayesOptV4.py +2342 -0
  155. user_packages legacy/Try/BayesOptV5.py +2372 -0
  156. user_packages legacy/Try/BayesOptV6.py +2759 -0
  157. user_packages legacy/Try/BayesOptV7.py +2832 -0
  158. user_packages legacy/Try/BayesOptV8Codex.py +2731 -0
  159. user_packages legacy/Try/BayesOptV8Gemini.py +2614 -0
  160. user_packages legacy/Try/BayesOptV9.py +2927 -0
  161. user_packages legacy/Try/BayesOpt_entry legacy.py +313 -0
  162. user_packages legacy/Try/ModelBayesOptSearch.py +359 -0
  163. user_packages legacy/Try/ResNetBayesOptSearch.py +249 -0
  164. user_packages legacy/Try/XgbBayesOptSearch.py +121 -0
  165. user_packages legacy/Try/xgbbayesopt.py +523 -0
  166. user_packages legacy/__init__.py +19 -0
  167. user_packages legacy/cli_common.py +124 -0
  168. user_packages legacy/notebook_utils.py +228 -0
  169. user_packages legacy/watchdog_run.py +202 -0
@@ -0,0 +1,1476 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import asdict
4
+ from datetime import datetime
5
+ import os
6
+ from typing import Any, Dict, List, Optional
7
+
8
+ try: # matplotlib is optional; avoid hard import failures in headless/minimal envs
9
+ import matplotlib
10
+ if os.name != "nt" and not os.environ.get("DISPLAY") and not os.environ.get("MPLBACKEND"):
11
+ matplotlib.use("Agg")
12
+ import matplotlib.pyplot as plt
13
+ _MPL_IMPORT_ERROR: Optional[BaseException] = None
14
+ except Exception as exc: # pragma: no cover - optional dependency
15
+ plt = None # type: ignore[assignment]
16
+ _MPL_IMPORT_ERROR = exc
17
+ import numpy as np
18
+ import pandas as pd
19
+ import torch
20
+ import statsmodels.api as sm
21
+ from sklearn.model_selection import ShuffleSplit
22
+ from sklearn.preprocessing import StandardScaler
23
+
24
+ from .config_preprocess import BayesOptConfig, DatasetPreprocessor, OutputManager, VersionManager
25
+ from .models import GraphNeuralNetSklearn
26
+ from .trainers import FTTrainer, GLMTrainer, GNNTrainer, ResNetTrainer, XGBTrainer
27
+ from .utils import EPS, PlotUtils, infer_factor_and_cate_list, set_global_seed
28
+ try:
29
+ from ..plotting import curves as plot_curves
30
+ from ..plotting import diagnostics as plot_diagnostics
31
+ from ..plotting.common import PlotStyle, finalize_figure
32
+ from ..explain import gradients as explain_gradients
33
+ from ..explain import permutation as explain_permutation
34
+ from ..explain import shap_utils as explain_shap
35
+ except Exception: # pragma: no cover - optional for legacy imports
36
+ try: # best-effort for non-package imports
37
+ from ins_pricing.plotting import curves as plot_curves
38
+ from ins_pricing.plotting import diagnostics as plot_diagnostics
39
+ from ins_pricing.plotting.common import PlotStyle, finalize_figure
40
+ from ins_pricing.explain import gradients as explain_gradients
41
+ from ins_pricing.explain import permutation as explain_permutation
42
+ from ins_pricing.explain import shap_utils as explain_shap
43
+ except Exception: # pragma: no cover
44
+ plot_curves = None
45
+ plot_diagnostics = None
46
+ PlotStyle = None
47
+ finalize_figure = None
48
+ explain_gradients = None
49
+ explain_permutation = None
50
+ explain_shap = None
51
+
52
+
53
+ def _plot_skip(label: str) -> None:
54
+ if _MPL_IMPORT_ERROR is not None:
55
+ print(f"[Plot] Skip {label}: matplotlib unavailable ({_MPL_IMPORT_ERROR}).", flush=True)
56
+ else:
57
+ print(f"[Plot] Skip {label}: matplotlib unavailable.", flush=True)
58
+
59
+ # BayesOpt orchestration and SHAP utilities
60
+ # =============================================================================
61
+ class BayesOptModel:
62
+ def __init__(self, train_data, test_data,
63
+ model_nme, resp_nme, weight_nme, factor_nmes: Optional[List[str]] = None, task_type='regression',
64
+ binary_resp_nme=None,
65
+ cate_list=None, prop_test=0.25, rand_seed=None,
66
+ epochs=100, use_gpu=True,
67
+ use_resn_data_parallel: bool = False, use_ft_data_parallel: bool = False,
68
+ use_gnn_data_parallel: bool = False,
69
+ use_resn_ddp: bool = False, use_ft_ddp: bool = False,
70
+ use_gnn_ddp: bool = False,
71
+ output_dir: Optional[str] = None,
72
+ gnn_use_approx_knn: bool = True,
73
+ gnn_approx_knn_threshold: int = 50000,
74
+ gnn_graph_cache: Optional[str] = None,
75
+ gnn_max_gpu_knn_nodes: Optional[int] = 200000,
76
+ gnn_knn_gpu_mem_ratio: float = 0.9,
77
+ gnn_knn_gpu_mem_overhead: float = 2.0,
78
+ ft_role: str = "model",
79
+ ft_feature_prefix: str = "ft_emb",
80
+ ft_num_numeric_tokens: Optional[int] = None,
81
+ infer_categorical_max_unique: int = 50,
82
+ infer_categorical_max_ratio: float = 0.05,
83
+ reuse_best_params: bool = False,
84
+ xgb_max_depth_max: int = 25,
85
+ xgb_n_estimators_max: int = 500,
86
+ resn_weight_decay: Optional[float] = None,
87
+ final_ensemble: bool = False,
88
+ final_ensemble_k: int = 3,
89
+ final_refit: bool = True,
90
+ optuna_storage: Optional[str] = None,
91
+ optuna_study_prefix: Optional[str] = None,
92
+ best_params_files: Optional[Dict[str, str]] = None):
93
+ """Orchestrate BayesOpt training across multiple trainers.
94
+
95
+ Args:
96
+ train_data: Training DataFrame.
97
+ test_data: Test DataFrame.
98
+ model_nme: Model name prefix used in outputs.
99
+ resp_nme: Target column name.
100
+ weight_nme: Sample weight column name.
101
+ factor_nmes: Feature column list.
102
+ task_type: "regression" or "classification".
103
+ binary_resp_nme: Optional binary target for lift curves.
104
+ cate_list: Categorical feature list.
105
+ prop_test: Validation split ratio in CV.
106
+ rand_seed: Random seed.
107
+ epochs: NN training epochs.
108
+ use_gpu: Prefer GPU when available.
109
+ use_resn_data_parallel: Enable DataParallel for ResNet.
110
+ use_ft_data_parallel: Enable DataParallel for FTTransformer.
111
+ use_gnn_data_parallel: Enable DataParallel for GNN.
112
+ use_resn_ddp: Enable DDP for ResNet.
113
+ use_ft_ddp: Enable DDP for FTTransformer.
114
+ use_gnn_ddp: Enable DDP for GNN.
115
+ output_dir: Output root for models/results/plots.
116
+ gnn_use_approx_knn: Use approximate kNN when available.
117
+ gnn_approx_knn_threshold: Row threshold to switch to approximate kNN.
118
+ gnn_graph_cache: Optional adjacency cache path.
119
+ gnn_max_gpu_knn_nodes: Force CPU kNN above this node count to avoid OOM.
120
+ gnn_knn_gpu_mem_ratio: Fraction of free GPU memory for kNN.
121
+ gnn_knn_gpu_mem_overhead: Temporary memory multiplier for GPU kNN.
122
+ ft_num_numeric_tokens: Number of numeric tokens for FT (None = auto).
123
+ final_ensemble: Enable k-fold model averaging at the final stage.
124
+ final_ensemble_k: Number of folds for averaging.
125
+ final_refit: Refit on full data using best stopping point.
126
+ """
127
+ inferred_factors, inferred_cats = infer_factor_and_cate_list(
128
+ train_df=train_data,
129
+ test_df=test_data,
130
+ resp_nme=resp_nme,
131
+ weight_nme=weight_nme,
132
+ binary_resp_nme=binary_resp_nme,
133
+ factor_nmes=factor_nmes,
134
+ cate_list=cate_list,
135
+ infer_categorical_max_unique=int(infer_categorical_max_unique),
136
+ infer_categorical_max_ratio=float(infer_categorical_max_ratio),
137
+ )
138
+
139
+ cfg = BayesOptConfig(
140
+ model_nme=model_nme,
141
+ task_type=task_type,
142
+ resp_nme=resp_nme,
143
+ weight_nme=weight_nme,
144
+ factor_nmes=list(inferred_factors),
145
+ binary_resp_nme=binary_resp_nme,
146
+ cate_list=list(inferred_cats) if inferred_cats else None,
147
+ prop_test=prop_test,
148
+ rand_seed=rand_seed,
149
+ epochs=epochs,
150
+ use_gpu=use_gpu,
151
+ xgb_max_depth_max=int(xgb_max_depth_max),
152
+ xgb_n_estimators_max=int(xgb_n_estimators_max),
153
+ use_resn_data_parallel=use_resn_data_parallel,
154
+ use_ft_data_parallel=use_ft_data_parallel,
155
+ use_resn_ddp=use_resn_ddp,
156
+ use_gnn_data_parallel=use_gnn_data_parallel,
157
+ use_ft_ddp=use_ft_ddp,
158
+ use_gnn_ddp=use_gnn_ddp,
159
+ gnn_use_approx_knn=gnn_use_approx_knn,
160
+ gnn_approx_knn_threshold=gnn_approx_knn_threshold,
161
+ gnn_graph_cache=gnn_graph_cache,
162
+ gnn_max_gpu_knn_nodes=gnn_max_gpu_knn_nodes,
163
+ gnn_knn_gpu_mem_ratio=gnn_knn_gpu_mem_ratio,
164
+ gnn_knn_gpu_mem_overhead=gnn_knn_gpu_mem_overhead,
165
+ output_dir=output_dir,
166
+ optuna_storage=optuna_storage,
167
+ optuna_study_prefix=optuna_study_prefix,
168
+ best_params_files=best_params_files,
169
+ ft_role=str(ft_role or "model"),
170
+ ft_feature_prefix=str(ft_feature_prefix or "ft_emb"),
171
+ ft_num_numeric_tokens=ft_num_numeric_tokens,
172
+ reuse_best_params=bool(reuse_best_params),
173
+ resn_weight_decay=float(resn_weight_decay)
174
+ if resn_weight_decay is not None
175
+ else 1e-4,
176
+ final_ensemble=bool(final_ensemble),
177
+ final_ensemble_k=int(final_ensemble_k),
178
+ final_refit=bool(final_refit),
179
+ )
180
+ self.config = cfg
181
+ self.model_nme = cfg.model_nme
182
+ self.task_type = cfg.task_type
183
+ self.resp_nme = cfg.resp_nme
184
+ self.weight_nme = cfg.weight_nme
185
+ self.factor_nmes = cfg.factor_nmes
186
+ self.binary_resp_nme = cfg.binary_resp_nme
187
+ self.cate_list = list(cfg.cate_list or [])
188
+ self.prop_test = cfg.prop_test
189
+ self.epochs = cfg.epochs
190
+ self.rand_seed = cfg.rand_seed if cfg.rand_seed is not None else np.random.randint(
191
+ 1, 10000)
192
+ set_global_seed(int(self.rand_seed))
193
+ self.use_gpu = bool(cfg.use_gpu and torch.cuda.is_available())
194
+ self.output_manager = OutputManager(
195
+ cfg.output_dir or os.getcwd(), self.model_nme)
196
+
197
+ preprocessor = DatasetPreprocessor(train_data, test_data, cfg).run()
198
+ self.train_data = preprocessor.train_data
199
+ self.test_data = preprocessor.test_data
200
+ self.train_oht_data = preprocessor.train_oht_data
201
+ self.test_oht_data = preprocessor.test_oht_data
202
+ self.train_oht_scl_data = preprocessor.train_oht_scl_data
203
+ self.test_oht_scl_data = preprocessor.test_oht_scl_data
204
+ self.var_nmes = preprocessor.var_nmes
205
+ self.num_features = preprocessor.num_features
206
+ self.cat_categories_for_shap = preprocessor.cat_categories_for_shap
207
+ self.geo_token_cols: List[str] = []
208
+ self.train_geo_tokens: Optional[pd.DataFrame] = None
209
+ self.test_geo_tokens: Optional[pd.DataFrame] = None
210
+ self.geo_gnn_model: Optional[GraphNeuralNetSklearn] = None
211
+ self._add_region_effect()
212
+
213
+ self.cv = ShuffleSplit(n_splits=int(1/self.prop_test),
214
+ test_size=self.prop_test,
215
+ random_state=self.rand_seed)
216
+ if self.task_type == 'classification':
217
+ self.obj = 'binary:logistic'
218
+ else: # regression task
219
+ if 'f' in self.model_nme:
220
+ self.obj = 'count:poisson'
221
+ elif 's' in self.model_nme:
222
+ self.obj = 'reg:gamma'
223
+ elif 'bc' in self.model_nme:
224
+ self.obj = 'reg:tweedie'
225
+ else:
226
+ self.obj = 'reg:tweedie'
227
+ self.fit_params = {
228
+ 'sample_weight': self.train_data[self.weight_nme].values
229
+ }
230
+ self.model_label: List[str] = []
231
+ self.optuna_storage = cfg.optuna_storage
232
+ self.optuna_study_prefix = cfg.optuna_study_prefix or "bayesopt"
233
+
234
+ # Keep trainers in a dict for unified access and easy extension.
235
+ self.trainers: Dict[str, TrainerBase] = {
236
+ 'glm': GLMTrainer(self),
237
+ 'xgb': XGBTrainer(self),
238
+ 'resn': ResNetTrainer(self),
239
+ 'ft': FTTrainer(self),
240
+ 'gnn': GNNTrainer(self),
241
+ }
242
+ self._prepare_geo_tokens()
243
+ self.xgb_best = None
244
+ self.resn_best = None
245
+ self.gnn_best = None
246
+ self.glm_best = None
247
+ self.ft_best = None
248
+ self.best_xgb_params = None
249
+ self.best_resn_params = None
250
+ self.best_gnn_params = None
251
+ self.best_ft_params = None
252
+ self.best_xgb_trial = None
253
+ self.best_resn_trial = None
254
+ self.best_gnn_trial = None
255
+ self.best_ft_trial = None
256
+ self.best_glm_params = None
257
+ self.best_glm_trial = None
258
+ self.xgb_load = None
259
+ self.resn_load = None
260
+ self.gnn_load = None
261
+ self.ft_load = None
262
+ self.version_manager = VersionManager(self.output_manager)
263
+
264
+ def default_tweedie_power(self, obj: Optional[str] = None) -> Optional[float]:
265
+ if self.task_type == 'classification':
266
+ return None
267
+ objective = obj or getattr(self, "obj", None)
268
+ if objective == 'count:poisson':
269
+ return 1.0
270
+ if objective == 'reg:gamma':
271
+ return 2.0
272
+ return 1.5
273
+
274
+ def _build_geo_tokens(self, params_override: Optional[Dict[str, Any]] = None):
275
+ """Internal builder; allows trial overrides and returns None on failure."""
276
+ geo_cols = list(self.config.geo_feature_nmes or [])
277
+ if not geo_cols:
278
+ return None
279
+
280
+ available = [c for c in geo_cols if c in self.train_data.columns]
281
+ if not available:
282
+ return None
283
+
284
+ # Preprocess text/numeric: fill numeric with median, label-encode text, map unknowns.
285
+ proc_train = {}
286
+ proc_test = {}
287
+ for col in available:
288
+ s_train = self.train_data[col]
289
+ s_test = self.test_data[col]
290
+ if pd.api.types.is_numeric_dtype(s_train):
291
+ tr = pd.to_numeric(s_train, errors="coerce")
292
+ te = pd.to_numeric(s_test, errors="coerce")
293
+ med = np.nanmedian(tr)
294
+ proc_train[col] = np.nan_to_num(tr, nan=med).astype(np.float32)
295
+ proc_test[col] = np.nan_to_num(te, nan=med).astype(np.float32)
296
+ else:
297
+ cats = pd.Categorical(s_train.astype(str))
298
+ tr_codes = cats.codes.astype(np.float32, copy=True)
299
+ tr_codes[tr_codes < 0] = len(cats.categories)
300
+ te_cats = pd.Categorical(
301
+ s_test.astype(str), categories=cats.categories)
302
+ te_codes = te_cats.codes.astype(np.float32, copy=True)
303
+ te_codes[te_codes < 0] = len(cats.categories)
304
+ proc_train[col] = tr_codes
305
+ proc_test[col] = te_codes
306
+
307
+ train_geo_raw = pd.DataFrame(proc_train, index=self.train_data.index)
308
+ test_geo_raw = pd.DataFrame(proc_test, index=self.test_data.index)
309
+
310
+ scaler = StandardScaler()
311
+ train_geo = pd.DataFrame(
312
+ scaler.fit_transform(train_geo_raw),
313
+ columns=available,
314
+ index=self.train_data.index
315
+ )
316
+ test_geo = pd.DataFrame(
317
+ scaler.transform(test_geo_raw),
318
+ columns=available,
319
+ index=self.test_data.index
320
+ )
321
+
322
+ tw_power = self.default_tweedie_power()
323
+
324
+ cfg = params_override or {}
325
+ try:
326
+ geo_gnn = GraphNeuralNetSklearn(
327
+ model_nme=f"{self.model_nme}_geo",
328
+ input_dim=len(available),
329
+ hidden_dim=cfg.get("geo_token_hidden_dim",
330
+ self.config.geo_token_hidden_dim),
331
+ num_layers=cfg.get("geo_token_layers",
332
+ self.config.geo_token_layers),
333
+ k_neighbors=cfg.get("geo_token_k_neighbors",
334
+ self.config.geo_token_k_neighbors),
335
+ dropout=cfg.get("geo_token_dropout",
336
+ self.config.geo_token_dropout),
337
+ learning_rate=cfg.get(
338
+ "geo_token_learning_rate", self.config.geo_token_learning_rate),
339
+ epochs=int(cfg.get("geo_token_epochs",
340
+ self.config.geo_token_epochs)),
341
+ patience=5,
342
+ task_type=self.task_type,
343
+ tweedie_power=tw_power,
344
+ use_data_parallel=False,
345
+ use_ddp=False,
346
+ use_approx_knn=self.config.gnn_use_approx_knn,
347
+ approx_knn_threshold=self.config.gnn_approx_knn_threshold,
348
+ graph_cache_path=None,
349
+ max_gpu_knn_nodes=self.config.gnn_max_gpu_knn_nodes,
350
+ knn_gpu_mem_ratio=self.config.gnn_knn_gpu_mem_ratio,
351
+ knn_gpu_mem_overhead=self.config.gnn_knn_gpu_mem_overhead
352
+ )
353
+ geo_gnn.fit(
354
+ train_geo,
355
+ self.train_data[self.resp_nme],
356
+ self.train_data[self.weight_nme]
357
+ )
358
+ train_embed = geo_gnn.encode(train_geo)
359
+ test_embed = geo_gnn.encode(test_geo)
360
+ cols = [f"geo_token_{i}" for i in range(train_embed.shape[1])]
361
+ train_tokens = pd.DataFrame(
362
+ train_embed, index=self.train_data.index, columns=cols)
363
+ test_tokens = pd.DataFrame(
364
+ test_embed, index=self.test_data.index, columns=cols)
365
+ return train_tokens, test_tokens, cols, geo_gnn
366
+ except Exception as exc:
367
+ print(f"[GeoToken] Generation failed: {exc}")
368
+ return None
369
+
370
+ def _prepare_geo_tokens(self) -> None:
371
+ """Build and persist geo tokens with default config values."""
372
+ gnn_trainer = self.trainers.get("gnn")
373
+ if gnn_trainer is not None and hasattr(gnn_trainer, "prepare_geo_tokens"):
374
+ try:
375
+ gnn_trainer.prepare_geo_tokens(force=False) # type: ignore[attr-defined]
376
+ return
377
+ except Exception as exc:
378
+ print(f"[GeoToken] GNNTrainer generation failed: {exc}")
379
+
380
+ result = self._build_geo_tokens()
381
+ if result is None:
382
+ return
383
+ train_tokens, test_tokens, cols, geo_gnn = result
384
+ self.train_geo_tokens = train_tokens
385
+ self.test_geo_tokens = test_tokens
386
+ self.geo_token_cols = cols
387
+ self.geo_gnn_model = geo_gnn
388
+ print(f"[GeoToken] Generated {len(cols)}-dim geo tokens; injecting into FT.")
389
+
390
+ def _add_region_effect(self) -> None:
391
+ """Partial pooling over province/city to create a smoothed region_effect feature."""
392
+ prov_col = self.config.region_province_col
393
+ city_col = self.config.region_city_col
394
+ if not prov_col or not city_col:
395
+ return
396
+ for col in [prov_col, city_col]:
397
+ if col not in self.train_data.columns:
398
+ print(f"[RegionEffect] Missing column {col}; skipped.")
399
+ return
400
+
401
+ def safe_mean(df: pd.DataFrame) -> float:
402
+ w = df[self.weight_nme]
403
+ y = df[self.resp_nme]
404
+ denom = max(float(w.sum()), EPS)
405
+ return float((y * w).sum() / denom)
406
+
407
+ global_mean = safe_mean(self.train_data)
408
+ alpha = max(float(self.config.region_effect_alpha), 0.0)
409
+
410
+ w_all = self.train_data[self.weight_nme]
411
+ y_all = self.train_data[self.resp_nme]
412
+ yw_all = y_all * w_all
413
+
414
+ prov_sumw = w_all.groupby(self.train_data[prov_col]).sum()
415
+ prov_sumyw = yw_all.groupby(self.train_data[prov_col]).sum()
416
+ prov_mean = (prov_sumyw / prov_sumw.clip(lower=EPS)).astype(float)
417
+ prov_mean = prov_mean.fillna(global_mean)
418
+
419
+ city_sumw = self.train_data.groupby([prov_col, city_col])[
420
+ self.weight_nme].sum()
421
+ city_sumyw = yw_all.groupby(
422
+ [self.train_data[prov_col], self.train_data[city_col]]).sum()
423
+ city_df = pd.DataFrame({
424
+ "sum_w": city_sumw,
425
+ "sum_yw": city_sumyw,
426
+ })
427
+ city_df["prior"] = city_df.index.get_level_values(0).map(
428
+ prov_mean).fillna(global_mean)
429
+ city_df["effect"] = (
430
+ city_df["sum_yw"] + alpha * city_df["prior"]
431
+ ) / (city_df["sum_w"] + alpha).clip(lower=EPS)
432
+ city_effect = city_df["effect"]
433
+
434
+ def lookup_effect(df: pd.DataFrame) -> pd.Series:
435
+ idx = pd.MultiIndex.from_frame(df[[prov_col, city_col]])
436
+ effects = city_effect.reindex(idx).to_numpy(dtype=np.float64)
437
+ prov_fallback = df[prov_col].map(
438
+ prov_mean).fillna(global_mean).to_numpy(dtype=np.float64)
439
+ effects = np.where(np.isfinite(effects), effects, prov_fallback)
440
+ effects = np.where(np.isfinite(effects), effects, global_mean)
441
+ return pd.Series(effects, index=df.index, dtype=np.float32)
442
+
443
+ re_train = lookup_effect(self.train_data)
444
+ re_test = lookup_effect(self.test_data)
445
+
446
+ col_name = "region_effect"
447
+ self.train_data[col_name] = re_train
448
+ self.test_data[col_name] = re_test
449
+
450
+ # Sync into one-hot and scaled variants.
451
+ for df in [self.train_oht_data, self.test_oht_data]:
452
+ if df is not None:
453
+ df[col_name] = re_train if df is self.train_oht_data else re_test
454
+
455
+ # Standardize region_effect and propagate.
456
+ scaler = StandardScaler()
457
+ re_train_s = scaler.fit_transform(
458
+ re_train.values.reshape(-1, 1)).astype(np.float32).reshape(-1)
459
+ re_test_s = scaler.transform(
460
+ re_test.values.reshape(-1, 1)).astype(np.float32).reshape(-1)
461
+ for df in [self.train_oht_scl_data, self.test_oht_scl_data]:
462
+ if df is not None:
463
+ df[col_name] = re_train_s if df is self.train_oht_scl_data else re_test_s
464
+
465
+ # Update feature lists.
466
+ if col_name not in self.factor_nmes:
467
+ self.factor_nmes.append(col_name)
468
+ if col_name not in self.num_features:
469
+ self.num_features.append(col_name)
470
+ if self.train_oht_scl_data is not None:
471
+ excluded = {self.weight_nme, self.resp_nme}
472
+ self.var_nmes = [
473
+ col for col in self.train_oht_scl_data.columns if col not in excluded
474
+ ]
475
+
476
+ # Single-factor plotting helper.
477
+ def plot_oneway(self, n_bins=10):
478
+ if plt is None and plot_diagnostics is None:
479
+ _plot_skip("oneway plot")
480
+ return
481
+ if plot_diagnostics is None:
482
+ for c in self.factor_nmes:
483
+ fig = plt.figure(figsize=(7, 5))
484
+ if c in self.cate_list:
485
+ group_col = c
486
+ plot_source = self.train_data
487
+ else:
488
+ group_col = f'{c}_bins'
489
+ bins = pd.qcut(
490
+ self.train_data[c],
491
+ n_bins,
492
+ duplicates='drop' # Drop duplicate quantiles to avoid errors.
493
+ )
494
+ plot_source = self.train_data.assign(**{group_col: bins})
495
+ plot_data = plot_source.groupby(
496
+ [group_col], observed=True).sum(numeric_only=True)
497
+ plot_data.reset_index(inplace=True)
498
+ plot_data['act_v'] = plot_data['w_act'] / \
499
+ plot_data[self.weight_nme]
500
+ ax = fig.add_subplot(111)
501
+ ax.plot(plot_data.index, plot_data['act_v'],
502
+ label='Actual', color='red')
503
+ ax.set_title(
504
+ 'Analysis of %s : Train Data' % group_col,
505
+ fontsize=8)
506
+ plt.xticks(plot_data.index,
507
+ list(plot_data[group_col].astype(str)),
508
+ rotation=90)
509
+ if len(list(plot_data[group_col].astype(str))) > 50:
510
+ plt.xticks(fontsize=3)
511
+ else:
512
+ plt.xticks(fontsize=6)
513
+ plt.yticks(fontsize=6)
514
+ ax2 = ax.twinx()
515
+ ax2.bar(plot_data.index,
516
+ plot_data[self.weight_nme],
517
+ alpha=0.5, color='seagreen')
518
+ plt.yticks(fontsize=6)
519
+ plt.margins(0.05)
520
+ plt.subplots_adjust(wspace=0.3)
521
+ save_path = self.output_manager.plot_path(
522
+ f'00_{self.model_nme}_{group_col}_oneway.png')
523
+ plt.savefig(save_path, dpi=300)
524
+ plt.close(fig)
525
+ return
526
+
527
+ if "w_act" not in self.train_data.columns:
528
+ print("[Oneway] Missing w_act column; skip plotting.", flush=True)
529
+ return
530
+
531
+ for c in self.factor_nmes:
532
+ is_cat = c in (self.cate_list or [])
533
+ group_col = c if is_cat else f"{c}_bins"
534
+ title = f"Analysis of {group_col} : Train Data"
535
+ save_path = self.output_manager.plot_path(
536
+ f"00_{self.model_nme}_{group_col}_oneway.png"
537
+ )
538
+ plot_diagnostics.plot_oneway(
539
+ self.train_data,
540
+ feature=c,
541
+ weight_col=self.weight_nme,
542
+ target_col="w_act",
543
+ n_bins=n_bins,
544
+ is_categorical=is_cat,
545
+ title=title,
546
+ save_path=save_path,
547
+ show=False,
548
+ )
549
+
550
+ def _require_trainer(self, model_key: str) -> "TrainerBase":
551
+ trainer = self.trainers.get(model_key)
552
+ if trainer is None:
553
+ raise KeyError(f"Unknown model key: {model_key}")
554
+ return trainer
555
+
556
+ def _pred_vector_columns(self, pred_prefix: str) -> List[str]:
557
+ """Return vector feature columns like pred_<prefix>_0.. sorted by suffix."""
558
+ col_prefix = f"pred_{pred_prefix}_"
559
+ cols = [c for c in self.train_data.columns if c.startswith(col_prefix)]
560
+
561
+ def sort_key(name: str):
562
+ tail = name.rsplit("_", 1)[-1]
563
+ try:
564
+ return (0, int(tail))
565
+ except Exception:
566
+ return (1, tail)
567
+
568
+ cols.sort(key=sort_key)
569
+ return cols
570
+
571
+ def _inject_pred_features(self, pred_prefix: str) -> List[str]:
572
+ """Inject pred_<prefix> or pred_<prefix>_i columns into features and return names."""
573
+ cols = self._pred_vector_columns(pred_prefix)
574
+ if cols:
575
+ self.add_numeric_features_from_columns(cols)
576
+ return cols
577
+ scalar_col = f"pred_{pred_prefix}"
578
+ if scalar_col in self.train_data.columns:
579
+ self.add_numeric_feature_from_column(scalar_col)
580
+ return [scalar_col]
581
+ return []
582
+
583
+ def _maybe_load_best_params(self, model_key: str, trainer: "TrainerBase") -> None:
584
+ # 1) If best_params_files is specified, load and skip tuning.
585
+ best_params_files = getattr(self.config, "best_params_files", None) or {}
586
+ best_params_file = best_params_files.get(model_key)
587
+ if best_params_file and not trainer.best_params:
588
+ trainer.best_params = IOUtils.load_params_file(best_params_file)
589
+ trainer.best_trial = None
590
+ print(
591
+ f"[Optuna][{trainer.label}] Loaded best_params from {best_params_file}; skip tuning."
592
+ )
593
+
594
+ # 2) If reuse_best_params is enabled, prefer version snapshots; else load legacy CSV.
595
+ reuse_params = bool(getattr(self.config, "reuse_best_params", False))
596
+ if reuse_params and not trainer.best_params:
597
+ payload = self.version_manager.load_latest(f"{model_key}_best")
598
+ best_params = None if payload is None else payload.get("best_params")
599
+ if best_params:
600
+ trainer.best_params = best_params
601
+ trainer.best_trial = None
602
+ trainer.study_name = payload.get(
603
+ "study_name") if isinstance(payload, dict) else None
604
+ print(
605
+ f"[Optuna][{trainer.label}] Reusing best_params from versions snapshot.")
606
+ return
607
+
608
+ params_path = self.output_manager.result_path(
609
+ f'{self.model_nme}_bestparams_{trainer.label.lower()}.csv'
610
+ )
611
+ if os.path.exists(params_path):
612
+ try:
613
+ trainer.best_params = IOUtils.load_params_file(params_path)
614
+ trainer.best_trial = None
615
+ print(
616
+ f"[Optuna][{trainer.label}] Reusing best_params from {params_path}.")
617
+ except ValueError:
618
+ # Legacy compatibility: ignore empty files and continue tuning.
619
+ pass
620
+
621
+ # Generic optimization entry point.
622
+ def optimize_model(self, model_key: str, max_evals: int = 100):
623
+ if model_key not in self.trainers:
624
+ print(f"Warning: Unknown model key: {model_key}")
625
+ return
626
+
627
+ trainer = self._require_trainer(model_key)
628
+ self._maybe_load_best_params(model_key, trainer)
629
+
630
+ should_tune = not trainer.best_params
631
+ if should_tune:
632
+ if model_key == "ft" and str(self.config.ft_role) == "unsupervised_embedding":
633
+ if hasattr(trainer, "cross_val_unsupervised"):
634
+ trainer.tune(
635
+ max_evals,
636
+ objective_fn=getattr(trainer, "cross_val_unsupervised")
637
+ )
638
+ else:
639
+ raise RuntimeError(
640
+ "FT trainer does not support unsupervised Optuna objective.")
641
+ else:
642
+ trainer.tune(max_evals)
643
+
644
+ if model_key == "ft" and str(self.config.ft_role) != "model":
645
+ prefix = str(self.config.ft_feature_prefix or "ft_emb")
646
+ role = str(self.config.ft_role)
647
+ if role == "embedding":
648
+ trainer.train_as_feature(
649
+ pred_prefix=prefix, feature_mode="embedding")
650
+ elif role == "unsupervised_embedding":
651
+ trainer.pretrain_unsupervised_as_feature(
652
+ pred_prefix=prefix,
653
+ params=trainer.best_params
654
+ )
655
+ else:
656
+ raise ValueError(
657
+ f"Unsupported ft_role='{role}', expected 'model'/'embedding'/'unsupervised_embedding'.")
658
+
659
+ # Inject generated prediction/embedding columns as features (scalar or vector).
660
+ self._inject_pred_features(prefix)
661
+ # Do not add FT as a standalone model label; downstream models handle evaluation.
662
+ else:
663
+ trainer.train()
664
+
665
+ if bool(getattr(self.config, "final_ensemble", False)):
666
+ k = int(getattr(self.config, "final_ensemble_k", 3) or 3)
667
+ if k > 1:
668
+ if model_key == "ft" and str(self.config.ft_role) != "model":
669
+ pass
670
+ elif hasattr(trainer, "ensemble_predict"):
671
+ trainer.ensemble_predict(k)
672
+ else:
673
+ print(
674
+ f"[Ensemble] Trainer '{model_key}' does not support ensemble prediction.",
675
+ flush=True,
676
+ )
677
+
678
+ # Update context fields for backward compatibility.
679
+ setattr(self, f"{model_key}_best", trainer.model)
680
+ setattr(self, f"best_{model_key}_params", trainer.best_params)
681
+ setattr(self, f"best_{model_key}_trial", trainer.best_trial)
682
+ # Save a snapshot for traceability.
683
+ study_name = getattr(trainer, "study_name", None)
684
+ if study_name is None and trainer.best_trial is not None:
685
+ study_obj = getattr(trainer.best_trial, "study", None)
686
+ study_name = getattr(study_obj, "study_name", None)
687
+ snapshot = {
688
+ "model_key": model_key,
689
+ "timestamp": datetime.now().isoformat(),
690
+ "best_params": trainer.best_params,
691
+ "study_name": study_name,
692
+ "config": asdict(self.config),
693
+ }
694
+ self.version_manager.save(f"{model_key}_best", snapshot)
695
+
696
+ def add_numeric_feature_from_column(self, col_name: str) -> None:
697
+ """Add an existing column as a feature and sync one-hot/scaled tables."""
698
+ if col_name not in self.train_data.columns or col_name not in self.test_data.columns:
699
+ raise KeyError(
700
+ f"Column '{col_name}' must exist in both train_data and test_data.")
701
+
702
+ if col_name not in self.factor_nmes:
703
+ self.factor_nmes.append(col_name)
704
+ if col_name not in self.config.factor_nmes:
705
+ self.config.factor_nmes.append(col_name)
706
+
707
+ if col_name not in self.cate_list and col_name not in self.num_features:
708
+ self.num_features.append(col_name)
709
+
710
+ if self.train_oht_data is not None and self.test_oht_data is not None:
711
+ self.train_oht_data[col_name] = self.train_data[col_name].values
712
+ self.test_oht_data[col_name] = self.test_data[col_name].values
713
+ if self.train_oht_scl_data is not None and self.test_oht_scl_data is not None:
714
+ scaler = StandardScaler()
715
+ tr = self.train_data[col_name].to_numpy(
716
+ dtype=np.float32, copy=False).reshape(-1, 1)
717
+ te = self.test_data[col_name].to_numpy(
718
+ dtype=np.float32, copy=False).reshape(-1, 1)
719
+ self.train_oht_scl_data[col_name] = scaler.fit_transform(
720
+ tr).reshape(-1)
721
+ self.test_oht_scl_data[col_name] = scaler.transform(te).reshape(-1)
722
+
723
+ if col_name not in self.var_nmes:
724
+ self.var_nmes.append(col_name)
725
+
726
+ def add_numeric_features_from_columns(self, col_names: List[str]) -> None:
727
+ if not col_names:
728
+ return
729
+
730
+ missing = [
731
+ col for col in col_names
732
+ if col not in self.train_data.columns or col not in self.test_data.columns
733
+ ]
734
+ if missing:
735
+ raise KeyError(
736
+ f"Column(s) {missing} must exist in both train_data and test_data."
737
+ )
738
+
739
+ for col_name in col_names:
740
+ if col_name not in self.factor_nmes:
741
+ self.factor_nmes.append(col_name)
742
+ if col_name not in self.config.factor_nmes:
743
+ self.config.factor_nmes.append(col_name)
744
+ if col_name not in self.cate_list and col_name not in self.num_features:
745
+ self.num_features.append(col_name)
746
+ if col_name not in self.var_nmes:
747
+ self.var_nmes.append(col_name)
748
+
749
+ if self.train_oht_data is not None and self.test_oht_data is not None:
750
+ self.train_oht_data.loc[:, col_names] = self.train_data[col_names].to_numpy(copy=False)
751
+ self.test_oht_data.loc[:, col_names] = self.test_data[col_names].to_numpy(copy=False)
752
+
753
+ if self.train_oht_scl_data is not None and self.test_oht_scl_data is not None:
754
+ scaler = StandardScaler()
755
+ tr = self.train_data[col_names].to_numpy(dtype=np.float32, copy=False)
756
+ te = self.test_data[col_names].to_numpy(dtype=np.float32, copy=False)
757
+ self.train_oht_scl_data.loc[:, col_names] = scaler.fit_transform(tr)
758
+ self.test_oht_scl_data.loc[:, col_names] = scaler.transform(te)
759
+
760
+ def prepare_ft_as_feature(self, max_evals: int = 50, pred_prefix: str = "ft_feat") -> str:
761
+ """Train FT as a feature generator and return the downstream column name."""
762
+ ft_trainer = self._require_trainer("ft")
763
+ ft_trainer.tune(max_evals=max_evals)
764
+ if hasattr(ft_trainer, "train_as_feature"):
765
+ ft_trainer.train_as_feature(pred_prefix=pred_prefix)
766
+ else:
767
+ ft_trainer.train()
768
+ feature_col = f"pred_{pred_prefix}"
769
+ self.add_numeric_feature_from_column(feature_col)
770
+ return feature_col
771
+
772
+ def prepare_ft_embedding_as_features(self, max_evals: int = 50, pred_prefix: str = "ft_emb") -> List[str]:
773
+ """Train FT and inject pooled embeddings as vector features pred_<prefix>_0.. ."""
774
+ ft_trainer = self._require_trainer("ft")
775
+ ft_trainer.tune(max_evals=max_evals)
776
+ if hasattr(ft_trainer, "train_as_feature"):
777
+ ft_trainer.train_as_feature(
778
+ pred_prefix=pred_prefix, feature_mode="embedding")
779
+ else:
780
+ raise RuntimeError(
781
+ "FT trainer does not support embedding feature mode.")
782
+ cols = self._pred_vector_columns(pred_prefix)
783
+ if not cols:
784
+ raise RuntimeError(
785
+ f"No embedding columns were generated for prefix '{pred_prefix}'.")
786
+ self.add_numeric_features_from_columns(cols)
787
+ return cols
788
+
789
+ def prepare_ft_unsupervised_embedding_as_features(self,
790
+ pred_prefix: str = "ft_uemb",
791
+ params: Optional[Dict[str,
792
+ Any]] = None,
793
+ mask_prob_num: float = 0.15,
794
+ mask_prob_cat: float = 0.15,
795
+ num_loss_weight: float = 1.0,
796
+ cat_loss_weight: float = 1.0) -> List[str]:
797
+ """Export embeddings after FT self-supervised masked reconstruction pretraining."""
798
+ ft_trainer = self._require_trainer("ft")
799
+ if not hasattr(ft_trainer, "pretrain_unsupervised_as_feature"):
800
+ raise RuntimeError(
801
+ "FT trainer does not support unsupervised pretraining.")
802
+ ft_trainer.pretrain_unsupervised_as_feature(
803
+ pred_prefix=pred_prefix,
804
+ params=params,
805
+ mask_prob_num=mask_prob_num,
806
+ mask_prob_cat=mask_prob_cat,
807
+ num_loss_weight=num_loss_weight,
808
+ cat_loss_weight=cat_loss_weight
809
+ )
810
+ cols = self._pred_vector_columns(pred_prefix)
811
+ if not cols:
812
+ raise RuntimeError(
813
+ f"No embedding columns were generated for prefix '{pred_prefix}'.")
814
+ self.add_numeric_features_from_columns(cols)
815
+ return cols
816
+
817
+ # GLM Bayesian optimization wrapper.
818
+ def bayesopt_glm(self, max_evals=50):
819
+ self.optimize_model('glm', max_evals)
820
+
821
+ # XGBoost Bayesian optimization wrapper.
822
+ def bayesopt_xgb(self, max_evals=100):
823
+ self.optimize_model('xgb', max_evals)
824
+
825
+ # ResNet Bayesian optimization wrapper.
826
+ def bayesopt_resnet(self, max_evals=100):
827
+ self.optimize_model('resn', max_evals)
828
+
829
+ # GNN Bayesian optimization wrapper.
830
+ def bayesopt_gnn(self, max_evals=50):
831
+ self.optimize_model('gnn', max_evals)
832
+
833
+ # FT-Transformer Bayesian optimization wrapper.
834
+ def bayesopt_ft(self, max_evals=50):
835
+ self.optimize_model('ft', max_evals)
836
+
837
+ # Lift curve plotting.
838
+ def plot_lift(self, model_label, pred_nme, n_bins=10):
839
+ if plt is None:
840
+ _plot_skip("lift plot")
841
+ return
842
+ model_map = {
843
+ 'Xgboost': 'pred_xgb',
844
+ 'ResNet': 'pred_resn',
845
+ 'ResNetClassifier': 'pred_resn',
846
+ 'GLM': 'pred_glm',
847
+ 'GNN': 'pred_gnn',
848
+ }
849
+ if str(self.config.ft_role) == "model":
850
+ model_map.update({
851
+ 'FTTransformer': 'pred_ft',
852
+ 'FTTransformerClassifier': 'pred_ft',
853
+ })
854
+ for k, v in model_map.items():
855
+ if model_label.startswith(k):
856
+ pred_nme = v
857
+ break
858
+
859
+ datasets = []
860
+ for title, data in [
861
+ ('Lift Chart on Train Data', self.train_data),
862
+ ('Lift Chart on Test Data', self.test_data),
863
+ ]:
864
+ if 'w_act' not in data.columns or data['w_act'].isna().all():
865
+ print(
866
+ f"[Lift] Missing labels for {title}; skip.",
867
+ flush=True,
868
+ )
869
+ continue
870
+ datasets.append((title, data))
871
+
872
+ if not datasets:
873
+ print("[Lift] No labeled data available; skip plotting.", flush=True)
874
+ return
875
+
876
+ if plot_curves is None:
877
+ fig = plt.figure(figsize=(11, 5))
878
+ positions = [111] if len(datasets) == 1 else [121, 122]
879
+ for pos, (title, data) in zip(positions, datasets):
880
+ if pred_nme not in data.columns or f'w_{pred_nme}' not in data.columns:
881
+ print(
882
+ f"[Lift] Missing prediction columns in {title}; skip.",
883
+ flush=True,
884
+ )
885
+ continue
886
+ lift_df = pd.DataFrame({
887
+ 'pred': data[pred_nme].values,
888
+ 'w_pred': data[f'w_{pred_nme}'].values,
889
+ 'act': data['w_act'].values,
890
+ 'weight': data[self.weight_nme].values
891
+ })
892
+ plot_data = PlotUtils.split_data(lift_df, 'pred', 'weight', n_bins)
893
+ denom = np.maximum(plot_data['weight'], EPS)
894
+ plot_data['exp_v'] = plot_data['w_pred'] / denom
895
+ plot_data['act_v'] = plot_data['act'] / denom
896
+ plot_data = plot_data.reset_index()
897
+
898
+ ax = fig.add_subplot(pos)
899
+ PlotUtils.plot_lift_ax(ax, plot_data, title)
900
+
901
+ plt.subplots_adjust(wspace=0.3)
902
+ save_path = self.output_manager.plot_path(
903
+ f'01_{self.model_nme}_{model_label}_lift.png')
904
+ plt.savefig(save_path, dpi=300)
905
+ plt.show()
906
+ plt.close(fig)
907
+ return
908
+
909
+ style = PlotStyle() if PlotStyle else None
910
+ fig, axes = plt.subplots(1, len(datasets), figsize=(11, 5))
911
+ if len(datasets) == 1:
912
+ axes = [axes]
913
+
914
+ for ax, (title, data) in zip(axes, datasets):
915
+ pred_vals = None
916
+ if pred_nme in data.columns:
917
+ pred_vals = data[pred_nme].values
918
+ else:
919
+ w_pred_col = f"w_{pred_nme}"
920
+ if w_pred_col in data.columns:
921
+ denom = np.maximum(data[self.weight_nme].values, EPS)
922
+ pred_vals = data[w_pred_col].values / denom
923
+ if pred_vals is None:
924
+ print(
925
+ f"[Lift] Missing prediction columns in {title}; skip.",
926
+ flush=True,
927
+ )
928
+ continue
929
+
930
+ plot_curves.plot_lift_curve(
931
+ pred_vals,
932
+ data['w_act'].values,
933
+ data[self.weight_nme].values,
934
+ n_bins=n_bins,
935
+ title=title,
936
+ pred_label="Predicted",
937
+ act_label="Actual",
938
+ weight_label="Earned Exposure",
939
+ pred_weighted=False,
940
+ actual_weighted=True,
941
+ ax=ax,
942
+ show=False,
943
+ style=style,
944
+ )
945
+
946
+ plt.subplots_adjust(wspace=0.3)
947
+ save_path = self.output_manager.plot_path(
948
+ f'01_{self.model_nme}_{model_label}_lift.png')
949
+ if finalize_figure:
950
+ finalize_figure(fig, save_path=save_path, show=True, style=style)
951
+ else:
952
+ plt.savefig(save_path, dpi=300)
953
+ plt.show()
954
+ plt.close(fig)
955
+
956
+ # Double lift curve plot.
957
+ def plot_dlift(self, model_comp: List[str] = ['xgb', 'resn'], n_bins: int = 10) -> None:
958
+ # Compare two models across bins.
959
+ # Args:
960
+ # model_comp: model keys to compare (e.g., ['xgb', 'resn']).
961
+ # n_bins: number of bins for lift curves.
962
+ if plt is None:
963
+ _plot_skip("double lift plot")
964
+ return
965
+ if len(model_comp) != 2:
966
+ raise ValueError("`model_comp` must contain two models to compare.")
967
+
968
+ model_name_map = {
969
+ 'xgb': 'Xgboost',
970
+ 'resn': 'ResNet',
971
+ 'glm': 'GLM',
972
+ 'gnn': 'GNN',
973
+ }
974
+ if str(self.config.ft_role) == "model":
975
+ model_name_map['ft'] = 'FTTransformer'
976
+
977
+ name1, name2 = model_comp
978
+ if name1 not in model_name_map or name2 not in model_name_map:
979
+ raise ValueError(f"Unsupported model key. Choose from {list(model_name_map.keys())}.")
980
+
981
+ datasets = []
982
+ for data_name, data in [('Train Data', self.train_data),
983
+ ('Test Data', self.test_data)]:
984
+ if 'w_act' not in data.columns or data['w_act'].isna().all():
985
+ print(
986
+ f"[Double Lift] Missing labels for {data_name}; skip.",
987
+ flush=True,
988
+ )
989
+ continue
990
+ datasets.append((data_name, data))
991
+
992
+ if not datasets:
993
+ print("[Double Lift] No labeled data available; skip plotting.", flush=True)
994
+ return
995
+
996
+ if plot_curves is None:
997
+ fig, axes = plt.subplots(1, len(datasets), figsize=(11, 5))
998
+ if len(datasets) == 1:
999
+ axes = [axes]
1000
+
1001
+ for ax, (data_name, data) in zip(axes, datasets):
1002
+ pred1_col = f'w_pred_{name1}'
1003
+ pred2_col = f'w_pred_{name2}'
1004
+
1005
+ if pred1_col not in data.columns or pred2_col not in data.columns:
1006
+ print(
1007
+ f"Warning: missing prediction columns {pred1_col} or {pred2_col} in {data_name}. Skip plot.")
1008
+ continue
1009
+
1010
+ lift_data = pd.DataFrame({
1011
+ 'pred1': data[pred1_col].values,
1012
+ 'pred2': data[pred2_col].values,
1013
+ 'diff_ly': data[pred1_col].values / np.maximum(data[pred2_col].values, EPS),
1014
+ 'act': data['w_act'].values,
1015
+ 'weight': data[self.weight_nme].values
1016
+ })
1017
+ plot_data = PlotUtils.split_data(
1018
+ lift_data, 'diff_ly', 'weight', n_bins)
1019
+ denom = np.maximum(plot_data['act'], EPS)
1020
+ plot_data['exp_v1'] = plot_data['pred1'] / denom
1021
+ plot_data['exp_v2'] = plot_data['pred2'] / denom
1022
+ plot_data['act_v'] = plot_data['act'] / denom
1023
+ plot_data.reset_index(inplace=True)
1024
+
1025
+ label1 = model_name_map[name1]
1026
+ label2 = model_name_map[name2]
1027
+
1028
+ PlotUtils.plot_dlift_ax(
1029
+ ax, plot_data, f'Double Lift Chart on {data_name}', label1, label2)
1030
+
1031
+ plt.subplots_adjust(bottom=0.25, top=0.95, right=0.8, wspace=0.3)
1032
+ save_path = self.output_manager.plot_path(
1033
+ f'02_{self.model_nme}_dlift_{name1}_vs_{name2}.png')
1034
+ plt.savefig(save_path, dpi=300)
1035
+ plt.show()
1036
+ plt.close(fig)
1037
+ return
1038
+
1039
+ style = PlotStyle() if PlotStyle else None
1040
+ fig, axes = plt.subplots(1, len(datasets), figsize=(11, 5))
1041
+ if len(datasets) == 1:
1042
+ axes = [axes]
1043
+
1044
+ label1 = model_name_map[name1]
1045
+ label2 = model_name_map[name2]
1046
+
1047
+ for ax, (data_name, data) in zip(axes, datasets):
1048
+ weight_vals = data[self.weight_nme].values
1049
+ pred1 = None
1050
+ pred2 = None
1051
+
1052
+ pred1_col = f"pred_{name1}"
1053
+ pred2_col = f"pred_{name2}"
1054
+ if pred1_col in data.columns:
1055
+ pred1 = data[pred1_col].values
1056
+ else:
1057
+ w_pred1_col = f"w_pred_{name1}"
1058
+ if w_pred1_col in data.columns:
1059
+ pred1 = data[w_pred1_col].values / np.maximum(weight_vals, EPS)
1060
+
1061
+ if pred2_col in data.columns:
1062
+ pred2 = data[pred2_col].values
1063
+ else:
1064
+ w_pred2_col = f"w_pred_{name2}"
1065
+ if w_pred2_col in data.columns:
1066
+ pred2 = data[w_pred2_col].values / np.maximum(weight_vals, EPS)
1067
+
1068
+ if pred1 is None or pred2 is None:
1069
+ print(
1070
+ f"Warning: missing pred_{name1}/pred_{name2} or w_pred columns in {data_name}. Skip plot.")
1071
+ continue
1072
+
1073
+ plot_curves.plot_double_lift_curve(
1074
+ pred1,
1075
+ pred2,
1076
+ data['w_act'].values,
1077
+ weight_vals,
1078
+ n_bins=n_bins,
1079
+ title=f"Double Lift Chart on {data_name}",
1080
+ label1=label1,
1081
+ label2=label2,
1082
+ pred1_weighted=False,
1083
+ pred2_weighted=False,
1084
+ actual_weighted=True,
1085
+ ax=ax,
1086
+ show=False,
1087
+ style=style,
1088
+ )
1089
+
1090
+ plt.subplots_adjust(bottom=0.25, top=0.95, right=0.8, wspace=0.3)
1091
+ save_path = self.output_manager.plot_path(
1092
+ f'02_{self.model_nme}_dlift_{name1}_vs_{name2}.png')
1093
+ if finalize_figure:
1094
+ finalize_figure(fig, save_path=save_path, show=True, style=style)
1095
+ else:
1096
+ plt.savefig(save_path, dpi=300)
1097
+ plt.show()
1098
+ plt.close(fig)
1099
+
1100
+ # Conversion lift curve plot.
1101
+ def plot_conversion_lift(self, model_pred_col: str, n_bins: int = 20):
1102
+ if plt is None:
1103
+ _plot_skip("conversion lift plot")
1104
+ return
1105
+ if not self.binary_resp_nme:
1106
+ print("Error: `binary_resp_nme` not provided at BayesOptModel init; cannot plot conversion lift.")
1107
+ return
1108
+
1109
+ if plot_curves is None:
1110
+ fig, axes = plt.subplots(1, 2, figsize=(14, 6), sharey=True)
1111
+ datasets = {
1112
+ 'Train Data': self.train_data,
1113
+ 'Test Data': self.test_data
1114
+ }
1115
+
1116
+ for ax, (data_name, data) in zip(axes, datasets.items()):
1117
+ if model_pred_col not in data.columns:
1118
+ print(f"Warning: missing prediction column '{model_pred_col}' in {data_name}. Skip plot.")
1119
+ continue
1120
+
1121
+ # Sort by model prediction and compute bins.
1122
+ plot_data = data.sort_values(by=model_pred_col).copy()
1123
+ plot_data['cum_weight'] = plot_data[self.weight_nme].cumsum()
1124
+ total_weight = plot_data[self.weight_nme].sum()
1125
+
1126
+ if total_weight > EPS:
1127
+ plot_data['bin'] = pd.cut(
1128
+ plot_data['cum_weight'],
1129
+ bins=n_bins,
1130
+ labels=False,
1131
+ right=False
1132
+ )
1133
+ else:
1134
+ plot_data['bin'] = 0
1135
+
1136
+ # Aggregate by bins.
1137
+ lift_agg = plot_data.groupby('bin').agg(
1138
+ total_weight=(self.weight_nme, 'sum'),
1139
+ actual_conversions=(self.binary_resp_nme, 'sum'),
1140
+ weighted_conversions=('w_binary_act', 'sum'),
1141
+ avg_pred=(model_pred_col, 'mean')
1142
+ ).reset_index()
1143
+
1144
+ # Compute conversion rate.
1145
+ lift_agg['conversion_rate'] = lift_agg['weighted_conversions'] / \
1146
+ lift_agg['total_weight']
1147
+
1148
+ # Compute overall average conversion rate.
1149
+ overall_conversion_rate = data['w_binary_act'].sum(
1150
+ ) / data[self.weight_nme].sum()
1151
+ ax.axhline(y=overall_conversion_rate, color='gray', linestyle='--',
1152
+ label=f'Overall Avg Rate ({overall_conversion_rate:.2%})')
1153
+
1154
+ ax.plot(lift_agg['bin'], lift_agg['conversion_rate'],
1155
+ marker='o', linestyle='-', label='Actual Conversion Rate')
1156
+ ax.set_title(f'Conversion Rate Lift Chart on {data_name}')
1157
+ ax.set_xlabel(f'Model Score Decile (based on {model_pred_col})')
1158
+ ax.set_ylabel('Conversion Rate')
1159
+ ax.grid(True, linestyle='--', alpha=0.6)
1160
+ ax.legend()
1161
+
1162
+ plt.tight_layout()
1163
+ plt.show()
1164
+ return
1165
+
1166
+ fig, axes = plt.subplots(1, 2, figsize=(14, 6), sharey=True)
1167
+ datasets = {
1168
+ 'Train Data': self.train_data,
1169
+ 'Test Data': self.test_data
1170
+ }
1171
+
1172
+ for ax, (data_name, data) in zip(axes, datasets.items()):
1173
+ if model_pred_col not in data.columns:
1174
+ print(f"Warning: missing prediction column '{model_pred_col}' in {data_name}. Skip plot.")
1175
+ continue
1176
+
1177
+ plot_curves.plot_conversion_lift(
1178
+ data[model_pred_col].values,
1179
+ data[self.binary_resp_nme].values,
1180
+ data[self.weight_nme].values,
1181
+ n_bins=n_bins,
1182
+ title=f'Conversion Rate Lift Chart on {data_name}',
1183
+ ax=ax,
1184
+ show=False,
1185
+ )
1186
+
1187
+ plt.tight_layout()
1188
+ plt.show()
1189
+
1190
+ # ========= Lightweight explainability: Permutation Importance =========
1191
+ def compute_permutation_importance(self,
1192
+ model_key: str,
1193
+ on_train: bool = True,
1194
+ metric: Any = "auto",
1195
+ n_repeats: int = 5,
1196
+ max_rows: int = 5000,
1197
+ random_state: Optional[int] = None):
1198
+ if explain_permutation is None:
1199
+ raise RuntimeError("explain.permutation is not available.")
1200
+
1201
+ model_key = str(model_key)
1202
+ data = self.train_data if on_train else self.test_data
1203
+ if self.resp_nme not in data.columns:
1204
+ raise RuntimeError("Missing response column for permutation importance.")
1205
+ y = data[self.resp_nme]
1206
+ w = data[self.weight_nme] if self.weight_nme in data.columns else None
1207
+
1208
+ if model_key == "resn":
1209
+ if self.resn_best is None:
1210
+ raise RuntimeError("ResNet model not trained.")
1211
+ X = self.train_oht_scl_data if on_train else self.test_oht_scl_data
1212
+ if X is None:
1213
+ raise RuntimeError("Missing standardized features for ResNet.")
1214
+ X = X[self.var_nmes]
1215
+ predict_fn = lambda df: self.resn_best.predict(df)
1216
+ elif model_key == "ft":
1217
+ if self.ft_best is None:
1218
+ raise RuntimeError("FT model not trained.")
1219
+ if str(self.config.ft_role) != "model":
1220
+ raise RuntimeError("FT role is not 'model'; FT predictions unavailable.")
1221
+ X = data[self.factor_nmes]
1222
+ geo_tokens = self.train_geo_tokens if on_train else self.test_geo_tokens
1223
+ geo_np = None
1224
+ if geo_tokens is not None:
1225
+ geo_np = geo_tokens.to_numpy(dtype=np.float32, copy=False)
1226
+ predict_fn = lambda df, geo=geo_np: self.ft_best.predict(df, geo_tokens=geo)
1227
+ elif model_key == "xgb":
1228
+ if self.xgb_best is None:
1229
+ raise RuntimeError("XGB model not trained.")
1230
+ X = data[self.factor_nmes]
1231
+ predict_fn = lambda df: self.xgb_best.predict(df)
1232
+ else:
1233
+ raise ValueError("Unsupported model_key for permutation importance.")
1234
+
1235
+ return explain_permutation.permutation_importance(
1236
+ predict_fn,
1237
+ X,
1238
+ y,
1239
+ sample_weight=w,
1240
+ metric=metric,
1241
+ task_type=self.task_type,
1242
+ n_repeats=n_repeats,
1243
+ random_state=random_state,
1244
+ max_rows=max_rows,
1245
+ )
1246
+
1247
+ # ========= Deep explainability: Integrated Gradients =========
1248
+ def compute_integrated_gradients_resn(self,
1249
+ on_train: bool = True,
1250
+ baseline: Any = None,
1251
+ steps: int = 50,
1252
+ batch_size: int = 256,
1253
+ target: Optional[int] = None):
1254
+ if explain_gradients is None:
1255
+ raise RuntimeError("explain.gradients is not available.")
1256
+ if self.resn_best is None:
1257
+ raise RuntimeError("ResNet model not trained.")
1258
+ X = self.train_oht_scl_data if on_train else self.test_oht_scl_data
1259
+ if X is None:
1260
+ raise RuntimeError("Missing standardized features for ResNet.")
1261
+ X = X[self.var_nmes]
1262
+ return explain_gradients.resnet_integrated_gradients(
1263
+ self.resn_best,
1264
+ X,
1265
+ baseline=baseline,
1266
+ steps=steps,
1267
+ batch_size=batch_size,
1268
+ target=target,
1269
+ )
1270
+
1271
+ def compute_integrated_gradients_ft(self,
1272
+ on_train: bool = True,
1273
+ geo_tokens: Optional[np.ndarray] = None,
1274
+ baseline_num: Any = None,
1275
+ baseline_geo: Any = None,
1276
+ steps: int = 50,
1277
+ batch_size: int = 256,
1278
+ target: Optional[int] = None):
1279
+ if explain_gradients is None:
1280
+ raise RuntimeError("explain.gradients is not available.")
1281
+ if self.ft_best is None:
1282
+ raise RuntimeError("FT model not trained.")
1283
+ if str(self.config.ft_role) != "model":
1284
+ raise RuntimeError("FT role is not 'model'; FT explanations unavailable.")
1285
+
1286
+ data = self.train_data if on_train else self.test_data
1287
+ X = data[self.factor_nmes]
1288
+
1289
+ if geo_tokens is None and getattr(self.ft_best, "num_geo", 0) > 0:
1290
+ tokens_df = self.train_geo_tokens if on_train else self.test_geo_tokens
1291
+ if tokens_df is not None:
1292
+ geo_tokens = tokens_df.to_numpy(dtype=np.float32, copy=False)
1293
+
1294
+ return explain_gradients.ft_integrated_gradients(
1295
+ self.ft_best,
1296
+ X,
1297
+ geo_tokens=geo_tokens,
1298
+ baseline_num=baseline_num,
1299
+ baseline_geo=baseline_geo,
1300
+ steps=steps,
1301
+ batch_size=batch_size,
1302
+ target=target,
1303
+ )
1304
+
1305
+ # Save model
1306
+ def save_model(self, model_name=None):
1307
+ keys = [model_name] if model_name else self.trainers.keys()
1308
+ for key in keys:
1309
+ if key in self.trainers:
1310
+ self.trainers[key].save()
1311
+ else:
1312
+ if model_name: # Only warn when the user specifies a model name.
1313
+ print(f"[save_model] Warning: Unknown model key {key}")
1314
+
1315
+ def load_model(self, model_name=None):
1316
+ keys = [model_name] if model_name else self.trainers.keys()
1317
+ for key in keys:
1318
+ if key in self.trainers:
1319
+ self.trainers[key].load()
1320
+ # Sync context fields.
1321
+ trainer = self.trainers[key]
1322
+ if trainer.model is not None:
1323
+ setattr(self, f"{key}_best", trainer.model)
1324
+ # For legacy compatibility, also update xxx_load.
1325
+ # Old versions only tracked xgb_load/resn_load/ft_load (not glm_load/gnn_load).
1326
+ if key in ['xgb', 'resn', 'ft', 'gnn']:
1327
+ setattr(self, f"{key}_load", trainer.model)
1328
+ else:
1329
+ if model_name:
1330
+ print(f"[load_model] Warning: Unknown model key {key}")
1331
+
1332
+ def _sample_rows(self, data: pd.DataFrame, n: int) -> pd.DataFrame:
1333
+ if len(data) == 0:
1334
+ return data
1335
+ return data.sample(min(len(data), n), random_state=self.rand_seed)
1336
+
1337
+ @staticmethod
1338
+ def _shap_nsamples(arr: np.ndarray, max_nsamples: int = 300) -> int:
1339
+ min_needed = arr.shape[1] + 2
1340
+ return max(min_needed, min(max_nsamples, arr.shape[0] * arr.shape[1]))
1341
+
1342
+ def _build_ft_shap_matrix(self, data: pd.DataFrame) -> np.ndarray:
1343
+ matrices = []
1344
+ for col in self.factor_nmes:
1345
+ s = data[col]
1346
+ if col in self.cate_list:
1347
+ cats = pd.Categorical(
1348
+ s,
1349
+ categories=self.cat_categories_for_shap[col]
1350
+ )
1351
+ codes = np.asarray(cats.codes, dtype=np.float64).reshape(-1, 1)
1352
+ matrices.append(codes)
1353
+ else:
1354
+ vals = pd.to_numeric(s, errors="coerce")
1355
+ arr = vals.to_numpy(dtype=np.float64, copy=True).reshape(-1, 1)
1356
+ matrices.append(arr)
1357
+ X_mat = np.concatenate(matrices, axis=1) # Result shape (N, F)
1358
+ return X_mat
1359
+
1360
+ def _decode_ft_shap_matrix_to_df(self, X_mat: np.ndarray) -> pd.DataFrame:
1361
+ data_dict = {}
1362
+ for j, col in enumerate(self.factor_nmes):
1363
+ col_vals = X_mat[:, j]
1364
+ if col in self.cate_list:
1365
+ cats = self.cat_categories_for_shap[col]
1366
+ codes = np.round(col_vals).astype(int)
1367
+ codes = np.clip(codes, -1, len(cats) - 1)
1368
+ cat_series = pd.Categorical.from_codes(
1369
+ codes,
1370
+ categories=cats
1371
+ )
1372
+ data_dict[col] = cat_series
1373
+ else:
1374
+ data_dict[col] = col_vals.astype(float)
1375
+
1376
+ df = pd.DataFrame(data_dict, columns=self.factor_nmes)
1377
+ for col in self.cate_list:
1378
+ if col in df.columns:
1379
+ df[col] = df[col].astype("category")
1380
+ return df
1381
+
1382
+ def _build_glm_design(self, data: pd.DataFrame) -> pd.DataFrame:
1383
+ X = data[self.var_nmes]
1384
+ return sm.add_constant(X, has_constant='add')
1385
+
1386
+ def _compute_shap_core(self,
1387
+ model_key: str,
1388
+ n_background: int,
1389
+ n_samples: int,
1390
+ on_train: bool,
1391
+ X_df: pd.DataFrame,
1392
+ prep_fn,
1393
+ predict_fn,
1394
+ cleanup_fn=None):
1395
+ if explain_shap is None:
1396
+ raise RuntimeError("explain.shap_utils is not available.")
1397
+ return explain_shap.compute_shap_core(
1398
+ self,
1399
+ model_key,
1400
+ n_background,
1401
+ n_samples,
1402
+ on_train,
1403
+ X_df=X_df,
1404
+ prep_fn=prep_fn,
1405
+ predict_fn=predict_fn,
1406
+ cleanup_fn=cleanup_fn,
1407
+ )
1408
+
1409
+ # ========= GLM SHAP explainability =========
1410
+ def compute_shap_glm(self, n_background: int = 500,
1411
+ n_samples: int = 200,
1412
+ on_train: bool = True):
1413
+ if explain_shap is None:
1414
+ raise RuntimeError("explain.shap_utils is not available.")
1415
+ self.shap_glm = explain_shap.compute_shap_glm(
1416
+ self,
1417
+ n_background=n_background,
1418
+ n_samples=n_samples,
1419
+ on_train=on_train,
1420
+ )
1421
+ return self.shap_glm
1422
+
1423
+ # ========= XGBoost SHAP explainability =========
1424
+ def compute_shap_xgb(self, n_background: int = 500,
1425
+ n_samples: int = 200,
1426
+ on_train: bool = True):
1427
+ if explain_shap is None:
1428
+ raise RuntimeError("explain.shap_utils is not available.")
1429
+ self.shap_xgb = explain_shap.compute_shap_xgb(
1430
+ self,
1431
+ n_background=n_background,
1432
+ n_samples=n_samples,
1433
+ on_train=on_train,
1434
+ )
1435
+ return self.shap_xgb
1436
+
1437
+ # ========= ResNet SHAP explainability =========
1438
+ def _resn_predict_wrapper(self, X_np):
1439
+ model = self.resn_best.resnet.to("cpu")
1440
+ with torch.no_grad():
1441
+ X_tensor = torch.tensor(X_np, dtype=torch.float32)
1442
+ y_pred = model(X_tensor).cpu().numpy()
1443
+ y_pred = np.clip(y_pred, 1e-6, None)
1444
+ return y_pred.reshape(-1)
1445
+
1446
+ def compute_shap_resn(self, n_background: int = 500,
1447
+ n_samples: int = 200,
1448
+ on_train: bool = True):
1449
+ if explain_shap is None:
1450
+ raise RuntimeError("explain.shap_utils is not available.")
1451
+ self.shap_resn = explain_shap.compute_shap_resn(
1452
+ self,
1453
+ n_background=n_background,
1454
+ n_samples=n_samples,
1455
+ on_train=on_train,
1456
+ )
1457
+ return self.shap_resn
1458
+
1459
+ # ========= FT-Transformer SHAP explainability =========
1460
+ def _ft_shap_predict_wrapper(self, X_mat: np.ndarray) -> np.ndarray:
1461
+ df_input = self._decode_ft_shap_matrix_to_df(X_mat)
1462
+ y_pred = self.ft_best.predict(df_input)
1463
+ return np.asarray(y_pred, dtype=np.float64).reshape(-1)
1464
+
1465
+ def compute_shap_ft(self, n_background: int = 500,
1466
+ n_samples: int = 200,
1467
+ on_train: bool = True):
1468
+ if explain_shap is None:
1469
+ raise RuntimeError("explain.shap_utils is not available.")
1470
+ self.shap_ft = explain_shap.compute_shap_ft(
1471
+ self,
1472
+ n_background=n_background,
1473
+ n_samples=n_samples,
1474
+ on_train=on_train,
1475
+ )
1476
+ return self.shap_ft