ins-pricing 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. ins_pricing/README.md +60 -0
  2. ins_pricing/__init__.py +102 -0
  3. ins_pricing/governance/README.md +18 -0
  4. ins_pricing/governance/__init__.py +20 -0
  5. ins_pricing/governance/approval.py +93 -0
  6. ins_pricing/governance/audit.py +37 -0
  7. ins_pricing/governance/registry.py +99 -0
  8. ins_pricing/governance/release.py +159 -0
  9. ins_pricing/modelling/BayesOpt.py +146 -0
  10. ins_pricing/modelling/BayesOpt_USAGE.md +925 -0
  11. ins_pricing/modelling/BayesOpt_entry.py +575 -0
  12. ins_pricing/modelling/BayesOpt_incremental.py +731 -0
  13. ins_pricing/modelling/Explain_Run.py +36 -0
  14. ins_pricing/modelling/Explain_entry.py +539 -0
  15. ins_pricing/modelling/Pricing_Run.py +36 -0
  16. ins_pricing/modelling/README.md +33 -0
  17. ins_pricing/modelling/__init__.py +44 -0
  18. ins_pricing/modelling/bayesopt/__init__.py +98 -0
  19. ins_pricing/modelling/bayesopt/config_preprocess.py +303 -0
  20. ins_pricing/modelling/bayesopt/core.py +1476 -0
  21. ins_pricing/modelling/bayesopt/models.py +2196 -0
  22. ins_pricing/modelling/bayesopt/trainers.py +2446 -0
  23. ins_pricing/modelling/bayesopt/utils.py +1021 -0
  24. ins_pricing/modelling/cli_common.py +136 -0
  25. ins_pricing/modelling/explain/__init__.py +55 -0
  26. ins_pricing/modelling/explain/gradients.py +334 -0
  27. ins_pricing/modelling/explain/metrics.py +176 -0
  28. ins_pricing/modelling/explain/permutation.py +155 -0
  29. ins_pricing/modelling/explain/shap_utils.py +146 -0
  30. ins_pricing/modelling/notebook_utils.py +284 -0
  31. ins_pricing/modelling/plotting/__init__.py +45 -0
  32. ins_pricing/modelling/plotting/common.py +63 -0
  33. ins_pricing/modelling/plotting/curves.py +572 -0
  34. ins_pricing/modelling/plotting/diagnostics.py +139 -0
  35. ins_pricing/modelling/plotting/geo.py +362 -0
  36. ins_pricing/modelling/plotting/importance.py +121 -0
  37. ins_pricing/modelling/run_logging.py +133 -0
  38. ins_pricing/modelling/tests/conftest.py +8 -0
  39. ins_pricing/modelling/tests/test_cross_val_generic.py +66 -0
  40. ins_pricing/modelling/tests/test_distributed_utils.py +18 -0
  41. ins_pricing/modelling/tests/test_explain.py +56 -0
  42. ins_pricing/modelling/tests/test_geo_tokens_split.py +49 -0
  43. ins_pricing/modelling/tests/test_graph_cache.py +33 -0
  44. ins_pricing/modelling/tests/test_plotting.py +63 -0
  45. ins_pricing/modelling/tests/test_plotting_library.py +150 -0
  46. ins_pricing/modelling/tests/test_preprocessor.py +48 -0
  47. ins_pricing/modelling/watchdog_run.py +211 -0
  48. ins_pricing/pricing/README.md +44 -0
  49. ins_pricing/pricing/__init__.py +27 -0
  50. ins_pricing/pricing/calibration.py +39 -0
  51. ins_pricing/pricing/data_quality.py +117 -0
  52. ins_pricing/pricing/exposure.py +85 -0
  53. ins_pricing/pricing/factors.py +91 -0
  54. ins_pricing/pricing/monitoring.py +99 -0
  55. ins_pricing/pricing/rate_table.py +78 -0
  56. ins_pricing/production/__init__.py +21 -0
  57. ins_pricing/production/drift.py +30 -0
  58. ins_pricing/production/monitoring.py +143 -0
  59. ins_pricing/production/scoring.py +40 -0
  60. ins_pricing/reporting/README.md +20 -0
  61. ins_pricing/reporting/__init__.py +11 -0
  62. ins_pricing/reporting/report_builder.py +72 -0
  63. ins_pricing/reporting/scheduler.py +45 -0
  64. ins_pricing/setup.py +41 -0
  65. ins_pricing v2/__init__.py +23 -0
  66. ins_pricing v2/governance/__init__.py +20 -0
  67. ins_pricing v2/governance/approval.py +93 -0
  68. ins_pricing v2/governance/audit.py +37 -0
  69. ins_pricing v2/governance/registry.py +99 -0
  70. ins_pricing v2/governance/release.py +159 -0
  71. ins_pricing v2/modelling/Explain_Run.py +36 -0
  72. ins_pricing v2/modelling/Pricing_Run.py +36 -0
  73. ins_pricing v2/modelling/__init__.py +151 -0
  74. ins_pricing v2/modelling/cli_common.py +141 -0
  75. ins_pricing v2/modelling/config.py +249 -0
  76. ins_pricing v2/modelling/config_preprocess.py +254 -0
  77. ins_pricing v2/modelling/core.py +741 -0
  78. ins_pricing v2/modelling/data_container.py +42 -0
  79. ins_pricing v2/modelling/explain/__init__.py +55 -0
  80. ins_pricing v2/modelling/explain/gradients.py +334 -0
  81. ins_pricing v2/modelling/explain/metrics.py +176 -0
  82. ins_pricing v2/modelling/explain/permutation.py +155 -0
  83. ins_pricing v2/modelling/explain/shap_utils.py +146 -0
  84. ins_pricing v2/modelling/features.py +215 -0
  85. ins_pricing v2/modelling/model_manager.py +148 -0
  86. ins_pricing v2/modelling/model_plotting.py +463 -0
  87. ins_pricing v2/modelling/models.py +2203 -0
  88. ins_pricing v2/modelling/notebook_utils.py +294 -0
  89. ins_pricing v2/modelling/plotting/__init__.py +45 -0
  90. ins_pricing v2/modelling/plotting/common.py +63 -0
  91. ins_pricing v2/modelling/plotting/curves.py +572 -0
  92. ins_pricing v2/modelling/plotting/diagnostics.py +139 -0
  93. ins_pricing v2/modelling/plotting/geo.py +362 -0
  94. ins_pricing v2/modelling/plotting/importance.py +121 -0
  95. ins_pricing v2/modelling/run_logging.py +133 -0
  96. ins_pricing v2/modelling/tests/conftest.py +8 -0
  97. ins_pricing v2/modelling/tests/test_cross_val_generic.py +66 -0
  98. ins_pricing v2/modelling/tests/test_distributed_utils.py +18 -0
  99. ins_pricing v2/modelling/tests/test_explain.py +56 -0
  100. ins_pricing v2/modelling/tests/test_geo_tokens_split.py +49 -0
  101. ins_pricing v2/modelling/tests/test_graph_cache.py +33 -0
  102. ins_pricing v2/modelling/tests/test_plotting.py +63 -0
  103. ins_pricing v2/modelling/tests/test_plotting_library.py +150 -0
  104. ins_pricing v2/modelling/tests/test_preprocessor.py +48 -0
  105. ins_pricing v2/modelling/trainers.py +2447 -0
  106. ins_pricing v2/modelling/utils.py +1020 -0
  107. ins_pricing v2/modelling/watchdog_run.py +211 -0
  108. ins_pricing v2/pricing/__init__.py +27 -0
  109. ins_pricing v2/pricing/calibration.py +39 -0
  110. ins_pricing v2/pricing/data_quality.py +117 -0
  111. ins_pricing v2/pricing/exposure.py +85 -0
  112. ins_pricing v2/pricing/factors.py +91 -0
  113. ins_pricing v2/pricing/monitoring.py +99 -0
  114. ins_pricing v2/pricing/rate_table.py +78 -0
  115. ins_pricing v2/production/__init__.py +21 -0
  116. ins_pricing v2/production/drift.py +30 -0
  117. ins_pricing v2/production/monitoring.py +143 -0
  118. ins_pricing v2/production/scoring.py +40 -0
  119. ins_pricing v2/reporting/__init__.py +11 -0
  120. ins_pricing v2/reporting/report_builder.py +72 -0
  121. ins_pricing v2/reporting/scheduler.py +45 -0
  122. ins_pricing v2/scripts/BayesOpt_incremental.py +722 -0
  123. ins_pricing v2/scripts/Explain_entry.py +545 -0
  124. ins_pricing v2/scripts/__init__.py +1 -0
  125. ins_pricing v2/scripts/train.py +568 -0
  126. ins_pricing v2/setup.py +55 -0
  127. ins_pricing v2/smoke_test.py +28 -0
  128. ins_pricing-0.1.6.dist-info/METADATA +78 -0
  129. ins_pricing-0.1.6.dist-info/RECORD +169 -0
  130. ins_pricing-0.1.6.dist-info/WHEEL +5 -0
  131. ins_pricing-0.1.6.dist-info/top_level.txt +4 -0
  132. user_packages/__init__.py +105 -0
  133. user_packages legacy/BayesOpt.py +5659 -0
  134. user_packages legacy/BayesOpt_entry.py +513 -0
  135. user_packages legacy/BayesOpt_incremental.py +685 -0
  136. user_packages legacy/Pricing_Run.py +36 -0
  137. user_packages legacy/Try/BayesOpt Legacy251213.py +3719 -0
  138. user_packages legacy/Try/BayesOpt Legacy251215.py +3758 -0
  139. user_packages legacy/Try/BayesOpt lagecy251201.py +3506 -0
  140. user_packages legacy/Try/BayesOpt lagecy251218.py +3992 -0
  141. user_packages legacy/Try/BayesOpt legacy.py +3280 -0
  142. user_packages legacy/Try/BayesOpt.py +838 -0
  143. user_packages legacy/Try/BayesOptAll.py +1569 -0
  144. user_packages legacy/Try/BayesOptAllPlatform.py +909 -0
  145. user_packages legacy/Try/BayesOptCPUGPU.py +1877 -0
  146. user_packages legacy/Try/BayesOptSearch.py +830 -0
  147. user_packages legacy/Try/BayesOptSearchOrigin.py +829 -0
  148. user_packages legacy/Try/BayesOptV1.py +1911 -0
  149. user_packages legacy/Try/BayesOptV10.py +2973 -0
  150. user_packages legacy/Try/BayesOptV11.py +3001 -0
  151. user_packages legacy/Try/BayesOptV12.py +3001 -0
  152. user_packages legacy/Try/BayesOptV2.py +2065 -0
  153. user_packages legacy/Try/BayesOptV3.py +2209 -0
  154. user_packages legacy/Try/BayesOptV4.py +2342 -0
  155. user_packages legacy/Try/BayesOptV5.py +2372 -0
  156. user_packages legacy/Try/BayesOptV6.py +2759 -0
  157. user_packages legacy/Try/BayesOptV7.py +2832 -0
  158. user_packages legacy/Try/BayesOptV8Codex.py +2731 -0
  159. user_packages legacy/Try/BayesOptV8Gemini.py +2614 -0
  160. user_packages legacy/Try/BayesOptV9.py +2927 -0
  161. user_packages legacy/Try/BayesOpt_entry legacy.py +313 -0
  162. user_packages legacy/Try/ModelBayesOptSearch.py +359 -0
  163. user_packages legacy/Try/ResNetBayesOptSearch.py +249 -0
  164. user_packages legacy/Try/XgbBayesOptSearch.py +121 -0
  165. user_packages legacy/Try/xgbbayesopt.py +523 -0
  166. user_packages legacy/__init__.py +19 -0
  167. user_packages legacy/cli_common.py +124 -0
  168. user_packages legacy/notebook_utils.py +228 -0
  169. user_packages legacy/watchdog_run.py +202 -0
@@ -0,0 +1,2614 @@
1
+ # 数据在 CPU 和 GPU 之间传输成本较高,可通过多条 CUDA 流并行搬运与计算来支撑更大数据集。
2
+
3
+ import copy
4
+ import gc
5
+ import math
6
+ import os
7
+ from dataclasses import dataclass
8
+ from pathlib import Path
9
+ from typing import Any, Dict, List, Optional
10
+ import csv
11
+
12
+ import joblib
13
+ import matplotlib.pyplot as plt
14
+ import numpy as np # 1.26.2
15
+ import optuna # 4.3.0
16
+ import pandas as pd # 2.2.3
17
+ import shap
18
+ import statsmodels.api as sm
19
+
20
+ import torch # 版本: 1.10.1+cu111
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ import xgboost as xgb # 1.7.0
24
+
25
+ from torch.utils.data import Dataset, DataLoader, TensorDataset
26
+ from torch.cuda.amp import autocast, GradScaler
27
+ from torch.nn.utils import clip_grad_norm_
28
+ from sklearn.model_selection import ShuffleSplit, cross_val_score # 1.2.2
29
+ from sklearn.preprocessing import StandardScaler
30
+ from sklearn.metrics import log_loss, make_scorer, mean_tweedie_deviance
31
+
32
+ # 常量与工具模块
33
+ # =============================================================================
34
+ torch.backends.cudnn.benchmark = True
35
+ EPS = 1e-8
36
+
37
+
38
+ class IOUtils:
39
+ # 文件与路径处理的小工具集合。
40
+
41
+ @staticmethod
42
+ def csv_to_dict(file_path: str) -> List[Dict[str, Any]]:
43
+ with open(file_path, mode='r', encoding='utf-8') as file:
44
+ reader = csv.DictReader(file)
45
+ return [
46
+ dict(filter(lambda item: item[0] != '', row.items()))
47
+ for row in reader
48
+ ]
49
+
50
+ @staticmethod
51
+ def ensure_parent_dir(file_path: str) -> None:
52
+ # 若目标文件所在目录不存在则自动创建
53
+ directory = os.path.dirname(file_path)
54
+ if directory:
55
+ os.makedirs(directory, exist_ok=True)
56
+
57
+
58
+ class TrainingUtils:
59
+ # 训练阶段常用的小型辅助函数集合。
60
+
61
+ @staticmethod
62
+ def compute_batch_size(data_size: int, learning_rate: float, batch_num: int, minimum: int) -> int:
63
+ estimated = int((learning_rate / 1e-4) ** 0.5 *
64
+ (data_size / max(batch_num, 1)))
65
+ return max(1, min(data_size, max(minimum, estimated)))
66
+
67
+ @staticmethod
68
+ def tweedie_loss(pred, target, p=1.5, eps=1e-6, max_clip=1e6):
69
+ # 为确保稳定性先将预测值裁剪为正数
70
+ pred_clamped = torch.clamp(pred, min=eps)
71
+ if p == 1:
72
+ term1 = target * torch.log(target / pred_clamped + eps) # 泊松
73
+ term2 = -target + pred_clamped
74
+ term3 = 0
75
+ elif p == 0:
76
+ term1 = 0.5 * torch.pow(target - pred_clamped, 2) # 高斯
77
+ term2 = 0
78
+ term3 = 0
79
+ elif p == 2:
80
+ term1 = torch.log(pred_clamped / target + eps) # 伽马
81
+ term2 = -target / pred_clamped + 1
82
+ term3 = 0
83
+ else:
84
+ term1 = torch.pow(target, 2 - p) / ((1 - p) * (2 - p))
85
+ term2 = target * torch.pow(pred_clamped, 1 - p) / (1 - p)
86
+ term3 = torch.pow(pred_clamped, 2 - p) / (2 - p)
87
+ return torch.nan_to_num( # Tweedie 负对数似然(忽略常数项)
88
+ 2 * (term1 - term2 + term3),
89
+ nan=eps,
90
+ posinf=max_clip,
91
+ neginf=-max_clip
92
+ )
93
+
94
+ @staticmethod
95
+ def free_cuda() -> None:
96
+ print(">>> Moving all models to CPU...")
97
+ for obj in gc.get_objects():
98
+ try:
99
+ if hasattr(obj, "to") and callable(obj.to):
100
+ obj.to("cpu")
101
+ except Exception:
102
+ pass
103
+
104
+ print(">>> Deleting tensors, optimizers, dataloaders...")
105
+ gc.collect()
106
+
107
+ print(">>> Emptying CUDA cache...")
108
+ torch.cuda.empty_cache()
109
+ torch.cuda.synchronize()
110
+
111
+ print(">>> CUDA memory freed.")
112
+
113
+
114
+ class PlotUtils:
115
+ # 多种模型共享的绘图辅助工具。
116
+
117
+ @staticmethod
118
+ def split_data(data: pd.DataFrame, col_nme: str, wgt_nme: str, n_bins: int = 10) -> pd.DataFrame:
119
+ data_sorted = data.sort_values(by=col_nme, ascending=True).copy()
120
+ data_sorted['cum_weight'] = data_sorted[wgt_nme].cumsum()
121
+ w_sum = data_sorted[wgt_nme].sum()
122
+ if w_sum <= EPS:
123
+ data_sorted.loc[:, 'bins'] = 0
124
+ else:
125
+ data_sorted.loc[:, 'bins'] = np.floor(
126
+ data_sorted['cum_weight'] * float(n_bins) / w_sum
127
+ )
128
+ data_sorted.loc[(data_sorted['bins'] == n_bins),
129
+ 'bins'] = n_bins - 1
130
+ return data_sorted.groupby(['bins'], observed=True).sum(numeric_only=True)
131
+
132
+ @staticmethod
133
+ def plot_lift_ax(ax, plot_data, title, pred_label='Predicted', act_label='Actual', weight_label='Earned Exposure'):
134
+ ax.plot(plot_data.index, plot_data['act_v'],
135
+ label=act_label, color='red')
136
+ ax.plot(plot_data.index, plot_data['exp_v'],
137
+ label=pred_label, color='blue')
138
+ ax.set_title(title, fontsize=8)
139
+ ax.set_xticks(plot_data.index)
140
+ ax.set_xticklabels(plot_data.index, rotation=90, fontsize=6)
141
+ ax.tick_params(axis='y', labelsize=6)
142
+ ax.legend(loc='upper left', fontsize=5, frameon=False)
143
+ ax.margins(0.05)
144
+ ax2 = ax.twinx()
145
+ ax2.bar(plot_data.index, plot_data['weight'],
146
+ alpha=0.5, color='seagreen',
147
+ label=weight_label)
148
+ ax2.tick_params(axis='y', labelsize=6)
149
+ ax2.legend(loc='upper right', fontsize=5, frameon=False)
150
+
151
+ @staticmethod
152
+ def plot_dlift_ax(ax, plot_data, title, label1, label2, act_label='Actual', weight_label='Earned Exposure'):
153
+ ax.plot(plot_data.index, plot_data['act_v'],
154
+ label=act_label, color='red')
155
+ ax.plot(plot_data.index, plot_data['exp_v1'],
156
+ label=label1, color='blue')
157
+ ax.plot(plot_data.index, plot_data['exp_v2'],
158
+ label=label2, color='black')
159
+ ax.set_title(title, fontsize=8)
160
+ ax.set_xticks(plot_data.index)
161
+ ax.set_xticklabels(plot_data.index, rotation=90, fontsize=6)
162
+ ax.set_xlabel(f'{label1} / {label2}', fontsize=6)
163
+ ax.tick_params(axis='y', labelsize=6)
164
+ ax.legend(loc='upper left', fontsize=5, frameon=False)
165
+ ax.margins(0.1)
166
+ ax2 = ax.twinx()
167
+ ax2.bar(plot_data.index, plot_data['weight'],
168
+ alpha=0.5, color='seagreen',
169
+ label=weight_label)
170
+ ax2.tick_params(axis='y', labelsize=6)
171
+ ax2.legend(loc='upper right', fontsize=5, frameon=False)
172
+
173
+ @staticmethod
174
+ def plot_lift_list(pred_model, w_pred_list, w_act_list,
175
+ weight_list, tgt_nme, n_bins: int = 10,
176
+ fig_nme: str = 'Lift Chart'):
177
+ lift_data = pd.DataFrame()
178
+ lift_data.loc[:, 'pred'] = pred_model
179
+ lift_data.loc[:, 'w_pred'] = w_pred_list
180
+ lift_data.loc[:, 'act'] = w_act_list
181
+ lift_data.loc[:, 'weight'] = weight_list
182
+ plot_data = PlotUtils.split_data(lift_data, 'pred', 'weight', n_bins)
183
+ plot_data['exp_v'] = plot_data['w_pred'] / plot_data['weight']
184
+ plot_data['act_v'] = plot_data['act'] / plot_data['weight']
185
+ plot_data.reset_index(inplace=True)
186
+
187
+ fig = plt.figure(figsize=(7, 5))
188
+ ax = fig.add_subplot(111)
189
+ PlotUtils.plot_lift_ax(ax, plot_data, f'Lift Chart of {tgt_nme}')
190
+ plt.subplots_adjust(wspace=0.3)
191
+
192
+ save_path = os.path.join(
193
+ os.getcwd(), 'plot', f'05_{tgt_nme}_{fig_nme}.png')
194
+ IOUtils.ensure_parent_dir(save_path)
195
+ plt.savefig(save_path, dpi=300)
196
+ plt.close(fig)
197
+
198
+ @staticmethod
199
+ def plot_dlift_list(pred_model_1, pred_model_2,
200
+ model_nme_1, model_nme_2,
201
+ tgt_nme,
202
+ w_list, w_act_list, n_bins: int = 10,
203
+ fig_nme: str = 'Double Lift Chart'):
204
+ lift_data = pd.DataFrame()
205
+ lift_data.loc[:, 'pred1'] = pred_model_1
206
+ lift_data.loc[:, 'pred2'] = pred_model_2
207
+ lift_data.loc[:, 'diff_ly'] = lift_data['pred1'] / lift_data['pred2']
208
+ lift_data.loc[:, 'act'] = w_act_list
209
+ lift_data.loc[:, 'weight'] = w_list
210
+ lift_data.loc[:, 'w_pred1'] = lift_data['pred1'] * lift_data['weight']
211
+ lift_data.loc[:, 'w_pred2'] = lift_data['pred2'] * lift_data['weight']
212
+ plot_data = PlotUtils.split_data(
213
+ lift_data, 'diff_ly', 'weight', n_bins)
214
+ plot_data['exp_v1'] = plot_data['w_pred1'] / plot_data['act']
215
+ plot_data['exp_v2'] = plot_data['w_pred2'] / plot_data['act']
216
+ plot_data['act_v'] = plot_data['act']/plot_data['act']
217
+ plot_data.reset_index(inplace=True)
218
+
219
+ fig = plt.figure(figsize=(7, 5))
220
+ ax = fig.add_subplot(111)
221
+ PlotUtils.plot_dlift_ax(ax, plot_data, f'Double Lift Chart of {tgt_nme}', model_nme_1, model_nme_2)
222
+ plt.subplots_adjust(bottom=0.25, top=0.95, right=0.8)
223
+
224
+ save_path = os.path.join(
225
+ os.getcwd(), 'plot', f'06_{tgt_nme}_{fig_nme}.png')
226
+ IOUtils.ensure_parent_dir(save_path)
227
+ plt.savefig(save_path, dpi=300)
228
+ plt.close(fig)
229
+
230
+
231
+ # 向后兼容的函数式封装
232
+ def csv_to_dict(file_path: str) -> List[Dict[str, Any]]:
233
+ return IOUtils.csv_to_dict(file_path)
234
+
235
+
236
+ def ensure_parent_dir(file_path: str) -> None:
237
+ IOUtils.ensure_parent_dir(file_path)
238
+
239
+
240
+ def compute_batch_size(data_size: int, learning_rate: float, batch_num: int, minimum: int) -> int:
241
+ return TrainingUtils.compute_batch_size(data_size, learning_rate, batch_num, minimum)
242
+
243
+
244
+ # 定义在 PyTorch 环境下的 Tweedie 偏差损失函数
245
+ # 参考文档:https://scikit-learn.org/stable/modules/model_evaluation.html#mean-poisson-gamma-and-tweedie-deviances
246
+ def tweedie_loss(pred, target, p=1.5, eps=1e-6, max_clip=1e6):
247
+ return TrainingUtils.tweedie_loss(pred, target, p=p, eps=eps, max_clip=max_clip)
248
+
249
+
250
+ # 定义释放CUDA内存函数
251
+ def free_cuda():
252
+ TrainingUtils.free_cuda()
253
+
254
+
255
+ class TorchTrainerMixin:
256
+ # 面向 Torch 表格训练器的共享工具方法。
257
+
258
+ def _device_type(self) -> str:
259
+ return getattr(self, "device", torch.device("cpu")).type
260
+
261
+ def _build_dataloader(self,
262
+ dataset,
263
+ N: int,
264
+ base_bs_gpu: tuple,
265
+ base_bs_cpu: tuple,
266
+ min_bs: int = 64,
267
+ target_effective_cuda: int = 8192,
268
+ target_effective_cpu: int = 4096,
269
+ large_threshold: int = 200_000,
270
+ mid_threshold: int = 50_000):
271
+ batch_size = TrainingUtils.compute_batch_size(
272
+ data_size=len(dataset),
273
+ learning_rate=self.learning_rate,
274
+ batch_num=self.batch_num,
275
+ minimum=min_bs
276
+ )
277
+ gpu_large, gpu_mid, gpu_small = base_bs_gpu
278
+ cpu_mid, cpu_small = base_bs_cpu
279
+
280
+ if self._device_type() == 'cuda':
281
+ device_count = torch.cuda.device_count()
282
+ # 多卡环境下,适当增大最小批量,确保每张卡都能分到足够数据
283
+ if device_count > 1:
284
+ min_bs = min_bs * device_count
285
+ print(f">>> Multi-GPU detected: {device_count} devices. Adjusted min_bs to {min_bs}.")
286
+
287
+ if N > large_threshold:
288
+ base_bs = gpu_large * device_count
289
+ elif N > mid_threshold:
290
+ base_bs = gpu_mid * device_count
291
+ else:
292
+ base_bs = gpu_small * device_count
293
+ else:
294
+ base_bs = cpu_mid if N > mid_threshold else cpu_small
295
+
296
+ # 重新计算 batch_size,确保不小于调整后的 min_bs
297
+ batch_size = TrainingUtils.compute_batch_size(
298
+ data_size=len(dataset),
299
+ learning_rate=self.learning_rate,
300
+ batch_num=self.batch_num,
301
+ minimum=min_bs
302
+ )
303
+ batch_size = min(batch_size, base_bs, N)
304
+
305
+ target_effective_bs = target_effective_cuda if self._device_type(
306
+ ) == 'cuda' else target_effective_cpu
307
+ accum_steps = max(1, target_effective_bs // batch_size)
308
+
309
+ print(f">>> DataLoader config: Batch Size={batch_size}, Accum Steps={accum_steps}, Workers={min(8, os.cpu_count() or 1)}")
310
+
311
+ # Linux (posix) 采用 fork 更高效;Windows (nt) 使用 spawn,开销更大。
312
+ if os.name == 'nt':
313
+ workers = 0
314
+ else:
315
+ workers = min(8, os.cpu_count() or 1)
316
+
317
+ dataloader = DataLoader(
318
+ dataset,
319
+ batch_size=batch_size,
320
+ shuffle=True,
321
+ num_workers=workers,
322
+ pin_memory=(self._device_type() == 'cuda')
323
+ )
324
+ return dataloader, accum_steps
325
+
326
+ def _compute_weighted_loss(self, y_pred, y_true, weights, apply_softplus: bool = False):
327
+ task = getattr(self, "task_type", "regression")
328
+ if task == 'classification':
329
+ loss_fn = nn.BCEWithLogitsLoss(reduction='none')
330
+ losses = loss_fn(y_pred, y_true).view(-1)
331
+ else:
332
+ if apply_softplus:
333
+ y_pred = F.softplus(y_pred)
334
+ y_pred = torch.clamp(y_pred, min=1e-6)
335
+ power = getattr(self, "tw_power", 1.5)
336
+ losses = tweedie_loss(y_pred, y_true, p=power).view(-1)
337
+ weighted_loss = (losses * weights.view(-1)).sum() / \
338
+ torch.clamp(weights.sum(), min=EPS)
339
+ return weighted_loss
340
+
341
+ def _early_stop_update(self, val_loss, best_loss, best_state, patience_counter, model):
342
+ if val_loss < best_loss:
343
+ return val_loss, copy.deepcopy(model.state_dict()), 0, False
344
+ patience_counter += 1
345
+ should_stop = best_state is not None and patience_counter >= getattr(
346
+ self, "patience", 0)
347
+ return best_loss, best_state, patience_counter, should_stop
348
+
349
+ def _train_model(self,
350
+ model,
351
+ dataloader,
352
+ accum_steps,
353
+ optimizer,
354
+ scaler,
355
+ forward_fn,
356
+ val_forward_fn=None,
357
+ apply_softplus: bool = False,
358
+ clip_fn=None,
359
+ trial: Optional[optuna.trial.Trial] = None):
360
+ device_type = self._device_type()
361
+ best_loss = float('inf')
362
+ best_state = None
363
+ patience_counter = 0
364
+ stop_training = False
365
+
366
+ for epoch in range(1, getattr(self, "epochs", 1) + 1):
367
+ model.train()
368
+ optimizer.zero_grad()
369
+
370
+ for step, batch in enumerate(dataloader):
371
+ with autocast(enabled=(device_type == 'cuda')):
372
+ y_pred, y_true, w = forward_fn(batch)
373
+ weighted_loss = self._compute_weighted_loss(
374
+ y_pred, y_true, w, apply_softplus=apply_softplus)
375
+ loss_for_backward = weighted_loss / accum_steps
376
+
377
+ scaler.scale(loss_for_backward).backward()
378
+
379
+ if ((step + 1) % accum_steps == 0) or ((step + 1) == len(dataloader)):
380
+ if clip_fn is not None:
381
+ clip_fn()
382
+ scaler.step(optimizer)
383
+ scaler.update()
384
+ optimizer.zero_grad()
385
+
386
+ if val_forward_fn is not None:
387
+ model.eval()
388
+ with torch.no_grad(), autocast(enabled=(device_type == 'cuda')):
389
+ val_result = val_forward_fn()
390
+ if isinstance(val_result, tuple) and len(val_result) == 3:
391
+ y_val_pred, y_val_true, w_val = val_result
392
+ val_weighted_loss = self._compute_weighted_loss(
393
+ y_val_pred, y_val_true, w_val, apply_softplus=apply_softplus)
394
+ else:
395
+ val_weighted_loss = val_result
396
+
397
+ best_loss, best_state, patience_counter, stop_training = self._early_stop_update(
398
+ val_weighted_loss, best_loss, best_state, patience_counter, model)
399
+
400
+ # Optuna 剪枝:若评估值劣于历史表现则提前中止该 trial
401
+ if trial is not None:
402
+ trial.report(val_weighted_loss, epoch)
403
+ if trial.should_prune():
404
+ raise optuna.TrialPruned()
405
+
406
+ if stop_training:
407
+ break
408
+
409
+ return best_state
410
+
411
+
412
+ # =============================================================================
413
+ # 绘图辅助模块
414
+ # =============================================================================
415
+
416
+ def split_data(data, col_nme, wgt_nme, n_bins=10):
417
+ return PlotUtils.split_data(data, col_nme, wgt_nme, n_bins)
418
+
419
+ # 定义提纯曲线(Lift)绘制函数
420
+
421
+
422
+ def plot_lift_list(pred_model, w_pred_list, w_act_list,
423
+ weight_list, tgt_nme, n_bins=10,
424
+ fig_nme='Lift Chart'):
425
+ return PlotUtils.plot_lift_list(pred_model, w_pred_list, w_act_list,
426
+ weight_list, tgt_nme, n_bins, fig_nme)
427
+
428
+ # 定义双提纯曲线绘制函数
429
+
430
+
431
+ def plot_dlift_list(pred_model_1, pred_model_2,
432
+ model_nme_1, model_nme_2,
433
+ tgt_nme,
434
+ w_list, w_act_list, n_bins=10,
435
+ fig_nme='Double Lift Chart'):
436
+ return PlotUtils.plot_dlift_list(pred_model_1, pred_model_2,
437
+ model_nme_1, model_nme_2,
438
+ tgt_nme, w_list, w_act_list,
439
+ n_bins, fig_nme)
440
+
441
+
442
+ # =============================================================================
443
+ # ResNet 模型与 sklearn 风格封装
444
+ # =============================================================================
445
+
446
+ # 开始定义ResNet模型结构
447
+ # 残差块:两层线性 + ReLU + 残差连接
448
+ # ResBlock 继承 nn.Module
449
+ class ResBlock(nn.Module):
450
+ def __init__(self, dim: int, dropout: float = 0.1,
451
+ use_layernorm: bool = False, residual_scale: float = 0.1
452
+ ):
453
+ super().__init__()
454
+ self.use_layernorm = use_layernorm
455
+
456
+ if use_layernorm:
457
+ Norm = nn.LayerNorm # 对最后一维做归一化
458
+ else:
459
+ def Norm(d): return nn.BatchNorm1d(d) # 保留一个开关,想试 BN 时也能用
460
+
461
+ self.norm1 = Norm(dim)
462
+ self.fc1 = nn.Linear(dim, dim, bias=True)
463
+ self.act = nn.ReLU(inplace=True)
464
+ self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()
465
+ self.norm2 = Norm(dim)
466
+ self.fc2 = nn.Linear(dim, dim, bias=True)
467
+
468
+ # 残差缩放,防止一开始就把主干搞炸
469
+ self.res_scale = nn.Parameter(
470
+ torch.tensor(residual_scale, dtype=torch.float32)
471
+ )
472
+
473
+ def forward(self, x):
474
+ # 前置激活结构
475
+ out = self.norm1(x)
476
+ out = self.fc1(out)
477
+ out = self.act(out)
478
+ out = self.dropout(out)
479
+ out = self.norm2(out)
480
+ out = self.fc2(out)
481
+ # 残差缩放再相加
482
+ return F.relu(x + self.res_scale * out)
483
+
484
+ # ResNetSequential 继承 nn.Module,定义整个网络结构
485
+
486
+
487
+ class ResNetSequential(nn.Module):
488
+ # 输入张量形状:(batch, input_dim)
489
+ # 网络结构:全连接 + 归一化 + ReLU,再堆叠若干残差块,最后输出 Softplus
490
+
491
+ def __init__(self, input_dim: int, hidden_dim: int = 64, block_num: int = 2,
492
+ use_layernorm: bool = True, dropout: float = 0.1,
493
+ residual_scale: float = 0.1, task_type: str = 'regression'):
494
+ super(ResNetSequential, self).__init__()
495
+
496
+ self.net = nn.Sequential()
497
+ self.net.add_module('fc1', nn.Linear(input_dim, hidden_dim))
498
+
499
+ if use_layernorm:
500
+ self.net.add_module('norm1', nn.LayerNorm(hidden_dim))
501
+ else:
502
+ self.net.add_module('norm1', nn.BatchNorm1d(hidden_dim))
503
+
504
+ self.net.add_module('relu1', nn.ReLU(inplace=True))
505
+
506
+ # 多个残差块
507
+ for i in range(block_num):
508
+ self.net.add_module(
509
+ f'ResBlk_{i+1}',
510
+ ResBlock(
511
+ hidden_dim,
512
+ dropout=dropout,
513
+ use_layernorm=use_layernorm,
514
+ residual_scale=residual_scale)
515
+ )
516
+
517
+ self.net.add_module('fc_out', nn.Linear(hidden_dim, 1))
518
+
519
+ if task_type == 'classification':
520
+ self.net.add_module('softplus', nn.Identity())
521
+ else:
522
+ self.net.add_module('softplus', nn.Softplus())
523
+
524
+ def forward(self, x):
525
+ if self.training and not hasattr(self, '_printed_device'):
526
+ print(f">>> ResNetSequential executing on device: {x.device}")
527
+ self._printed_device = True
528
+ return self.net(x)
529
+
530
+ # 定义ResNet模型的Scikit-Learn接口类
531
+
532
+
533
+ class ResNetSklearn(TorchTrainerMixin, nn.Module):
534
+ def __init__(self, model_nme: str, input_dim: int, hidden_dim: int = 64,
535
+ block_num: int = 2, batch_num: int = 100, epochs: int = 100,
536
+ task_type: str = 'regression',
537
+ tweedie_power: float = 1.5, learning_rate: float = 0.01, patience: int = 10,
538
+ use_layernorm: bool = True, dropout: float = 0.1,
539
+ residual_scale: float = 0.1,
540
+ use_data_parallel: bool = True):
541
+ super(ResNetSklearn, self).__init__()
542
+
543
+ self.input_dim = input_dim
544
+ self.hidden_dim = hidden_dim
545
+ self.block_num = block_num
546
+ self.batch_num = batch_num
547
+ self.epochs = epochs
548
+ self.task_type = task_type
549
+ self.model_nme = model_nme
550
+ self.learning_rate = learning_rate
551
+ self.patience = patience
552
+ self.use_layernorm = use_layernorm
553
+ self.dropout = dropout
554
+ self.residual_scale = residual_scale
555
+
556
+ # 设备选择:cuda > mps > cpu
557
+ if torch.cuda.is_available():
558
+ self.device = torch.device('cuda')
559
+ elif torch.backends.mps.is_available():
560
+ self.device = torch.device('mps')
561
+ else:
562
+ self.device = torch.device('cpu')
563
+
564
+ # Tweedie 幂指数设定(分类时不使用)
565
+ if self.task_type == 'classification':
566
+ self.tw_power = None
567
+ elif 'f' in self.model_nme:
568
+ self.tw_power = 1
569
+ elif 's' in self.model_nme:
570
+ self.tw_power = 2
571
+ else:
572
+ self.tw_power = tweedie_power
573
+
574
+ # 搭建网络(先在 CPU 上建好)
575
+ core = ResNetSequential(
576
+ self.input_dim,
577
+ self.hidden_dim,
578
+ self.block_num,
579
+ use_layernorm=self.use_layernorm,
580
+ dropout=self.dropout,
581
+ residual_scale=self.residual_scale,
582
+ task_type=self.task_type
583
+ )
584
+
585
+ # ===== 多卡支持:DataParallel =====
586
+ if use_data_parallel and (self.device.type == 'cuda') and (torch.cuda.device_count() > 1):
587
+ core = nn.DataParallel(core, device_ids=list(
588
+ range(torch.cuda.device_count())))
589
+ # DataParallel 会把输入 scatter 到多卡上,但“主设备”仍然是 cuda:0
590
+ self.device = torch.device('cuda')
591
+
592
+ self.resnet = core.to(self.device)
593
+
594
+ # ================ 内部工具 ================
595
+ def _build_train_val_tensors(self, X_train, y_train, w_train, X_val, y_val, w_val):
596
+ X_tensor = torch.tensor(X_train.values, dtype=torch.float32)
597
+ y_tensor = torch.tensor(
598
+ y_train.values, dtype=torch.float32).view(-1, 1)
599
+ w_tensor = torch.tensor(w_train.values, dtype=torch.float32).view(
600
+ -1, 1) if w_train is not None else torch.ones_like(y_tensor)
601
+
602
+ has_val = X_val is not None and y_val is not None
603
+ if has_val:
604
+ X_val_tensor = torch.tensor(X_val.values, dtype=torch.float32)
605
+ y_val_tensor = torch.tensor(
606
+ y_val.values, dtype=torch.float32).view(-1, 1)
607
+ w_val_tensor = torch.tensor(w_val.values, dtype=torch.float32).view(
608
+ -1, 1) if w_val is not None else torch.ones_like(y_val_tensor)
609
+ else:
610
+ X_val_tensor = y_val_tensor = w_val_tensor = None
611
+ return X_tensor, y_tensor, w_tensor, X_val_tensor, y_val_tensor, w_val_tensor, has_val
612
+
613
+ def forward(self, x):
614
+ # 处理 SHAP 的 NumPy 输入
615
+ if isinstance(x, np.ndarray):
616
+ x_tensor = torch.tensor(x, dtype=torch.float32)
617
+ else:
618
+ x_tensor = x
619
+
620
+ x_tensor = x_tensor.to(self.device)
621
+ y_pred = self.resnet(x_tensor)
622
+ return y_pred
623
+
624
+ # ---------------- 训练 ----------------
625
+
626
+ def fit(self, X_train, y_train, w_train=None,
627
+ X_val=None, y_val=None, w_val=None, trial=None):
628
+
629
+ X_tensor, y_tensor, w_tensor, X_val_tensor, y_val_tensor, w_val_tensor, has_val = \
630
+ self._build_train_val_tensors(
631
+ X_train, y_train, w_train, X_val, y_val, w_val)
632
+
633
+ dataset = TensorDataset(X_tensor, y_tensor, w_tensor)
634
+ dataloader, accum_steps = self._build_dataloader(
635
+ dataset,
636
+ N=X_tensor.shape[0],
637
+ base_bs_gpu=(2048, 1024, 512),
638
+ base_bs_cpu=(1024, 512),
639
+ min_bs=64,
640
+ target_effective_cuda=8192,
641
+ target_effective_cpu=4096
642
+ )
643
+
644
+ # === 4. 优化器与 AMP ===
645
+ self.optimizer = torch.optim.Adam(
646
+ self.resnet.parameters(), lr=self.learning_rate)
647
+ self.scaler = GradScaler(enabled=(self.device.type == 'cuda'))
648
+
649
+ X_val_dev = y_val_dev = w_val_dev = None
650
+ val_dataloader = None
651
+ if has_val:
652
+ # 构建验证集 DataLoader
653
+ val_dataset = TensorDataset(X_val_tensor, y_val_tensor, w_val_tensor)
654
+ # 验证阶段无需反向传播,可适当放大批量以提高吞吐
655
+ val_bs = accum_steps * dataloader.batch_size
656
+
657
+ # 验证集的 worker 数沿用相同的分配逻辑
658
+ if os.name == 'nt':
659
+ val_workers = 0
660
+ else:
661
+ val_workers = min(4, os.cpu_count() or 1)
662
+
663
+ val_dataloader = DataLoader(
664
+ val_dataset,
665
+ batch_size=val_bs,
666
+ shuffle=False,
667
+ num_workers=val_workers,
668
+ pin_memory=(self.device.type == 'cuda')
669
+ )
670
+
671
+ def forward_fn(batch):
672
+ X_batch, y_batch, w_batch = batch
673
+ X_batch = X_batch.to(self.device, non_blocking=True)
674
+ y_batch = y_batch.to(self.device, non_blocking=True)
675
+ w_batch = w_batch.to(self.device, non_blocking=True)
676
+ y_pred = self.resnet(X_batch)
677
+ return y_pred, y_batch, w_batch
678
+
679
+ def val_forward_fn():
680
+ total_loss = 0.0
681
+ total_weight = 0.0
682
+ for batch in val_dataloader:
683
+ X_b, y_b, w_b = batch
684
+ X_b = X_b.to(self.device, non_blocking=True)
685
+ y_b = y_b.to(self.device, non_blocking=True)
686
+ w_b = w_b.to(self.device, non_blocking=True)
687
+
688
+ y_pred = self.resnet(X_b)
689
+
690
+ # 手动计算当前批次的加权损失,以便后续精确加总
691
+ task = getattr(self, "task_type", "regression")
692
+ if task == 'classification':
693
+ loss_fn = nn.BCEWithLogitsLoss(reduction='none')
694
+ losses = loss_fn(y_pred, y_b).view(-1)
695
+ else:
696
+ # 此处无需再做 softplus:训练时 apply_softplus=False,模型前向结果本身已为正
697
+ y_pred_clamped = torch.clamp(y_pred, min=1e-6)
698
+ power = getattr(self, "tw_power", 1.5)
699
+ losses = tweedie_loss(y_pred_clamped, y_b, p=power).view(-1)
700
+
701
+ batch_weight_sum = torch.clamp(w_b.sum(), min=EPS)
702
+ batch_weighted_loss_sum = (losses * w_b.view(-1)).sum()
703
+
704
+ total_loss += batch_weighted_loss_sum.item()
705
+ total_weight += batch_weight_sum.item()
706
+
707
+ return total_loss / max(total_weight, EPS)
708
+
709
+ clip_fn = None
710
+ if self.device.type == 'cuda':
711
+ def clip_fn(): return (self.scaler.unscale_(self.optimizer),
712
+ clip_grad_norm_(self.resnet.parameters(), max_norm=1.0))
713
+
714
+ best_state = self._train_model(
715
+ self.resnet,
716
+ dataloader,
717
+ accum_steps,
718
+ self.optimizer,
719
+ self.scaler,
720
+ forward_fn,
721
+ val_forward_fn if has_val else None,
722
+ apply_softplus=False,
723
+ clip_fn=clip_fn,
724
+ trial=trial
725
+ )
726
+
727
+ if has_val and best_state is not None:
728
+ self.resnet.load_state_dict(best_state)
729
+
730
+ # ---------------- 预测 ----------------
731
+
732
+ def predict(self, X_test):
733
+ self.resnet.eval()
734
+ if isinstance(X_test, pd.DataFrame):
735
+ X_np = X_test.values.astype(np.float32)
736
+ else:
737
+ X_np = X_test
738
+
739
+ with torch.no_grad():
740
+ y_pred = self(X_np).cpu().numpy()
741
+
742
+ if self.task_type == 'classification':
743
+ y_pred = 1 / (1 + np.exp(-y_pred)) # Sigmoid 函数将 logit 转换为概率
744
+ else:
745
+ y_pred = np.clip(y_pred, 1e-6, None)
746
+ return y_pred.flatten()
747
+
748
+ # ---------------- 设置参数 ----------------
749
+
750
+ def set_params(self, params):
751
+ for key, value in params.items():
752
+ if hasattr(self, key):
753
+ setattr(self, key, value)
754
+ else:
755
+ raise ValueError(f"Parameter {key} not found in model.")
756
+ return self
757
+
758
+
759
+ # =============================================================================
760
+ # FT-Transformer 模型与 sklearn 风格封装
761
+ # =============================================================================
762
+ # 开始定义FT Transformer模型结构
763
+
764
+
765
+ class FeatureTokenizer(nn.Module):
766
+ # 将数值特征与类别特征统一映射为 token,输出形状为 (batch, token_num, d_model)
767
+ # 约定:
768
+ # - X_num:表示数值特征,shape=(batch, num_numeric)
769
+ # - X_cat:表示类别特征,shape=(batch, num_categorical),每列是编码后的整数标签 [0, card-1]
770
+
771
+ def __init__(self, num_numeric: int, cat_cardinalities, d_model: int):
772
+ super().__init__()
773
+
774
+ self.num_numeric = num_numeric
775
+ self.has_numeric = num_numeric > 0
776
+
777
+ if self.has_numeric:
778
+ self.num_linear = nn.Linear(num_numeric, d_model)
779
+
780
+ self.embeddings = nn.ModuleList([
781
+ nn.Embedding(card, d_model) for card in cat_cardinalities
782
+ ])
783
+
784
+ def forward(self, X_num, X_cat):
785
+ tokens = []
786
+
787
+ if self.has_numeric:
788
+ # 数值特征整体映射为一个 token
789
+ num_token = self.num_linear(X_num) # shape = (batch, d_model)
790
+ tokens.append(num_token)
791
+
792
+ # 每个类别特征各生成一个嵌入 token
793
+ for i, emb in enumerate(self.embeddings):
794
+ tok = emb(X_cat[:, i]) # shape = (batch, d_model)
795
+ tokens.append(tok)
796
+
797
+ # 拼接后得到 (batch, token_num, d_model)
798
+ x = torch.stack(tokens, dim=1)
799
+ return x
800
+
801
+ # 定义具有残差缩放的Encoder层
802
+
803
+
804
+ class ScaledTransformerEncoderLayer(nn.Module):
805
+ def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048,
806
+ dropout: float = 0.1, residual_scale_attn: float = 1.0,
807
+ residual_scale_ffn: float = 1.0, norm_first: bool = True,
808
+ ):
809
+ super().__init__()
810
+ self.self_attn = nn.MultiheadAttention(
811
+ embed_dim=d_model,
812
+ num_heads=nhead,
813
+ dropout=dropout,
814
+ batch_first=True
815
+ )
816
+
817
+ # 前馈网络部分
818
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
819
+ self.dropout = nn.Dropout(dropout)
820
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
821
+
822
+ # 归一化与 Dropout
823
+ self.norm1 = nn.LayerNorm(d_model)
824
+ self.norm2 = nn.LayerNorm(d_model)
825
+ self.dropout1 = nn.Dropout(dropout)
826
+ self.dropout2 = nn.Dropout(dropout)
827
+
828
+ self.activation = nn.GELU()
829
+ # self.activation = nn.ReLU()
830
+ self.norm_first = norm_first
831
+
832
+ # 残差缩放系数
833
+ self.res_scale_attn = residual_scale_attn
834
+ self.res_scale_ffn = residual_scale_ffn
835
+
836
+ def forward(self, src, src_mask=None, src_key_padding_mask=None):
837
+ # 输入张量形状:(batch, 序列长度, d_model)
838
+ x = src
839
+
840
+ if self.norm_first:
841
+ # 先归一化再做注意力
842
+ x = x + self._sa_block(self.norm1(x), src_mask,
843
+ src_key_padding_mask)
844
+ x = x + self._ff_block(self.norm2(x))
845
+ else:
846
+ # 后归一化(一般不启用)
847
+ x = self.norm1(
848
+ x + self._sa_block(x, src_mask, src_key_padding_mask))
849
+ x = self.norm2(x + self._ff_block(x))
850
+
851
+ return x
852
+
853
+ def _sa_block(self, x, attn_mask, key_padding_mask):
854
+ # 自注意力并附带残差缩放
855
+ attn_out, _ = self.self_attn(
856
+ x, x, x,
857
+ attn_mask=attn_mask,
858
+ key_padding_mask=key_padding_mask,
859
+ need_weights=False
860
+ )
861
+ return self.res_scale_attn * self.dropout1(attn_out)
862
+
863
+ def _ff_block(self, x):
864
+ # 前馈网络并附带残差缩放
865
+ x2 = self.linear2(self.dropout(self.activation(self.linear1(x))))
866
+ return self.res_scale_ffn * self.dropout2(x2)
867
+
868
+ # 定义FT-Transformer核心模型
869
+
870
+
871
+ class FTTransformerCore(nn.Module):
872
+ # 最小可用版本的 FT-Transformer,由三部分组成:
873
+ # 1) FeatureTokenizer:将数值/类别特征转换成 token;
874
+ # 2) TransformerEncoder:建模特征之间的交互;
875
+ # 3) 池化 + MLP + Softplus:输出正值,方便 Tweedie/Gamma 等任务。
876
+
877
+ def __init__(self, num_numeric: int, cat_cardinalities, d_model: int = 64,
878
+ n_heads: int = 8, n_layers: int = 4, dropout: float = 0.1,
879
+ task_type: str = 'regression'
880
+ ):
881
+ super().__init__()
882
+
883
+ self.tokenizer = FeatureTokenizer(
884
+ num_numeric=num_numeric,
885
+ cat_cardinalities=cat_cardinalities,
886
+ d_model=d_model
887
+ )
888
+ scale = 1.0 / math.sqrt(n_layers) # 推荐一个默认值
889
+ encoder_layer = ScaledTransformerEncoderLayer(
890
+ d_model=d_model,
891
+ nhead=n_heads,
892
+ dim_feedforward=d_model * 4,
893
+ dropout=dropout,
894
+ residual_scale_attn=scale,
895
+ residual_scale_ffn=scale,
896
+ norm_first=True,
897
+ )
898
+ self.encoder = nn.TransformerEncoder(
899
+ encoder_layer,
900
+ num_layers=n_layers
901
+ )
902
+ self.n_layers = n_layers
903
+
904
+ layers = [
905
+ nn.LayerNorm(d_model),
906
+ nn.Linear(d_model, d_model),
907
+ nn.GELU(),
908
+ nn.Linear(d_model, 1),
909
+ ]
910
+
911
+ if task_type == 'classification':
912
+ # 分类任务输出 logits,与 BCEWithLogitsLoss 更匹配
913
+ layers.append(nn.Identity())
914
+ else:
915
+ # 回归任务需保持正值,适配 Tweedie/Gamma
916
+ layers.append(nn.Softplus())
917
+
918
+ self.head = nn.Sequential(*layers)
919
+
920
+ def forward(self, X_num, X_cat):
921
+
922
+ # 输入:
923
+ # X_num -> (batch, 数值特征数) 的 float32 张量
924
+ # X_cat -> (batch, 类别特征数) 的 long 张量
925
+
926
+ if self.training and not hasattr(self, '_printed_device'):
927
+ print(f">>> FTTransformerCore executing on device: {X_num.device}")
928
+ self._printed_device = True
929
+
930
+ tokens = self.tokenizer(X_num, X_cat) # => (batch, token_num, d_model)
931
+ x = self.encoder(tokens) # => (batch, token_num, d_model)
932
+
933
+ # 对 token 做平均池化,再送入回归头
934
+ x = x.mean(dim=1) # => (batch, d_model)
935
+
936
+ out = self.head(x) # => (batch, 1),Softplus 约束为正
937
+ return out
938
+
939
+ # 定义TabularDataset类
940
+
941
+
942
+ class TabularDataset(Dataset):
943
+ def __init__(self, X_num, X_cat, y, w):
944
+
945
+ # 输入张量说明:
946
+ # X_num: torch.float32,shape=(N, 数值特征数)
947
+ # X_cat: torch.long, shape=(N, 类别特征数)
948
+ # y: torch.float32,shape=(N, 1)
949
+ # w: torch.float32,shape=(N, 1)
950
+
951
+ self.X_num = X_num
952
+ self.X_cat = X_cat
953
+ self.y = y
954
+ self.w = w
955
+
956
+ def __len__(self):
957
+ return self.y.shape[0]
958
+
959
+ def __getitem__(self, idx):
960
+ return (
961
+ self.X_num[idx],
962
+ self.X_cat[idx],
963
+ self.y[idx],
964
+ self.w[idx],
965
+ )
966
+
967
+ # 定义FTTransformer的Scikit-Learn接口类
968
+
969
+
970
+ class FTTransformerSklearn(TorchTrainerMixin, nn.Module):
971
+
972
+ # sklearn 风格包装:
973
+ # - num_cols:数值特征列名列表
974
+ # - cat_cols:类别特征列名列表(需事先做标签编码,取值 ∈ [0, n_classes-1])
975
+
976
+ def __init__(self, model_nme: str, num_cols, cat_cols, d_model: int = 64, n_heads: int = 8,
977
+ n_layers: int = 4, dropout: float = 0.1, batch_num: int = 100, epochs: int = 100,
978
+ task_type: str = 'regression',
979
+ tweedie_power: float = 1.5, learning_rate: float = 1e-3, patience: int = 10,
980
+ use_data_parallel: bool = True,
981
+ ):
982
+ super().__init__()
983
+
984
+ self.model_nme = model_nme
985
+ self.num_cols = list(num_cols)
986
+ self.cat_cols = list(cat_cols)
987
+ self.d_model = d_model
988
+ self.n_heads = n_heads
989
+ self.n_layers = n_layers
990
+ self.dropout = dropout
991
+ self.batch_num = batch_num
992
+ self.epochs = epochs
993
+ self.learning_rate = learning_rate
994
+ self.task_type = task_type
995
+ self.patience = patience
996
+ if self.task_type == 'classification':
997
+ self.tw_power = None # 分类时不使用 Tweedie 幂
998
+ elif 'f' in self.model_nme:
999
+ self.tw_power = 1.0
1000
+ elif 's' in self.model_nme:
1001
+ self.tw_power = 2.0
1002
+ else:
1003
+ self.tw_power = tweedie_power
1004
+ if torch.cuda.is_available():
1005
+ self.device = torch.device("cuda")
1006
+ elif torch.backends.mps.is_available():
1007
+ self.device = torch.device("mps")
1008
+ else:
1009
+ self.device = torch.device("cpu")
1010
+ self.cat_cardinalities = None
1011
+ self.cat_categories = {}
1012
+ self.ft = None
1013
+ self.use_data_parallel = torch.cuda.device_count() > 1 and use_data_parallel
1014
+
1015
+ def _build_model(self, X_train):
1016
+ num_numeric = len(self.num_cols)
1017
+ cat_cardinalities = []
1018
+
1019
+ for col in self.cat_cols:
1020
+ cats = X_train[col].astype('category')
1021
+ categories = cats.cat.categories
1022
+ self.cat_categories[col] = categories # 保存训练集类别全集
1023
+
1024
+ card = len(categories) + 1 # 多预留 1 类给“未知/缺失”
1025
+ cat_cardinalities.append(card)
1026
+
1027
+ self.cat_cardinalities = cat_cardinalities
1028
+
1029
+ core = FTTransformerCore(
1030
+ num_numeric=num_numeric,
1031
+ cat_cardinalities=cat_cardinalities,
1032
+ d_model=self.d_model,
1033
+ n_heads=self.n_heads,
1034
+ n_layers=self.n_layers,
1035
+ dropout=self.dropout,
1036
+ task_type=self.task_type
1037
+ )
1038
+ if self.use_data_parallel:
1039
+ core = nn.DataParallel(core, device_ids=list(
1040
+ range(torch.cuda.device_count())))
1041
+ self.device = torch.device("cuda")
1042
+ self.ft = core.to(self.device)
1043
+
1044
+ def _encode_cats(self, X):
1045
+ # 输入 DataFrame 至少需要包含所有类别特征列
1046
+ # 返回形状 (N, 类别特征数) 的 int64 数组
1047
+
1048
+ if not self.cat_cols:
1049
+ return np.zeros((len(X), 0), dtype='int64')
1050
+
1051
+ X_cat_list = []
1052
+ for col in self.cat_cols:
1053
+ # 使用训练阶段记录的类别全集
1054
+ categories = self.cat_categories[col]
1055
+ # 按固定类别构造 Categorical
1056
+ cats = pd.Categorical(X[col], categories=categories)
1057
+ codes = cats.codes.astype('int64', copy=True) # -1 表示未知或缺失
1058
+ # 未知或缺失映射到额外的“未知”索引 len(categories)
1059
+ codes[codes < 0] = len(categories)
1060
+ X_cat_list.append(codes)
1061
+
1062
+ X_cat_np = np.stack(X_cat_list, axis=1) # 形状 (N, 类别特征数)
1063
+ return X_cat_np
1064
+
1065
+ def _build_train_tensors(self, X_train, y_train, w_train):
1066
+ return self._tensorize_split(X_train, y_train, w_train)
1067
+
1068
+ def _build_val_tensors(self, X_val, y_val, w_val):
1069
+ return self._tensorize_split(X_val, y_val, w_val, allow_none=True)
1070
+
1071
+ def _tensorize_split(self, X, y, w, allow_none: bool = False):
1072
+ if X is None:
1073
+ if allow_none:
1074
+ return None, None, None, None, False
1075
+ raise ValueError("输入特征 X 不能为空。")
1076
+
1077
+ X_num = torch.tensor(
1078
+ X[self.num_cols].to_numpy(dtype=np.float32, copy=True),
1079
+ dtype=torch.float32
1080
+ )
1081
+ if self.cat_cols:
1082
+ X_cat = torch.tensor(self._encode_cats(X), dtype=torch.long)
1083
+ else:
1084
+ X_cat = torch.zeros((X_num.shape[0], 0), dtype=torch.long)
1085
+
1086
+ y_tensor = torch.tensor(
1087
+ y.values, dtype=torch.float32).view(-1, 1) if y is not None else None
1088
+ if y_tensor is None:
1089
+ w_tensor = None
1090
+ elif w is not None:
1091
+ w_tensor = torch.tensor(
1092
+ w.values, dtype=torch.float32).view(-1, 1)
1093
+ else:
1094
+ w_tensor = torch.ones_like(y_tensor)
1095
+ return X_num, X_cat, y_tensor, w_tensor, y is not None
1096
+
1097
+ def fit(self, X_train, y_train, w_train=None,
1098
+ X_val=None, y_val=None, w_val=None, trial=None):
1099
+
1100
+ # 首次拟合时需要构建底层模型结构
1101
+ if self.ft is None:
1102
+ self._build_model(X_train)
1103
+
1104
+ X_num_train, X_cat_train, y_tensor, w_tensor, _ = self._build_train_tensors(
1105
+ X_train, y_train, w_train)
1106
+ X_num_val, X_cat_val, y_val_tensor, w_val_tensor, has_val = self._build_val_tensors(
1107
+ X_val, y_val, w_val)
1108
+
1109
+ # --- 构建 DataLoader ---
1110
+ dataset = TabularDataset(
1111
+ X_num_train, X_cat_train, y_tensor, w_tensor
1112
+ )
1113
+
1114
+ dataloader, accum_steps = self._build_dataloader(
1115
+ dataset,
1116
+ N=X_num_train.shape[0],
1117
+ base_bs_gpu=(2048, 1024, 512),
1118
+ base_bs_cpu=(256, 128),
1119
+ min_bs=64,
1120
+ target_effective_cuda=4096,
1121
+ target_effective_cpu=2048
1122
+ )
1123
+
1124
+ optimizer = torch.optim.Adam(
1125
+ self.ft.parameters(), lr=self.learning_rate)
1126
+ scaler = GradScaler(enabled=(self.device.type == 'cuda'))
1127
+
1128
+ X_num_val_dev = X_cat_val_dev = y_val_dev = w_val_dev = None
1129
+ val_dataloader = None
1130
+ if has_val:
1131
+ val_dataset = TabularDataset(
1132
+ X_num_val, X_cat_val, y_val_tensor, w_val_tensor
1133
+ )
1134
+ val_bs = accum_steps * dataloader.batch_size
1135
+
1136
+ if os.name == 'nt':
1137
+ val_workers = 0
1138
+ else:
1139
+ val_workers = min(4, os.cpu_count() or 1)
1140
+
1141
+ val_dataloader = DataLoader(
1142
+ val_dataset,
1143
+ batch_size=val_bs,
1144
+ shuffle=False,
1145
+ num_workers=val_workers,
1146
+ pin_memory=(self.device.type == 'cuda')
1147
+ )
1148
+
1149
+ def forward_fn(batch):
1150
+ X_num_b, X_cat_b, y_b, w_b = batch
1151
+ X_num_b = X_num_b.to(self.device, non_blocking=True)
1152
+ X_cat_b = X_cat_b.to(self.device, non_blocking=True)
1153
+ y_b = y_b.to(self.device, non_blocking=True)
1154
+ w_b = w_b.to(self.device, non_blocking=True)
1155
+ y_pred = self.ft(X_num_b, X_cat_b)
1156
+ return y_pred, y_b, w_b
1157
+
1158
+ def val_forward_fn():
1159
+ total_loss = 0.0
1160
+ total_weight = 0.0
1161
+ for batch in val_dataloader:
1162
+ X_num_b, X_cat_b, y_b, w_b = batch
1163
+ X_num_b = X_num_b.to(self.device, non_blocking=True)
1164
+ X_cat_b = X_cat_b.to(self.device, non_blocking=True)
1165
+ y_b = y_b.to(self.device, non_blocking=True)
1166
+ w_b = w_b.to(self.device, non_blocking=True)
1167
+
1168
+ y_pred = self.ft(X_num_b, X_cat_b)
1169
+
1170
+ # 手动计算验证损失
1171
+ task = getattr(self, "task_type", "regression")
1172
+ if task == 'classification':
1173
+ loss_fn = nn.BCEWithLogitsLoss(reduction='none')
1174
+ losses = loss_fn(y_pred, y_b).view(-1)
1175
+ else:
1176
+ # 模型输出已通过 Softplus,无需再次应用
1177
+ y_pred_clamped = torch.clamp(y_pred, min=1e-6)
1178
+ power = getattr(self, "tw_power", 1.5)
1179
+ losses = tweedie_loss(y_pred_clamped, y_b, p=power).view(-1)
1180
+
1181
+ batch_weight_sum = torch.clamp(w_b.sum(), min=EPS)
1182
+ batch_weighted_loss_sum = (losses * w_b.view(-1)).sum()
1183
+
1184
+ total_loss += batch_weighted_loss_sum.item()
1185
+ total_weight += batch_weight_sum.item()
1186
+
1187
+ return total_loss / max(total_weight, EPS)
1188
+
1189
+ clip_fn = None
1190
+ if self.device.type == 'cuda':
1191
+ def clip_fn(): return (scaler.unscale_(optimizer),
1192
+ clip_grad_norm_(self.ft.parameters(), max_norm=1.0))
1193
+
1194
+ best_state = self._train_model(
1195
+ self.ft,
1196
+ dataloader,
1197
+ accum_steps,
1198
+ optimizer,
1199
+ scaler,
1200
+ forward_fn,
1201
+ val_forward_fn if has_val else None,
1202
+ apply_softplus=False,
1203
+ clip_fn=clip_fn,
1204
+ trial=trial
1205
+ )
1206
+
1207
+ if has_val and best_state is not None:
1208
+ self.ft.load_state_dict(best_state)
1209
+
1210
+ def predict(self, X_test):
1211
+ # X_test 需要包含所有数值列与类别列
1212
+
1213
+ self.ft.eval()
1214
+ X_num, X_cat, _, _, _ = self._tensorize_split(
1215
+ X_test, None, None, allow_none=True)
1216
+
1217
+ with torch.no_grad():
1218
+ X_num = X_num.to(self.device, non_blocking=True)
1219
+ X_cat = X_cat.to(self.device, non_blocking=True)
1220
+ y_pred = self.ft(X_num, X_cat).cpu().numpy()
1221
+
1222
+ if self.task_type == 'classification':
1223
+ # 从 logits 转换为概率
1224
+ y_pred = 1 / (1 + np.exp(-y_pred))
1225
+ else:
1226
+ # 模型已含 softplus,若需要可按需做 log-exp 平滑:y_pred = log(1 + exp(y_pred))
1227
+ y_pred = np.clip(y_pred, 1e-6, None)
1228
+ return y_pred.ravel()
1229
+
1230
+ def set_params(self, params: dict):
1231
+
1232
+ # 和 sklearn 风格保持一致。
1233
+ # 注意:对结构性参数(如 d_model/n_heads)修改后,需要重新 fit 才会生效。
1234
+
1235
+ for key, value in params.items():
1236
+ if hasattr(self, key):
1237
+ setattr(self, key, value)
1238
+ else:
1239
+ raise ValueError(f"Parameter {key} not found in model.")
1240
+ return self
1241
+
1242
+
1243
+ # ===== 基础组件与训练封装 =====================================================
1244
+
1245
+ # =============================================================================
1246
+ # 配置、预处理与训练器基类
1247
+ # =============================================================================
1248
+ @dataclass
1249
+ class BayesOptConfig:
1250
+ model_nme: str
1251
+ resp_nme: str
1252
+ weight_nme: str
1253
+ factor_nmes: List[str]
1254
+ task_type: str = 'regression'
1255
+ binary_resp_nme: Optional[str] = None
1256
+ cate_list: Optional[List[str]] = None
1257
+ prop_test: float = 0.25
1258
+ rand_seed: Optional[int] = None
1259
+ epochs: int = 100
1260
+ use_gpu: bool = True
1261
+ use_resn_data_parallel: bool = True
1262
+ use_ft_data_parallel: bool = True
1263
+
1264
+
1265
+ class OutputManager:
1266
+ # 统一管理结果、图表与模型的输出路径
1267
+
1268
+ def __init__(self, root: Optional[str] = None, model_name: str = "model") -> None:
1269
+ self.root = Path(root or os.getcwd())
1270
+ self.model_name = model_name
1271
+ self.plot_dir = self.root / 'plot'
1272
+ self.result_dir = self.root / 'Results'
1273
+ self.model_dir = self.root / 'model'
1274
+
1275
+ def _prepare(self, path: Path) -> str:
1276
+ ensure_parent_dir(str(path))
1277
+ return str(path)
1278
+
1279
+ def plot_path(self, filename: str) -> str:
1280
+ return self._prepare(self.plot_dir / filename)
1281
+
1282
+ def result_path(self, filename: str) -> str:
1283
+ return self._prepare(self.result_dir / filename)
1284
+
1285
+ def model_path(self, filename: str) -> str:
1286
+ return self._prepare(self.model_dir / filename)
1287
+
1288
+
1289
+ class DatasetPreprocessor:
1290
+ # 为各训练器准备通用的训练/测试数据视图
1291
+
1292
+ def __init__(self, train_df: pd.DataFrame, test_df: pd.DataFrame,
1293
+ config: BayesOptConfig) -> None:
1294
+ self.config = config
1295
+ self.train_data = train_df.copy(deep=True)
1296
+ self.test_data = test_df.copy(deep=True)
1297
+ self.num_features: List[str] = []
1298
+ self.train_oht_scl_data: Optional[pd.DataFrame] = None
1299
+ self.test_oht_scl_data: Optional[pd.DataFrame] = None
1300
+ self.var_nmes: List[str] = []
1301
+ self.cat_categories_for_shap: Dict[str, List[Any]] = {}
1302
+
1303
+ def run(self) -> "DatasetPreprocessor":
1304
+ cfg = self.config
1305
+ # 预先计算加权实际值,后续画图、校验都依赖该字段
1306
+ self.train_data.loc[:, 'w_act'] = self.train_data[cfg.resp_nme] * \
1307
+ self.train_data[cfg.weight_nme]
1308
+ self.test_data.loc[:, 'w_act'] = self.test_data[cfg.resp_nme] * \
1309
+ self.test_data[cfg.weight_nme]
1310
+ if cfg.binary_resp_nme:
1311
+ self.train_data.loc[:, 'w_binary_act'] = self.train_data[cfg.binary_resp_nme] * \
1312
+ self.train_data[cfg.weight_nme]
1313
+ self.test_data.loc[:, 'w_binary_act'] = self.test_data[cfg.binary_resp_nme] * \
1314
+ self.test_data[cfg.weight_nme]
1315
+ # 高分位裁剪用来吸收离群值;若删除会导致极端点主导损失
1316
+ q99 = self.train_data[cfg.resp_nme].quantile(0.999)
1317
+ self.train_data[cfg.resp_nme] = self.train_data[cfg.resp_nme].clip(
1318
+ upper=q99)
1319
+ cate_list = list(cfg.cate_list or [])
1320
+ if cate_list:
1321
+ for cate in cate_list:
1322
+ self.train_data[cate] = self.train_data[cate].astype(
1323
+ 'category')
1324
+ self.test_data[cate] = self.test_data[cate].astype('category')
1325
+ cats = self.train_data[cate].cat.categories
1326
+ self.cat_categories_for_shap[cate] = list(cats)
1327
+ self.num_features = [
1328
+ nme for nme in cfg.factor_nmes if nme not in cate_list]
1329
+ train_oht = self.train_data[cfg.factor_nmes +
1330
+ [cfg.weight_nme] + [cfg.resp_nme]].copy()
1331
+ test_oht = self.test_data[cfg.factor_nmes +
1332
+ [cfg.weight_nme] + [cfg.resp_nme]].copy()
1333
+ train_oht = pd.get_dummies(
1334
+ train_oht,
1335
+ columns=cate_list,
1336
+ drop_first=True,
1337
+ dtype=np.int8
1338
+ )
1339
+ test_oht = pd.get_dummies(
1340
+ test_oht,
1341
+ columns=cate_list,
1342
+ drop_first=True,
1343
+ dtype=np.int8
1344
+ )
1345
+ for num_chr in self.num_features:
1346
+ # 逐列标准化保障每个特征在同一量级,否则神经网络会难以收敛
1347
+ scaler = StandardScaler()
1348
+ train_oht[num_chr] = scaler.fit_transform(
1349
+ train_oht[num_chr].values.reshape(-1, 1))
1350
+ test_oht[num_chr] = scaler.transform(
1351
+ test_oht[num_chr].values.reshape(-1, 1))
1352
+ # reindex 时将缺失的哑变量列补零,避免测试集列数与训练集不一致
1353
+ test_oht = test_oht.reindex(columns=train_oht.columns, fill_value=0)
1354
+ self.train_oht_scl_data = train_oht
1355
+ self.test_oht_scl_data = test_oht
1356
+ self.var_nmes = list(
1357
+ set(list(train_oht.columns)) - set([cfg.weight_nme, cfg.resp_nme])
1358
+ )
1359
+ return self
1360
+
1361
+ # =============================================================================
1362
+ # 训练器体系
1363
+ # =============================================================================
1364
+
1365
+
1366
+ class TrainerBase:
1367
+ def __init__(self, context: "BayesOptModel", label: str, model_name_prefix: str) -> None:
1368
+ self.ctx = context
1369
+ self.label = label
1370
+ self.model_name_prefix = model_name_prefix
1371
+ self.model = None
1372
+ self.best_params: Optional[Dict[str, Any]] = None
1373
+ self.best_trial = None
1374
+
1375
+ @property
1376
+ def config(self) -> BayesOptConfig:
1377
+ return self.ctx.config
1378
+
1379
+ @property
1380
+ def output(self) -> OutputManager:
1381
+ return self.ctx.output_manager
1382
+
1383
+ def _get_model_filename(self) -> str:
1384
+ ext = 'pkl' if self.label in ['Xgboost', 'GLM'] else 'pth'
1385
+ return f'01_{self.ctx.model_nme}_{self.model_name_prefix}.{ext}'
1386
+
1387
+ def tune(self, max_evals: int, objective_fn=None) -> None:
1388
+ # 通用的 Optuna 调参循环流程。
1389
+ if objective_fn is None:
1390
+ # 若子类未显式提供 objective_fn,则默认使用 cross_val 作为优化目标
1391
+ objective_fn = self.cross_val
1392
+
1393
+ def objective_wrapper(trial: optuna.trial.Trial) -> float:
1394
+ try:
1395
+ result = objective_fn(trial)
1396
+ finally:
1397
+ self._clean_gpu()
1398
+ return result
1399
+
1400
+ study = optuna.create_study(
1401
+ direction='minimize',
1402
+ sampler=optuna.samplers.TPESampler(seed=self.ctx.rand_seed)
1403
+ )
1404
+ study.optimize(objective_wrapper, n_trials=max_evals)
1405
+ self.best_params = study.best_params
1406
+ self.best_trial = study.best_trial
1407
+
1408
+ # 将最优参数保存为 CSV,方便复现
1409
+ params_path = self.output.result_path(
1410
+ f'{self.ctx.model_nme}_bestparams_{self.label.lower()}.csv'
1411
+ )
1412
+ pd.DataFrame(self.best_params, index=[0]).to_csv(params_path)
1413
+
1414
+ def train(self) -> None:
1415
+ raise NotImplementedError
1416
+
1417
+ def save(self) -> None:
1418
+ if self.model is None:
1419
+ print(f"[save] Warning: No model to save for {self.label}")
1420
+ return
1421
+
1422
+ path = self.output.model_path(self._get_model_filename())
1423
+ if self.label in ['Xgboost', 'GLM']:
1424
+ joblib.dump(self.model, path)
1425
+ else:
1426
+ # Torch 模型既可以只存 state_dict,也可以整个对象一起序列化
1427
+ # 兼容历史行为:ResNetTrainer 保存 state_dict,FTTrainer 保存完整对象
1428
+ if hasattr(self.model, 'resnet'): # ResNetSklearn
1429
+ torch.save(self.model.resnet.state_dict(), path)
1430
+ else: # FTTransformerSklearn or others
1431
+ torch.save(self.model, path)
1432
+
1433
+ def load(self) -> None:
1434
+ path = self.output.model_path(self._get_model_filename())
1435
+ if not os.path.exists(path):
1436
+ print(f"[load] Warning: Model file not found: {path}")
1437
+ return
1438
+
1439
+ if self.label in ['Xgboost', 'GLM']:
1440
+ self.model = joblib.load(path)
1441
+ else:
1442
+ # Torch 模型的加载需要根据结构区别处理
1443
+ if self.label == 'ResNet' or self.label == 'ResNetClassifier':
1444
+ # ResNet 需要重新构建骨架,结构参数依赖 ctx,因此交由子类处理
1445
+ pass
1446
+ else:
1447
+ # FT-Transformer 序列化了整个对象,可直接加载后迁移到目标设备
1448
+ loaded = torch.load(path, map_location='cpu')
1449
+ self._move_to_device(loaded)
1450
+ self.model = loaded
1451
+
1452
+ def _move_to_device(self, model_obj):
1453
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
1454
+ if hasattr(model_obj, 'device'):
1455
+ model_obj.device = device
1456
+ if hasattr(model_obj, 'to'):
1457
+ model_obj.to(device)
1458
+ # 若对象内部还包含 ft/resnet 子模块,也要同时迁移设备
1459
+ if hasattr(model_obj, 'ft'): model_obj.ft.to(device)
1460
+ if hasattr(model_obj, 'resnet'): model_obj.resnet.to(device)
1461
+
1462
+ def _clean_gpu(self):
1463
+ gc.collect()
1464
+ if torch.cuda.is_available():
1465
+ torch.cuda.empty_cache()
1466
+
1467
+ # 预测 + 缓存逻辑
1468
+ def _predict_and_cache(self,
1469
+ model,
1470
+ pred_prefix: str,
1471
+ use_oht: bool = False,
1472
+ design_fn=None) -> None:
1473
+ if design_fn:
1474
+ X_train = design_fn(train=True)
1475
+ X_test = design_fn(train=False)
1476
+ elif use_oht:
1477
+ X_train = self.ctx.train_oht_scl_data[self.ctx.var_nmes]
1478
+ X_test = self.ctx.test_oht_scl_data[self.ctx.var_nmes]
1479
+ else:
1480
+ X_train = self.ctx.train_data[self.ctx.factor_nmes]
1481
+ X_test = self.ctx.test_data[self.ctx.factor_nmes]
1482
+
1483
+ preds_train = model.predict(X_train)
1484
+ preds_test = model.predict(X_test)
1485
+
1486
+ self.ctx.train_data[f'pred_{pred_prefix}'] = preds_train
1487
+ self.ctx.test_data[f'pred_{pred_prefix}'] = preds_test
1488
+ self.ctx.train_data[f'w_pred_{pred_prefix}'] = (
1489
+ self.ctx.train_data[f'pred_{pred_prefix}'] *
1490
+ self.ctx.train_data[self.ctx.weight_nme]
1491
+ )
1492
+ self.ctx.test_data[f'w_pred_{pred_prefix}'] = (
1493
+ self.ctx.test_data[f'pred_{pred_prefix}'] *
1494
+ self.ctx.test_data[self.ctx.weight_nme]
1495
+ )
1496
+
1497
+ def _fit_predict_cache(self,
1498
+ model,
1499
+ X_train,
1500
+ y_train,
1501
+ sample_weight,
1502
+ pred_prefix: str,
1503
+ use_oht: bool = False,
1504
+ design_fn=None,
1505
+ fit_kwargs: Optional[Dict[str, Any]] = None,
1506
+ sample_weight_arg: Optional[str] = 'sample_weight') -> None:
1507
+ fit_kwargs = fit_kwargs.copy() if fit_kwargs else {}
1508
+ if sample_weight is not None and sample_weight_arg:
1509
+ fit_kwargs.setdefault(sample_weight_arg, sample_weight)
1510
+ model.fit(X_train, y_train, **fit_kwargs)
1511
+ self.ctx.model_label.append(self.label)
1512
+ self._predict_and_cache(
1513
+ model, pred_prefix, use_oht=use_oht, design_fn=design_fn)
1514
+
1515
+
1516
+ class XGBTrainer(TrainerBase):
1517
+ def __init__(self, context: "BayesOptModel") -> None:
1518
+ super().__init__(context, 'Xgboost', 'Xgboost')
1519
+ self.model: Optional[xgb.XGBRegressor] = None
1520
+
1521
+ def _build_estimator(self) -> xgb.XGBRegressor:
1522
+ params = dict(
1523
+ objective=self.ctx.obj,
1524
+ random_state=self.ctx.rand_seed,
1525
+ subsample=0.9,
1526
+ tree_method='gpu_hist' if self.ctx.use_gpu else 'hist',
1527
+ enable_categorical=True,
1528
+ predictor='gpu_predictor' if self.ctx.use_gpu else 'cpu_predictor'
1529
+ )
1530
+ if self.ctx.use_gpu:
1531
+ params['gpu_id'] = 0
1532
+ print(f">>> XGBoost using GPU ID: 0 (Single GPU Mode)")
1533
+ return xgb.XGBRegressor(**params)
1534
+
1535
+ def cross_val(self, trial: optuna.trial.Trial) -> float:
1536
+ learning_rate = trial.suggest_float(
1537
+ 'learning_rate', 1e-5, 1e-1, log=True)
1538
+ gamma = trial.suggest_float('gamma', 0, 10000)
1539
+ max_depth = trial.suggest_int('max_depth', 3, 25)
1540
+ n_estimators = trial.suggest_int('n_estimators', 10, 500, step=10)
1541
+ min_child_weight = trial.suggest_int(
1542
+ 'min_child_weight', 100, 10000, step=100)
1543
+ reg_alpha = trial.suggest_float('reg_alpha', 1e-10, 1, log=True)
1544
+ reg_lambda = trial.suggest_float('reg_lambda', 1e-10, 1, log=True)
1545
+ if self.ctx.obj == 'reg:tweedie':
1546
+ tweedie_variance_power = trial.suggest_float(
1547
+ 'tweedie_variance_power', 1, 2)
1548
+ elif self.ctx.obj == 'count:poisson':
1549
+ tweedie_variance_power = 1
1550
+ elif self.ctx.obj == 'reg:gamma':
1551
+ tweedie_variance_power = 2
1552
+ else:
1553
+ tweedie_variance_power = 1.5
1554
+ clf = self._build_estimator()
1555
+ params = {
1556
+ 'learning_rate': learning_rate,
1557
+ 'gamma': gamma,
1558
+ 'max_depth': max_depth,
1559
+ 'n_estimators': n_estimators,
1560
+ 'min_child_weight': min_child_weight,
1561
+ 'reg_alpha': reg_alpha,
1562
+ 'reg_lambda': reg_lambda
1563
+ }
1564
+ if self.ctx.obj == 'reg:tweedie':
1565
+ params['tweedie_variance_power'] = tweedie_variance_power
1566
+ clf.set_params(**params)
1567
+ n_jobs = 1 if self.ctx.use_gpu else int(1 / self.ctx.prop_test)
1568
+ acc = cross_val_score(
1569
+ clf,
1570
+ self.ctx.train_data[self.ctx.factor_nmes],
1571
+ self.ctx.train_data[self.ctx.resp_nme].values,
1572
+ fit_params=self.ctx.fit_params,
1573
+ cv=self.ctx.cv,
1574
+ scoring=make_scorer(
1575
+ mean_tweedie_deviance,
1576
+ power=tweedie_variance_power,
1577
+ greater_is_better=False),
1578
+ error_score='raise',
1579
+ n_jobs=n_jobs
1580
+ ).mean()
1581
+ return -acc
1582
+
1583
+ def train(self) -> None:
1584
+ if not self.best_params:
1585
+ raise RuntimeError('请先运行 tune() 以获得 XGB 最优参数。')
1586
+ self.model = self._build_estimator()
1587
+ self.model.set_params(**self.best_params)
1588
+ self._fit_predict_cache(
1589
+ self.model,
1590
+ self.ctx.train_data[self.ctx.factor_nmes],
1591
+ self.ctx.train_data[self.ctx.resp_nme].values,
1592
+ sample_weight=None,
1593
+ pred_prefix='xgb',
1594
+ fit_kwargs=self.ctx.fit_params,
1595
+ sample_weight_arg=None # 样本权重已通过 fit_kwargs 传入
1596
+ )
1597
+ self.ctx.xgb_best = self.model
1598
+
1599
+
1600
+ class GLMTrainer(TrainerBase):
1601
+ def __init__(self, context: "BayesOptModel") -> None:
1602
+ super().__init__(context, 'GLM', 'GLM')
1603
+ self.model = None
1604
+
1605
+ def _select_family(self, tweedie_power: Optional[float] = None):
1606
+ if self.ctx.task_type == 'classification':
1607
+ return sm.families.Binomial()
1608
+ if self.ctx.obj == 'count:poisson':
1609
+ return sm.families.Poisson()
1610
+ if self.ctx.obj == 'reg:gamma':
1611
+ return sm.families.Gamma()
1612
+ power = tweedie_power if tweedie_power is not None else 1.5
1613
+ return sm.families.Tweedie(var_power=power, link=sm.families.links.log())
1614
+
1615
+ def _prepare_design(self, data: pd.DataFrame) -> pd.DataFrame:
1616
+ # 为 statsmodels 设计矩阵添加截距项
1617
+ X = data[self.ctx.var_nmes]
1618
+ return sm.add_constant(X, has_constant='add')
1619
+
1620
+ def _metric_power(self, family, tweedie_power: Optional[float]) -> float:
1621
+ if isinstance(family, sm.families.Poisson):
1622
+ return 1.0
1623
+ if isinstance(family, sm.families.Gamma):
1624
+ return 2.0
1625
+ if isinstance(family, sm.families.Tweedie):
1626
+ return tweedie_power if tweedie_power is not None else getattr(family, 'var_power', 1.5)
1627
+ return 1.5
1628
+
1629
+ def cross_val(self, trial: optuna.trial.Trial) -> float:
1630
+ alpha = trial.suggest_float('alpha', 1e-6, 1e2, log=True)
1631
+ l1_ratio = trial.suggest_float('l1_ratio', 0.0, 1.0)
1632
+ tweedie_power = None
1633
+ if self.ctx.task_type == 'regression' and self.ctx.obj == 'reg:tweedie':
1634
+ tweedie_power = trial.suggest_float('tweedie_power', 1.01, 1.99)
1635
+
1636
+ X_all = self._prepare_design(self.ctx.train_oht_scl_data)
1637
+ y_all = self.ctx.train_oht_scl_data[self.ctx.resp_nme]
1638
+ w_all = self.ctx.train_oht_scl_data[self.ctx.weight_nme]
1639
+
1640
+ scores = []
1641
+ for train_idx, val_idx in self.ctx.cv.split(X_all):
1642
+ X_train, X_val = X_all.iloc[train_idx], X_all.iloc[val_idx]
1643
+ y_train, y_val = y_all.iloc[train_idx], y_all.iloc[val_idx]
1644
+ w_train, w_val = w_all.iloc[train_idx], w_all.iloc[val_idx]
1645
+
1646
+ family = self._select_family(tweedie_power)
1647
+ glm = sm.GLM(y_train, X_train, family=family,
1648
+ freq_weights=w_train)
1649
+ result = glm.fit_regularized(
1650
+ alpha=alpha, L1_wt=l1_ratio, maxiter=200)
1651
+
1652
+ y_pred = result.predict(X_val)
1653
+ if self.ctx.task_type == 'classification':
1654
+ y_pred = np.clip(y_pred, EPS, 1 - EPS)
1655
+ fold_score = log_loss(
1656
+ y_val, y_pred, sample_weight=w_val)
1657
+ else:
1658
+ y_pred = np.maximum(y_pred, EPS)
1659
+ fold_score = mean_tweedie_deviance(
1660
+ y_val,
1661
+ y_pred,
1662
+ sample_weight=w_val,
1663
+ power=self._metric_power(family, tweedie_power)
1664
+ )
1665
+ scores.append(fold_score)
1666
+
1667
+ return float(np.mean(scores))
1668
+
1669
+ def train(self) -> None:
1670
+ if not self.best_params:
1671
+ raise RuntimeError('请先运行 tune() 以获得 GLM 最优参数。')
1672
+ tweedie_power = self.best_params.get('tweedie_power')
1673
+ family = self._select_family(tweedie_power)
1674
+
1675
+ X_train = self._prepare_design(self.ctx.train_oht_scl_data)
1676
+ y_train = self.ctx.train_oht_scl_data[self.ctx.resp_nme]
1677
+ w_train = self.ctx.train_oht_scl_data[self.ctx.weight_nme]
1678
+
1679
+ glm = sm.GLM(y_train, X_train, family=family,
1680
+ freq_weights=w_train)
1681
+ self.model = glm.fit_regularized(
1682
+ alpha=self.best_params['alpha'],
1683
+ L1_wt=self.best_params['l1_ratio'],
1684
+ maxiter=300
1685
+ )
1686
+
1687
+ self.ctx.glm_best = self.model
1688
+ self.ctx.model_label += [self.label]
1689
+ self._predict_and_cache(
1690
+ self.model,
1691
+ 'glm',
1692
+ design_fn=lambda train: self._prepare_design(
1693
+ self.ctx.train_oht_scl_data if train else self.ctx.test_oht_scl_data
1694
+ )
1695
+ )
1696
+
1697
+
1698
+ class ResNetTrainer(TrainerBase):
1699
+ def __init__(self, context: "BayesOptModel") -> None:
1700
+ if context.task_type == 'classification':
1701
+ super().__init__(context, 'ResNetClassifier', 'ResNet')
1702
+ else:
1703
+ super().__init__(context, 'ResNet', 'ResNet')
1704
+ self.model: Optional[ResNetSklearn] = None
1705
+
1706
+ # ========= 交叉验证(BayesOpt 用) =========
1707
+ def cross_val(self, trial: optuna.trial.Trial) -> float:
1708
+ # 针对 ResNet 的交叉验证流程,重点控制显存:
1709
+ # - 每个 fold 单独创建 ResNetSklearn,结束立刻释放资源;
1710
+ # - fold 完成后迁移模型到 CPU,删除对象并调用 gc/empty_cache;
1711
+ # - 可选:BayesOpt 期间只抽样部分训练集以减少显存压力。
1712
+
1713
+ # 1. 超参空间(基本沿用你之前的设定)
1714
+ learning_rate = trial.suggest_float(
1715
+ 'learning_rate', 1e-6, 1e-2, log=True
1716
+ )
1717
+ # hidden_dim = trial.suggest_int('hidden_dim', 32, 256, step=32) # 不宜过大
1718
+ hidden_dim = trial.suggest_int('hidden_dim', 8, 32, step=2)
1719
+ block_num = trial.suggest_int('block_num', 2, 10)
1720
+ # batch_num = trial.suggest_int(
1721
+ # 'batch_num',
1722
+ # 10 if self.ctx.obj == 'reg:gamma' else 100,
1723
+ # 100 if self.ctx.obj == 'reg:gamma' else 1000,
1724
+ # step=10 if self.ctx.obj == 'reg:gamma' else 100
1725
+ # )
1726
+
1727
+ if self.ctx.task_type == 'regression':
1728
+ if self.ctx.obj == 'reg:tweedie':
1729
+ tw_power = trial.suggest_float('tw_power', 1.0, 2.0)
1730
+ elif self.ctx.obj == 'count:poisson':
1731
+ tw_power = 1.0
1732
+ elif self.ctx.obj == 'reg:gamma':
1733
+ tw_power = 2.0
1734
+ else:
1735
+ tw_power = 1.5
1736
+ else: # classification
1737
+ tw_power = None # Not used
1738
+
1739
+ fold_losses = []
1740
+
1741
+ # 2. (可选)BayesOpt 只在子样本上做 CV,减轻显存 & 时间压力
1742
+ data_for_cv = self.ctx.train_oht_scl_data
1743
+ max_rows_for_resnet_bo = min(100000, int(
1744
+ len(data_for_cv)/5)) # 你可以按 A30 情况调小,比如 50_000
1745
+ if len(data_for_cv) > max_rows_for_resnet_bo:
1746
+ data_for_cv = data_for_cv.sample(
1747
+ max_rows_for_resnet_bo,
1748
+ random_state=self.ctx.rand_seed
1749
+ )
1750
+
1751
+ X_all = data_for_cv[self.ctx.var_nmes]
1752
+ y_all = data_for_cv[self.ctx.resp_nme]
1753
+ w_all = data_for_cv[self.ctx.weight_nme]
1754
+
1755
+ # 用局部 ShuffleSplit,避免子样本时索引不一致
1756
+ cv_local = ShuffleSplit(
1757
+ n_splits=int(1 / self.ctx.prop_test),
1758
+ test_size=self.ctx.prop_test,
1759
+ random_state=self.ctx.rand_seed
1760
+ )
1761
+
1762
+ # 使用 Hold-out 验证代替 K-Fold CV 以提高速度
1763
+ # 只取一次划分
1764
+ train_idx, val_idx = next(cv_local.split(X_all))
1765
+
1766
+ X_train_fold = X_all.iloc[train_idx]
1767
+ y_train_fold = y_all.iloc[train_idx]
1768
+ w_train_fold = w_all.iloc[train_idx]
1769
+
1770
+ X_val_fold = X_all.iloc[val_idx]
1771
+ y_val_fold = y_all.iloc[val_idx]
1772
+ w_val_fold = w_all.iloc[val_idx]
1773
+
1774
+ # 3. 创建 ResNet 模型
1775
+ cv_net = ResNetSklearn(
1776
+ model_nme=self.ctx.model_nme,
1777
+ input_dim=X_all.shape[1],
1778
+ hidden_dim=hidden_dim,
1779
+ block_num=block_num,
1780
+ task_type=self.ctx.task_type,
1781
+ # batch_num=batch_num,
1782
+ epochs=self.ctx.epochs,
1783
+ tweedie_power=tw_power,
1784
+ learning_rate=learning_rate,
1785
+ patience=5,
1786
+ use_layernorm=True,
1787
+ dropout=0.1,
1788
+ residual_scale=0.1,
1789
+ use_data_parallel=self.ctx.config.use_resn_data_parallel
1790
+ )
1791
+
1792
+ try:
1793
+ # 4. 训练
1794
+ cv_net.fit(
1795
+ X_train_fold,
1796
+ y_train_fold,
1797
+ w_train_fold,
1798
+ X_val_fold,
1799
+ y_val_fold,
1800
+ w_val_fold,
1801
+ trial=trial
1802
+ )
1803
+
1804
+ # 5. 验证集预测
1805
+ y_pred_fold = cv_net.predict(X_val_fold)
1806
+
1807
+ # 6. 评估:Tweedie deviance(评估用,训练 loss 不动)
1808
+ if self.ctx.task_type == 'regression':
1809
+ loss = mean_tweedie_deviance(
1810
+ y_val_fold,
1811
+ y_pred_fold,
1812
+ sample_weight=w_val_fold,
1813
+ power=tw_power
1814
+ )
1815
+ else: # classification
1816
+ from sklearn.metrics import log_loss
1817
+ loss = log_loss(
1818
+ y_val_fold,
1819
+ y_pred_fold,
1820
+ sample_weight=w_val_fold,
1821
+ )
1822
+ fold_losses.append(loss)
1823
+ finally:
1824
+ # 7. 结束后释放 GPU 资源
1825
+ try:
1826
+ if hasattr(cv_net, "resnet"):
1827
+ cv_net.resnet.to("cpu")
1828
+ except Exception:
1829
+ pass
1830
+ del cv_net
1831
+ self._clean_gpu()
1832
+
1833
+ return np.mean(fold_losses)
1834
+
1835
+ # ========= 用最优超参训练最终 ResNet =========
1836
+ def train(self) -> None:
1837
+ if not self.best_params:
1838
+ raise RuntimeError('请先运行 tune() 以获得 ResNet 最优参数。')
1839
+
1840
+ self.model = ResNetSklearn(
1841
+ model_nme=self.ctx.model_nme,
1842
+ input_dim=self.ctx.train_oht_scl_data[self.ctx.var_nmes].shape[1],
1843
+ task_type=self.ctx.task_type,
1844
+ use_data_parallel=self.ctx.config.use_resn_data_parallel
1845
+ )
1846
+ self.model.set_params(self.best_params)
1847
+
1848
+ self._fit_predict_cache(
1849
+ self.model,
1850
+ self.ctx.train_oht_scl_data[self.ctx.var_nmes],
1851
+ self.ctx.train_oht_scl_data[self.ctx.resp_nme],
1852
+ sample_weight=self.ctx.train_oht_scl_data[self.ctx.weight_nme],
1853
+ pred_prefix='resn',
1854
+ use_oht=True,
1855
+ sample_weight_arg='w_train'
1856
+ )
1857
+
1858
+ # 方便外部调用
1859
+ self.ctx.resn_best = self.model
1860
+
1861
+ # ========= 保存 / 加载 =========
1862
+ # ResNet 使用 state_dict 保存,需要特殊的 load 逻辑,所以保留 load
1863
+ # save 逻辑已经在 TrainerBase 中处理了 (check for .resnet attribute)
1864
+
1865
+ def load(self) -> None:
1866
+ # 将磁盘中的 ResNet 权重加载到当前设备,保持与上下文一致。
1867
+ path = self.output.model_path(self._get_model_filename())
1868
+ if os.path.exists(path):
1869
+ resn_loaded = ResNetSklearn(
1870
+ model_nme=self.ctx.model_nme,
1871
+ input_dim=self.ctx.train_oht_scl_data[self.ctx.var_nmes].shape[1],
1872
+ task_type=self.ctx.task_type,
1873
+ use_data_parallel=self.ctx.config.use_resn_data_parallel
1874
+ )
1875
+ state_dict = torch.load(path, map_location='cpu')
1876
+ resn_loaded.resnet.load_state_dict(state_dict)
1877
+
1878
+ self._move_to_device(resn_loaded)
1879
+ self.model = resn_loaded
1880
+ self.ctx.resn_best = self.model
1881
+ else:
1882
+ print(f"[ResNetTrainer.load] 未找到模型文件:{path}")
1883
+
1884
+
1885
+ class FTTrainer(TrainerBase):
1886
+ def __init__(self, context: "BayesOptModel") -> None:
1887
+ if context.task_type == 'classification':
1888
+ super().__init__(context, 'FTTransformerClassifier', 'FTTransformer')
1889
+ else:
1890
+ super().__init__(context, 'FTTransformer', 'FTTransformer')
1891
+ self.model: Optional[FTTransformerSklearn] = None
1892
+
1893
+ def cross_val(self, trial: optuna.trial.Trial) -> float:
1894
+ # 针对 FT-Transformer 的交叉验证,重点同样在显存控制:
1895
+ # - 收缩超参搜索空间,防止不必要的超大模型;
1896
+ # - 每个 fold 结束后立即释放 GPU 显存,确保下一个 trial 顺利进行。
1897
+ # 超参空间适当缩小一点,避免特别大的模型
1898
+ learning_rate = trial.suggest_float(
1899
+ 'learning_rate', 1e-5, 5e-4, log=True
1900
+ )
1901
+ d_model = trial.suggest_int('d_model', 32, 256, step=32)
1902
+ # n_heads = trial.suggest_categorical('n_heads', [2, 4]) 避免欠拟合
1903
+ n_heads = trial.suggest_categorical('n_heads', [2, 4, 8])
1904
+ # n_layers = trial.suggest_int('n_layers', 2, 4) 避免欠拟合
1905
+ n_layers = trial.suggest_int('n_layers', 2, 8)
1906
+ dropout = trial.suggest_float('dropout', 0.0, 0.2)
1907
+ # batch_num = trial.suggest_int(
1908
+ # 'batch_num',
1909
+ # 5 if self.ctx.obj == 'reg:gamma' else 10,
1910
+ # 10 if self.ctx.obj == 'reg:gamma' else 50,
1911
+ # step=1 if self.ctx.obj == 'reg:gamma' else 10
1912
+ # )
1913
+
1914
+ if self.ctx.task_type == 'regression':
1915
+ if self.ctx.obj == 'reg:tweedie':
1916
+ tw_power = trial.suggest_float('tw_power', 1.0, 2.0)
1917
+ elif self.ctx.obj == 'count:poisson':
1918
+ tw_power = 1.0
1919
+ elif self.ctx.obj == 'reg:gamma':
1920
+ tw_power = 2.0
1921
+ else:
1922
+ tw_power = 1.5
1923
+ else: # classification
1924
+ tw_power = None # Not used
1925
+
1926
+ fold_losses = []
1927
+
1928
+ # 可选:只在子样本上做 BO,避免大数据直接压垮显存
1929
+ data_for_cv = self.ctx.train_data
1930
+ max_rows_for_ft_bo = min(1000000, int(
1931
+ len(data_for_cv)/2)) # 你可以根据显存情况调小或调大
1932
+ if len(data_for_cv) > max_rows_for_ft_bo:
1933
+ data_for_cv = data_for_cv.sample(
1934
+ max_rows_for_ft_bo,
1935
+ random_state=self.ctx.rand_seed
1936
+ )
1937
+
1938
+ # 用局部 ShuffleSplit,避免子样本时索引不一致
1939
+ cv_local = ShuffleSplit(
1940
+ n_splits=int(1 / self.ctx.prop_test),
1941
+ test_size=self.ctx.prop_test,
1942
+ random_state=self.ctx.rand_seed
1943
+ )
1944
+
1945
+ # 使用 Hold-out 验证代替 K-Fold CV 以提高速度
1946
+ # 只取一次划分
1947
+ train_idx, val_idx = next(cv_local.split(data_for_cv[self.ctx.factor_nmes]))
1948
+
1949
+ X_train_fold = data_for_cv.iloc[train_idx][self.ctx.factor_nmes]
1950
+ y_train_fold = data_for_cv.iloc[train_idx][self.ctx.resp_nme]
1951
+ w_train_fold = data_for_cv.iloc[train_idx][self.ctx.weight_nme]
1952
+ X_val_fold = data_for_cv.iloc[val_idx][self.ctx.factor_nmes]
1953
+ y_val_fold = data_for_cv.iloc[val_idx][self.ctx.resp_nme]
1954
+ w_val_fold = data_for_cv.iloc[val_idx][self.ctx.weight_nme]
1955
+
1956
+ cv_ft = FTTransformerSklearn(
1957
+ model_nme=self.ctx.model_nme,
1958
+ num_cols=self.ctx.num_features,
1959
+ cat_cols=self.ctx.cate_list,
1960
+ d_model=d_model,
1961
+ n_heads=n_heads,
1962
+ n_layers=n_layers,
1963
+ dropout=dropout,
1964
+ task_type=self.ctx.task_type,
1965
+ # batch_num=batch_num,
1966
+ epochs=self.ctx.epochs,
1967
+ tweedie_power=tw_power,
1968
+ learning_rate=learning_rate,
1969
+ patience=5,
1970
+ use_data_parallel=self.ctx.config.use_ft_data_parallel
1971
+ )
1972
+
1973
+ try:
1974
+ cv_ft.fit(
1975
+ X_train_fold, y_train_fold, w_train_fold,
1976
+ X_val_fold, y_val_fold, w_val_fold,
1977
+ trial=trial
1978
+ )
1979
+ y_pred_fold = cv_ft.predict(X_val_fold)
1980
+ if self.ctx.task_type == 'regression':
1981
+ loss = mean_tweedie_deviance(
1982
+ y_val_fold,
1983
+ y_pred_fold,
1984
+ sample_weight=w_val_fold,
1985
+ power=tw_power
1986
+ )
1987
+ else: # classification
1988
+ from sklearn.metrics import log_loss
1989
+ loss = log_loss(
1990
+ y_val_fold,
1991
+ y_pred_fold,
1992
+ sample_weight=w_val_fold,
1993
+ )
1994
+ fold_losses.append(loss)
1995
+ finally:
1996
+ # 结束后立即释放 GPU 资源
1997
+ try:
1998
+ # 如果模型在 GPU 上,先挪回 CPU
1999
+ if hasattr(cv_ft, "ft"):
2000
+ cv_ft.ft.to("cpu")
2001
+ except Exception:
2002
+ pass
2003
+ del cv_ft
2004
+ self._clean_gpu()
2005
+
2006
+ return np.mean(fold_losses)
2007
+
2008
+ def train(self) -> None:
2009
+ if not self.best_params:
2010
+ raise RuntimeError('请先运行 tune() 以获得 FT-Transformer 最优参数。')
2011
+ self.model = FTTransformerSklearn(
2012
+ model_nme=self.ctx.model_nme,
2013
+ num_cols=self.ctx.num_features,
2014
+ cat_cols=self.ctx.cate_list,
2015
+ task_type=self.ctx.task_type,
2016
+ use_data_parallel=self.ctx.config.use_ft_data_parallel
2017
+ )
2018
+ self.model.set_params(self.best_params)
2019
+ self._fit_predict_cache(
2020
+ self.model,
2021
+ self.ctx.train_data[self.ctx.factor_nmes],
2022
+ self.ctx.train_data[self.ctx.resp_nme],
2023
+ sample_weight=self.ctx.train_data[self.ctx.weight_nme],
2024
+ pred_prefix='ft',
2025
+ sample_weight_arg='w_train'
2026
+ )
2027
+ self.ctx.ft_best = self.model
2028
+
2029
+
2030
+ # =============================================================================
2031
+ # BayesOpt orchestration & SHAP utilities
2032
+ # =============================================================================
2033
+ class BayesOptModel:
2034
+ def __init__(self, train_data, test_data,
2035
+ model_nme, resp_nme, weight_nme, factor_nmes, task_type='regression',
2036
+ binary_resp_nme=None,
2037
+ cate_list=None, prop_test=0.25, rand_seed=None,
2038
+ epochs=100, use_gpu=True,
2039
+ use_resn_data_parallel: bool = False, use_ft_data_parallel: bool = False):
2040
+ cfg = BayesOptConfig(
2041
+ model_nme=model_nme,
2042
+ task_type=task_type,
2043
+ resp_nme=resp_nme,
2044
+ weight_nme=weight_nme,
2045
+ factor_nmes=list(factor_nmes),
2046
+ binary_resp_nme=binary_resp_nme,
2047
+ cate_list=list(cate_list) if cate_list else None,
2048
+ prop_test=prop_test,
2049
+ rand_seed=rand_seed,
2050
+ epochs=epochs,
2051
+ use_gpu=use_gpu,
2052
+ use_resn_data_parallel=use_resn_data_parallel,
2053
+ use_ft_data_parallel=use_ft_data_parallel
2054
+ )
2055
+ self.config = cfg
2056
+ self.model_nme = cfg.model_nme
2057
+ self.task_type = cfg.task_type
2058
+ self.resp_nme = cfg.resp_nme
2059
+ self.weight_nme = cfg.weight_nme
2060
+ self.factor_nmes = cfg.factor_nmes
2061
+ self.binary_resp_nme = cfg.binary_resp_nme
2062
+ self.cate_list = list(cfg.cate_list or [])
2063
+ self.prop_test = cfg.prop_test
2064
+ self.epochs = cfg.epochs
2065
+ self.rand_seed = cfg.rand_seed if cfg.rand_seed is not None else np.random.randint(
2066
+ 1, 10000)
2067
+ self.use_gpu = bool(cfg.use_gpu and torch.cuda.is_available())
2068
+ self.output_manager = OutputManager(os.getcwd(), self.model_nme)
2069
+
2070
+ preprocessor = DatasetPreprocessor(train_data, test_data, cfg).run()
2071
+ self.train_data = preprocessor.train_data
2072
+ self.test_data = preprocessor.test_data
2073
+ self.train_oht_scl_data = preprocessor.train_oht_scl_data
2074
+ self.test_oht_scl_data = preprocessor.test_oht_scl_data
2075
+ self.var_nmes = preprocessor.var_nmes
2076
+ self.num_features = preprocessor.num_features
2077
+ self.cat_categories_for_shap = preprocessor.cat_categories_for_shap
2078
+
2079
+ self.cv = ShuffleSplit(n_splits=int(1/self.prop_test),
2080
+ test_size=self.prop_test,
2081
+ random_state=self.rand_seed)
2082
+ if self.task_type == 'classification':
2083
+ self.obj = 'binary:logistic'
2084
+ else: # regression
2085
+ if 'f' in self.model_nme:
2086
+ self.obj = 'count:poisson'
2087
+ elif 's' in self.model_nme:
2088
+ self.obj = 'reg:gamma'
2089
+ elif 'bc' in self.model_nme:
2090
+ self.obj = 'reg:tweedie'
2091
+ else:
2092
+ self.obj = 'reg:tweedie'
2093
+ self.fit_params = {
2094
+ 'sample_weight': self.train_data[self.weight_nme].values
2095
+ }
2096
+ self.model_label: List[str] = []
2097
+
2098
+ # 记录各模型训练器,后续统一通过标签访问,方便扩展新模型
2099
+ self.trainers: Dict[str, TrainerBase] = {
2100
+ 'glm': GLMTrainer(self),
2101
+ 'xgb': XGBTrainer(self),
2102
+ 'resn': ResNetTrainer(self),
2103
+ 'ft': FTTrainer(self)
2104
+ }
2105
+ self.xgb_best = None
2106
+ self.resn_best = None
2107
+ self.glm_best = None
2108
+ self.ft_best = None
2109
+ self.best_xgb_params = None
2110
+ self.best_resn_params = None
2111
+ self.best_ft_params = None
2112
+ self.best_xgb_trial = None
2113
+ self.best_resn_trial = None
2114
+ self.best_ft_trial = None
2115
+ self.best_glm_params = None
2116
+ self.best_glm_trial = None
2117
+ self.xgb_load = None
2118
+ self.resn_load = None
2119
+ self.ft_load = None
2120
+
2121
+ # 定义单因素画图函数
2122
+ def plot_oneway(self, n_bins=10):
2123
+ for c in self.factor_nmes:
2124
+ fig = plt.figure(figsize=(7, 5))
2125
+ if c in self.cate_list:
2126
+ group_col = c
2127
+ plot_source = self.train_data
2128
+ else:
2129
+ group_col = f'{c}_bins'
2130
+ bins = pd.qcut(
2131
+ self.train_data[c],
2132
+ n_bins,
2133
+ duplicates='drop' # 注意:如果分位数重复会丢 bin,避免异常终止
2134
+ )
2135
+ plot_source = self.train_data.assign(**{group_col: bins})
2136
+ plot_data = plot_source.groupby(
2137
+ [group_col], observed=True).sum(numeric_only=True)
2138
+ plot_data.reset_index(inplace=True)
2139
+ plot_data['act_v'] = plot_data['w_act'] / \
2140
+ plot_data[self.weight_nme]
2141
+ plot_data.head()
2142
+ ax = fig.add_subplot(111)
2143
+ ax.plot(plot_data.index, plot_data['act_v'],
2144
+ label='Actual', color='red')
2145
+ ax.set_title(
2146
+ 'Analysis of %s : Train Data' % group_col,
2147
+ fontsize=8)
2148
+ plt.xticks(plot_data.index,
2149
+ list(plot_data[group_col].astype(str)),
2150
+ rotation=90)
2151
+ if len(list(plot_data[group_col].astype(str))) > 50:
2152
+ plt.xticks(fontsize=3)
2153
+ else:
2154
+ plt.xticks(fontsize=6)
2155
+ plt.yticks(fontsize=6)
2156
+ ax2 = ax.twinx()
2157
+ ax2.bar(plot_data.index,
2158
+ plot_data[self.weight_nme],
2159
+ alpha=0.5, color='seagreen')
2160
+ plt.yticks(fontsize=6)
2161
+ plt.margins(0.05)
2162
+ plt.subplots_adjust(wspace=0.3)
2163
+ save_path = self.output_manager.plot_path(
2164
+ f'00_{self.model_nme}_{group_col}_oneway.png')
2165
+ plt.savefig(save_path, dpi=300)
2166
+ plt.close(fig)
2167
+
2168
+ # 定义通用优化函数
2169
+ def optimize_model(self, model_key: str, max_evals: int = 100):
2170
+ if model_key not in self.trainers:
2171
+ print(f"Warning: Unknown model key: {model_key}")
2172
+ return
2173
+
2174
+ trainer = self.trainers[model_key]
2175
+ trainer.tune(max_evals)
2176
+ trainer.train()
2177
+
2178
+ # Update context attributes for backward compatibility
2179
+ setattr(self, f"{model_key}_best", trainer.model)
2180
+ setattr(self, f"best_{model_key}_params", trainer.best_params)
2181
+ setattr(self, f"best_{model_key}_trial", trainer.best_trial)
2182
+
2183
+ # 定义GLM贝叶斯优化函数
2184
+ def bayesopt_glm(self, max_evals=50):
2185
+ self.optimize_model('glm', max_evals)
2186
+
2187
+ # 定义Xgboost贝叶斯优化函数
2188
+ def bayesopt_xgb(self, max_evals=100):
2189
+ self.optimize_model('xgb', max_evals)
2190
+
2191
+ # 定义ResNet贝叶斯优化函数
2192
+ def bayesopt_resnet(self, max_evals=100):
2193
+ self.optimize_model('resn', max_evals)
2194
+
2195
+ # 定义 FT-Transformer 贝叶斯优化函数
2196
+ def bayesopt_ft(self, max_evals=50):
2197
+ self.optimize_model('ft', max_evals)
2198
+
2199
+ # 绘制提纯曲线
2200
+ def plot_lift(self, model_label, pred_nme, n_bins=10):
2201
+ model_map = {
2202
+ 'Xgboost': 'pred_xgb',
2203
+ 'ResNet': 'pred_resn',
2204
+ 'ResNetClassifier': 'pred_resn',
2205
+ 'FTTransformer': 'pred_ft',
2206
+ 'FTTransformerClassifier': 'pred_ft',
2207
+ 'GLM': 'pred_glm'
2208
+ }
2209
+ for k, v in model_map.items():
2210
+ if model_label.startswith(k):
2211
+ pred_nme = v
2212
+ break
2213
+
2214
+ fig = plt.figure(figsize=(11, 5))
2215
+ for pos, (title, data) in zip([121, 122],
2216
+ [('Lift Chart on Train Data', self.train_data),
2217
+ ('Lift Chart on Test Data', self.test_data)]):
2218
+ lift_df = pd.DataFrame({
2219
+ 'pred': data[pred_nme].values,
2220
+ 'w_pred': data[f'w_{pred_nme}'].values,
2221
+ 'act': data['w_act'].values,
2222
+ 'weight': data[self.weight_nme].values
2223
+ })
2224
+ plot_data = PlotUtils.split_data(lift_df, 'pred', 'weight', n_bins)
2225
+ denom = np.maximum(plot_data['weight'], EPS)
2226
+ plot_data['exp_v'] = plot_data['w_pred'] / denom
2227
+ plot_data['act_v'] = plot_data['act'] / denom
2228
+ plot_data = plot_data.reset_index()
2229
+
2230
+ ax = fig.add_subplot(pos)
2231
+ PlotUtils.plot_lift_ax(ax, plot_data, title)
2232
+
2233
+ plt.subplots_adjust(wspace=0.3)
2234
+ save_path = self.output_manager.plot_path(
2235
+ f'01_{self.model_nme}_{model_label}_lift.png')
2236
+ plt.savefig(save_path, dpi=300)
2237
+ plt.show()
2238
+ plt.close(fig)
2239
+
2240
+ # 绘制双提纯曲线
2241
+ def plot_dlift(self, model_comp: List[str] = ['xgb', 'resn'], n_bins: int = 10) -> None:
2242
+ # 绘制双提纯曲线,对比两个模型在不同分箱下的表现。
2243
+ # Args:
2244
+ # model_comp: 需要对比的模型简称(如 ['xgb', 'resn'],支持 'xgb'/'resn'/'ft')。
2245
+ # n_bins: 分箱数量,用于控制 lift 曲线的粒度。
2246
+ if len(model_comp) != 2:
2247
+ raise ValueError("`model_comp` 必须包含两个模型进行对比。")
2248
+
2249
+ model_name_map = {
2250
+ 'xgb': 'Xgboost',
2251
+ 'resn': 'ResNet',
2252
+ 'ft': 'FTTransformer',
2253
+ 'glm': 'GLM'
2254
+ }
2255
+
2256
+ name1, name2 = model_comp
2257
+ if name1 not in model_name_map or name2 not in model_name_map:
2258
+ raise ValueError(f"不支持的模型简称。请从 {list(model_name_map.keys())} 中选择。")
2259
+
2260
+ fig, axes = plt.subplots(1, 2, figsize=(11, 5))
2261
+ datasets = {
2262
+ 'Train Data': self.train_data,
2263
+ 'Test Data': self.test_data
2264
+ }
2265
+
2266
+ for ax, (data_name, data) in zip(axes, datasets.items()):
2267
+ pred1_col = f'w_pred_{name1}'
2268
+ pred2_col = f'w_pred_{name2}'
2269
+
2270
+ if pred1_col not in data.columns or pred2_col not in data.columns:
2271
+ print(
2272
+ f"警告: 在 {data_name} 中找不到预测列 {pred1_col} 或 {pred2_col}。跳过绘图。")
2273
+ continue
2274
+
2275
+ lift_data = pd.DataFrame({
2276
+ 'pred1': data[pred1_col].values,
2277
+ 'pred2': data[pred2_col].values,
2278
+ 'diff_ly': data[pred1_col].values / np.maximum(data[pred2_col].values, EPS),
2279
+ 'act': data['w_act'].values,
2280
+ 'weight': data[self.weight_nme].values
2281
+ })
2282
+ plot_data = PlotUtils.split_data(
2283
+ lift_data, 'diff_ly', 'weight', n_bins)
2284
+ denom = np.maximum(plot_data['act'], EPS)
2285
+ plot_data['exp_v1'] = plot_data['pred1'] / denom
2286
+ plot_data['exp_v2'] = plot_data['pred2'] / denom
2287
+ plot_data['act_v'] = plot_data['act'] / denom
2288
+ plot_data.reset_index(inplace=True)
2289
+
2290
+ label1 = model_name_map[name1]
2291
+ label2 = model_name_map[name2]
2292
+
2293
+ PlotUtils.plot_dlift_ax(ax, plot_data, f'Double Lift Chart on {data_name}', label1, label2)
2294
+
2295
+ plt.subplots_adjust(bottom=0.25, top=0.95, right=0.8, wspace=0.3)
2296
+ save_path = self.output_manager.plot_path(
2297
+ f'02_{self.model_nme}_dlift_{name1}_vs_{name2}.png')
2298
+ plt.savefig(save_path, dpi=300)
2299
+ plt.show()
2300
+ plt.close(fig)
2301
+
2302
+ # 绘制成交率提升曲线
2303
+ def plot_conversion_lift(self, model_pred_col: str, n_bins: int = 20):
2304
+ if not self.binary_resp_nme:
2305
+ print("错误: 未在 BayesOptModel 初始化时提供 `binary_resp_nme`。无法绘制成交率曲线。")
2306
+ return
2307
+
2308
+ fig, axes = plt.subplots(1, 2, figsize=(14, 6), sharey=True)
2309
+ datasets = {
2310
+ 'Train Data': self.train_data,
2311
+ 'Test Data': self.test_data
2312
+ }
2313
+
2314
+ for ax, (data_name, data) in zip(axes, datasets.items()):
2315
+ if model_pred_col not in data.columns:
2316
+ print(f"警告: 在 {data_name} 中找不到预测列 '{model_pred_col}'。跳过绘图。")
2317
+ continue
2318
+
2319
+ # 按模型预测分排序,并计算分箱
2320
+ plot_data = data.sort_values(by=model_pred_col).copy()
2321
+ plot_data['cum_weight'] = plot_data[self.weight_nme].cumsum()
2322
+ total_weight = plot_data[self.weight_nme].sum()
2323
+
2324
+ if total_weight > EPS:
2325
+ plot_data['bin'] = pd.cut(
2326
+ plot_data['cum_weight'],
2327
+ bins=n_bins,
2328
+ labels=False,
2329
+ right=False
2330
+ )
2331
+ else:
2332
+ plot_data['bin'] = 0
2333
+
2334
+ # 按分箱聚合
2335
+ lift_agg = plot_data.groupby('bin').agg(
2336
+ total_weight=(self.weight_nme, 'sum'),
2337
+ actual_conversions=(self.binary_resp_nme, 'sum'),
2338
+ weighted_conversions=('w_binary_act', 'sum'),
2339
+ avg_pred=(model_pred_col, 'mean')
2340
+ ).reset_index()
2341
+
2342
+ # 计算成交率
2343
+ lift_agg['conversion_rate'] = lift_agg['weighted_conversions'] / \
2344
+ lift_agg['total_weight']
2345
+
2346
+ # 计算整体平均成交率
2347
+ overall_conversion_rate = data['w_binary_act'].sum(
2348
+ ) / data[self.weight_nme].sum()
2349
+ ax.axhline(y=overall_conversion_rate, color='gray', linestyle='--',
2350
+ label=f'Overall Avg Rate ({overall_conversion_rate:.2%})')
2351
+
2352
+ ax.plot(lift_agg['bin'], lift_agg['conversion_rate'],
2353
+ marker='o', linestyle='-', label='Actual Conversion Rate')
2354
+ ax.set_title(f'Conversion Rate Lift Chart on {data_name}')
2355
+ ax.set_xlabel(f'Model Score Decile (based on {model_pred_col})')
2356
+ ax.set_ylabel('Conversion Rate')
2357
+ ax.grid(True, linestyle='--', alpha=0.6)
2358
+ ax.legend()
2359
+
2360
+ plt.tight_layout()
2361
+ plt.show()
2362
+
2363
+ # 保存模型
2364
+ def save_model(self, model_name=None):
2365
+ keys = [model_name] if model_name else self.trainers.keys()
2366
+ for key in keys:
2367
+ if key in self.trainers:
2368
+ self.trainers[key].save()
2369
+ else:
2370
+ if model_name: # Only warn if specific model requested
2371
+ print(f"[save_model] Warning: Unknown model key {key}")
2372
+
2373
+ def load_model(self, model_name=None):
2374
+ keys = [model_name] if model_name else self.trainers.keys()
2375
+ for key in keys:
2376
+ if key in self.trainers:
2377
+ self.trainers[key].load()
2378
+ # Update context attributes
2379
+ trainer = self.trainers[key]
2380
+ if trainer.model is not None:
2381
+ setattr(self, f"{key}_best", trainer.model)
2382
+ # Also update xxx_load for backward compatibility if needed
2383
+ # Original code had xgb_load, resn_load, ft_load but not glm_load
2384
+ if key in ['xgb', 'resn', 'ft']:
2385
+ setattr(self, f"{key}_load", trainer.model)
2386
+ else:
2387
+ if model_name:
2388
+ print(f"[load_model] Warning: Unknown model key {key}")
2389
+
2390
+ def _sample_rows(self, data: pd.DataFrame, n: int) -> pd.DataFrame:
2391
+ if len(data) == 0:
2392
+ return data
2393
+ return data.sample(min(len(data), n), random_state=self.rand_seed)
2394
+
2395
+ @staticmethod
2396
+ def _shap_nsamples(arr: np.ndarray, max_nsamples: int = 300) -> int:
2397
+ min_needed = arr.shape[1] + 2
2398
+ return max(min_needed, min(max_nsamples, arr.shape[0] * arr.shape[1]))
2399
+
2400
+ def _build_ft_shap_matrix(self, data: pd.DataFrame) -> np.ndarray:
2401
+
2402
+ # 将原始特征 DataFrame (包含 self.factor_nmes) 转成
2403
+ # 纯数值矩阵: 数值列为 float64,类别列为整数 code(float64 存储)。
2404
+ # 列顺序与 self.factor_nmes 保持一致。
2405
+
2406
+ matrices = []
2407
+
2408
+ for col in self.factor_nmes:
2409
+ s = data[col]
2410
+
2411
+ if col in self.cate_list:
2412
+ # 类别列:按训练时的类别全集编码
2413
+ cats = pd.Categorical(
2414
+ s,
2415
+ categories=self.cat_categories_for_shap[col]
2416
+ )
2417
+ # cats.codes 是一个 Index / ndarray,用 np.asarray 包一下再 reshape
2418
+ codes = np.asarray(cats.codes, dtype=np.float64).reshape(-1, 1)
2419
+ matrices.append(codes)
2420
+ else:
2421
+ # 数值列:转成 Series -> numpy -> reshape
2422
+ vals = pd.to_numeric(s, errors="coerce")
2423
+ arr = vals.to_numpy(dtype=np.float64, copy=True).reshape(-1, 1)
2424
+ matrices.append(arr)
2425
+
2426
+ X_mat = np.concatenate(matrices, axis=1) # (N, F)
2427
+ return X_mat
2428
+
2429
+ def _decode_ft_shap_matrix_to_df(self, X_mat: np.ndarray) -> pd.DataFrame:
2430
+
2431
+ # 将 SHAP 的数值矩阵 (N, F) 还原为原始特征 DataFrame,
2432
+ # 数值列为 float,类别列还原为 pandas 的 category 类型,
2433
+ # 以便兼容 enable_categorical=True 的 XGBoost 和 FT-Transformer 的输入。
2434
+ # 列顺序 = self.factor_nmes
2435
+
2436
+ data_dict = {}
2437
+
2438
+ for j, col in enumerate(self.factor_nmes):
2439
+ col_vals = X_mat[:, j]
2440
+
2441
+ if col in self.cate_list:
2442
+ cats = self.cat_categories_for_shap[col]
2443
+
2444
+ # SHAP 会扰动成小数,这里 round 回整数 code
2445
+ codes = np.round(col_vals).astype(int)
2446
+ # 限制在 [-1, len(cats)-1]
2447
+ codes = np.clip(codes, -1, len(cats) - 1)
2448
+
2449
+ # 使用 pandas.Categorical.from_codes:
2450
+ # - codes = -1 被当成缺失 (NaN)
2451
+ # - 其他索引映射到 cats 中对应的类别
2452
+ cat_series = pd.Categorical.from_codes(
2453
+ codes,
2454
+ categories=cats
2455
+ )
2456
+ # 存的是 Categorical 类型,而不是 object
2457
+ data_dict[col] = cat_series
2458
+ else:
2459
+ # 数值列:直接 float
2460
+ data_dict[col] = col_vals.astype(float)
2461
+
2462
+ df = pd.DataFrame(data_dict, columns=self.factor_nmes)
2463
+
2464
+ # 再保险:确保所有类别列 dtype 真的是 category
2465
+ for col in self.cate_list:
2466
+ if col in df.columns:
2467
+ df[col] = df[col].astype("category")
2468
+ return df
2469
+
2470
+ def _build_glm_design(self, data: pd.DataFrame) -> pd.DataFrame:
2471
+ # 与 GLM 训练阶段一致:在 one-hot + 标准化特征上添加截距
2472
+ X = data[self.var_nmes]
2473
+ return sm.add_constant(X, has_constant='add')
2474
+
2475
+ def _compute_shap_core(self,
2476
+ model_key: str,
2477
+ n_background: int,
2478
+ n_samples: int,
2479
+ on_train: bool,
2480
+ X_df: pd.DataFrame,
2481
+ prep_fn,
2482
+ predict_fn,
2483
+ cleanup_fn=None):
2484
+ # 通用的 SHAP 计算核心逻辑:配置背景样本、构建解释器并返回结果。
2485
+ if model_key not in self.trainers or self.trainers[model_key].model is None:
2486
+ raise RuntimeError(f"Model {model_key} not trained.")
2487
+
2488
+ if cleanup_fn:
2489
+ cleanup_fn()
2490
+
2491
+ # Background
2492
+ bg_df = self._sample_rows(X_df, n_background)
2493
+ bg_mat = prep_fn(bg_df)
2494
+
2495
+ # Explainer
2496
+ explainer = shap.KernelExplainer(predict_fn, bg_mat)
2497
+
2498
+ # Explain data
2499
+ ex_df = self._sample_rows(X_df, n_samples)
2500
+ ex_mat = prep_fn(ex_df)
2501
+
2502
+ nsample_eff = self._shap_nsamples(ex_mat)
2503
+ shap_values = explainer.shap_values(ex_mat, nsamples=nsample_eff)
2504
+
2505
+ # Base value
2506
+ bg_pred = predict_fn(bg_mat)
2507
+ base_value = float(np.asarray(bg_pred).mean())
2508
+
2509
+ return {
2510
+ "explainer": explainer,
2511
+ "X_explain": ex_df,
2512
+ "shap_values": shap_values,
2513
+ "base_value": base_value
2514
+ }
2515
+
2516
+ # ========= XGBoost SHAP =========
2517
+ def compute_shap_xgb(self, n_background: int = 500,
2518
+ n_samples: int = 200,
2519
+ on_train: bool = True):
2520
+ data = self.train_data if on_train else self.test_data
2521
+ X_raw = data[self.factor_nmes]
2522
+
2523
+ def predict_wrapper(x_mat):
2524
+ df_input = self._decode_ft_shap_matrix_to_df(x_mat)
2525
+ return self.xgb_best.predict(df_input)
2526
+
2527
+ self.shap_xgb = self._compute_shap_core(
2528
+ 'xgb', n_background, n_samples, on_train,
2529
+ X_df=X_raw,
2530
+ prep_fn=lambda df: self._build_ft_shap_matrix(df).astype(np.float64),
2531
+ predict_fn=predict_wrapper
2532
+ )
2533
+ return self.shap_xgb
2534
+
2535
+ # ========= ResNet SHAP =========
2536
+ def _resn_predict_wrapper(self, X_np):
2537
+ # 保证走 CPU
2538
+ model = self.resn_best.resnet.to("cpu")
2539
+ with torch.no_grad():
2540
+ X_tensor = torch.tensor(X_np, dtype=torch.float32)
2541
+ y_pred = model(X_tensor).cpu().numpy()
2542
+ y_pred = np.clip(y_pred, 1e-6, None)
2543
+ return y_pred.reshape(-1)
2544
+
2545
+ def compute_shap_resn(self, n_background: int = 500,
2546
+ n_samples: int = 200,
2547
+ on_train: bool = True):
2548
+ data = self.train_oht_scl_data if on_train else self.test_oht_scl_data
2549
+ X = data[self.var_nmes]
2550
+
2551
+ def cleanup():
2552
+ self.resn_best.device = torch.device("cpu")
2553
+ self.resn_best.resnet.to("cpu")
2554
+ if torch.cuda.is_available():
2555
+ torch.cuda.empty_cache()
2556
+
2557
+ self.shap_resn = self._compute_shap_core(
2558
+ 'resn', n_background, n_samples, on_train,
2559
+ X_df=X,
2560
+ prep_fn=lambda df: df.to_numpy(dtype=np.float64),
2561
+ predict_fn=lambda x: self._resn_predict_wrapper(x),
2562
+ cleanup_fn=cleanup
2563
+ )
2564
+ return self.shap_resn
2565
+
2566
+ # ========= FT-Transformer SHAP =========
2567
+ def _ft_shap_predict_wrapper(self, X_mat: np.ndarray) -> np.ndarray:
2568
+ df_input = self._decode_ft_shap_matrix_to_df(X_mat)
2569
+ y_pred = self.ft_best.predict(df_input)
2570
+ return np.asarray(y_pred, dtype=np.float64).reshape(-1)
2571
+
2572
+ def compute_shap_ft(self, n_background: int = 500,
2573
+ n_samples: int = 200,
2574
+ on_train: bool = True):
2575
+ data = self.train_data if on_train else self.test_data
2576
+ X_raw = data[self.factor_nmes]
2577
+
2578
+ def cleanup():
2579
+ self.ft_best.device = torch.device("cpu")
2580
+ self.ft_best.ft.to("cpu")
2581
+ if torch.cuda.is_available():
2582
+ torch.cuda.empty_cache()
2583
+
2584
+ self.shap_ft = self._compute_shap_core(
2585
+ 'ft', n_background, n_samples, on_train,
2586
+ X_df=X_raw,
2587
+ prep_fn=lambda df: self._build_ft_shap_matrix(df).astype(np.float64),
2588
+ predict_fn=self._ft_shap_predict_wrapper,
2589
+ cleanup_fn=cleanup
2590
+ )
2591
+ return self.shap_ft
2592
+
2593
+ # ========= GLM SHAP =========
2594
+ def compute_shap_glm(self, n_background: int = 500,
2595
+ n_samples: int = 200,
2596
+ on_train: bool = True):
2597
+ data = self.train_oht_scl_data if on_train else self.test_oht_scl_data
2598
+ design_all = self._build_glm_design(data)
2599
+ design_cols = list(design_all.columns)
2600
+
2601
+ def predict_wrapper(x_np):
2602
+ x_df = pd.DataFrame(x_np, columns=design_cols)
2603
+ y_pred = self.glm_best.predict(x_df)
2604
+ return np.asarray(y_pred, dtype=np.float64).reshape(-1)
2605
+
2606
+ res = self._compute_shap_core(
2607
+ 'glm', n_background, n_samples, on_train,
2608
+ X_df=design_all,
2609
+ prep_fn=lambda df: df.to_numpy(dtype=np.float64),
2610
+ predict_fn=predict_wrapper
2611
+ )
2612
+ res['design_columns'] = design_cols
2613
+ self.shap_glm = res
2614
+ return self.shap_glm