ins-pricing 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. ins_pricing/README.md +60 -0
  2. ins_pricing/__init__.py +102 -0
  3. ins_pricing/governance/README.md +18 -0
  4. ins_pricing/governance/__init__.py +20 -0
  5. ins_pricing/governance/approval.py +93 -0
  6. ins_pricing/governance/audit.py +37 -0
  7. ins_pricing/governance/registry.py +99 -0
  8. ins_pricing/governance/release.py +159 -0
  9. ins_pricing/modelling/BayesOpt.py +146 -0
  10. ins_pricing/modelling/BayesOpt_USAGE.md +925 -0
  11. ins_pricing/modelling/BayesOpt_entry.py +575 -0
  12. ins_pricing/modelling/BayesOpt_incremental.py +731 -0
  13. ins_pricing/modelling/Explain_Run.py +36 -0
  14. ins_pricing/modelling/Explain_entry.py +539 -0
  15. ins_pricing/modelling/Pricing_Run.py +36 -0
  16. ins_pricing/modelling/README.md +33 -0
  17. ins_pricing/modelling/__init__.py +44 -0
  18. ins_pricing/modelling/bayesopt/__init__.py +98 -0
  19. ins_pricing/modelling/bayesopt/config_preprocess.py +303 -0
  20. ins_pricing/modelling/bayesopt/core.py +1476 -0
  21. ins_pricing/modelling/bayesopt/models.py +2196 -0
  22. ins_pricing/modelling/bayesopt/trainers.py +2446 -0
  23. ins_pricing/modelling/bayesopt/utils.py +1021 -0
  24. ins_pricing/modelling/cli_common.py +136 -0
  25. ins_pricing/modelling/explain/__init__.py +55 -0
  26. ins_pricing/modelling/explain/gradients.py +334 -0
  27. ins_pricing/modelling/explain/metrics.py +176 -0
  28. ins_pricing/modelling/explain/permutation.py +155 -0
  29. ins_pricing/modelling/explain/shap_utils.py +146 -0
  30. ins_pricing/modelling/notebook_utils.py +284 -0
  31. ins_pricing/modelling/plotting/__init__.py +45 -0
  32. ins_pricing/modelling/plotting/common.py +63 -0
  33. ins_pricing/modelling/plotting/curves.py +572 -0
  34. ins_pricing/modelling/plotting/diagnostics.py +139 -0
  35. ins_pricing/modelling/plotting/geo.py +362 -0
  36. ins_pricing/modelling/plotting/importance.py +121 -0
  37. ins_pricing/modelling/run_logging.py +133 -0
  38. ins_pricing/modelling/tests/conftest.py +8 -0
  39. ins_pricing/modelling/tests/test_cross_val_generic.py +66 -0
  40. ins_pricing/modelling/tests/test_distributed_utils.py +18 -0
  41. ins_pricing/modelling/tests/test_explain.py +56 -0
  42. ins_pricing/modelling/tests/test_geo_tokens_split.py +49 -0
  43. ins_pricing/modelling/tests/test_graph_cache.py +33 -0
  44. ins_pricing/modelling/tests/test_plotting.py +63 -0
  45. ins_pricing/modelling/tests/test_plotting_library.py +150 -0
  46. ins_pricing/modelling/tests/test_preprocessor.py +48 -0
  47. ins_pricing/modelling/watchdog_run.py +211 -0
  48. ins_pricing/pricing/README.md +44 -0
  49. ins_pricing/pricing/__init__.py +27 -0
  50. ins_pricing/pricing/calibration.py +39 -0
  51. ins_pricing/pricing/data_quality.py +117 -0
  52. ins_pricing/pricing/exposure.py +85 -0
  53. ins_pricing/pricing/factors.py +91 -0
  54. ins_pricing/pricing/monitoring.py +99 -0
  55. ins_pricing/pricing/rate_table.py +78 -0
  56. ins_pricing/production/__init__.py +21 -0
  57. ins_pricing/production/drift.py +30 -0
  58. ins_pricing/production/monitoring.py +143 -0
  59. ins_pricing/production/scoring.py +40 -0
  60. ins_pricing/reporting/README.md +20 -0
  61. ins_pricing/reporting/__init__.py +11 -0
  62. ins_pricing/reporting/report_builder.py +72 -0
  63. ins_pricing/reporting/scheduler.py +45 -0
  64. ins_pricing/setup.py +41 -0
  65. ins_pricing v2/__init__.py +23 -0
  66. ins_pricing v2/governance/__init__.py +20 -0
  67. ins_pricing v2/governance/approval.py +93 -0
  68. ins_pricing v2/governance/audit.py +37 -0
  69. ins_pricing v2/governance/registry.py +99 -0
  70. ins_pricing v2/governance/release.py +159 -0
  71. ins_pricing v2/modelling/Explain_Run.py +36 -0
  72. ins_pricing v2/modelling/Pricing_Run.py +36 -0
  73. ins_pricing v2/modelling/__init__.py +151 -0
  74. ins_pricing v2/modelling/cli_common.py +141 -0
  75. ins_pricing v2/modelling/config.py +249 -0
  76. ins_pricing v2/modelling/config_preprocess.py +254 -0
  77. ins_pricing v2/modelling/core.py +741 -0
  78. ins_pricing v2/modelling/data_container.py +42 -0
  79. ins_pricing v2/modelling/explain/__init__.py +55 -0
  80. ins_pricing v2/modelling/explain/gradients.py +334 -0
  81. ins_pricing v2/modelling/explain/metrics.py +176 -0
  82. ins_pricing v2/modelling/explain/permutation.py +155 -0
  83. ins_pricing v2/modelling/explain/shap_utils.py +146 -0
  84. ins_pricing v2/modelling/features.py +215 -0
  85. ins_pricing v2/modelling/model_manager.py +148 -0
  86. ins_pricing v2/modelling/model_plotting.py +463 -0
  87. ins_pricing v2/modelling/models.py +2203 -0
  88. ins_pricing v2/modelling/notebook_utils.py +294 -0
  89. ins_pricing v2/modelling/plotting/__init__.py +45 -0
  90. ins_pricing v2/modelling/plotting/common.py +63 -0
  91. ins_pricing v2/modelling/plotting/curves.py +572 -0
  92. ins_pricing v2/modelling/plotting/diagnostics.py +139 -0
  93. ins_pricing v2/modelling/plotting/geo.py +362 -0
  94. ins_pricing v2/modelling/plotting/importance.py +121 -0
  95. ins_pricing v2/modelling/run_logging.py +133 -0
  96. ins_pricing v2/modelling/tests/conftest.py +8 -0
  97. ins_pricing v2/modelling/tests/test_cross_val_generic.py +66 -0
  98. ins_pricing v2/modelling/tests/test_distributed_utils.py +18 -0
  99. ins_pricing v2/modelling/tests/test_explain.py +56 -0
  100. ins_pricing v2/modelling/tests/test_geo_tokens_split.py +49 -0
  101. ins_pricing v2/modelling/tests/test_graph_cache.py +33 -0
  102. ins_pricing v2/modelling/tests/test_plotting.py +63 -0
  103. ins_pricing v2/modelling/tests/test_plotting_library.py +150 -0
  104. ins_pricing v2/modelling/tests/test_preprocessor.py +48 -0
  105. ins_pricing v2/modelling/trainers.py +2447 -0
  106. ins_pricing v2/modelling/utils.py +1020 -0
  107. ins_pricing v2/modelling/watchdog_run.py +211 -0
  108. ins_pricing v2/pricing/__init__.py +27 -0
  109. ins_pricing v2/pricing/calibration.py +39 -0
  110. ins_pricing v2/pricing/data_quality.py +117 -0
  111. ins_pricing v2/pricing/exposure.py +85 -0
  112. ins_pricing v2/pricing/factors.py +91 -0
  113. ins_pricing v2/pricing/monitoring.py +99 -0
  114. ins_pricing v2/pricing/rate_table.py +78 -0
  115. ins_pricing v2/production/__init__.py +21 -0
  116. ins_pricing v2/production/drift.py +30 -0
  117. ins_pricing v2/production/monitoring.py +143 -0
  118. ins_pricing v2/production/scoring.py +40 -0
  119. ins_pricing v2/reporting/__init__.py +11 -0
  120. ins_pricing v2/reporting/report_builder.py +72 -0
  121. ins_pricing v2/reporting/scheduler.py +45 -0
  122. ins_pricing v2/scripts/BayesOpt_incremental.py +722 -0
  123. ins_pricing v2/scripts/Explain_entry.py +545 -0
  124. ins_pricing v2/scripts/__init__.py +1 -0
  125. ins_pricing v2/scripts/train.py +568 -0
  126. ins_pricing v2/setup.py +55 -0
  127. ins_pricing v2/smoke_test.py +28 -0
  128. ins_pricing-0.1.6.dist-info/METADATA +78 -0
  129. ins_pricing-0.1.6.dist-info/RECORD +169 -0
  130. ins_pricing-0.1.6.dist-info/WHEEL +5 -0
  131. ins_pricing-0.1.6.dist-info/top_level.txt +4 -0
  132. user_packages/__init__.py +105 -0
  133. user_packages legacy/BayesOpt.py +5659 -0
  134. user_packages legacy/BayesOpt_entry.py +513 -0
  135. user_packages legacy/BayesOpt_incremental.py +685 -0
  136. user_packages legacy/Pricing_Run.py +36 -0
  137. user_packages legacy/Try/BayesOpt Legacy251213.py +3719 -0
  138. user_packages legacy/Try/BayesOpt Legacy251215.py +3758 -0
  139. user_packages legacy/Try/BayesOpt lagecy251201.py +3506 -0
  140. user_packages legacy/Try/BayesOpt lagecy251218.py +3992 -0
  141. user_packages legacy/Try/BayesOpt legacy.py +3280 -0
  142. user_packages legacy/Try/BayesOpt.py +838 -0
  143. user_packages legacy/Try/BayesOptAll.py +1569 -0
  144. user_packages legacy/Try/BayesOptAllPlatform.py +909 -0
  145. user_packages legacy/Try/BayesOptCPUGPU.py +1877 -0
  146. user_packages legacy/Try/BayesOptSearch.py +830 -0
  147. user_packages legacy/Try/BayesOptSearchOrigin.py +829 -0
  148. user_packages legacy/Try/BayesOptV1.py +1911 -0
  149. user_packages legacy/Try/BayesOptV10.py +2973 -0
  150. user_packages legacy/Try/BayesOptV11.py +3001 -0
  151. user_packages legacy/Try/BayesOptV12.py +3001 -0
  152. user_packages legacy/Try/BayesOptV2.py +2065 -0
  153. user_packages legacy/Try/BayesOptV3.py +2209 -0
  154. user_packages legacy/Try/BayesOptV4.py +2342 -0
  155. user_packages legacy/Try/BayesOptV5.py +2372 -0
  156. user_packages legacy/Try/BayesOptV6.py +2759 -0
  157. user_packages legacy/Try/BayesOptV7.py +2832 -0
  158. user_packages legacy/Try/BayesOptV8Codex.py +2731 -0
  159. user_packages legacy/Try/BayesOptV8Gemini.py +2614 -0
  160. user_packages legacy/Try/BayesOptV9.py +2927 -0
  161. user_packages legacy/Try/BayesOpt_entry legacy.py +313 -0
  162. user_packages legacy/Try/ModelBayesOptSearch.py +359 -0
  163. user_packages legacy/Try/ResNetBayesOptSearch.py +249 -0
  164. user_packages legacy/Try/XgbBayesOptSearch.py +121 -0
  165. user_packages legacy/Try/xgbbayesopt.py +523 -0
  166. user_packages legacy/__init__.py +19 -0
  167. user_packages legacy/cli_common.py +124 -0
  168. user_packages legacy/notebook_utils.py +228 -0
  169. user_packages legacy/watchdog_run.py +202 -0
@@ -0,0 +1,3992 @@
1
+ from sklearn.metrics import log_loss, make_scorer, mean_tweedie_deviance
2
+ from sklearn.preprocessing import StandardScaler
3
+ from sklearn.neighbors import NearestNeighbors
4
+ from sklearn.model_selection import ShuffleSplit, cross_val_score # 1.2.2
5
+ import torch.distributed as dist
6
+ from torch.nn.parallel import DistributedDataParallel as DDP
7
+ from torch.nn.utils import clip_grad_norm_
8
+ from torch.cuda.amp import autocast, GradScaler
9
+ from torch.utils.data import Dataset, DataLoader, TensorDataset, DistributedSampler
10
+ import xgboost as xgb # 1.7.0
11
+ import torch.nn.functional as F
12
+ import torch.nn as nn
13
+ import torch # 版本: 1.10.1+cu111
14
+ import statsmodels.api as sm
15
+ import shap
16
+ import pandas as pd # 2.2.3
17
+ import optuna # 4.3.0
18
+ import numpy as np # 1.26.2
19
+ try:
20
+ from torch_geometric.nn import knn_graph
21
+ from torch_geometric.utils import add_self_loops, to_undirected
22
+ _PYG_AVAILABLE = True
23
+ except Exception:
24
+ knn_graph = None # type: ignore
25
+ add_self_loops = None # type: ignore
26
+ to_undirected = None # type: ignore
27
+ _PYG_AVAILABLE = False
28
+ try:
29
+ import pynndescent
30
+ _PYNN_AVAILABLE = True
31
+ except Exception:
32
+ pynndescent = None # type: ignore
33
+ _PYNN_AVAILABLE = False
34
+ import matplotlib
35
+ import joblib
36
+ import csv
37
+ import json
38
+ import os
39
+ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
40
+ from pathlib import Path
41
+ from dataclasses import dataclass, asdict
42
+ from contextlib import nullcontext
43
+ from datetime import datetime
44
+ import math
45
+ import gc
46
+ import copy
47
+
48
+ if os.name != "nt" and not os.environ.get("DISPLAY") and not os.environ.get("MPLBACKEND"):
49
+ matplotlib.use("Agg")
50
+ import matplotlib.pyplot as plt
51
+ # 通过限制内存块拆分让 PyTorch 分配器减少碎片化,缓解显存 OOM
52
+ # 如需自定义,可通过环境变量 PYTORCH_CUDA_ALLOC_CONF 覆盖默认设置
53
+ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "max_split_size_mb:256")
54
+
55
+
56
+ # 常量与工具模块
57
+ # =============================================================================
58
+ torch.backends.cudnn.benchmark = True
59
+ EPS = 1e-8
60
+
61
+
62
+ class IOUtils:
63
+ # 文件与路径处理的小工具集合。
64
+
65
+ @staticmethod
66
+ def csv_to_dict(file_path: str) -> List[Dict[str, Any]]:
67
+ with open(file_path, mode='r', encoding='utf-8') as file:
68
+ reader = csv.DictReader(file)
69
+ return [
70
+ dict(filter(lambda item: item[0] != '', row.items()))
71
+ for row in reader
72
+ ]
73
+
74
+ @staticmethod
75
+ def ensure_parent_dir(file_path: str) -> None:
76
+ # 若目标文件所在目录不存在则自动创建
77
+ directory = os.path.dirname(file_path)
78
+ if directory:
79
+ os.makedirs(directory, exist_ok=True)
80
+
81
+
82
+ class TrainingUtils:
83
+ # 训练阶段常用的小型辅助函数集合。
84
+
85
+ @staticmethod
86
+ def compute_batch_size(data_size: int, learning_rate: float, batch_num: int, minimum: int) -> int:
87
+ estimated = int((learning_rate / 1e-4) ** 0.5 *
88
+ (data_size / max(batch_num, 1)))
89
+ return max(1, min(data_size, max(minimum, estimated)))
90
+
91
+ @staticmethod
92
+ def tweedie_loss(pred, target, p=1.5, eps=1e-6, max_clip=1e6):
93
+ # 为确保稳定性先将预测值裁剪为正数
94
+ pred_clamped = torch.clamp(pred, min=eps)
95
+ if p == 1:
96
+ term1 = target * torch.log(target / pred_clamped + eps) # 泊松
97
+ term2 = -target + pred_clamped
98
+ term3 = 0
99
+ elif p == 0:
100
+ term1 = 0.5 * torch.pow(target - pred_clamped, 2) # 高斯
101
+ term2 = 0
102
+ term3 = 0
103
+ elif p == 2:
104
+ term1 = torch.log(pred_clamped / target + eps) # 伽马
105
+ term2 = -target / pred_clamped + 1
106
+ term3 = 0
107
+ else:
108
+ term1 = torch.pow(target, 2 - p) / ((1 - p) * (2 - p))
109
+ term2 = target * torch.pow(pred_clamped, 1 - p) / (1 - p)
110
+ term3 = torch.pow(pred_clamped, 2 - p) / (2 - p)
111
+ return torch.nan_to_num( # Tweedie 负对数似然(忽略常数项)
112
+ 2 * (term1 - term2 + term3),
113
+ nan=eps,
114
+ posinf=max_clip,
115
+ neginf=-max_clip
116
+ )
117
+
118
+ @staticmethod
119
+ def free_cuda() -> None:
120
+ print(">>> Moving all models to CPU...")
121
+ for obj in gc.get_objects():
122
+ try:
123
+ if hasattr(obj, "to") and callable(obj.to):
124
+ obj.to("cpu")
125
+ except Exception:
126
+ pass
127
+
128
+ print(">>> Deleting tensors, optimizers, dataloaders...")
129
+ gc.collect()
130
+
131
+ print(">>> Emptying CUDA cache...")
132
+ torch.cuda.empty_cache()
133
+ torch.cuda.synchronize()
134
+
135
+ print(">>> CUDA memory freed.")
136
+
137
+
138
+ class DistributedUtils:
139
+ _cached_state: Optional[tuple] = None
140
+
141
+ @staticmethod
142
+ def setup_ddp():
143
+ """初始化用于分布式训练的 DDP 进程组。"""
144
+ if dist.is_initialized():
145
+ if DistributedUtils._cached_state is None:
146
+ rank = dist.get_rank()
147
+ world_size = dist.get_world_size()
148
+ local_rank = int(os.environ.get("LOCAL_RANK", 0))
149
+ DistributedUtils._cached_state = (
150
+ True,
151
+ local_rank,
152
+ rank,
153
+ world_size,
154
+ )
155
+ return DistributedUtils._cached_state
156
+
157
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
158
+ rank = int(os.environ["RANK"])
159
+ world_size = int(os.environ["WORLD_SIZE"])
160
+ local_rank = int(os.environ["LOCAL_RANK"])
161
+
162
+ if torch.cuda.is_available():
163
+ torch.cuda.set_device(local_rank)
164
+
165
+ dist.init_process_group(backend="nccl", init_method="env://")
166
+ print(
167
+ f">>> DDP Initialized: Rank {rank}/{world_size}, Local Rank {local_rank}")
168
+ DistributedUtils._cached_state = (
169
+ True,
170
+ local_rank,
171
+ rank,
172
+ world_size,
173
+ )
174
+ return DistributedUtils._cached_state
175
+ else:
176
+ print(
177
+ f">>> DDP Setup Failed: RANK or WORLD_SIZE not found in env. Keys found: {list(os.environ.keys())}")
178
+ return False, 0, 0, 1
179
+
180
+ @staticmethod
181
+ def cleanup_ddp():
182
+ """销毁已创建的 DDP 进程组并清理本地缓存状态。"""
183
+ if dist.is_initialized():
184
+ dist.destroy_process_group()
185
+ DistributedUtils._cached_state = None
186
+
187
+ @staticmethod
188
+ def is_main_process():
189
+ return not dist.is_initialized() or dist.get_rank() == 0
190
+
191
+ @staticmethod
192
+ def world_size() -> int:
193
+ return dist.get_world_size() if dist.is_initialized() else 1
194
+
195
+
196
+ class PlotUtils:
197
+ # 多种模型共享的绘图辅助工具。
198
+
199
+ @staticmethod
200
+ def split_data(data: pd.DataFrame, col_nme: str, wgt_nme: str, n_bins: int = 10) -> pd.DataFrame:
201
+ data_sorted = data.sort_values(by=col_nme, ascending=True).copy()
202
+ data_sorted['cum_weight'] = data_sorted[wgt_nme].cumsum()
203
+ w_sum = data_sorted[wgt_nme].sum()
204
+ if w_sum <= EPS:
205
+ data_sorted.loc[:, 'bins'] = 0
206
+ else:
207
+ data_sorted.loc[:, 'bins'] = np.floor(
208
+ data_sorted['cum_weight'] * float(n_bins) / w_sum
209
+ )
210
+ data_sorted.loc[(data_sorted['bins'] == n_bins),
211
+ 'bins'] = n_bins - 1
212
+ return data_sorted.groupby(['bins'], observed=True).sum(numeric_only=True)
213
+
214
+ @staticmethod
215
+ def plot_lift_ax(ax, plot_data, title, pred_label='Predicted', act_label='Actual', weight_label='Earned Exposure'):
216
+ ax.plot(plot_data.index, plot_data['act_v'],
217
+ label=act_label, color='red')
218
+ ax.plot(plot_data.index, plot_data['exp_v'],
219
+ label=pred_label, color='blue')
220
+ ax.set_title(title, fontsize=8)
221
+ ax.set_xticks(plot_data.index)
222
+ ax.set_xticklabels(plot_data.index, rotation=90, fontsize=6)
223
+ ax.tick_params(axis='y', labelsize=6)
224
+ ax.legend(loc='upper left', fontsize=5, frameon=False)
225
+ ax.margins(0.05)
226
+ ax2 = ax.twinx()
227
+ ax2.bar(plot_data.index, plot_data['weight'],
228
+ alpha=0.5, color='seagreen',
229
+ label=weight_label)
230
+ ax2.tick_params(axis='y', labelsize=6)
231
+ ax2.legend(loc='upper right', fontsize=5, frameon=False)
232
+
233
+ @staticmethod
234
+ def plot_dlift_ax(ax, plot_data, title, label1, label2, act_label='Actual', weight_label='Earned Exposure'):
235
+ ax.plot(plot_data.index, plot_data['act_v'],
236
+ label=act_label, color='red')
237
+ ax.plot(plot_data.index, plot_data['exp_v1'],
238
+ label=label1, color='blue')
239
+ ax.plot(plot_data.index, plot_data['exp_v2'],
240
+ label=label2, color='black')
241
+ ax.set_title(title, fontsize=8)
242
+ ax.set_xticks(plot_data.index)
243
+ ax.set_xticklabels(plot_data.index, rotation=90, fontsize=6)
244
+ ax.set_xlabel(f'{label1} / {label2}', fontsize=6)
245
+ ax.tick_params(axis='y', labelsize=6)
246
+ ax.legend(loc='upper left', fontsize=5, frameon=False)
247
+ ax.margins(0.1)
248
+ ax2 = ax.twinx()
249
+ ax2.bar(plot_data.index, plot_data['weight'],
250
+ alpha=0.5, color='seagreen',
251
+ label=weight_label)
252
+ ax2.tick_params(axis='y', labelsize=6)
253
+ ax2.legend(loc='upper right', fontsize=5, frameon=False)
254
+
255
+ @staticmethod
256
+ def plot_lift_list(pred_model, w_pred_list, w_act_list,
257
+ weight_list, tgt_nme, n_bins: int = 10,
258
+ fig_nme: str = 'Lift Chart'):
259
+ lift_data = pd.DataFrame()
260
+ lift_data.loc[:, 'pred'] = pred_model
261
+ lift_data.loc[:, 'w_pred'] = w_pred_list
262
+ lift_data.loc[:, 'act'] = w_act_list
263
+ lift_data.loc[:, 'weight'] = weight_list
264
+ plot_data = PlotUtils.split_data(lift_data, 'pred', 'weight', n_bins)
265
+ plot_data['exp_v'] = plot_data['w_pred'] / plot_data['weight']
266
+ plot_data['act_v'] = plot_data['act'] / plot_data['weight']
267
+ plot_data.reset_index(inplace=True)
268
+
269
+ fig = plt.figure(figsize=(7, 5))
270
+ ax = fig.add_subplot(111)
271
+ PlotUtils.plot_lift_ax(ax, plot_data, f'Lift Chart of {tgt_nme}')
272
+ plt.subplots_adjust(wspace=0.3)
273
+
274
+ save_path = os.path.join(
275
+ os.getcwd(), 'plot', f'05_{tgt_nme}_{fig_nme}.png')
276
+ IOUtils.ensure_parent_dir(save_path)
277
+ plt.savefig(save_path, dpi=300)
278
+ plt.close(fig)
279
+
280
+ @staticmethod
281
+ def plot_dlift_list(pred_model_1, pred_model_2,
282
+ model_nme_1, model_nme_2,
283
+ tgt_nme,
284
+ w_list, w_act_list, n_bins: int = 10,
285
+ fig_nme: str = 'Double Lift Chart'):
286
+ lift_data = pd.DataFrame()
287
+ lift_data.loc[:, 'pred1'] = pred_model_1
288
+ lift_data.loc[:, 'pred2'] = pred_model_2
289
+ lift_data.loc[:, 'diff_ly'] = lift_data['pred1'] / lift_data['pred2']
290
+ lift_data.loc[:, 'act'] = w_act_list
291
+ lift_data.loc[:, 'weight'] = w_list
292
+ lift_data.loc[:, 'w_pred1'] = lift_data['pred1'] * lift_data['weight']
293
+ lift_data.loc[:, 'w_pred2'] = lift_data['pred2'] * lift_data['weight']
294
+ plot_data = PlotUtils.split_data(
295
+ lift_data, 'diff_ly', 'weight', n_bins)
296
+ plot_data['exp_v1'] = plot_data['w_pred1'] / plot_data['act']
297
+ plot_data['exp_v2'] = plot_data['w_pred2'] / plot_data['act']
298
+ plot_data['act_v'] = plot_data['act']/plot_data['act']
299
+ plot_data.reset_index(inplace=True)
300
+
301
+ fig = plt.figure(figsize=(7, 5))
302
+ ax = fig.add_subplot(111)
303
+ PlotUtils.plot_dlift_ax(
304
+ ax, plot_data, f'Double Lift Chart of {tgt_nme}', model_nme_1, model_nme_2)
305
+ plt.subplots_adjust(bottom=0.25, top=0.95, right=0.8)
306
+
307
+ save_path = os.path.join(
308
+ os.getcwd(), 'plot', f'06_{tgt_nme}_{fig_nme}.png')
309
+ IOUtils.ensure_parent_dir(save_path)
310
+ plt.savefig(save_path, dpi=300)
311
+ plt.close(fig)
312
+
313
+
314
+ # 向后兼容的函数式封装
315
+ def csv_to_dict(file_path: str) -> List[Dict[str, Any]]:
316
+ return IOUtils.csv_to_dict(file_path)
317
+
318
+
319
+ def ensure_parent_dir(file_path: str) -> None:
320
+ IOUtils.ensure_parent_dir(file_path)
321
+
322
+
323
+ def compute_batch_size(data_size: int, learning_rate: float, batch_num: int, minimum: int) -> int:
324
+ return TrainingUtils.compute_batch_size(data_size, learning_rate, batch_num, minimum)
325
+
326
+
327
+ # 定义在 PyTorch 环境下的 Tweedie 偏差损失函数
328
+ # 参考文档:https://scikit-learn.org/stable/modules/model_evaluation.html#mean-poisson-gamma-and-tweedie-deviances
329
+ def tweedie_loss(pred, target, p=1.5, eps=1e-6, max_clip=1e6):
330
+ return TrainingUtils.tweedie_loss(pred, target, p=p, eps=eps, max_clip=max_clip)
331
+
332
+
333
+ # 定义释放CUDA内存函数
334
+ def free_cuda():
335
+ TrainingUtils.free_cuda()
336
+
337
+
338
+ class TorchTrainerMixin:
339
+ # 面向 Torch 表格训练器的共享工具方法。
340
+
341
+ def _device_type(self) -> str:
342
+ return getattr(self, "device", torch.device("cpu")).type
343
+
344
+ def _build_dataloader(self,
345
+ dataset,
346
+ N: int,
347
+ base_bs_gpu: tuple,
348
+ base_bs_cpu: tuple,
349
+ min_bs: int = 64,
350
+ target_effective_cuda: int = 1024,
351
+ target_effective_cpu: int = 512,
352
+ large_threshold: int = 200_000,
353
+ mid_threshold: int = 50_000):
354
+ batch_size = TrainingUtils.compute_batch_size(
355
+ data_size=len(dataset),
356
+ learning_rate=self.learning_rate,
357
+ batch_num=self.batch_num,
358
+ minimum=min_bs
359
+ )
360
+ gpu_large, gpu_mid, gpu_small = base_bs_gpu
361
+ cpu_mid, cpu_small = base_bs_cpu
362
+
363
+ if self._device_type() == 'cuda':
364
+ device_count = torch.cuda.device_count()
365
+ if getattr(self, "is_ddp_enabled", False):
366
+ device_count = 1
367
+ # 多卡环境下,适当增大最小批量,确保每张卡都能分到足够数据
368
+ if device_count > 1:
369
+ min_bs = min_bs * device_count
370
+ print(
371
+ f">>> Multi-GPU detected: {device_count} devices. Adjusted min_bs to {min_bs}.")
372
+
373
+ if N > large_threshold:
374
+ base_bs = gpu_large * device_count
375
+ elif N > mid_threshold:
376
+ base_bs = gpu_mid * device_count
377
+ else:
378
+ base_bs = gpu_small * device_count
379
+ else:
380
+ base_bs = cpu_mid if N > mid_threshold else cpu_small
381
+
382
+ # 重新计算 batch_size,确保不小于调整后的 min_bs
383
+ batch_size = TrainingUtils.compute_batch_size(
384
+ data_size=len(dataset),
385
+ learning_rate=self.learning_rate,
386
+ batch_num=self.batch_num,
387
+ minimum=min_bs
388
+ )
389
+ batch_size = min(batch_size, base_bs, N)
390
+
391
+ target_effective_bs = target_effective_cuda if self._device_type(
392
+ ) == 'cuda' else target_effective_cpu
393
+ if getattr(self, "is_ddp_enabled", False):
394
+ world_size = max(1, DistributedUtils.world_size())
395
+ target_effective_bs = max(1, target_effective_bs // world_size)
396
+
397
+ world_size = getattr(self, "world_size", 1) if getattr(
398
+ self, "is_ddp_enabled", False) else 1
399
+ samples_per_rank = math.ceil(
400
+ N / max(1, world_size)) if world_size > 1 else N
401
+ steps_per_epoch = max(
402
+ 1, math.ceil(samples_per_rank / max(1, batch_size)))
403
+ # Clamp accumulation so we do not divide the loss by a factor larger
404
+ # than the number of batches we actually run in an epoch.
405
+ desired_accum = max(1, target_effective_bs // max(1, batch_size))
406
+ accum_steps = max(1, min(desired_accum, steps_per_epoch))
407
+
408
+ # Linux (posix) 采用 fork 更高效;Windows (nt) 使用 spawn,开销更大。
409
+ if os.name == 'nt':
410
+ workers = 0
411
+ else:
412
+ workers = min(8, os.cpu_count() or 1)
413
+ if getattr(self, "is_ddp_enabled", False):
414
+ workers = 0 # DDP 下禁用多 worker,防止子进程冲突
415
+ print(
416
+ f">>> DataLoader config: Batch Size={batch_size}, Accum Steps={accum_steps}, Workers={workers}")
417
+ sampler = None
418
+ if dist.is_initialized():
419
+ sampler = DistributedSampler(dataset, shuffle=True)
420
+ shuffle = False # 交由 DistributedSampler 完成随机打乱
421
+ else:
422
+ shuffle = True
423
+
424
+ persistent = workers > 0
425
+ if getattr(self, "is_ddp_enabled", False):
426
+ persistent = False # 防止 DDP/pruning 提前退出导致 worker 状态异常
427
+ dataloader = DataLoader(
428
+ dataset,
429
+ batch_size=batch_size,
430
+ shuffle=shuffle,
431
+ sampler=sampler,
432
+ num_workers=workers,
433
+ pin_memory=(self._device_type() == 'cuda'),
434
+ persistent_workers=persistent,
435
+ )
436
+ return dataloader, accum_steps
437
+
438
+ def _compute_weighted_loss(self, y_pred, y_true, weights, apply_softplus: bool = False):
439
+ task = getattr(self, "task_type", "regression")
440
+ if task == 'classification':
441
+ loss_fn = nn.BCEWithLogitsLoss(reduction='none')
442
+ losses = loss_fn(y_pred, y_true).view(-1)
443
+ else:
444
+ if apply_softplus:
445
+ y_pred = F.softplus(y_pred)
446
+ y_pred = torch.clamp(y_pred, min=1e-6)
447
+ power = getattr(self, "tw_power", 1.5)
448
+ losses = tweedie_loss(y_pred, y_true, p=power).view(-1)
449
+ weighted_loss = (losses * weights.view(-1)).sum() / \
450
+ torch.clamp(weights.sum(), min=EPS)
451
+ return weighted_loss
452
+
453
+ def _early_stop_update(self, val_loss, best_loss, best_state, patience_counter, model,
454
+ ignore_keys: Optional[List[str]] = None):
455
+ if val_loss < best_loss:
456
+ ignore_keys = ignore_keys or []
457
+ state_dict = {
458
+ k: (v.clone() if isinstance(v, torch.Tensor) else copy.deepcopy(v))
459
+ for k, v in model.state_dict().items()
460
+ if not any(k.startswith(ignore_key) for ignore_key in ignore_keys)
461
+ }
462
+ return val_loss, state_dict, 0, False
463
+ patience_counter += 1
464
+ should_stop = best_state is not None and patience_counter >= getattr(
465
+ self, "patience", 0)
466
+ return best_loss, best_state, patience_counter, should_stop
467
+
468
+ def _train_model(self,
469
+ model,
470
+ dataloader,
471
+ accum_steps,
472
+ optimizer,
473
+ scaler,
474
+ forward_fn,
475
+ val_forward_fn=None,
476
+ apply_softplus: bool = False,
477
+ clip_fn=None,
478
+ trial: Optional[optuna.trial.Trial] = None,
479
+ loss_curve_path: Optional[str] = None):
480
+ device_type = self._device_type()
481
+ best_loss = float('inf')
482
+ best_state = None
483
+ patience_counter = 0
484
+ stop_training = False
485
+ train_history: List[float] = []
486
+ val_history: List[float] = []
487
+
488
+ is_ddp_model = isinstance(model, DDP)
489
+
490
+ for epoch in range(1, getattr(self, "epochs", 1) + 1):
491
+ if hasattr(self, 'dataloader_sampler') and self.dataloader_sampler is not None:
492
+ self.dataloader_sampler.set_epoch(epoch)
493
+
494
+ model.train()
495
+ optimizer.zero_grad()
496
+
497
+ epoch_loss_sum = 0.0
498
+ epoch_weight_sum = 0.0
499
+ for step, batch in enumerate(dataloader):
500
+ is_update_step = ((step + 1) % accum_steps == 0) or \
501
+ ((step + 1) == len(dataloader))
502
+ sync_cm = model.no_sync if (
503
+ is_ddp_model and not is_update_step) else nullcontext
504
+
505
+ with sync_cm():
506
+ with autocast(enabled=(device_type == 'cuda')):
507
+ y_pred, y_true, w = forward_fn(batch)
508
+ weighted_loss = self._compute_weighted_loss(
509
+ y_pred, y_true, w, apply_softplus=apply_softplus)
510
+ loss_for_backward = weighted_loss / accum_steps
511
+
512
+ batch_weight = torch.clamp(w.sum(), min=EPS).item()
513
+ epoch_loss_sum += float(weighted_loss.item()
514
+ * batch_weight)
515
+ epoch_weight_sum += float(batch_weight)
516
+ scaler.scale(loss_for_backward).backward()
517
+
518
+ if is_update_step:
519
+ if clip_fn is not None:
520
+ clip_fn()
521
+ scaler.step(optimizer)
522
+ scaler.update()
523
+ optimizer.zero_grad()
524
+
525
+ train_epoch_loss = epoch_loss_sum / max(epoch_weight_sum, EPS)
526
+ train_history.append(float(train_epoch_loss))
527
+
528
+ if val_forward_fn is not None:
529
+ should_compute_val = (not dist.is_initialized()
530
+ or DistributedUtils.is_main_process())
531
+ val_device = getattr(self, "device", torch.device("cpu"))
532
+ if not isinstance(val_device, torch.device):
533
+ val_device = torch.device(val_device)
534
+ loss_tensor_device = val_device if device_type == 'cuda' else torch.device(
535
+ "cpu")
536
+ val_loss_tensor = torch.zeros(1, device=loss_tensor_device)
537
+
538
+ if should_compute_val:
539
+ model.eval()
540
+ with torch.no_grad(), autocast(enabled=(device_type == 'cuda')):
541
+ val_result = val_forward_fn()
542
+ if isinstance(val_result, tuple) and len(val_result) == 3:
543
+ y_val_pred, y_val_true, w_val = val_result
544
+ val_weighted_loss = self._compute_weighted_loss(
545
+ y_val_pred, y_val_true, w_val, apply_softplus=apply_softplus)
546
+ else:
547
+ val_weighted_loss = val_result
548
+ val_loss_tensor[0] = float(val_weighted_loss)
549
+
550
+ if dist.is_initialized():
551
+ dist.broadcast(val_loss_tensor, src=0)
552
+ val_weighted_loss = float(val_loss_tensor.item())
553
+
554
+ val_history.append(val_weighted_loss)
555
+
556
+ best_loss, best_state, patience_counter, stop_training = self._early_stop_update(
557
+ val_weighted_loss, best_loss, best_state, patience_counter, model)
558
+
559
+ prune_flag = False
560
+ is_main_rank = DistributedUtils.is_main_process()
561
+ if trial is not None and (not dist.is_initialized() or is_main_rank):
562
+ trial.report(val_weighted_loss, epoch)
563
+ prune_flag = trial.should_prune()
564
+
565
+ if dist.is_initialized():
566
+ prune_device = getattr(self, "device", torch.device("cpu"))
567
+ if not isinstance(prune_device, torch.device):
568
+ prune_device = torch.device(prune_device)
569
+ prune_tensor = torch.zeros(1, device=prune_device)
570
+ if is_main_rank:
571
+ prune_tensor.fill_(1 if prune_flag else 0)
572
+ dist.broadcast(prune_tensor, src=0)
573
+ prune_flag = bool(prune_tensor.item())
574
+
575
+ if prune_flag:
576
+ raise optuna.TrialPruned()
577
+
578
+ if stop_training:
579
+ break
580
+
581
+ history = {"train": train_history, "val": val_history}
582
+ self._plot_loss_curve(history, loss_curve_path)
583
+ return best_state, history
584
+
585
+ def _plot_loss_curve(self, history: Dict[str, List[float]], save_path: Optional[str]) -> None:
586
+ if not save_path:
587
+ return
588
+ if dist.is_initialized() and not DistributedUtils.is_main_process():
589
+ return
590
+ train_hist = history.get("train", []) if history else []
591
+ val_hist = history.get("val", []) if history else []
592
+ if not train_hist and not val_hist:
593
+ return
594
+ ensure_parent_dir(save_path)
595
+ epochs = range(1, max(len(train_hist), len(val_hist)) + 1)
596
+ fig = plt.figure(figsize=(8, 4))
597
+ ax = fig.add_subplot(111)
598
+ if train_hist:
599
+ ax.plot(range(1, len(train_hist) + 1), train_hist,
600
+ label='Train Loss', color='tab:blue')
601
+ if val_hist:
602
+ ax.plot(range(1, len(val_hist) + 1), val_hist,
603
+ label='Validation Loss', color='tab:orange')
604
+ ax.set_xlabel('Epoch')
605
+ ax.set_ylabel('Weighted Loss')
606
+ ax.set_title('Loss vs. Epoch')
607
+ ax.grid(True, linestyle='--', alpha=0.3)
608
+ ax.legend()
609
+ plt.tight_layout()
610
+ plt.savefig(save_path, dpi=300)
611
+ plt.close(fig)
612
+ print(f"[Training] Loss curve saved to {save_path}")
613
+
614
+
615
+ # =============================================================================
616
+ # 绘图辅助模块
617
+ # =============================================================================
618
+
619
+ def split_data(data, col_nme, wgt_nme, n_bins=10):
620
+ return PlotUtils.split_data(data, col_nme, wgt_nme, n_bins)
621
+
622
+ # 定义提纯曲线(Lift)绘制函数
623
+
624
+
625
+ def plot_lift_list(pred_model, w_pred_list, w_act_list,
626
+ weight_list, tgt_nme, n_bins=10,
627
+ fig_nme='Lift Chart'):
628
+ return PlotUtils.plot_lift_list(pred_model, w_pred_list, w_act_list,
629
+ weight_list, tgt_nme, n_bins, fig_nme)
630
+
631
+ # 定义双提纯曲线绘制函数
632
+
633
+
634
+ def plot_dlift_list(pred_model_1, pred_model_2,
635
+ model_nme_1, model_nme_2,
636
+ tgt_nme,
637
+ w_list, w_act_list, n_bins=10,
638
+ fig_nme='Double Lift Chart'):
639
+ return PlotUtils.plot_dlift_list(pred_model_1, pred_model_2,
640
+ model_nme_1, model_nme_2,
641
+ tgt_nme, w_list, w_act_list,
642
+ n_bins, fig_nme)
643
+
644
+
645
+ # =============================================================================
646
+ # ResNet 模型与 sklearn 风格封装
647
+ # =============================================================================
648
+
649
+ # 开始定义ResNet模型结构
650
+ # 残差块:两层线性 + ReLU + 残差连接
651
+ # ResBlock 继承 nn.Module
652
+ class ResBlock(nn.Module):
653
+ def __init__(self, dim: int, dropout: float = 0.1,
654
+ use_layernorm: bool = False, residual_scale: float = 0.1
655
+ ):
656
+ super().__init__()
657
+ self.use_layernorm = use_layernorm
658
+
659
+ if use_layernorm:
660
+ Norm = nn.LayerNorm # 对最后一维做归一化
661
+ else:
662
+ def Norm(d): return nn.BatchNorm1d(d) # 保留一个开关,想试 BN 时也能用
663
+
664
+ self.norm1 = Norm(dim)
665
+ self.fc1 = nn.Linear(dim, dim, bias=True)
666
+ self.act = nn.ReLU(inplace=True)
667
+ self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()
668
+ # 如需在第二层后追加归一化,可启用下行:self.norm2 = Norm(dim)
669
+ self.fc2 = nn.Linear(dim, dim, bias=True)
670
+
671
+ # 残差缩放,防止一开始就把主干搞炸
672
+ self.res_scale = nn.Parameter(
673
+ torch.tensor(residual_scale, dtype=torch.float32)
674
+ )
675
+
676
+ def forward(self, x):
677
+ # 前置激活结构
678
+ out = self.norm1(x)
679
+ out = self.fc1(out)
680
+ out = self.act(out)
681
+ out = self.dropout(out)
682
+ # 若启用了第二次归一化,在此处调用:out = self.norm2(out)
683
+ out = self.fc2(out)
684
+ # 残差缩放再相加
685
+ return x + self.res_scale * out
686
+
687
+ # ResNetSequential 继承 nn.Module,定义整个网络结构
688
+
689
+
690
+ class ResNetSequential(nn.Module):
691
+ # 输入张量形状:(batch, input_dim)
692
+ # 网络结构:全连接 + 归一化 + ReLU,再堆叠若干残差块,最后输出 Softplus
693
+
694
+ def __init__(self, input_dim: int, hidden_dim: int = 64, block_num: int = 2,
695
+ use_layernorm: bool = True, dropout: float = 0.1,
696
+ residual_scale: float = 0.1, task_type: str = 'regression'):
697
+ super(ResNetSequential, self).__init__()
698
+
699
+ self.net = nn.Sequential()
700
+ self.net.add_module('fc1', nn.Linear(input_dim, hidden_dim))
701
+
702
+ # 如需在首层后增加显式归一化,可按需启用下方示例代码:
703
+ # 如果使用 LayerNorm:
704
+ # self.net.add_module('norm1', nn.LayerNorm(hidden_dim))
705
+ # 否则可改用 BatchNorm:
706
+ # self.net.add_module('norm1', nn.BatchNorm1d(hidden_dim))
707
+
708
+ # 如果希望在残差块前加入 ReLU,可启用:self.net.add_module('relu1', nn.ReLU(inplace=True))
709
+
710
+ # 多个残差块
711
+ for i in range(block_num):
712
+ self.net.add_module(
713
+ f'ResBlk_{i+1}',
714
+ ResBlock(
715
+ hidden_dim,
716
+ dropout=dropout,
717
+ use_layernorm=use_layernorm,
718
+ residual_scale=residual_scale)
719
+ )
720
+
721
+ self.net.add_module('fc_out', nn.Linear(hidden_dim, 1))
722
+
723
+ if task_type == 'classification':
724
+ self.net.add_module('softplus', nn.Identity())
725
+ else:
726
+ self.net.add_module('softplus', nn.Softplus())
727
+
728
+ def forward(self, x):
729
+ if self.training and not hasattr(self, '_printed_device'):
730
+ print(f">>> ResNetSequential executing on device: {x.device}")
731
+ self._printed_device = True
732
+ return self.net(x)
733
+
734
+ # 定义ResNet模型的Scikit-Learn接口类
735
+
736
+
737
+ class ResNetSklearn(TorchTrainerMixin, nn.Module):
738
+ def __init__(self, model_nme: str, input_dim: int, hidden_dim: int = 64,
739
+ block_num: int = 2, batch_num: int = 100, epochs: int = 100,
740
+ task_type: str = 'regression',
741
+ tweedie_power: float = 1.5, learning_rate: float = 0.01, patience: int = 10,
742
+ use_layernorm: bool = True, dropout: float = 0.1,
743
+ residual_scale: float = 0.1,
744
+ use_data_parallel: bool = True,
745
+ use_ddp: bool = False):
746
+ super(ResNetSklearn, self).__init__()
747
+
748
+ self.use_ddp = use_ddp
749
+ self.is_ddp_enabled, self.local_rank, self.rank, self.world_size = (
750
+ False, 0, 0, 1)
751
+
752
+ if self.use_ddp:
753
+ self.is_ddp_enabled, self.local_rank, self.rank, self.world_size = DistributedUtils.setup_ddp()
754
+
755
+ self.input_dim = input_dim
756
+ self.hidden_dim = hidden_dim
757
+ self.block_num = block_num
758
+ self.batch_num = batch_num
759
+ self.epochs = epochs
760
+ self.task_type = task_type
761
+ self.model_nme = model_nme
762
+ self.learning_rate = learning_rate
763
+ self.patience = patience
764
+ self.use_layernorm = use_layernorm
765
+ self.dropout = dropout
766
+ self.residual_scale = residual_scale
767
+ self.loss_curve_path: Optional[str] = None
768
+ self.training_history: Dict[str, List[float]] = {
769
+ "train": [], "val": []}
770
+
771
+ # 设备选择:cuda > mps > cpu
772
+ if self.is_ddp_enabled:
773
+ self.device = torch.device(f'cuda:{self.local_rank}')
774
+ elif torch.cuda.is_available():
775
+ self.device = torch.device('cuda')
776
+ elif torch.backends.mps.is_available():
777
+ self.device = torch.device('mps')
778
+ else:
779
+ self.device = torch.device('cpu')
780
+
781
+ # Tweedie 幂指数设定(分类时不使用)
782
+ if self.task_type == 'classification':
783
+ self.tw_power = None
784
+ elif 'f' in self.model_nme:
785
+ self.tw_power = 1
786
+ elif 's' in self.model_nme:
787
+ self.tw_power = 2
788
+ else:
789
+ self.tw_power = tweedie_power
790
+
791
+ # 搭建网络(先在 CPU 上建好)
792
+ core = ResNetSequential(
793
+ self.input_dim,
794
+ self.hidden_dim,
795
+ self.block_num,
796
+ use_layernorm=self.use_layernorm,
797
+ dropout=self.dropout,
798
+ residual_scale=self.residual_scale,
799
+ task_type=self.task_type
800
+ )
801
+
802
+ # ===== 多卡支持:DataParallel vs DistributedDataParallel =====
803
+ if self.is_ddp_enabled:
804
+ core = core.to(self.device)
805
+ core = DDP(core, device_ids=[
806
+ self.local_rank], output_device=self.local_rank)
807
+ elif use_data_parallel and (self.device.type == 'cuda') and (torch.cuda.device_count() > 1):
808
+ core = nn.DataParallel(core, device_ids=list(
809
+ range(torch.cuda.device_count())))
810
+ # DataParallel 会把输入 scatter 到多卡上,但“主设备”仍然是 cuda:0
811
+ self.device = torch.device('cuda')
812
+
813
+ self.resnet = core.to(self.device)
814
+
815
+ # ================ 内部工具 ================
816
+ def _build_train_val_tensors(self, X_train, y_train, w_train, X_val, y_val, w_val):
817
+ X_tensor = torch.tensor(X_train.values, dtype=torch.float32)
818
+ y_tensor = torch.tensor(
819
+ y_train.values, dtype=torch.float32).view(-1, 1)
820
+ w_tensor = torch.tensor(w_train.values, dtype=torch.float32).view(
821
+ -1, 1) if w_train is not None else torch.ones_like(y_tensor)
822
+
823
+ has_val = X_val is not None and y_val is not None
824
+ if has_val:
825
+ X_val_tensor = torch.tensor(X_val.values, dtype=torch.float32)
826
+ y_val_tensor = torch.tensor(
827
+ y_val.values, dtype=torch.float32).view(-1, 1)
828
+ w_val_tensor = torch.tensor(w_val.values, dtype=torch.float32).view(
829
+ -1, 1) if w_val is not None else torch.ones_like(y_val_tensor)
830
+ else:
831
+ X_val_tensor = y_val_tensor = w_val_tensor = None
832
+ return X_tensor, y_tensor, w_tensor, X_val_tensor, y_val_tensor, w_val_tensor, has_val
833
+
834
+ def forward(self, x):
835
+ # 处理 SHAP 的 NumPy 输入
836
+ if isinstance(x, np.ndarray):
837
+ x_tensor = torch.tensor(x, dtype=torch.float32)
838
+ else:
839
+ x_tensor = x
840
+
841
+ x_tensor = x_tensor.to(self.device)
842
+ y_pred = self.resnet(x_tensor)
843
+ return y_pred
844
+
845
+ # ---------------- 训练 ----------------
846
+
847
+ def fit(self, X_train, y_train, w_train=None,
848
+ X_val=None, y_val=None, w_val=None, trial=None):
849
+
850
+ X_tensor, y_tensor, w_tensor, X_val_tensor, y_val_tensor, w_val_tensor, has_val = \
851
+ self._build_train_val_tensors(
852
+ X_train, y_train, w_train, X_val, y_val, w_val)
853
+
854
+ dataset = TensorDataset(X_tensor, y_tensor, w_tensor)
855
+ dataloader, accum_steps = self._build_dataloader(
856
+ dataset,
857
+ N=X_tensor.shape[0],
858
+ base_bs_gpu=(2048, 1024, 512),
859
+ base_bs_cpu=(256, 128),
860
+ min_bs=64,
861
+ target_effective_cuda=2048,
862
+ target_effective_cpu=1024
863
+ )
864
+
865
+ # 在每个 epoch 开始前设置 sampler 的 epoch,以保证 shuffle 的随机性
866
+ if self.is_ddp_enabled and hasattr(dataloader.sampler, 'set_epoch'):
867
+ self.dataloader_sampler = dataloader.sampler
868
+ else:
869
+ self.dataloader_sampler = None
870
+
871
+ # === 4. 优化器与 AMP ===
872
+ self.optimizer = torch.optim.Adam(
873
+ self.resnet.parameters(), lr=self.learning_rate)
874
+ self.scaler = GradScaler(enabled=(self.device.type == 'cuda'))
875
+
876
+ X_val_dev = y_val_dev = w_val_dev = None
877
+ val_dataloader = None
878
+ if has_val:
879
+ # 构建验证集 DataLoader
880
+ val_dataset = TensorDataset(
881
+ X_val_tensor, y_val_tensor, w_val_tensor)
882
+ # 验证阶段无需反向传播,可适当放大批量以提高吞吐
883
+ val_bs = accum_steps * dataloader.batch_size
884
+
885
+ # 验证集的 worker 数沿用相同的分配逻辑
886
+ if os.name == 'nt':
887
+ val_workers = 0
888
+ else:
889
+ val_workers = min(4, os.cpu_count() or 1)
890
+ if getattr(self, "is_ddp_enabled", False):
891
+ val_workers = 0 # DDP 下禁用多 worker,防止子进程冲突
892
+
893
+ val_dataloader = DataLoader(
894
+ val_dataset,
895
+ batch_size=val_bs,
896
+ shuffle=False,
897
+ num_workers=val_workers,
898
+ pin_memory=(self.device.type == 'cuda'),
899
+ persistent_workers=val_workers > 0,
900
+ )
901
+ # 验证集通常不需要 DDP Sampler,因为我们只在主进程验证或汇总验证结果
902
+ # 但为了简单起见,这里保持单卡验证或主进程验证
903
+
904
+ is_data_parallel = isinstance(self.resnet, nn.DataParallel)
905
+
906
+ def forward_fn(batch):
907
+ X_batch, y_batch, w_batch = batch
908
+
909
+ if not is_data_parallel:
910
+ X_batch = X_batch.to(self.device, non_blocking=True)
911
+ # 目标值与权重始终与主设备保持一致,便于后续损失计算
912
+ y_batch = y_batch.to(self.device, non_blocking=True)
913
+ w_batch = w_batch.to(self.device, non_blocking=True)
914
+
915
+ y_pred = self.resnet(X_batch)
916
+ return y_pred, y_batch, w_batch
917
+
918
+ def val_forward_fn():
919
+ total_loss = 0.0
920
+ total_weight = 0.0
921
+ for batch in val_dataloader:
922
+ X_b, y_b, w_b = batch
923
+ if not is_data_parallel:
924
+ X_b = X_b.to(self.device, non_blocking=True)
925
+ y_b = y_b.to(self.device, non_blocking=True)
926
+ w_b = w_b.to(self.device, non_blocking=True)
927
+
928
+ y_pred = self.resnet(X_b)
929
+
930
+ # 手动计算当前批次的加权损失,以便后续精确加总
931
+ task = getattr(self, "task_type", "regression")
932
+ if task == 'classification':
933
+ loss_fn = nn.BCEWithLogitsLoss(reduction='none')
934
+ losses = loss_fn(y_pred, y_b).view(-1)
935
+ else:
936
+ # 此处无需再做 softplus:训练时 apply_softplus=False,模型前向结果本身已为正
937
+ y_pred_clamped = torch.clamp(y_pred, min=1e-6)
938
+ power = getattr(self, "tw_power", 1.5)
939
+ losses = tweedie_loss(
940
+ y_pred_clamped, y_b, p=power).view(-1)
941
+
942
+ batch_weight_sum = torch.clamp(w_b.sum(), min=EPS)
943
+ batch_weighted_loss_sum = (losses * w_b.view(-1)).sum()
944
+
945
+ total_loss += batch_weighted_loss_sum.item()
946
+ total_weight += batch_weight_sum.item()
947
+
948
+ return total_loss / max(total_weight, EPS)
949
+
950
+ clip_fn = None
951
+ if self.device.type == 'cuda':
952
+ def clip_fn(): return (self.scaler.unscale_(self.optimizer),
953
+ clip_grad_norm_(self.resnet.parameters(), max_norm=1.0))
954
+
955
+ # DDP 模式下,只在主进程打印日志和保存模型
956
+ if self.is_ddp_enabled and not DistributedUtils.is_main_process():
957
+ # 非主进程不进行验证回调中的打印操作(需在 _train_model 内部控制,这里暂略)
958
+ pass
959
+
960
+ best_state, history = self._train_model(
961
+ self.resnet,
962
+ dataloader,
963
+ accum_steps,
964
+ self.optimizer,
965
+ self.scaler,
966
+ forward_fn,
967
+ val_forward_fn if has_val else None,
968
+ apply_softplus=False,
969
+ clip_fn=clip_fn,
970
+ trial=trial,
971
+ loss_curve_path=getattr(self, "loss_curve_path", None)
972
+ )
973
+
974
+ if has_val and best_state is not None:
975
+ self.resnet.load_state_dict(best_state)
976
+ self.training_history = history
977
+
978
+ # ---------------- 预测 ----------------
979
+
980
+ def predict(self, X_test):
981
+ self.resnet.eval()
982
+ if isinstance(X_test, pd.DataFrame):
983
+ X_np = X_test.values.astype(np.float32)
984
+ else:
985
+ X_np = X_test
986
+
987
+ with torch.no_grad():
988
+ y_pred = self(X_np).cpu().numpy()
989
+
990
+ if self.task_type == 'classification':
991
+ y_pred = 1 / (1 + np.exp(-y_pred)) # Sigmoid 函数将 logit 转换为概率
992
+ else:
993
+ y_pred = np.clip(y_pred, 1e-6, None)
994
+ return y_pred.flatten()
995
+
996
+ # ---------------- 设置参数 ----------------
997
+
998
+ def set_params(self, params):
999
+ for key, value in params.items():
1000
+ if hasattr(self, key):
1001
+ setattr(self, key, value)
1002
+ else:
1003
+ raise ValueError(f"Parameter {key} not found in model.")
1004
+ return self
1005
+
1006
+
1007
+ # =============================================================================
1008
+ # FT-Transformer 模型与 sklearn 风格封装
1009
+ # =============================================================================
1010
+ # 开始定义FT Transformer模型结构
1011
+
1012
+
1013
+ class FeatureTokenizer(nn.Module):
1014
+ """将数值/类别/地理 token 统一映射为 Transformer 输入 tokens。"""
1015
+
1016
+ def __init__(self, num_numeric: int, cat_cardinalities, d_model: int, num_geo: int = 0):
1017
+ super().__init__()
1018
+
1019
+ self.num_numeric = num_numeric
1020
+ self.has_numeric = num_numeric > 0
1021
+ self.num_geo = num_geo
1022
+ self.has_geo = num_geo > 0
1023
+
1024
+ if self.has_numeric:
1025
+ self.num_linear = nn.Linear(num_numeric, d_model)
1026
+
1027
+ self.embeddings = nn.ModuleList([
1028
+ nn.Embedding(card, d_model) for card in cat_cardinalities
1029
+ ])
1030
+
1031
+ if self.has_geo:
1032
+ # 地理 token 直接线性映射,避免对原始字符串做 one-hot;上游已编码/标准化
1033
+ self.geo_linear = nn.Linear(num_geo, d_model)
1034
+
1035
+ def forward(self, X_num, X_cat, X_geo=None):
1036
+ tokens = []
1037
+
1038
+ if self.has_numeric:
1039
+ num_token = self.num_linear(X_num)
1040
+ tokens.append(num_token)
1041
+
1042
+ for i, emb in enumerate(self.embeddings):
1043
+ tok = emb(X_cat[:, i])
1044
+ tokens.append(tok)
1045
+
1046
+ if self.has_geo:
1047
+ if X_geo is None:
1048
+ raise RuntimeError("启用了地理 token,但未提供 X_geo。")
1049
+ geo_token = self.geo_linear(X_geo)
1050
+ tokens.append(geo_token)
1051
+
1052
+ x = torch.stack(tokens, dim=1)
1053
+ return x
1054
+
1055
+ # 定义具有残差缩放的Encoder层
1056
+
1057
+
1058
+ class ScaledTransformerEncoderLayer(nn.Module):
1059
+ def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048,
1060
+ dropout: float = 0.1, residual_scale_attn: float = 1.0,
1061
+ residual_scale_ffn: float = 1.0, norm_first: bool = True,
1062
+ ):
1063
+ super().__init__()
1064
+ self.self_attn = nn.MultiheadAttention(
1065
+ embed_dim=d_model,
1066
+ num_heads=nhead,
1067
+ dropout=dropout,
1068
+ batch_first=True
1069
+ )
1070
+
1071
+ # 前馈网络部分
1072
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
1073
+ self.dropout = nn.Dropout(dropout)
1074
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
1075
+
1076
+ # 归一化与 Dropout
1077
+ self.norm1 = nn.LayerNorm(d_model)
1078
+ self.norm2 = nn.LayerNorm(d_model)
1079
+ self.dropout1 = nn.Dropout(dropout)
1080
+ self.dropout2 = nn.Dropout(dropout)
1081
+
1082
+ self.activation = nn.GELU()
1083
+ # 如果偏好 ReLU,可将激活函数改为:self.activation = nn.ReLU()
1084
+ self.norm_first = norm_first
1085
+
1086
+ # 残差缩放系数
1087
+ self.res_scale_attn = residual_scale_attn
1088
+ self.res_scale_ffn = residual_scale_ffn
1089
+
1090
+ def forward(self, src, src_mask=None, src_key_padding_mask=None):
1091
+ # 输入张量形状:(batch, 序列长度, d_model)
1092
+ x = src
1093
+
1094
+ if self.norm_first:
1095
+ # 先归一化再做注意力
1096
+ x = x + self._sa_block(self.norm1(x), src_mask,
1097
+ src_key_padding_mask)
1098
+ x = x + self._ff_block(self.norm2(x))
1099
+ else:
1100
+ # 后归一化(一般不启用)
1101
+ x = self.norm1(
1102
+ x + self._sa_block(x, src_mask, src_key_padding_mask))
1103
+ x = self.norm2(x + self._ff_block(x))
1104
+
1105
+ return x
1106
+
1107
+ def _sa_block(self, x, attn_mask, key_padding_mask):
1108
+ # 自注意力并附带残差缩放
1109
+ attn_out, _ = self.self_attn(
1110
+ x, x, x,
1111
+ attn_mask=attn_mask,
1112
+ key_padding_mask=key_padding_mask,
1113
+ need_weights=False
1114
+ )
1115
+ return self.res_scale_attn * self.dropout1(attn_out)
1116
+
1117
+ def _ff_block(self, x):
1118
+ # 前馈网络并附带残差缩放
1119
+ x2 = self.linear2(self.dropout(self.activation(self.linear1(x))))
1120
+ return self.res_scale_ffn * self.dropout2(x2)
1121
+
1122
+ # 定义FT-Transformer核心模型
1123
+
1124
+
1125
+ class FTTransformerCore(nn.Module):
1126
+ # 最小可用版本的 FT-Transformer,由三部分组成:
1127
+ # 1) FeatureTokenizer:将数值/类别特征转换成 token;
1128
+ # 2) TransformerEncoder:建模特征之间的交互;
1129
+ # 3) 池化 + MLP + Softplus:输出正值,方便 Tweedie/Gamma 等任务。
1130
+
1131
+ def __init__(self, num_numeric: int, cat_cardinalities, d_model: int = 64,
1132
+ n_heads: int = 8, n_layers: int = 4, dropout: float = 0.1,
1133
+ task_type: str = 'regression', num_geo: int = 0
1134
+ ):
1135
+ super().__init__()
1136
+
1137
+ self.tokenizer = FeatureTokenizer(
1138
+ num_numeric=num_numeric,
1139
+ cat_cardinalities=cat_cardinalities,
1140
+ d_model=d_model,
1141
+ num_geo=num_geo
1142
+ )
1143
+ scale = 1.0 / math.sqrt(n_layers) # 推荐一个默认值
1144
+ encoder_layer = ScaledTransformerEncoderLayer(
1145
+ d_model=d_model,
1146
+ nhead=n_heads,
1147
+ dim_feedforward=d_model * 4,
1148
+ dropout=dropout,
1149
+ residual_scale_attn=scale,
1150
+ residual_scale_ffn=scale,
1151
+ norm_first=True,
1152
+ )
1153
+ self.encoder = nn.TransformerEncoder(
1154
+ encoder_layer,
1155
+ num_layers=n_layers
1156
+ )
1157
+ self.n_layers = n_layers
1158
+
1159
+ layers = [
1160
+ # 如需更深的输出头,可按需启用下方示例层:
1161
+ # nn.LayerNorm(d_model), # 额外的归一化层
1162
+ # nn.Linear(d_model, d_model), # 额外的全连接层
1163
+ # nn.GELU(), # 对应的激活层
1164
+ nn.Linear(d_model, 1),
1165
+ ]
1166
+
1167
+ if task_type == 'classification':
1168
+ # 分类任务输出 logits,与 BCEWithLogitsLoss 更匹配
1169
+ layers.append(nn.Identity())
1170
+ else:
1171
+ # 回归任务需保持正值,适配 Tweedie/Gamma
1172
+ layers.append(nn.Softplus())
1173
+
1174
+ self.head = nn.Sequential(*layers)
1175
+
1176
+ def forward(self, X_num, X_cat, X_geo=None):
1177
+
1178
+ # 输入:
1179
+ # X_num -> (batch, 数值特征数) 的 float32 张量
1180
+ # X_cat -> (batch, 类别特征数) 的 long 张量
1181
+ # X_geo -> (batch, 地理 token 维度) 的 float32 张量
1182
+
1183
+ if self.training and not hasattr(self, '_printed_device'):
1184
+ print(f">>> FTTransformerCore executing on device: {X_num.device}")
1185
+ self._printed_device = True
1186
+
1187
+ # => 张量形状 (batch, token_num, d_model)
1188
+ tokens = self.tokenizer(X_num, X_cat, X_geo)
1189
+ # => 张量形状 (batch, token_num, d_model)
1190
+ x = self.encoder(tokens)
1191
+
1192
+ # 对 token 做平均池化,再送入回归头
1193
+ x = x.mean(dim=1) # => 张量形状 (batch, d_model)
1194
+
1195
+ # => 张量形状 (batch, 1),Softplus 约束为正
1196
+ out = self.head(x)
1197
+ return out
1198
+
1199
+ # 定义TabularDataset类
1200
+
1201
+
1202
+ class TabularDataset(Dataset):
1203
+ def __init__(self, X_num, X_cat, X_geo, y, w):
1204
+
1205
+ # 输入张量说明:
1206
+ # X_num: torch.float32,shape=(N, 数值特征数)
1207
+ # X_cat: torch.long, shape=(N, 类别特征数)
1208
+ # X_geo: torch.float32,shape=(N, 地理 token 维度),可为空张量
1209
+ # y: torch.float32,shape=(N, 1)
1210
+ # w: torch.float32,shape=(N, 1)
1211
+
1212
+ self.X_num = X_num
1213
+ self.X_cat = X_cat
1214
+ self.X_geo = X_geo
1215
+ self.y = y
1216
+ self.w = w
1217
+
1218
+ def __len__(self):
1219
+ return self.y.shape[0]
1220
+
1221
+ def __getitem__(self, idx):
1222
+ return (
1223
+ self.X_num[idx],
1224
+ self.X_cat[idx],
1225
+ self.X_geo[idx],
1226
+ self.y[idx],
1227
+ self.w[idx],
1228
+ )
1229
+
1230
+ # 定义FTTransformer的Scikit-Learn接口类
1231
+
1232
+
1233
+ class FTTransformerSklearn(TorchTrainerMixin, nn.Module):
1234
+
1235
+ # sklearn 风格包装:
1236
+ # - num_cols:数值特征列名列表
1237
+ # - cat_cols:类别特征列名列表(需事先做标签编码,取值 ∈ [0, n_classes-1])
1238
+
1239
+ def __init__(self, model_nme: str, num_cols, cat_cols, d_model: int = 64, n_heads: int = 8,
1240
+ n_layers: int = 4, dropout: float = 0.1, batch_num: int = 100, epochs: int = 100,
1241
+ task_type: str = 'regression',
1242
+ tweedie_power: float = 1.5, learning_rate: float = 1e-3, patience: int = 10,
1243
+ use_data_parallel: bool = True,
1244
+ use_ddp: bool = False
1245
+ ):
1246
+ super().__init__()
1247
+
1248
+ self.use_ddp = use_ddp
1249
+ self.is_ddp_enabled, self.local_rank, self.rank, self.world_size = (
1250
+ False, 0, 0, 1)
1251
+ if self.use_ddp:
1252
+ self.is_ddp_enabled, self.local_rank, self.rank, self.world_size = DistributedUtils.setup_ddp()
1253
+
1254
+ self.model_nme = model_nme
1255
+ self.num_cols = list(num_cols)
1256
+ self.cat_cols = list(cat_cols)
1257
+ self.d_model = d_model
1258
+ self.n_heads = n_heads
1259
+ self.n_layers = n_layers
1260
+ self.dropout = dropout
1261
+ self.batch_num = batch_num
1262
+ self.epochs = epochs
1263
+ self.learning_rate = learning_rate
1264
+ self.task_type = task_type
1265
+ self.patience = patience
1266
+ if self.task_type == 'classification':
1267
+ self.tw_power = None # 分类时不使用 Tweedie 幂
1268
+ elif 'f' in self.model_nme:
1269
+ self.tw_power = 1.0
1270
+ elif 's' in self.model_nme:
1271
+ self.tw_power = 2.0
1272
+ else:
1273
+ self.tw_power = tweedie_power
1274
+
1275
+ if self.is_ddp_enabled:
1276
+ self.device = torch.device(f"cuda:{self.local_rank}")
1277
+ elif torch.cuda.is_available():
1278
+ self.device = torch.device("cuda")
1279
+ elif torch.backends.mps.is_available():
1280
+ self.device = torch.device("mps")
1281
+ else:
1282
+ self.device = torch.device("cpu")
1283
+ self.cat_cardinalities = None
1284
+ self.cat_categories = {}
1285
+ self.ft = None
1286
+ self.use_data_parallel = torch.cuda.device_count() > 1 and use_data_parallel
1287
+ self.num_geo = 0
1288
+ self._geo_params: Dict[str, Any] = {}
1289
+ self.loss_curve_path: Optional[str] = None
1290
+ self.training_history: Dict[str, List[float]] = {
1291
+ "train": [], "val": []}
1292
+
1293
+ def _build_model(self, X_train):
1294
+ num_numeric = len(self.num_cols)
1295
+ cat_cardinalities = []
1296
+
1297
+ for col in self.cat_cols:
1298
+ cats = X_train[col].astype('category')
1299
+ categories = cats.cat.categories
1300
+ self.cat_categories[col] = categories # 保存训练集类别全集
1301
+
1302
+ card = len(categories) + 1 # 多预留 1 类给“未知/缺失”
1303
+ cat_cardinalities.append(card)
1304
+
1305
+ self.cat_cardinalities = cat_cardinalities
1306
+
1307
+ core = FTTransformerCore(
1308
+ num_numeric=num_numeric,
1309
+ cat_cardinalities=cat_cardinalities,
1310
+ d_model=self.d_model,
1311
+ n_heads=self.n_heads,
1312
+ n_layers=self.n_layers,
1313
+ dropout=self.dropout,
1314
+ task_type=self.task_type,
1315
+ num_geo=self.num_geo
1316
+ )
1317
+ if self.is_ddp_enabled:
1318
+ core = core.to(self.device)
1319
+ core = DDP(core, device_ids=[
1320
+ self.local_rank], output_device=self.local_rank)
1321
+ elif self.use_data_parallel:
1322
+ core = nn.DataParallel(core, device_ids=list(
1323
+ range(torch.cuda.device_count())))
1324
+ self.device = torch.device("cuda")
1325
+ self.ft = core.to(self.device)
1326
+
1327
+ def _encode_cats(self, X):
1328
+ # 输入 DataFrame 至少需要包含所有类别特征列
1329
+ # 返回形状 (N, 类别特征数) 的 int64 数组
1330
+
1331
+ if not self.cat_cols:
1332
+ return np.zeros((len(X), 0), dtype='int64')
1333
+
1334
+ X_cat_list = []
1335
+ for col in self.cat_cols:
1336
+ # 使用训练阶段记录的类别全集
1337
+ categories = self.cat_categories[col]
1338
+ # 按固定类别构造 Categorical
1339
+ cats = pd.Categorical(X[col], categories=categories)
1340
+ codes = cats.codes.astype('int64', copy=True) # -1 表示未知或缺失
1341
+ # 未知或缺失映射到额外的“未知”索引 len(categories)
1342
+ codes[codes < 0] = len(categories)
1343
+ X_cat_list.append(codes)
1344
+
1345
+ X_cat_np = np.stack(X_cat_list, axis=1) # 形状 (N, 类别特征数)
1346
+ return X_cat_np
1347
+
1348
+ def _build_train_tensors(self, X_train, y_train, w_train, geo_train=None):
1349
+ return self._tensorize_split(X_train, y_train, w_train, geo_tokens=geo_train)
1350
+
1351
+ def _build_val_tensors(self, X_val, y_val, w_val, geo_val=None):
1352
+ return self._tensorize_split(X_val, y_val, w_val, geo_tokens=geo_val, allow_none=True)
1353
+
1354
+ def _tensorize_split(self, X, y, w, geo_tokens=None, allow_none: bool = False):
1355
+ if X is None:
1356
+ if allow_none:
1357
+ return None, None, None, None, None, False
1358
+ raise ValueError("输入特征 X 不能为空。")
1359
+
1360
+ X_num = torch.tensor(
1361
+ X[self.num_cols].to_numpy(dtype=np.float32, copy=True),
1362
+ dtype=torch.float32
1363
+ )
1364
+ if self.cat_cols:
1365
+ X_cat = torch.tensor(self._encode_cats(X), dtype=torch.long)
1366
+ else:
1367
+ X_cat = torch.zeros((X_num.shape[0], 0), dtype=torch.long)
1368
+
1369
+ if geo_tokens is not None:
1370
+ geo_np = np.asarray(geo_tokens, dtype=np.float32)
1371
+ if geo_np.ndim == 1:
1372
+ geo_np = geo_np.reshape(-1, 1)
1373
+ elif self.num_geo > 0:
1374
+ raise RuntimeError("geo_tokens 不能为空,请先准备好地理 token。")
1375
+ else:
1376
+ geo_np = np.zeros((X_num.shape[0], 0), dtype=np.float32)
1377
+ X_geo = torch.tensor(geo_np, dtype=torch.float32)
1378
+
1379
+ y_tensor = torch.tensor(
1380
+ y.values, dtype=torch.float32).view(-1, 1) if y is not None else None
1381
+ if y_tensor is None:
1382
+ w_tensor = None
1383
+ elif w is not None:
1384
+ w_tensor = torch.tensor(
1385
+ w.values, dtype=torch.float32).view(-1, 1)
1386
+ else:
1387
+ w_tensor = torch.ones_like(y_tensor)
1388
+ return X_num, X_cat, X_geo, y_tensor, w_tensor, y is not None
1389
+
1390
+ def fit(self, X_train, y_train, w_train=None,
1391
+ X_val=None, y_val=None, w_val=None, trial=None,
1392
+ geo_train=None, geo_val=None):
1393
+
1394
+ # 首次拟合时需要构建底层模型结构
1395
+ self.num_geo = geo_train.shape[1] if geo_train is not None else 0
1396
+ if self.ft is None:
1397
+ self._build_model(X_train)
1398
+
1399
+ X_num_train, X_cat_train, X_geo_train, y_tensor, w_tensor, _ = self._build_train_tensors(
1400
+ X_train, y_train, w_train, geo_train=geo_train)
1401
+ X_num_val, X_cat_val, X_geo_val, y_val_tensor, w_val_tensor, has_val = self._build_val_tensors(
1402
+ X_val, y_val, w_val, geo_val=geo_val)
1403
+
1404
+ # --- 构建 DataLoader ---
1405
+ dataset = TabularDataset(
1406
+ X_num_train, X_cat_train, X_geo_train, y_tensor, w_tensor
1407
+ )
1408
+
1409
+ dataloader, accum_steps = self._build_dataloader(
1410
+ dataset,
1411
+ N=X_num_train.shape[0],
1412
+ base_bs_gpu=(2048, 1024, 512),
1413
+ base_bs_cpu=(256, 128),
1414
+ min_bs=64,
1415
+ target_effective_cuda=2048,
1416
+ target_effective_cpu=1024
1417
+ )
1418
+
1419
+ if self.is_ddp_enabled and hasattr(dataloader.sampler, 'set_epoch'):
1420
+ self.dataloader_sampler = dataloader.sampler
1421
+ else:
1422
+ self.dataloader_sampler = None
1423
+
1424
+ optimizer = torch.optim.Adam(
1425
+ self.ft.parameters(), lr=self.learning_rate)
1426
+ scaler = GradScaler(enabled=(self.device.type == 'cuda'))
1427
+
1428
+ X_num_val_dev = X_cat_val_dev = y_val_dev = w_val_dev = None
1429
+ val_dataloader = None
1430
+ if has_val:
1431
+ val_dataset = TabularDataset(
1432
+ X_num_val, X_cat_val, X_geo_val, y_val_tensor, w_val_tensor
1433
+ )
1434
+ val_bs = accum_steps * dataloader.batch_size
1435
+
1436
+ if os.name == 'nt':
1437
+ val_workers = 0
1438
+ else:
1439
+ val_workers = min(4, os.cpu_count() or 1)
1440
+ if getattr(self, "is_ddp_enabled", False):
1441
+ val_workers = 0 # DDP 下禁用多 worker,防止子进程冲突
1442
+
1443
+ val_dataloader = DataLoader(
1444
+ val_dataset,
1445
+ batch_size=val_bs,
1446
+ shuffle=False,
1447
+ num_workers=val_workers,
1448
+ pin_memory=(self.device.type == 'cuda'),
1449
+ persistent_workers=val_workers > 0,
1450
+ )
1451
+
1452
+ is_data_parallel = isinstance(self.ft, nn.DataParallel)
1453
+
1454
+ def forward_fn(batch):
1455
+ X_num_b, X_cat_b, X_geo_b, y_b, w_b = batch
1456
+
1457
+ if not is_data_parallel:
1458
+ X_num_b = X_num_b.to(self.device, non_blocking=True)
1459
+ X_cat_b = X_cat_b.to(self.device, non_blocking=True)
1460
+ X_geo_b = X_geo_b.to(self.device, non_blocking=True)
1461
+ y_b = y_b.to(self.device, non_blocking=True)
1462
+ w_b = w_b.to(self.device, non_blocking=True)
1463
+
1464
+ y_pred = self.ft(X_num_b, X_cat_b, X_geo_b)
1465
+ return y_pred, y_b, w_b
1466
+
1467
+ def val_forward_fn():
1468
+ total_loss = 0.0
1469
+ total_weight = 0.0
1470
+ for batch in val_dataloader:
1471
+ X_num_b, X_cat_b, X_geo_b, y_b, w_b = batch
1472
+ if not is_data_parallel:
1473
+ X_num_b = X_num_b.to(self.device, non_blocking=True)
1474
+ X_cat_b = X_cat_b.to(self.device, non_blocking=True)
1475
+ X_geo_b = X_geo_b.to(self.device, non_blocking=True)
1476
+ y_b = y_b.to(self.device, non_blocking=True)
1477
+ w_b = w_b.to(self.device, non_blocking=True)
1478
+
1479
+ y_pred = self.ft(X_num_b, X_cat_b, X_geo_b)
1480
+
1481
+ # 手动计算验证损失
1482
+ task = getattr(self, "task_type", "regression")
1483
+ if task == 'classification':
1484
+ loss_fn = nn.BCEWithLogitsLoss(reduction='none')
1485
+ losses = loss_fn(y_pred, y_b).view(-1)
1486
+ else:
1487
+ # 模型输出已通过 Softplus,无需再次应用
1488
+ y_pred_clamped = torch.clamp(y_pred, min=1e-6)
1489
+ power = getattr(self, "tw_power", 1.5)
1490
+ losses = tweedie_loss(
1491
+ y_pred_clamped, y_b, p=power).view(-1)
1492
+
1493
+ batch_weight_sum = torch.clamp(w_b.sum(), min=EPS)
1494
+ batch_weighted_loss_sum = (losses * w_b.view(-1)).sum()
1495
+
1496
+ total_loss += batch_weighted_loss_sum.item()
1497
+ total_weight += batch_weight_sum.item()
1498
+
1499
+ return total_loss / max(total_weight, EPS)
1500
+
1501
+ clip_fn = None
1502
+ if self.device.type == 'cuda':
1503
+ def clip_fn(): return (scaler.unscale_(optimizer),
1504
+ clip_grad_norm_(self.ft.parameters(), max_norm=1.0))
1505
+
1506
+ best_state, history = self._train_model(
1507
+ self.ft,
1508
+ dataloader,
1509
+ accum_steps,
1510
+ optimizer,
1511
+ scaler,
1512
+ forward_fn,
1513
+ val_forward_fn if has_val else None,
1514
+ apply_softplus=False,
1515
+ clip_fn=clip_fn,
1516
+ trial=trial,
1517
+ loss_curve_path=getattr(self, "loss_curve_path", None)
1518
+ )
1519
+
1520
+ if has_val and best_state is not None:
1521
+ self.ft.load_state_dict(best_state)
1522
+ self.training_history = history
1523
+
1524
+ def predict(self, X_test, geo_tokens=None, batch_size: Optional[int] = None):
1525
+ # X_test 需要包含所有数值列与类别列;geo_tokens 为可选的地理 token
1526
+
1527
+ self.ft.eval()
1528
+ X_num, X_cat, X_geo, _, _, _ = self._tensorize_split(
1529
+ X_test, None, None, geo_tokens=geo_tokens, allow_none=True)
1530
+
1531
+ num_rows = X_num.shape[0]
1532
+ if num_rows == 0:
1533
+ return np.empty(0, dtype=np.float32)
1534
+
1535
+ device = self.device if isinstance(
1536
+ self.device, torch.device) else torch.device(self.device)
1537
+
1538
+ def resolve_batch_size(n_rows: int) -> int:
1539
+ if batch_size is not None:
1540
+ return max(1, min(int(batch_size), n_rows))
1541
+ # 依据模型规模估算安全批量,防止多头注意力在推理阶段拉满显存
1542
+ token_cnt = len(self.num_cols) + len(self.cat_cols)
1543
+ if self.num_geo > 0:
1544
+ token_cnt += 1
1545
+ approx_units = max(1, token_cnt * max(1, self.d_model))
1546
+ if device.type == 'cuda':
1547
+ if approx_units >= 8192:
1548
+ base = 512
1549
+ elif approx_units >= 4096:
1550
+ base = 1024
1551
+ else:
1552
+ base = 2048
1553
+ else:
1554
+ base = 512
1555
+ return max(1, min(base, n_rows))
1556
+
1557
+ eff_batch = resolve_batch_size(num_rows)
1558
+ preds: List[torch.Tensor] = []
1559
+
1560
+ with torch.no_grad():
1561
+ for start in range(0, num_rows, eff_batch):
1562
+ end = min(num_rows, start + eff_batch)
1563
+ X_num_b = X_num[start:end].to(device, non_blocking=True)
1564
+ X_cat_b = X_cat[start:end].to(device, non_blocking=True)
1565
+ X_geo_b = X_geo[start:end].to(device, non_blocking=True)
1566
+ pred_chunk = self.ft(X_num_b, X_cat_b, X_geo_b)
1567
+ preds.append(pred_chunk.cpu())
1568
+
1569
+ y_pred = torch.cat(preds, dim=0).numpy()
1570
+
1571
+ if self.task_type == 'classification':
1572
+ # 从 logits 转换为概率
1573
+ y_pred = 1 / (1 + np.exp(-y_pred))
1574
+ else:
1575
+ # 模型已含 softplus,若需要可按需做 log-exp 平滑:y_pred = log(1 + exp(y_pred))
1576
+ y_pred = np.clip(y_pred, 1e-6, None)
1577
+ return y_pred.ravel()
1578
+
1579
+ def set_params(self, params: dict):
1580
+
1581
+ # 和 sklearn 风格保持一致。
1582
+ # 注意:对结构性参数(如 d_model/n_heads)修改后,需要重新 fit 才会生效。
1583
+
1584
+ for key, value in params.items():
1585
+ if hasattr(self, key):
1586
+ setattr(self, key, value)
1587
+ else:
1588
+ raise ValueError(f"Parameter {key} not found in model.")
1589
+ return self
1590
+
1591
+
1592
+ # =============================================================================
1593
+ # 图神经网络 (GNN) 简化实现
1594
+ # =============================================================================
1595
+
1596
+
1597
+ class SimpleGraphLayer(nn.Module):
1598
+ def __init__(self, in_dim: int, out_dim: int, dropout: float = 0.1):
1599
+ super().__init__()
1600
+ self.linear = nn.Linear(in_dim, out_dim)
1601
+ self.activation = nn.ReLU(inplace=True)
1602
+ self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
1603
+
1604
+ def forward(self, x: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
1605
+ # 基于归一化稀疏邻接矩阵的消息传递:A_hat * X * W
1606
+ h = torch.sparse.mm(adj, x)
1607
+ h = self.linear(h)
1608
+ h = self.activation(h)
1609
+ return self.dropout(h)
1610
+
1611
+
1612
+ class SimpleGNN(nn.Module):
1613
+ def __init__(self, input_dim: int, hidden_dim: int = 64, num_layers: int = 2,
1614
+ dropout: float = 0.1, task_type: str = 'regression'):
1615
+ super().__init__()
1616
+ layers = []
1617
+ dim_in = input_dim
1618
+ for _ in range(max(1, num_layers)):
1619
+ layers.append(SimpleGraphLayer(
1620
+ dim_in, hidden_dim, dropout=dropout))
1621
+ dim_in = hidden_dim
1622
+ self.layers = nn.ModuleList(layers)
1623
+ self.output = nn.Linear(hidden_dim, 1)
1624
+ if task_type == 'classification':
1625
+ self.output_act = nn.Identity()
1626
+ else:
1627
+ self.output_act = nn.Softplus()
1628
+ self.task_type = task_type
1629
+ # 用 buffer 保持邻接矩阵,便于 DataParallel 复制
1630
+ self.register_buffer("adj_buffer", torch.empty(0))
1631
+
1632
+ def forward(self, x: torch.Tensor, adj: Optional[torch.Tensor] = None) -> torch.Tensor:
1633
+ adj_used = adj if adj is not None else getattr(
1634
+ self, "adj_buffer", None)
1635
+ if adj_used is None or adj_used.numel() == 0:
1636
+ raise RuntimeError("Adjacency is not set for GNN forward.")
1637
+ h = x
1638
+ for layer in self.layers:
1639
+ h = layer(h, adj_used)
1640
+ h = torch.sparse.mm(adj_used, h)
1641
+ out = self.output(h)
1642
+ return self.output_act(out)
1643
+
1644
+
1645
+ class GraphNeuralNetSklearn(TorchTrainerMixin, nn.Module):
1646
+ def __init__(self, model_nme: str, input_dim: int, hidden_dim: int = 64,
1647
+ num_layers: int = 2, k_neighbors: int = 10, dropout: float = 0.1,
1648
+ learning_rate: float = 1e-3, epochs: int = 100, patience: int = 10,
1649
+ task_type: str = 'regression', tweedie_power: float = 1.5,
1650
+ use_data_parallel: bool = False, use_ddp: bool = False,
1651
+ use_approx_knn: bool = True, approx_knn_threshold: int = 50000,
1652
+ graph_cache_path: Optional[str] = None,
1653
+ max_gpu_knn_nodes: Optional[int] = None,
1654
+ knn_gpu_mem_ratio: float = 0.9,
1655
+ knn_gpu_mem_overhead: float = 2.0) -> None:
1656
+ super().__init__()
1657
+ self.model_nme = model_nme
1658
+ self.input_dim = input_dim
1659
+ self.hidden_dim = hidden_dim
1660
+ self.num_layers = num_layers
1661
+ self.k_neighbors = max(1, k_neighbors)
1662
+ self.dropout = dropout
1663
+ self.learning_rate = learning_rate
1664
+ self.epochs = epochs
1665
+ self.patience = patience
1666
+ self.task_type = task_type
1667
+ self.use_approx_knn = use_approx_knn
1668
+ self.approx_knn_threshold = approx_knn_threshold
1669
+ self.graph_cache_path = Path(
1670
+ graph_cache_path) if graph_cache_path else None
1671
+ self.max_gpu_knn_nodes = max_gpu_knn_nodes
1672
+ self.knn_gpu_mem_ratio = max(0.0, min(1.0, knn_gpu_mem_ratio))
1673
+ self.knn_gpu_mem_overhead = max(1.0, knn_gpu_mem_overhead)
1674
+ self._knn_warning_emitted = False
1675
+
1676
+ if self.task_type == 'classification':
1677
+ self.tw_power = None
1678
+ elif 'f' in self.model_nme:
1679
+ self.tw_power = 1.0
1680
+ elif 's' in self.model_nme:
1681
+ self.tw_power = 2.0
1682
+ else:
1683
+ self.tw_power = tweedie_power
1684
+
1685
+ self.ddp_enabled = False
1686
+ self.local_rank = 0
1687
+ self.data_parallel_enabled = False
1688
+
1689
+ # DDP 仅在 CUDA 下有效;若未初始化成功则自动回退单卡
1690
+ if use_ddp and torch.cuda.is_available():
1691
+ ddp_ok, local_rank, _, _ = DistributedUtils.setup_ddp()
1692
+ if ddp_ok:
1693
+ self.ddp_enabled = True
1694
+ self.local_rank = local_rank
1695
+ self.device = torch.device(f'cuda:{local_rank}')
1696
+ else:
1697
+ self.device = torch.device('cuda')
1698
+ elif torch.cuda.is_available():
1699
+ self.device = torch.device('cuda')
1700
+ elif torch.backends.mps.is_available():
1701
+ self.device = torch.device('mps')
1702
+ else:
1703
+ self.device = torch.device('cpu')
1704
+ self.use_pyg_knn = self.device.type == 'cuda' and _PYG_AVAILABLE
1705
+
1706
+ self.gnn = SimpleGNN(
1707
+ input_dim=self.input_dim,
1708
+ hidden_dim=self.hidden_dim,
1709
+ num_layers=self.num_layers,
1710
+ dropout=self.dropout,
1711
+ task_type=self.task_type
1712
+ ).to(self.device)
1713
+
1714
+ # DataParallel 会将完整图复制到每张卡并按特征拆分,适合中等规模图
1715
+ if (not self.ddp_enabled) and use_data_parallel and (self.device.type == 'cuda') and (torch.cuda.device_count() > 1):
1716
+ self.data_parallel_enabled = True
1717
+ self.gnn = nn.DataParallel(
1718
+ self.gnn, device_ids=list(range(torch.cuda.device_count())))
1719
+ self.device = torch.device('cuda')
1720
+
1721
+ if self.ddp_enabled:
1722
+ self.gnn = DDP(
1723
+ self.gnn,
1724
+ device_ids=[self.local_rank],
1725
+ output_device=self.local_rank,
1726
+ find_unused_parameters=False
1727
+ )
1728
+
1729
+ def _unwrap_gnn(self) -> nn.Module:
1730
+ return self.gnn.module if isinstance(self.gnn, DDP) else self.gnn
1731
+
1732
+ def _set_adj_buffer(self, adj: torch.Tensor) -> None:
1733
+ base = self._unwrap_gnn()
1734
+ if hasattr(base, "adj_buffer"):
1735
+ base.adj_buffer = adj
1736
+ else:
1737
+ base.register_buffer("adj_buffer", adj)
1738
+
1739
+ def _load_cached_adj(self) -> Optional[torch.Tensor]:
1740
+ if self.graph_cache_path and self.graph_cache_path.exists():
1741
+ try:
1742
+ adj = torch.load(self.graph_cache_path,
1743
+ map_location=self.device)
1744
+ return adj.to(self.device)
1745
+ except Exception as exc:
1746
+ print(
1747
+ f"[GNN] Failed to load cached graph from {self.graph_cache_path}: {exc}")
1748
+ return None
1749
+
1750
+ def _build_edge_index_cpu(self, X_np: np.ndarray) -> torch.Tensor:
1751
+ n_samples = X_np.shape[0]
1752
+ k = min(self.k_neighbors, max(1, n_samples - 1))
1753
+ n_neighbors = min(k + 1, n_samples)
1754
+ use_approx = (self.use_approx_knn or n_samples >=
1755
+ self.approx_knn_threshold) and _PYNN_AVAILABLE
1756
+ indices = None
1757
+ if use_approx:
1758
+ try:
1759
+ nn_index = pynndescent.NNDescent(
1760
+ X_np,
1761
+ n_neighbors=n_neighbors,
1762
+ random_state=0
1763
+ )
1764
+ indices, _ = nn_index.neighbor_graph
1765
+ except Exception as exc:
1766
+ print(
1767
+ f"[GNN] Approximate kNN failed ({exc}); falling back to exact search.")
1768
+ use_approx = False
1769
+
1770
+ if indices is None:
1771
+ nbrs = NearestNeighbors(n_neighbors=n_neighbors, algorithm="auto")
1772
+ nbrs.fit(X_np)
1773
+ _, indices = nbrs.kneighbors(X_np)
1774
+
1775
+ indices = np.asarray(indices)
1776
+
1777
+ rows = []
1778
+ cols = []
1779
+ for i in range(n_samples):
1780
+ for j in indices[i]:
1781
+ if i == j:
1782
+ continue
1783
+ rows.append(i)
1784
+ cols.append(j)
1785
+ rows.append(j)
1786
+ cols.append(i)
1787
+
1788
+ # 添加自环,避免度为 0 的节点
1789
+ rows.extend(range(n_samples))
1790
+ cols.extend(range(n_samples))
1791
+
1792
+ edge_index = torch.tensor(
1793
+ [rows, cols], dtype=torch.long, device=self.device)
1794
+ return edge_index
1795
+
1796
+ def _build_edge_index_gpu(self, X_tensor: torch.Tensor) -> torch.Tensor:
1797
+ if not self.use_pyg_knn or knn_graph is None or add_self_loops is None or to_undirected is None:
1798
+ # 防御式编程:调用前应检查 use_pyg_knn
1799
+ raise RuntimeError(
1800
+ "GPU graph builder requested but PyG is unavailable.")
1801
+
1802
+ n_samples = X_tensor.size(0)
1803
+ k = min(self.k_neighbors, max(1, n_samples - 1))
1804
+
1805
+ # knn_graph 运行在 GPU 上,避免 CPU 构图成为瓶颈
1806
+ edge_index = knn_graph(
1807
+ X_tensor,
1808
+ k=k,
1809
+ loop=False
1810
+ )
1811
+ edge_index = to_undirected(edge_index, num_nodes=n_samples)
1812
+ edge_index, _ = add_self_loops(edge_index, num_nodes=n_samples)
1813
+ return edge_index
1814
+
1815
+ def _log_knn_fallback(self, reason: str) -> None:
1816
+ if self._knn_warning_emitted:
1817
+ return
1818
+ if (not self.ddp_enabled) or self.local_rank == 0:
1819
+ print(f"[GNN] Falling back to CPU kNN builder: {reason}")
1820
+ self._knn_warning_emitted = True
1821
+
1822
+ def _should_use_gpu_knn(self, n_samples: int, X_tensor: torch.Tensor) -> bool:
1823
+ if not self.use_pyg_knn:
1824
+ return False
1825
+
1826
+ reason = None
1827
+ if self.max_gpu_knn_nodes is not None and n_samples > self.max_gpu_knn_nodes:
1828
+ reason = f"node count {n_samples} exceeds max_gpu_knn_nodes={self.max_gpu_knn_nodes}"
1829
+ elif self.device.type == 'cuda' and torch.cuda.is_available():
1830
+ try:
1831
+ device_index = self.device.index
1832
+ if device_index is None:
1833
+ device_index = torch.cuda.current_device()
1834
+ free_mem, total_mem = torch.cuda.mem_get_info(device_index)
1835
+ feature_bytes = X_tensor.element_size() * X_tensor.nelement()
1836
+ required = int(feature_bytes * self.knn_gpu_mem_overhead)
1837
+ budget = int(free_mem * self.knn_gpu_mem_ratio)
1838
+ if required > budget:
1839
+ required_gb = required / (1024 ** 3)
1840
+ budget_gb = budget / (1024 ** 3)
1841
+ reason = (f"requires ~{required_gb:.2f} GiB temporary GPU memory "
1842
+ f"but only {budget_gb:.2f} GiB free on cuda:{device_index}")
1843
+ except Exception:
1844
+ # 在较老版本或某些环境下 mem_get_info 可能不可用,默认继续尝试 GPU
1845
+ reason = None
1846
+
1847
+ if reason:
1848
+ self._log_knn_fallback(reason)
1849
+ return False
1850
+ return True
1851
+
1852
+ def _normalized_adj(self, edge_index: torch.Tensor, num_nodes: int) -> torch.Tensor:
1853
+ values = torch.ones(edge_index.shape[1], device=self.device)
1854
+ adj = torch.sparse_coo_tensor(
1855
+ edge_index.to(self.device), values, (num_nodes, num_nodes))
1856
+ adj = adj.coalesce()
1857
+
1858
+ deg = torch.sparse.sum(adj, dim=1).to_dense()
1859
+ deg_inv_sqrt = torch.pow(deg + 1e-8, -0.5)
1860
+ row, col = adj.indices()
1861
+ norm_values = deg_inv_sqrt[row] * adj.values() * deg_inv_sqrt[col]
1862
+ adj_norm = torch.sparse_coo_tensor(
1863
+ adj.indices(), norm_values, size=adj.shape)
1864
+ return adj_norm
1865
+
1866
+ def _tensorize_split(self, X, y, w, allow_none: bool = False):
1867
+ if X is None and allow_none:
1868
+ return None, None, None
1869
+ X_np = X.values.astype(np.float32)
1870
+ X_tensor = torch.tensor(X_np, dtype=torch.float32, device=self.device)
1871
+ if y is None:
1872
+ y_tensor = None
1873
+ else:
1874
+ y_tensor = torch.tensor(
1875
+ y.values, dtype=torch.float32, device=self.device).view(-1, 1)
1876
+ if w is None:
1877
+ w_tensor = torch.ones(
1878
+ (len(X), 1), dtype=torch.float32, device=self.device)
1879
+ else:
1880
+ w_tensor = torch.tensor(
1881
+ w.values, dtype=torch.float32, device=self.device).view(-1, 1)
1882
+ return X_tensor, y_tensor, w_tensor
1883
+
1884
+ def _build_graph_from_df(self, X_df: pd.DataFrame, X_tensor: Optional[torch.Tensor] = None) -> torch.Tensor:
1885
+ if X_tensor is None:
1886
+ X_tensor = torch.tensor(
1887
+ X_df.values.astype(np.float32),
1888
+ dtype=torch.float32,
1889
+ device=self.device
1890
+ )
1891
+ if self.graph_cache_path:
1892
+ cached = self._load_cached_adj()
1893
+ if cached is not None:
1894
+ return cached
1895
+ use_gpu_knn = self._should_use_gpu_knn(X_df.shape[0], X_tensor)
1896
+ if use_gpu_knn:
1897
+ edge_index = self._build_edge_index_gpu(X_tensor)
1898
+ else:
1899
+ edge_index = self._build_edge_index_cpu(
1900
+ X_df.values.astype(np.float32))
1901
+ adj_norm = self._normalized_adj(edge_index, X_df.shape[0])
1902
+ if self.graph_cache_path:
1903
+ try:
1904
+ IOUtils.ensure_parent_dir(str(self.graph_cache_path))
1905
+ torch.save(adj_norm.cpu(), self.graph_cache_path)
1906
+ except Exception as exc:
1907
+ print(
1908
+ f"[GNN] Failed to cache graph to {self.graph_cache_path}: {exc}")
1909
+ return adj_norm
1910
+
1911
+ def fit(self, X_train, y_train, w_train=None,
1912
+ X_val=None, y_val=None, w_val=None,
1913
+ trial: Optional[optuna.trial.Trial] = None):
1914
+
1915
+ X_train_tensor, y_train_tensor, w_train_tensor = self._tensorize_split(
1916
+ X_train, y_train, w_train, allow_none=False)
1917
+ has_val = X_val is not None and y_val is not None
1918
+ if has_val:
1919
+ X_val_tensor, y_val_tensor, w_val_tensor = self._tensorize_split(
1920
+ X_val, y_val, w_val, allow_none=False)
1921
+ else:
1922
+ X_val_tensor = y_val_tensor = w_val_tensor = None
1923
+
1924
+ adj_train = self._build_graph_from_df(X_train, X_train_tensor)
1925
+ adj_val = self._build_graph_from_df(
1926
+ X_val, X_val_tensor) if has_val else None
1927
+ # DataParallel 需要将邻接矩阵缓存在模型上,避免被 scatter
1928
+ self._set_adj_buffer(adj_train)
1929
+
1930
+ base_gnn = self._unwrap_gnn()
1931
+ optimizer = torch.optim.Adam(
1932
+ base_gnn.parameters(), lr=self.learning_rate)
1933
+ scaler = GradScaler(enabled=(self.device.type == 'cuda'))
1934
+
1935
+ best_loss = float('inf')
1936
+ best_state = None
1937
+ patience_counter = 0
1938
+
1939
+ for epoch in range(1, self.epochs + 1):
1940
+ self.gnn.train()
1941
+ optimizer.zero_grad()
1942
+ with autocast(enabled=(self.device.type == 'cuda')):
1943
+ if self.data_parallel_enabled:
1944
+ y_pred = self.gnn(X_train_tensor)
1945
+ else:
1946
+ y_pred = self.gnn(X_train_tensor, adj_train)
1947
+ loss = self._compute_weighted_loss(
1948
+ y_pred, y_train_tensor, w_train_tensor, apply_softplus=False)
1949
+ scaler.scale(loss).backward()
1950
+ scaler.unscale_(optimizer)
1951
+ clip_grad_norm_(self.gnn.parameters(), max_norm=1.0)
1952
+ scaler.step(optimizer)
1953
+ scaler.update()
1954
+
1955
+ if has_val:
1956
+ self.gnn.eval()
1957
+ if self.data_parallel_enabled and adj_val is not None:
1958
+ self._set_adj_buffer(adj_val)
1959
+ with torch.no_grad(), autocast(enabled=(self.device.type == 'cuda')):
1960
+ if self.data_parallel_enabled:
1961
+ y_val_pred = self.gnn(X_val_tensor)
1962
+ else:
1963
+ y_val_pred = self.gnn(X_val_tensor, adj_val)
1964
+ val_loss = self._compute_weighted_loss(
1965
+ y_val_pred, y_val_tensor, w_val_tensor, apply_softplus=False)
1966
+ if self.data_parallel_enabled:
1967
+ # 恢复训练邻接矩阵
1968
+ self._set_adj_buffer(adj_train)
1969
+
1970
+ best_loss, best_state, patience_counter, stop_training = self._early_stop_update(
1971
+ val_loss, best_loss, best_state, patience_counter, base_gnn,
1972
+ ignore_keys=["adj_buffer"])
1973
+
1974
+ if trial is not None:
1975
+ trial.report(val_loss, epoch)
1976
+ if trial.should_prune():
1977
+ raise optuna.TrialPruned()
1978
+ if stop_training:
1979
+ break
1980
+
1981
+ if best_state is not None:
1982
+ base_gnn.load_state_dict(best_state, strict=False)
1983
+
1984
+ def predict(self, X: pd.DataFrame) -> np.ndarray:
1985
+ self.gnn.eval()
1986
+ X_tensor, _, _ = self._tensorize_split(
1987
+ X, None, None, allow_none=False)
1988
+ adj = self._build_graph_from_df(X, X_tensor)
1989
+ if self.data_parallel_enabled:
1990
+ self._set_adj_buffer(adj)
1991
+ with torch.no_grad():
1992
+ if self.data_parallel_enabled:
1993
+ y_pred = self.gnn(X_tensor).cpu().numpy()
1994
+ else:
1995
+ y_pred = self.gnn(X_tensor, adj).cpu().numpy()
1996
+ if self.task_type == 'classification':
1997
+ y_pred = 1 / (1 + np.exp(-y_pred))
1998
+ else:
1999
+ y_pred = np.clip(y_pred, 1e-6, None)
2000
+ return y_pred.ravel()
2001
+
2002
+ def set_params(self, params: Dict[str, Any]):
2003
+ for key, value in params.items():
2004
+ if hasattr(self, key):
2005
+ setattr(self, key, value)
2006
+ else:
2007
+ raise ValueError(f"Parameter {key} not found in GNN model.")
2008
+ # 结构参数变化后需要重建骨架
2009
+ self.gnn = SimpleGNN(
2010
+ input_dim=self.input_dim,
2011
+ hidden_dim=self.hidden_dim,
2012
+ num_layers=self.num_layers,
2013
+ dropout=self.dropout,
2014
+ task_type=self.task_type
2015
+ ).to(self.device)
2016
+ return self
2017
+
2018
+
2019
+ # ===== 基础组件与训练封装 =====================================================
2020
+
2021
+ # =============================================================================
2022
+ # 配置、预处理与训练器基类
2023
+ # =============================================================================
2024
+ @dataclass
2025
+ class BayesOptConfig:
2026
+ model_nme: str
2027
+ resp_nme: str
2028
+ weight_nme: str
2029
+ factor_nmes: List[str]
2030
+ task_type: str = 'regression'
2031
+ binary_resp_nme: Optional[str] = None
2032
+ cate_list: Optional[List[str]] = None
2033
+ prop_test: float = 0.25
2034
+ rand_seed: Optional[int] = None
2035
+ epochs: int = 100
2036
+ use_gpu: bool = True
2037
+ use_resn_data_parallel: bool = False
2038
+ use_ft_data_parallel: bool = False
2039
+ use_resn_ddp: bool = False
2040
+ use_ft_ddp: bool = False
2041
+ use_gnn_data_parallel: bool = False
2042
+ use_gnn_ddp: bool = False
2043
+ gnn_use_approx_knn: bool = True
2044
+ gnn_approx_knn_threshold: int = 50000
2045
+ gnn_graph_cache: Optional[str] = None
2046
+ gnn_max_gpu_knn_nodes: Optional[int] = 200000
2047
+ gnn_knn_gpu_mem_ratio: float = 0.9
2048
+ gnn_knn_gpu_mem_overhead: float = 2.0
2049
+ region_province_col: Optional[str] = None # 省级列名,用于层级平滑
2050
+ region_city_col: Optional[str] = None # 市级列名,用于层级平滑
2051
+ region_effect_alpha: float = 50.0 # 层级平滑强度(伪样本量)
2052
+ geo_feature_nmes: Optional[List[str]] = None # 用于构造地理 token 的列,空则不启用 GNN
2053
+ geo_token_hidden_dim: int = 32
2054
+ geo_token_layers: int = 2
2055
+ geo_token_dropout: float = 0.1
2056
+ geo_token_k_neighbors: int = 10
2057
+ geo_token_learning_rate: float = 1e-3
2058
+ geo_token_epochs: int = 50
2059
+ output_dir: Optional[str] = None
2060
+ optuna_storage: Optional[str] = None
2061
+ optuna_study_prefix: Optional[str] = None
2062
+
2063
+
2064
+ class OutputManager:
2065
+ # 统一管理结果、图表与模型的输出路径
2066
+
2067
+ def __init__(self, root: Optional[str] = None, model_name: str = "model") -> None:
2068
+ self.root = Path(root or os.getcwd())
2069
+ self.model_name = model_name
2070
+ self.plot_dir = self.root / 'plot'
2071
+ self.result_dir = self.root / 'Results'
2072
+ self.model_dir = self.root / 'model'
2073
+
2074
+ def _prepare(self, path: Path) -> str:
2075
+ ensure_parent_dir(str(path))
2076
+ return str(path)
2077
+
2078
+ def plot_path(self, filename: str) -> str:
2079
+ return self._prepare(self.plot_dir / filename)
2080
+
2081
+ def result_path(self, filename: str) -> str:
2082
+ return self._prepare(self.result_dir / filename)
2083
+
2084
+ def model_path(self, filename: str) -> str:
2085
+ return self._prepare(self.model_dir / filename)
2086
+
2087
+
2088
+ class VersionManager:
2089
+ """简单的版本记录:保存配置与最优参数快照,便于回溯。"""
2090
+
2091
+ def __init__(self, output: OutputManager) -> None:
2092
+ self.output = output
2093
+ self.version_dir = Path(self.output.result_dir) / "versions"
2094
+ IOUtils.ensure_parent_dir(str(self.version_dir))
2095
+
2096
+ def save(self, tag: str, payload: Dict[str, Any]) -> str:
2097
+ safe_tag = tag.replace(" ", "_")
2098
+ ts = datetime.now().strftime("%Y%m%d_%H%M%S")
2099
+ path = self.version_dir / f"{ts}_{safe_tag}.json"
2100
+ IOUtils.ensure_parent_dir(str(path))
2101
+ with open(path, "w", encoding="utf-8") as f:
2102
+ json.dump(payload, f, ensure_ascii=False, indent=2, default=str)
2103
+ print(f"[Version] Saved snapshot to {path}")
2104
+ return str(path)
2105
+
2106
+
2107
+ class DatasetPreprocessor:
2108
+ # 为各训练器准备通用的训练/测试数据视图
2109
+
2110
+ def __init__(self, train_df: pd.DataFrame, test_df: pd.DataFrame,
2111
+ config: BayesOptConfig) -> None:
2112
+ self.config = config
2113
+ self.train_data = train_df.copy(deep=True)
2114
+ self.test_data = test_df.copy(deep=True)
2115
+ self.num_features: List[str] = []
2116
+ self.train_oht_data: Optional[pd.DataFrame] = None
2117
+ self.test_oht_data: Optional[pd.DataFrame] = None
2118
+ self.train_oht_scl_data: Optional[pd.DataFrame] = None
2119
+ self.test_oht_scl_data: Optional[pd.DataFrame] = None
2120
+ self.var_nmes: List[str] = []
2121
+ self.cat_categories_for_shap: Dict[str, List[Any]] = {}
2122
+
2123
+ def run(self) -> "DatasetPreprocessor":
2124
+ """执行预处理:类别编码、目标值裁剪以及数值特征标准化。"""
2125
+ cfg = self.config
2126
+ # 预先计算加权实际值,后续画图、校验都依赖该字段
2127
+ self.train_data.loc[:, 'w_act'] = self.train_data[cfg.resp_nme] * \
2128
+ self.train_data[cfg.weight_nme]
2129
+ self.test_data.loc[:, 'w_act'] = self.test_data[cfg.resp_nme] * \
2130
+ self.test_data[cfg.weight_nme]
2131
+ if cfg.binary_resp_nme:
2132
+ self.train_data.loc[:, 'w_binary_act'] = self.train_data[cfg.binary_resp_nme] * \
2133
+ self.train_data[cfg.weight_nme]
2134
+ self.test_data.loc[:, 'w_binary_act'] = self.test_data[cfg.binary_resp_nme] * \
2135
+ self.test_data[cfg.weight_nme]
2136
+ # 高分位裁剪用来吸收离群值;若删除会导致极端点主导损失
2137
+ q99 = self.train_data[cfg.resp_nme].quantile(0.999)
2138
+ self.train_data[cfg.resp_nme] = self.train_data[cfg.resp_nme].clip(
2139
+ upper=q99)
2140
+ cate_list = list(cfg.cate_list or [])
2141
+ if cate_list:
2142
+ for cate in cate_list:
2143
+ self.train_data[cate] = self.train_data[cate].astype(
2144
+ 'category')
2145
+ self.test_data[cate] = self.test_data[cate].astype('category')
2146
+ cats = self.train_data[cate].cat.categories
2147
+ self.cat_categories_for_shap[cate] = list(cats)
2148
+ self.num_features = [
2149
+ nme for nme in cfg.factor_nmes if nme not in cate_list]
2150
+ train_oht = self.train_data[cfg.factor_nmes +
2151
+ [cfg.weight_nme] + [cfg.resp_nme]].copy()
2152
+ test_oht = self.test_data[cfg.factor_nmes +
2153
+ [cfg.weight_nme] + [cfg.resp_nme]].copy()
2154
+ train_oht = pd.get_dummies(
2155
+ train_oht,
2156
+ columns=cate_list,
2157
+ drop_first=True,
2158
+ dtype=np.int8
2159
+ )
2160
+ test_oht = pd.get_dummies(
2161
+ test_oht,
2162
+ columns=cate_list,
2163
+ drop_first=True,
2164
+ dtype=np.int8
2165
+ )
2166
+
2167
+ # 重新索引时将缺失的哑变量列补零,避免测试集列数与训练集不一致
2168
+ test_oht = test_oht.reindex(columns=train_oht.columns, fill_value=0)
2169
+
2170
+ # 保留未缩放的 one-hot 数据,供交叉验证时按折内标准化避免泄露
2171
+ self.train_oht_data = train_oht.copy(deep=True)
2172
+ self.test_oht_data = test_oht.copy(deep=True)
2173
+
2174
+ train_oht_scaled = train_oht.copy(deep=True)
2175
+ test_oht_scaled = test_oht.copy(deep=True)
2176
+ for num_chr in self.num_features:
2177
+ # 逐列标准化保障每个特征在同一量级,否则神经网络会难以收敛
2178
+ scaler = StandardScaler()
2179
+ train_oht_scaled[num_chr] = scaler.fit_transform(
2180
+ train_oht_scaled[num_chr].values.reshape(-1, 1))
2181
+ test_oht_scaled[num_chr] = scaler.transform(
2182
+ test_oht_scaled[num_chr].values.reshape(-1, 1))
2183
+ # 重新索引时将缺失的哑变量列补零,避免测试集列数与训练集不一致
2184
+ test_oht_scaled = test_oht_scaled.reindex(
2185
+ columns=train_oht_scaled.columns, fill_value=0)
2186
+ self.train_oht_scl_data = train_oht_scaled
2187
+ self.test_oht_scl_data = test_oht_scaled
2188
+ self.var_nmes = list(
2189
+ set(list(train_oht_scaled.columns)) -
2190
+ set([cfg.weight_nme, cfg.resp_nme])
2191
+ )
2192
+ return self
2193
+
2194
+ # =============================================================================
2195
+ # 训练器体系
2196
+ # =============================================================================
2197
+
2198
+
2199
+ class TrainerBase:
2200
+ def __init__(self, context: "BayesOptModel", label: str, model_name_prefix: str) -> None:
2201
+ self.ctx = context
2202
+ self.label = label
2203
+ self.model_name_prefix = model_name_prefix
2204
+ self.model = None
2205
+ self.best_params: Optional[Dict[str, Any]] = None
2206
+ self.best_trial = None
2207
+ self.study_name: Optional[str] = None
2208
+ self.enable_distributed_optuna: bool = False
2209
+ self._distributed_forced_params: Optional[Dict[str, Any]] = None
2210
+
2211
+ @property
2212
+ def config(self) -> BayesOptConfig:
2213
+ return self.ctx.config
2214
+
2215
+ @property
2216
+ def output(self) -> OutputManager:
2217
+ return self.ctx.output_manager
2218
+
2219
+ def _get_model_filename(self) -> str:
2220
+ ext = 'pkl' if self.label in ['Xgboost', 'GLM'] else 'pth'
2221
+ return f'01_{self.ctx.model_nme}_{self.model_name_prefix}.{ext}'
2222
+
2223
+ def tune(self, max_evals: int, objective_fn=None) -> None:
2224
+ # 通用的 Optuna 调参循环流程。
2225
+ if objective_fn is None:
2226
+ # 若子类未显式提供 objective_fn,则默认使用 cross_val 作为优化目标
2227
+ objective_fn = self.cross_val
2228
+
2229
+ if self._should_use_distributed_optuna():
2230
+ self._distributed_tune(max_evals, objective_fn)
2231
+ return
2232
+
2233
+ total_trials = max(1, int(max_evals))
2234
+ progress_counter = {"count": 0}
2235
+
2236
+ def objective_wrapper(trial: optuna.trial.Trial) -> float:
2237
+ should_log = DistributedUtils.is_main_process()
2238
+ if should_log:
2239
+ current_idx = progress_counter["count"] + 1
2240
+ print(
2241
+ f"[Optuna][{self.label}] Trial {current_idx}/{total_trials} started "
2242
+ f"(trial_id={trial.number})."
2243
+ )
2244
+ try:
2245
+ result = objective_fn(trial)
2246
+ except RuntimeError as exc:
2247
+ if "out of memory" in str(exc).lower():
2248
+ print(
2249
+ f"[Optuna][{self.label}] OOM detected. Pruning trial and clearing CUDA cache."
2250
+ )
2251
+ self._clean_gpu()
2252
+ raise optuna.TrialPruned() from exc
2253
+ raise
2254
+ finally:
2255
+ self._clean_gpu()
2256
+ if should_log:
2257
+ progress_counter["count"] = progress_counter["count"] + 1
2258
+ trial_state = getattr(trial, "state", None)
2259
+ state_repr = getattr(trial_state, "name", "OK")
2260
+ print(
2261
+ f"[Optuna][{self.label}] Trial {progress_counter['count']}/{total_trials} finished "
2262
+ f"(status={state_repr})."
2263
+ )
2264
+ return result
2265
+
2266
+ study = optuna.create_study(
2267
+ direction='minimize',
2268
+ sampler=optuna.samplers.TPESampler(seed=self.ctx.rand_seed)
2269
+ )
2270
+ self.study_name = getattr(study, "study_name", None)
2271
+ study.optimize(objective_wrapper, n_trials=max_evals)
2272
+ self.best_params = study.best_params
2273
+ self.best_trial = study.best_trial
2274
+
2275
+ # 将最优参数保存为 CSV,方便复现
2276
+ params_path = self.output.result_path(
2277
+ f'{self.ctx.model_nme}_bestparams_{self.label.lower()}.csv'
2278
+ )
2279
+ pd.DataFrame(self.best_params, index=[0]).to_csv(params_path)
2280
+
2281
+ def train(self) -> None:
2282
+ raise NotImplementedError
2283
+
2284
+ def save(self) -> None:
2285
+ if self.model is None:
2286
+ print(f"[save] Warning: No model to save for {self.label}")
2287
+ return
2288
+
2289
+ path = self.output.model_path(self._get_model_filename())
2290
+ if self.label in ['Xgboost', 'GLM']:
2291
+ joblib.dump(self.model, path)
2292
+ else:
2293
+ # PyTorch 模型既可以只存 state_dict,也可以整个对象一起序列化
2294
+ # 兼容历史行为:ResNetTrainer 保存 state_dict,FTTrainer 保存完整对象
2295
+ if hasattr(self.model, 'resnet'): # ResNetSklearn 模型
2296
+ torch.save(self.model.resnet.state_dict(), path)
2297
+ else: # FTTransformerSklearn 或其他 PyTorch 模型
2298
+ torch.save(self.model, path)
2299
+
2300
+ def load(self) -> None:
2301
+ path = self.output.model_path(self._get_model_filename())
2302
+ if not os.path.exists(path):
2303
+ print(f"[load] Warning: Model file not found: {path}")
2304
+ return
2305
+
2306
+ if self.label in ['Xgboost', 'GLM']:
2307
+ self.model = joblib.load(path)
2308
+ else:
2309
+ # PyTorch 模型的加载需要根据结构区别处理
2310
+ if self.label == 'ResNet' or self.label == 'ResNetClassifier':
2311
+ # ResNet 需要重新构建骨架,结构参数依赖 ctx,因此交由子类处理
2312
+ pass
2313
+ else:
2314
+ # FT-Transformer 序列化了整个对象,可直接加载后迁移到目标设备
2315
+ loaded = torch.load(path, map_location='cpu')
2316
+ self._move_to_device(loaded)
2317
+ self.model = loaded
2318
+
2319
+ def _move_to_device(self, model_obj):
2320
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
2321
+ if hasattr(model_obj, 'device'):
2322
+ model_obj.device = device
2323
+ if hasattr(model_obj, 'to'):
2324
+ model_obj.to(device)
2325
+ # 若对象内部还包含 ft/resnet 子模块,也要同时迁移设备
2326
+ if hasattr(model_obj, 'ft'):
2327
+ model_obj.ft.to(device)
2328
+ if hasattr(model_obj, 'resnet'):
2329
+ model_obj.resnet.to(device)
2330
+ if hasattr(model_obj, 'gnn'):
2331
+ model_obj.gnn.to(device)
2332
+
2333
+ def _should_use_distributed_optuna(self) -> bool:
2334
+ if not self.enable_distributed_optuna:
2335
+ return False
2336
+ try:
2337
+ world_env = int(os.environ.get("WORLD_SIZE", "1"))
2338
+ except Exception:
2339
+ world_env = 1
2340
+ return world_env > 1
2341
+
2342
+ def _distributed_is_main(self) -> bool:
2343
+ return DistributedUtils.is_main_process()
2344
+
2345
+ def _distributed_send_command(self, payload: Dict[str, Any]) -> None:
2346
+ if not self._should_use_distributed_optuna() or not self._distributed_is_main():
2347
+ return
2348
+ DistributedUtils.setup_ddp()
2349
+ message = [payload]
2350
+ dist.broadcast_object_list(message, src=0)
2351
+
2352
+ def _distributed_prepare_trial(self, params: Dict[str, Any]) -> None:
2353
+ if not self._should_use_distributed_optuna():
2354
+ return
2355
+ if not self._distributed_is_main():
2356
+ return
2357
+ self._distributed_send_command({"type": "RUN", "params": params})
2358
+ dist.barrier()
2359
+
2360
+ def _distributed_worker_loop(self, objective_fn: Callable[[Optional[optuna.trial.Trial]], float]) -> None:
2361
+ DistributedUtils.setup_ddp()
2362
+ while True:
2363
+ message = [None]
2364
+ dist.broadcast_object_list(message, src=0)
2365
+ payload = message[0]
2366
+ if not isinstance(payload, dict):
2367
+ continue
2368
+ cmd = payload.get("type")
2369
+ if cmd == "STOP":
2370
+ best_params = payload.get("best_params")
2371
+ if best_params is not None:
2372
+ self.best_params = best_params
2373
+ break
2374
+ if cmd == "RUN":
2375
+ params = payload.get("params") or {}
2376
+ self._distributed_forced_params = params
2377
+ dist.barrier()
2378
+ try:
2379
+ objective_fn(None)
2380
+ except optuna.TrialPruned:
2381
+ pass
2382
+ except Exception as exc:
2383
+ print(
2384
+ f"[Optuna][Worker][{self.label}] Exception: {exc}", flush=True)
2385
+ finally:
2386
+ self._clean_gpu()
2387
+ dist.barrier()
2388
+
2389
+ def _distributed_tune(self, max_evals: int, objective_fn: Callable[[optuna.trial.Trial], float]) -> None:
2390
+ DistributedUtils.setup_ddp()
2391
+ if not self._distributed_is_main():
2392
+ self._distributed_worker_loop(objective_fn)
2393
+ return
2394
+
2395
+ total_trials = max(1, int(max_evals))
2396
+ progress_counter = {"count": 0}
2397
+
2398
+ def objective_wrapper(trial: optuna.trial.Trial) -> float:
2399
+ should_log = True
2400
+ if should_log:
2401
+ current_idx = progress_counter["count"] + 1
2402
+ print(
2403
+ f"[Optuna][{self.label}] Trial {current_idx}/{total_trials} started "
2404
+ f"(trial_id={trial.number})."
2405
+ )
2406
+ try:
2407
+ result = objective_fn(trial)
2408
+ except RuntimeError as exc:
2409
+ if "out of memory" in str(exc).lower():
2410
+ print(
2411
+ f"[Optuna][{self.label}] OOM detected. Pruning trial and clearing CUDA cache."
2412
+ )
2413
+ self._clean_gpu()
2414
+ raise optuna.TrialPruned() from exc
2415
+ raise
2416
+ finally:
2417
+ self._clean_gpu()
2418
+ if should_log:
2419
+ progress_counter["count"] = progress_counter["count"] + 1
2420
+ trial_state = getattr(trial, "state", None)
2421
+ state_repr = getattr(trial_state, "name", "OK")
2422
+ print(
2423
+ f"[Optuna][{self.label}] Trial {progress_counter['count']}/{total_trials} finished "
2424
+ f"(status={state_repr})."
2425
+ )
2426
+ dist.barrier()
2427
+ return result
2428
+
2429
+ study = optuna.create_study(
2430
+ direction='minimize',
2431
+ sampler=optuna.samplers.TPESampler(seed=self.ctx.rand_seed)
2432
+ )
2433
+ self.study_name = getattr(study, "study_name", None)
2434
+ try:
2435
+ study.optimize(objective_wrapper, n_trials=max_evals)
2436
+ self.best_params = study.best_params
2437
+ self.best_trial = study.best_trial
2438
+ params_path = self.output.result_path(
2439
+ f'{self.ctx.model_nme}_bestparams_{self.label.lower()}.csv'
2440
+ )
2441
+ pd.DataFrame(self.best_params, index=[0]).to_csv(params_path)
2442
+ finally:
2443
+ self._distributed_send_command(
2444
+ {"type": "STOP", "best_params": self.best_params})
2445
+
2446
+ def _clean_gpu(self):
2447
+ gc.collect()
2448
+ if torch.cuda.is_available():
2449
+ device = None
2450
+ try:
2451
+ device = getattr(self, "device", None)
2452
+ except Exception:
2453
+ device = None
2454
+ if isinstance(device, torch.device):
2455
+ try:
2456
+ torch.cuda.set_device(device)
2457
+ except Exception:
2458
+ pass
2459
+ torch.cuda.empty_cache()
2460
+ torch.cuda.ipc_collect()
2461
+ torch.cuda.synchronize()
2462
+
2463
+ def _standardize_fold(self,
2464
+ X_train: pd.DataFrame,
2465
+ X_val: pd.DataFrame,
2466
+ columns: Optional[List[str]] = None
2467
+ ) -> Tuple[pd.DataFrame, pd.DataFrame, StandardScaler]:
2468
+ """在训练折上拟合 StandardScaler,并同步变换训练/验证特征。
2469
+
2470
+ 参数:
2471
+ X_train: 训练集特征。
2472
+ X_val: 验证集特征。
2473
+ columns: 需要缩放的列名;默认对全部列进行缩放。
2474
+
2475
+ 返回:
2476
+ 缩放后的训练/验证特征,以及已拟合好的 scaler 对象。
2477
+ """
2478
+ scaler = StandardScaler()
2479
+ cols = list(columns) if columns else list(X_train.columns)
2480
+ X_train_scaled = X_train.copy(deep=True)
2481
+ X_val_scaled = X_val.copy(deep=True)
2482
+ if cols:
2483
+ scaler.fit(X_train_scaled[cols])
2484
+ X_train_scaled[cols] = scaler.transform(X_train_scaled[cols])
2485
+ X_val_scaled[cols] = scaler.transform(X_val_scaled[cols])
2486
+ return X_train_scaled, X_val_scaled, scaler
2487
+
2488
+ def cross_val_generic(
2489
+ self,
2490
+ trial: optuna.trial.Trial,
2491
+ hyperparameter_space: Dict[str, Callable[[optuna.trial.Trial], Any]],
2492
+ data_provider: Callable[[], Tuple[pd.DataFrame, pd.Series, Optional[pd.Series]]],
2493
+ model_builder: Callable[[Dict[str, Any]], Any],
2494
+ metric_fn: Callable[[pd.Series, np.ndarray, Optional[pd.Series]], float],
2495
+ sample_limit: Optional[int] = None,
2496
+ preprocess_fn: Optional[Callable[[
2497
+ pd.DataFrame, pd.DataFrame], Tuple[pd.DataFrame, pd.DataFrame]]] = None,
2498
+ fit_predict_fn: Optional[
2499
+ Callable[[Any, pd.DataFrame, pd.Series, Optional[pd.Series],
2500
+ pd.DataFrame, pd.Series, Optional[pd.Series],
2501
+ optuna.trial.Trial], np.ndarray]
2502
+ ] = None,
2503
+ cleanup_fn: Optional[Callable[[Any], None]] = None,
2504
+ splitter: Optional[Iterable[Tuple[np.ndarray, np.ndarray]]] = None) -> float:
2505
+ """通用的留出/交叉验证辅助函数,用于复用调参流程。
2506
+
2507
+ 参数:
2508
+ trial: 当前 Optuna trial。
2509
+ hyperparameter_space: 参数采样器字典,键为参数名。
2510
+ data_provider: 返回 (X, y, sample_weight) 的回调。
2511
+ model_builder: 每个折构造新模型的回调。
2512
+ metric_fn: 计算损失或得分的函数,入参为 y_true、y_pred、weight。
2513
+ sample_limit: 可选样本上限,超出时随机抽样。
2514
+ preprocess_fn: 可选的折内预处理函数,输入 (X_train, X_val)。
2515
+ fit_predict_fn: 可选自定义训练与预测逻辑,返回验证集预测。
2516
+ cleanup_fn: 可选清理函数,每个折结束后调用。
2517
+ splitter: 可选的 (train_idx, val_idx) 迭代器;默认使用单个 ShuffleSplit。
2518
+
2519
+ 返回:
2520
+ 各折验证指标的平均值。
2521
+ """
2522
+ params: Optional[Dict[str, Any]] = None
2523
+ if self._distributed_forced_params is not None:
2524
+ params = self._distributed_forced_params
2525
+ self._distributed_forced_params = None
2526
+ else:
2527
+ if trial is None:
2528
+ raise RuntimeError(
2529
+ "Missing Optuna trial for parameter sampling.")
2530
+ params = {name: sampler(trial)
2531
+ for name, sampler in hyperparameter_space.items()}
2532
+ if self._should_use_distributed_optuna():
2533
+ self._distributed_prepare_trial(params)
2534
+ X_all, y_all, w_all = data_provider()
2535
+ if sample_limit is not None and len(X_all) > sample_limit:
2536
+ sampled_idx = X_all.sample(
2537
+ n=sample_limit,
2538
+ random_state=self.ctx.rand_seed
2539
+ ).index
2540
+ X_all = X_all.loc[sampled_idx]
2541
+ y_all = y_all.loc[sampled_idx]
2542
+ w_all = w_all.loc[sampled_idx] if w_all is not None else None
2543
+
2544
+ splits = splitter or [next(ShuffleSplit(
2545
+ n_splits=int(1 / self.ctx.prop_test),
2546
+ test_size=self.ctx.prop_test,
2547
+ random_state=self.ctx.rand_seed
2548
+ ).split(X_all))]
2549
+
2550
+ losses: List[float] = []
2551
+ for train_idx, val_idx in splits:
2552
+ X_train = X_all.iloc[train_idx]
2553
+ y_train = y_all.iloc[train_idx]
2554
+ X_val = X_all.iloc[val_idx]
2555
+ y_val = y_all.iloc[val_idx]
2556
+ w_train = w_all.iloc[train_idx] if w_all is not None else None
2557
+ w_val = w_all.iloc[val_idx] if w_all is not None else None
2558
+
2559
+ if preprocess_fn:
2560
+ X_train, X_val = preprocess_fn(X_train, X_val)
2561
+
2562
+ model = model_builder(params)
2563
+ try:
2564
+ if fit_predict_fn:
2565
+ y_pred = fit_predict_fn(
2566
+ model, X_train, y_train, w_train,
2567
+ X_val, y_val, w_val, trial
2568
+ )
2569
+ else:
2570
+ fit_kwargs = {}
2571
+ if w_train is not None:
2572
+ fit_kwargs["sample_weight"] = w_train
2573
+ model.fit(X_train, y_train, **fit_kwargs)
2574
+ y_pred = model.predict(X_val)
2575
+ losses.append(metric_fn(y_val, y_pred, w_val))
2576
+ finally:
2577
+ if cleanup_fn:
2578
+ cleanup_fn(model)
2579
+ self._clean_gpu()
2580
+
2581
+ return float(np.mean(losses))
2582
+
2583
+ # 预测 + 缓存逻辑
2584
+ def _predict_and_cache(self,
2585
+ model,
2586
+ pred_prefix: str,
2587
+ use_oht: bool = False,
2588
+ design_fn=None,
2589
+ predict_kwargs_train: Optional[Dict[str, Any]] = None,
2590
+ predict_kwargs_test: Optional[Dict[str, Any]] = None) -> None:
2591
+ if design_fn:
2592
+ X_train = design_fn(train=True)
2593
+ X_test = design_fn(train=False)
2594
+ elif use_oht:
2595
+ X_train = self.ctx.train_oht_scl_data[self.ctx.var_nmes]
2596
+ X_test = self.ctx.test_oht_scl_data[self.ctx.var_nmes]
2597
+ else:
2598
+ X_train = self.ctx.train_data[self.ctx.factor_nmes]
2599
+ X_test = self.ctx.test_data[self.ctx.factor_nmes]
2600
+
2601
+ preds_train = model.predict(X_train, **(predict_kwargs_train or {}))
2602
+ preds_test = model.predict(X_test, **(predict_kwargs_test or {}))
2603
+
2604
+ self.ctx.train_data[f'pred_{pred_prefix}'] = preds_train
2605
+ self.ctx.test_data[f'pred_{pred_prefix}'] = preds_test
2606
+ self.ctx.train_data[f'w_pred_{pred_prefix}'] = (
2607
+ self.ctx.train_data[f'pred_{pred_prefix}'] *
2608
+ self.ctx.train_data[self.ctx.weight_nme]
2609
+ )
2610
+ self.ctx.test_data[f'w_pred_{pred_prefix}'] = (
2611
+ self.ctx.test_data[f'pred_{pred_prefix}'] *
2612
+ self.ctx.test_data[self.ctx.weight_nme]
2613
+ )
2614
+
2615
+ def _fit_predict_cache(self,
2616
+ model,
2617
+ X_train,
2618
+ y_train,
2619
+ sample_weight,
2620
+ pred_prefix: str,
2621
+ use_oht: bool = False,
2622
+ design_fn=None,
2623
+ fit_kwargs: Optional[Dict[str, Any]] = None,
2624
+ sample_weight_arg: Optional[str] = 'sample_weight',
2625
+ predict_kwargs_train: Optional[Dict[str, Any]] = None,
2626
+ predict_kwargs_test: Optional[Dict[str, Any]] = None) -> None:
2627
+ fit_kwargs = fit_kwargs.copy() if fit_kwargs else {}
2628
+ if sample_weight is not None and sample_weight_arg:
2629
+ fit_kwargs.setdefault(sample_weight_arg, sample_weight)
2630
+ model.fit(X_train, y_train, **fit_kwargs)
2631
+ self.ctx.model_label.append(self.label)
2632
+ self._predict_and_cache(
2633
+ model,
2634
+ pred_prefix,
2635
+ use_oht=use_oht,
2636
+ design_fn=design_fn,
2637
+ predict_kwargs_train=predict_kwargs_train,
2638
+ predict_kwargs_test=predict_kwargs_test)
2639
+
2640
+
2641
+ class XGBTrainer(TrainerBase):
2642
+ def __init__(self, context: "BayesOptModel") -> None:
2643
+ super().__init__(context, 'Xgboost', 'Xgboost')
2644
+ self.model: Optional[xgb.XGBRegressor] = None
2645
+
2646
+ def _build_estimator(self) -> xgb.XGBRegressor:
2647
+ params = dict(
2648
+ objective=self.ctx.obj,
2649
+ random_state=self.ctx.rand_seed,
2650
+ subsample=0.9,
2651
+ tree_method='gpu_hist' if self.ctx.use_gpu else 'hist',
2652
+ enable_categorical=True,
2653
+ predictor='gpu_predictor' if self.ctx.use_gpu else 'cpu_predictor'
2654
+ )
2655
+ if self.ctx.use_gpu:
2656
+ params['gpu_id'] = 0
2657
+ print(f">>> XGBoost using GPU ID: 0 (Single GPU Mode)")
2658
+ return xgb.XGBRegressor(**params)
2659
+
2660
+ def cross_val(self, trial: optuna.trial.Trial) -> float:
2661
+ learning_rate = trial.suggest_float(
2662
+ 'learning_rate', 1e-5, 1e-1, log=True)
2663
+ gamma = trial.suggest_float('gamma', 0, 10000)
2664
+ max_depth = trial.suggest_int('max_depth', 3, 25)
2665
+ n_estimators = trial.suggest_int('n_estimators', 10, 500, step=10)
2666
+ min_child_weight = trial.suggest_int(
2667
+ 'min_child_weight', 100, 10000, step=100)
2668
+ reg_alpha = trial.suggest_float('reg_alpha', 1e-10, 1, log=True)
2669
+ reg_lambda = trial.suggest_float('reg_lambda', 1e-10, 1, log=True)
2670
+ if self.ctx.obj == 'reg:tweedie':
2671
+ tweedie_variance_power = trial.suggest_float(
2672
+ 'tweedie_variance_power', 1, 2)
2673
+ elif self.ctx.obj == 'count:poisson':
2674
+ tweedie_variance_power = 1
2675
+ elif self.ctx.obj == 'reg:gamma':
2676
+ tweedie_variance_power = 2
2677
+ else:
2678
+ tweedie_variance_power = 1.5
2679
+ clf = self._build_estimator()
2680
+ params = {
2681
+ 'learning_rate': learning_rate,
2682
+ 'gamma': gamma,
2683
+ 'max_depth': max_depth,
2684
+ 'n_estimators': n_estimators,
2685
+ 'min_child_weight': min_child_weight,
2686
+ 'reg_alpha': reg_alpha,
2687
+ 'reg_lambda': reg_lambda
2688
+ }
2689
+ if self.ctx.obj == 'reg:tweedie':
2690
+ params['tweedie_variance_power'] = tweedie_variance_power
2691
+ clf.set_params(**params)
2692
+ n_jobs = 1 if self.ctx.use_gpu else int(1 / self.ctx.prop_test)
2693
+ acc = cross_val_score(
2694
+ clf,
2695
+ self.ctx.train_data[self.ctx.factor_nmes],
2696
+ self.ctx.train_data[self.ctx.resp_nme].values,
2697
+ fit_params=self.ctx.fit_params,
2698
+ cv=self.ctx.cv,
2699
+ scoring=make_scorer(
2700
+ mean_tweedie_deviance,
2701
+ power=tweedie_variance_power,
2702
+ greater_is_better=False),
2703
+ error_score='raise',
2704
+ n_jobs=n_jobs
2705
+ ).mean()
2706
+ return -acc
2707
+
2708
+ def train(self) -> None:
2709
+ if not self.best_params:
2710
+ raise RuntimeError('请先运行 tune() 以获得 XGB 最优参数。')
2711
+ self.model = self._build_estimator()
2712
+ self.model.set_params(**self.best_params)
2713
+ self._fit_predict_cache(
2714
+ self.model,
2715
+ self.ctx.train_data[self.ctx.factor_nmes],
2716
+ self.ctx.train_data[self.ctx.resp_nme].values,
2717
+ sample_weight=None,
2718
+ pred_prefix='xgb',
2719
+ fit_kwargs=self.ctx.fit_params,
2720
+ sample_weight_arg=None # 样本权重已通过 fit_kwargs 传入
2721
+ )
2722
+ self.ctx.xgb_best = self.model
2723
+
2724
+
2725
+ class GLMTrainer(TrainerBase):
2726
+ def __init__(self, context: "BayesOptModel") -> None:
2727
+ super().__init__(context, 'GLM', 'GLM')
2728
+ self.model = None
2729
+
2730
+ def _select_family(self, tweedie_power: Optional[float] = None):
2731
+ if self.ctx.task_type == 'classification':
2732
+ return sm.families.Binomial()
2733
+ if self.ctx.obj == 'count:poisson':
2734
+ return sm.families.Poisson()
2735
+ if self.ctx.obj == 'reg:gamma':
2736
+ return sm.families.Gamma()
2737
+ power = tweedie_power if tweedie_power is not None else 1.5
2738
+ return sm.families.Tweedie(var_power=power, link=sm.families.links.log())
2739
+
2740
+ def _prepare_design(self, data: pd.DataFrame) -> pd.DataFrame:
2741
+ # 为 statsmodels 设计矩阵添加截距项
2742
+ X = data[self.ctx.var_nmes]
2743
+ return sm.add_constant(X, has_constant='add')
2744
+
2745
+ def _metric_power(self, family, tweedie_power: Optional[float]) -> float:
2746
+ if isinstance(family, sm.families.Poisson):
2747
+ return 1.0
2748
+ if isinstance(family, sm.families.Gamma):
2749
+ return 2.0
2750
+ if isinstance(family, sm.families.Tweedie):
2751
+ return tweedie_power if tweedie_power is not None else getattr(family, 'var_power', 1.5)
2752
+ return 1.5
2753
+
2754
+ def cross_val(self, trial: optuna.trial.Trial) -> float:
2755
+ param_space = {
2756
+ "alpha": lambda t: t.suggest_float('alpha', 1e-6, 1e2, log=True),
2757
+ "l1_ratio": lambda t: t.suggest_float('l1_ratio', 0.0, 1.0)
2758
+ }
2759
+ if self.ctx.task_type == 'regression' and self.ctx.obj == 'reg:tweedie':
2760
+ param_space["tweedie_power"] = lambda t: t.suggest_float(
2761
+ 'tweedie_power', 1.0, 2.0)
2762
+
2763
+ def data_provider():
2764
+ data = self.ctx.train_oht_data if self.ctx.train_oht_data is not None else self.ctx.train_oht_scl_data
2765
+ assert data is not None, "Preprocessed training data is missing."
2766
+ return data[self.ctx.var_nmes], data[self.ctx.resp_nme], data[self.ctx.weight_nme]
2767
+
2768
+ def preprocess_fn(X_train, X_val):
2769
+ X_train_s, X_val_s, _ = self._standardize_fold(
2770
+ X_train, X_val, self.ctx.num_features)
2771
+ return self._prepare_design(X_train_s), self._prepare_design(X_val_s)
2772
+
2773
+ metric_ctx: Dict[str, Any] = {}
2774
+
2775
+ def model_builder(params):
2776
+ family = self._select_family(params.get("tweedie_power"))
2777
+ metric_ctx["family"] = family
2778
+ metric_ctx["tweedie_power"] = params.get("tweedie_power")
2779
+ return {
2780
+ "family": family,
2781
+ "alpha": params["alpha"],
2782
+ "l1_ratio": params["l1_ratio"],
2783
+ "tweedie_power": params.get("tweedie_power")
2784
+ }
2785
+
2786
+ def fit_predict(model_cfg, X_train, y_train, w_train, X_val, y_val, w_val, _trial):
2787
+ glm = sm.GLM(y_train, X_train,
2788
+ family=model_cfg["family"],
2789
+ freq_weights=w_train)
2790
+ result = glm.fit_regularized(
2791
+ alpha=model_cfg["alpha"],
2792
+ L1_wt=model_cfg["l1_ratio"],
2793
+ maxiter=200
2794
+ )
2795
+ return result.predict(X_val)
2796
+
2797
+ def metric_fn(y_true, y_pred, weight):
2798
+ if self.ctx.task_type == 'classification':
2799
+ y_pred_clipped = np.clip(y_pred, EPS, 1 - EPS)
2800
+ return log_loss(y_true, y_pred_clipped, sample_weight=weight)
2801
+ y_pred_safe = np.maximum(y_pred, EPS)
2802
+ return mean_tweedie_deviance(
2803
+ y_true,
2804
+ y_pred_safe,
2805
+ sample_weight=weight,
2806
+ power=self._metric_power(
2807
+ metric_ctx.get("family"), metric_ctx.get("tweedie_power"))
2808
+ )
2809
+
2810
+ return self.cross_val_generic(
2811
+ trial=trial,
2812
+ hyperparameter_space=param_space,
2813
+ data_provider=data_provider,
2814
+ model_builder=model_builder,
2815
+ metric_fn=metric_fn,
2816
+ preprocess_fn=preprocess_fn,
2817
+ fit_predict_fn=fit_predict,
2818
+ splitter=self.ctx.cv.split(self.ctx.train_oht_data[self.ctx.var_nmes]
2819
+ if self.ctx.train_oht_data is not None else self.ctx.train_oht_scl_data[self.ctx.var_nmes])
2820
+ )
2821
+
2822
+ def train(self) -> None:
2823
+ if not self.best_params:
2824
+ raise RuntimeError('请先运行 tune() 以获得 GLM 最优参数。')
2825
+ tweedie_power = self.best_params.get('tweedie_power')
2826
+ family = self._select_family(tweedie_power)
2827
+
2828
+ X_train = self._prepare_design(self.ctx.train_oht_scl_data)
2829
+ y_train = self.ctx.train_oht_scl_data[self.ctx.resp_nme]
2830
+ w_train = self.ctx.train_oht_scl_data[self.ctx.weight_nme]
2831
+
2832
+ glm = sm.GLM(y_train, X_train, family=family,
2833
+ freq_weights=w_train)
2834
+ self.model = glm.fit_regularized(
2835
+ alpha=self.best_params['alpha'],
2836
+ L1_wt=self.best_params['l1_ratio'],
2837
+ maxiter=300
2838
+ )
2839
+
2840
+ self.ctx.glm_best = self.model
2841
+ self.ctx.model_label += [self.label]
2842
+ self._predict_and_cache(
2843
+ self.model,
2844
+ 'glm',
2845
+ design_fn=lambda train: self._prepare_design(
2846
+ self.ctx.train_oht_scl_data if train else self.ctx.test_oht_scl_data
2847
+ )
2848
+ )
2849
+
2850
+
2851
+ class ResNetTrainer(TrainerBase):
2852
+ def __init__(self, context: "BayesOptModel") -> None:
2853
+ if context.task_type == 'classification':
2854
+ super().__init__(context, 'ResNetClassifier', 'ResNet')
2855
+ else:
2856
+ super().__init__(context, 'ResNet', 'ResNet')
2857
+ self.model: Optional[ResNetSklearn] = None
2858
+ self.enable_distributed_optuna = bool(context.config.use_resn_ddp)
2859
+
2860
+ # ========= 交叉验证(BayesOpt 用) =========
2861
+ def cross_val(self, trial: optuna.trial.Trial) -> float:
2862
+ # 针对 ResNet 的交叉验证流程,重点控制显存:
2863
+ # - 每个 fold 单独创建 ResNetSklearn,结束立刻释放资源;
2864
+ # - fold 完成后迁移模型到 CPU,删除对象并调用 gc/empty_cache;
2865
+ # - 可选:BayesOpt 期间只抽样部分训练集以减少显存压力。
2866
+
2867
+ base_tw_power = None
2868
+ if self.ctx.task_type == 'regression':
2869
+ if self.ctx.obj == 'count:poisson':
2870
+ base_tw_power = 1.0
2871
+ elif self.ctx.obj == 'reg:gamma':
2872
+ base_tw_power = 2.0
2873
+ else:
2874
+ base_tw_power = 1.5
2875
+
2876
+ def data_provider():
2877
+ data = self.ctx.train_oht_data if self.ctx.train_oht_data is not None else self.ctx.train_oht_scl_data
2878
+ assert data is not None, "Preprocessed training data is missing."
2879
+ return data[self.ctx.var_nmes], data[self.ctx.resp_nme], data[self.ctx.weight_nme]
2880
+
2881
+ metric_ctx: Dict[str, Any] = {}
2882
+
2883
+ def model_builder(params):
2884
+ power = params.get("tw_power", base_tw_power)
2885
+ metric_ctx["tw_power"] = power
2886
+ return ResNetSklearn(
2887
+ model_nme=self.ctx.model_nme,
2888
+ input_dim=len(self.ctx.var_nmes),
2889
+ hidden_dim=params["hidden_dim"],
2890
+ block_num=params["block_num"],
2891
+ task_type=self.ctx.task_type,
2892
+ epochs=self.ctx.epochs,
2893
+ tweedie_power=power,
2894
+ learning_rate=params["learning_rate"],
2895
+ patience=5,
2896
+ use_layernorm=True,
2897
+ dropout=0.1,
2898
+ residual_scale=0.1,
2899
+ use_data_parallel=self.ctx.config.use_resn_data_parallel,
2900
+ use_ddp=self.ctx.config.use_resn_ddp
2901
+ )
2902
+
2903
+ def preprocess_fn(X_train, X_val):
2904
+ X_train_s, X_val_s, _ = self._standardize_fold(
2905
+ X_train, X_val, self.ctx.num_features)
2906
+ return X_train_s, X_val_s
2907
+
2908
+ def fit_predict(model, X_train, y_train, w_train, X_val, y_val, w_val, trial_obj):
2909
+ model.fit(
2910
+ X_train, y_train, w_train,
2911
+ X_val, y_val, w_val,
2912
+ trial=trial_obj
2913
+ )
2914
+ return model.predict(X_val)
2915
+
2916
+ def metric_fn(y_true, y_pred, weight):
2917
+ if self.ctx.task_type == 'regression':
2918
+ return mean_tweedie_deviance(
2919
+ y_true,
2920
+ y_pred,
2921
+ sample_weight=weight,
2922
+ power=metric_ctx.get("tw_power", base_tw_power)
2923
+ )
2924
+ return log_loss(y_true, y_pred, sample_weight=weight)
2925
+
2926
+ sample_cap = data_provider()[0]
2927
+ max_rows_for_resnet_bo = min(100000, int(len(sample_cap)/5))
2928
+
2929
+ return self.cross_val_generic(
2930
+ trial=trial,
2931
+ hyperparameter_space={
2932
+ "learning_rate": lambda t: t.suggest_float('learning_rate', 1e-6, 1e-2, log=True),
2933
+ "hidden_dim": lambda t: t.suggest_int('hidden_dim', 8, 32, step=2),
2934
+ "block_num": lambda t: t.suggest_int('block_num', 2, 10),
2935
+ **({"tw_power": lambda t: t.suggest_float('tw_power', 1.0, 2.0)} if self.ctx.task_type == 'regression' and self.ctx.obj == 'reg:tweedie' else {})
2936
+ },
2937
+ data_provider=data_provider,
2938
+ model_builder=model_builder,
2939
+ metric_fn=metric_fn,
2940
+ sample_limit=max_rows_for_resnet_bo if len(
2941
+ sample_cap) > max_rows_for_resnet_bo > 0 else None,
2942
+ preprocess_fn=preprocess_fn,
2943
+ fit_predict_fn=fit_predict,
2944
+ cleanup_fn=lambda m: getattr(
2945
+ getattr(m, "resnet", None), "to", lambda *_args, **_kwargs: None)("cpu")
2946
+ )
2947
+
2948
+ # ========= 用最优超参训练最终 ResNet =========
2949
+ def train(self) -> None:
2950
+ if not self.best_params:
2951
+ raise RuntimeError('请先运行 tune() 以获得 ResNet 最优参数。')
2952
+
2953
+ self.model = ResNetSklearn(
2954
+ model_nme=self.ctx.model_nme,
2955
+ input_dim=self.ctx.train_oht_scl_data[self.ctx.var_nmes].shape[1],
2956
+ task_type=self.ctx.task_type,
2957
+ use_data_parallel=self.ctx.config.use_resn_data_parallel,
2958
+ use_ddp=self.ctx.config.use_resn_ddp
2959
+ )
2960
+ self.model.set_params(self.best_params)
2961
+ loss_plot_path = self.output.plot_path(
2962
+ f'loss_{self.ctx.model_nme}_{self.model_name_prefix}.png')
2963
+ self.model.loss_curve_path = loss_plot_path
2964
+
2965
+ self._fit_predict_cache(
2966
+ self.model,
2967
+ self.ctx.train_oht_scl_data[self.ctx.var_nmes],
2968
+ self.ctx.train_oht_scl_data[self.ctx.resp_nme],
2969
+ sample_weight=self.ctx.train_oht_scl_data[self.ctx.weight_nme],
2970
+ pred_prefix='resn',
2971
+ use_oht=True,
2972
+ sample_weight_arg='w_train'
2973
+ )
2974
+
2975
+ # 方便外部调用
2976
+ self.ctx.resn_best = self.model
2977
+
2978
+ # ========= 保存 / 加载 =========
2979
+ # ResNet 以 state_dict 形式保存,需要专用的加载流程,因此保留自定义加载方法
2980
+ # 保存逻辑已在 TrainerBase 中实现(会自动检查 .resnet 属性)
2981
+
2982
+ def load(self) -> None:
2983
+ # 将磁盘中的 ResNet 权重加载到当前设备,保持与上下文一致。
2984
+ path = self.output.model_path(self._get_model_filename())
2985
+ if os.path.exists(path):
2986
+ resn_loaded = ResNetSklearn(
2987
+ model_nme=self.ctx.model_nme,
2988
+ input_dim=self.ctx.train_oht_scl_data[self.ctx.var_nmes].shape[1],
2989
+ task_type=self.ctx.task_type,
2990
+ use_data_parallel=self.ctx.config.use_resn_data_parallel,
2991
+ use_ddp=self.ctx.config.use_resn_ddp
2992
+ )
2993
+ state_dict = torch.load(path, map_location='cpu')
2994
+ resn_loaded.resnet.load_state_dict(state_dict)
2995
+
2996
+ self._move_to_device(resn_loaded)
2997
+ self.model = resn_loaded
2998
+ self.ctx.resn_best = self.model
2999
+ else:
3000
+ print(f"[ResNetTrainer.load] 未找到模型文件:{path}")
3001
+
3002
+
3003
+ class FTTrainer(TrainerBase):
3004
+ def __init__(self, context: "BayesOptModel") -> None:
3005
+ if context.task_type == 'classification':
3006
+ super().__init__(context, 'FTTransformerClassifier', 'FTTransformer')
3007
+ else:
3008
+ super().__init__(context, 'FTTransformer', 'FTTransformer')
3009
+ self.model: Optional[FTTransformerSklearn] = None
3010
+ self.enable_distributed_optuna = bool(context.config.use_ft_ddp)
3011
+
3012
+ def cross_val(self, trial: optuna.trial.Trial) -> float:
3013
+ # 针对 FT-Transformer 的交叉验证,重点同样在显存控制:
3014
+ # - 收缩超参搜索空间,防止不必要的超大模型;
3015
+ # - 每个 fold 结束后立即释放 GPU 显存,确保下一个 trial 顺利进行。
3016
+ # 超参空间适当缩小一点,避免特别大的模型
3017
+ param_space: Dict[str, Callable[[optuna.trial.Trial], Any]] = {
3018
+ "learning_rate": lambda t: t.suggest_float('learning_rate', 1e-5, 5e-4, log=True),
3019
+ # "d_model": lambda t: t.suggest_int('d_model', 8, 64, step=8),
3020
+ "d_model": lambda t: t.suggest_int('d_model', 16, 128, step=16),
3021
+ # "n_heads": lambda t: t.suggest_categorical('n_heads', [2, 4, 8]),
3022
+ "n_layers": lambda t: t.suggest_int('n_layers', 2, 8),
3023
+ "dropout": lambda t: t.suggest_float('dropout', 0.0, 0.2)
3024
+ }
3025
+ if self.ctx.task_type == 'regression' and self.ctx.obj == 'reg:tweedie':
3026
+ param_space["tw_power"] = lambda t: t.suggest_float(
3027
+ 'tw_power', 1.0, 2.0)
3028
+ geo_enabled = bool(
3029
+ self.ctx.geo_token_cols or self.ctx.config.geo_feature_nmes)
3030
+ if geo_enabled:
3031
+ # 仅在启用地理 token 时调节 GNN 相关超参
3032
+ param_space.update({
3033
+ "geo_token_hidden_dim": lambda t: t.suggest_int('geo_token_hidden_dim', 16, 128, step=16),
3034
+ "geo_token_layers": lambda t: t.suggest_int('geo_token_layers', 1, 4),
3035
+ "geo_token_k_neighbors": lambda t: t.suggest_int('geo_token_k_neighbors', 5, 20),
3036
+ "geo_token_dropout": lambda t: t.suggest_float('geo_token_dropout', 0.0, 0.3),
3037
+ "geo_token_learning_rate": lambda t: t.suggest_float('geo_token_learning_rate', 1e-4, 5e-3, log=True),
3038
+ })
3039
+
3040
+ metric_ctx: Dict[str, Any] = {}
3041
+ geo_tokens_cache = self.ctx.train_geo_tokens
3042
+
3043
+ def data_provider():
3044
+ data = self.ctx.train_data
3045
+ return data[self.ctx.factor_nmes], data[self.ctx.resp_nme], data[self.ctx.weight_nme]
3046
+
3047
+ def model_builder(params):
3048
+ d_model = params["d_model"]
3049
+ n_layers = params["n_layers"]
3050
+ token_count = len(self.ctx.factor_nmes) + \
3051
+ (1 if self.ctx.geo_token_cols else 0)
3052
+ approx_units = d_model * n_layers * max(1, token_count)
3053
+ if approx_units > 12_000_000:
3054
+ print(
3055
+ f"[FTTrainer] Trial pruned early: d_model={d_model}, n_layers={n_layers} -> approx_units={approx_units}")
3056
+ raise optuna.TrialPruned(
3057
+ "config exceeds safe memory budget; prune before training")
3058
+ geo_params_local = {k: v for k, v in params.items()
3059
+ if k.startswith("geo_token_")}
3060
+
3061
+ tw_power = params.get("tw_power")
3062
+ if self.ctx.task_type == 'regression':
3063
+ if self.ctx.obj == 'count:poisson':
3064
+ tw_power = 1.0
3065
+ elif self.ctx.obj == 'reg:gamma':
3066
+ tw_power = 2.0
3067
+ elif tw_power is None:
3068
+ tw_power = 1.5
3069
+ metric_ctx["tw_power"] = tw_power
3070
+
3071
+ return FTTransformerSklearn(
3072
+ model_nme=self.ctx.model_nme,
3073
+ num_cols=self.ctx.num_features,
3074
+ cat_cols=self.ctx.cate_list,
3075
+ d_model=d_model,
3076
+ # n_heads=params["n_heads"],
3077
+ n_heads=max(2, d_model // 16),
3078
+ n_layers=n_layers,
3079
+ dropout=params["dropout"],
3080
+ task_type=self.ctx.task_type,
3081
+ epochs=self.ctx.epochs,
3082
+ tweedie_power=tw_power,
3083
+ learning_rate=params["learning_rate"],
3084
+ patience=5,
3085
+ use_data_parallel=self.ctx.config.use_ft_data_parallel,
3086
+ use_ddp=self.ctx.config.use_ft_ddp
3087
+ ).set_params({"_geo_params": geo_params_local} if geo_enabled else {})
3088
+
3089
+ def fit_predict(model, X_train, y_train, w_train, X_val, y_val, w_val, trial_obj):
3090
+ geo_train = geo_val = None
3091
+ if geo_enabled:
3092
+ geo_params = getattr(model, "_geo_params", {})
3093
+ built = self.ctx._build_geo_tokens(geo_params)
3094
+ if built is not None:
3095
+ geo_train_full, _, _, _ = built
3096
+ geo_train = geo_train_full.loc[X_train.index]
3097
+ geo_val = geo_train_full.loc[X_val.index]
3098
+ elif geo_tokens_cache is not None:
3099
+ # fallback to cached tokens if available
3100
+ geo_train = geo_tokens_cache.loc[X_train.index]
3101
+ geo_val = geo_tokens_cache.loc[X_val.index]
3102
+ model.fit(
3103
+ X_train, y_train, w_train,
3104
+ X_val, y_val, w_val,
3105
+ trial=trial_obj,
3106
+ geo_train=geo_train,
3107
+ geo_val=geo_val
3108
+ )
3109
+ return model.predict(X_val, geo_tokens=geo_val)
3110
+
3111
+ def metric_fn(y_true, y_pred, weight):
3112
+ if self.ctx.task_type == 'regression':
3113
+ return mean_tweedie_deviance(
3114
+ y_true,
3115
+ y_pred,
3116
+ sample_weight=weight,
3117
+ power=metric_ctx.get("tw_power", 1.5)
3118
+ )
3119
+ return log_loss(y_true, y_pred, sample_weight=weight)
3120
+
3121
+ data_for_cap = data_provider()[0]
3122
+ max_rows_for_ft_bo = min(1000000, int(len(data_for_cap)/2))
3123
+
3124
+ return self.cross_val_generic(
3125
+ trial=trial,
3126
+ hyperparameter_space=param_space,
3127
+ data_provider=data_provider,
3128
+ model_builder=model_builder,
3129
+ metric_fn=metric_fn,
3130
+ sample_limit=max_rows_for_ft_bo if len(
3131
+ data_for_cap) > max_rows_for_ft_bo > 0 else None,
3132
+ fit_predict_fn=fit_predict,
3133
+ cleanup_fn=lambda m: getattr(
3134
+ getattr(m, "ft", None), "to", lambda *_args, **_kwargs: None)("cpu")
3135
+ )
3136
+
3137
+ def train(self) -> None:
3138
+ if not self.best_params:
3139
+ raise RuntimeError('请先运行 tune() 以获得 FT-Transformer 最优参数。')
3140
+ self.model = FTTransformerSklearn(
3141
+ model_nme=self.ctx.model_nme,
3142
+ num_cols=self.ctx.num_features,
3143
+ cat_cols=self.ctx.cate_list,
3144
+ task_type=self.ctx.task_type,
3145
+ use_data_parallel=self.ctx.config.use_ft_data_parallel,
3146
+ use_ddp=self.ctx.config.use_ft_ddp
3147
+ )
3148
+ self.model.set_params(self.best_params)
3149
+ loss_plot_path = self.output.plot_path(
3150
+ f'loss_{self.ctx.model_nme}_{self.model_name_prefix}.png')
3151
+ self.model.loss_curve_path = loss_plot_path
3152
+ geo_train = self.ctx.train_geo_tokens
3153
+ geo_test = self.ctx.test_geo_tokens
3154
+ fit_kwargs = {}
3155
+ predict_kwargs_train = None
3156
+ predict_kwargs_test = None
3157
+ if geo_train is not None and geo_test is not None:
3158
+ fit_kwargs["geo_train"] = geo_train
3159
+ predict_kwargs_train = {"geo_tokens": geo_train}
3160
+ predict_kwargs_test = {"geo_tokens": geo_test}
3161
+ self._fit_predict_cache(
3162
+ self.model,
3163
+ self.ctx.train_data[self.ctx.factor_nmes],
3164
+ self.ctx.train_data[self.ctx.resp_nme],
3165
+ sample_weight=self.ctx.train_data[self.ctx.weight_nme],
3166
+ pred_prefix='ft',
3167
+ sample_weight_arg='w_train',
3168
+ fit_kwargs=fit_kwargs,
3169
+ predict_kwargs_train=predict_kwargs_train,
3170
+ predict_kwargs_test=predict_kwargs_test
3171
+ )
3172
+ self.ctx.ft_best = self.model
3173
+
3174
+
3175
+ # =============================================================================
3176
+ # BayesOpt 调度与 SHAP 工具集
3177
+ # =============================================================================
3178
+ class BayesOptModel:
3179
+ def __init__(self, train_data, test_data,
3180
+ model_nme, resp_nme, weight_nme, factor_nmes, task_type='regression',
3181
+ binary_resp_nme=None,
3182
+ cate_list=None, prop_test=0.25, rand_seed=None,
3183
+ epochs=100, use_gpu=True,
3184
+ use_resn_data_parallel: bool = False, use_ft_data_parallel: bool = False,
3185
+ use_gnn_data_parallel: bool = False,
3186
+ use_resn_ddp: bool = False, use_ft_ddp: bool = False,
3187
+ use_gnn_ddp: bool = False,
3188
+ output_dir: Optional[str] = None,
3189
+ gnn_use_approx_knn: bool = True,
3190
+ gnn_approx_knn_threshold: int = 50000,
3191
+ gnn_graph_cache: Optional[str] = None,
3192
+ gnn_max_gpu_knn_nodes: Optional[int] = 200000,
3193
+ gnn_knn_gpu_mem_ratio: float = 0.9,
3194
+ gnn_knn_gpu_mem_overhead: float = 2.0):
3195
+ """封装各类训练器的 BayesOpt 调度入口。
3196
+
3197
+ 参数:
3198
+ train_data: 训练集 DataFrame。
3199
+ test_data: 测试集 DataFrame。
3200
+ model_nme: 模型名称前缀,用于输出文件。
3201
+ resp_nme: 目标列名。
3202
+ weight_nme: 样本权重列名。
3203
+ factor_nmes: 特征列名列表。
3204
+ task_type: 'regression' 或 'classification'。
3205
+ binary_resp_nme: 可选的二分类目标,用于成交率曲线。
3206
+ cate_list: 类别型特征列表。
3207
+ prop_test: 交叉验证中验证集的占比。
3208
+ rand_seed: 随机种子。
3209
+ epochs: 神经网络训练轮数。
3210
+ use_gpu: 是否优先使用 GPU。
3211
+ use_resn_data_parallel: ResNet 是否启用 DataParallel。
3212
+ use_ft_data_parallel: FTTransformer 是否启用 DataParallel。
3213
+ use_gnn_data_parallel: GNN 是否启用 DataParallel。
3214
+ use_resn_ddp: ResNet 是否启用 DDP。
3215
+ use_ft_ddp: FTTransformer 是否启用 DDP。
3216
+ use_gnn_ddp: GNN 是否启用 DDP。
3217
+ output_dir: 模型、结果与图表的输出根目录。
3218
+ gnn_use_approx_knn: 是否在可用时使用近似 kNN。
3219
+ gnn_approx_knn_threshold: 触发近似 kNN 的行数阈值。
3220
+ gnn_graph_cache: 可选的邻接矩阵缓存路径。
3221
+ gnn_max_gpu_knn_nodes: 超过该节点数则强制使用 CPU kNN 以避免 GPU OOM。
3222
+ gnn_knn_gpu_mem_ratio: GPU kNN 允许使用的可用显存比例。
3223
+ gnn_knn_gpu_mem_overhead: GPU kNN 估算的临时显存放大倍数。
3224
+ """
3225
+ cfg = BayesOptConfig(
3226
+ model_nme=model_nme,
3227
+ task_type=task_type,
3228
+ resp_nme=resp_nme,
3229
+ weight_nme=weight_nme,
3230
+ factor_nmes=list(factor_nmes),
3231
+ binary_resp_nme=binary_resp_nme,
3232
+ cate_list=list(cate_list) if cate_list else None,
3233
+ prop_test=prop_test,
3234
+ rand_seed=rand_seed,
3235
+ epochs=epochs,
3236
+ use_gpu=use_gpu,
3237
+ use_resn_data_parallel=use_resn_data_parallel,
3238
+ use_ft_data_parallel=use_ft_data_parallel,
3239
+ use_resn_ddp=use_resn_ddp,
3240
+ use_gnn_data_parallel=use_gnn_data_parallel,
3241
+ use_ft_ddp=use_ft_ddp,
3242
+ use_gnn_ddp=use_gnn_ddp,
3243
+ gnn_use_approx_knn=gnn_use_approx_knn,
3244
+ gnn_approx_knn_threshold=gnn_approx_knn_threshold,
3245
+ gnn_graph_cache=gnn_graph_cache,
3246
+ gnn_max_gpu_knn_nodes=gnn_max_gpu_knn_nodes,
3247
+ gnn_knn_gpu_mem_ratio=gnn_knn_gpu_mem_ratio,
3248
+ gnn_knn_gpu_mem_overhead=gnn_knn_gpu_mem_overhead,
3249
+ output_dir=output_dir
3250
+ )
3251
+ self.config = cfg
3252
+ self.model_nme = cfg.model_nme
3253
+ self.task_type = cfg.task_type
3254
+ self.resp_nme = cfg.resp_nme
3255
+ self.weight_nme = cfg.weight_nme
3256
+ self.factor_nmes = cfg.factor_nmes
3257
+ self.binary_resp_nme = cfg.binary_resp_nme
3258
+ self.cate_list = list(cfg.cate_list or [])
3259
+ self.prop_test = cfg.prop_test
3260
+ self.epochs = cfg.epochs
3261
+ self.rand_seed = cfg.rand_seed if cfg.rand_seed is not None else np.random.randint(
3262
+ 1, 10000)
3263
+ self.use_gpu = bool(cfg.use_gpu and torch.cuda.is_available())
3264
+ self.output_manager = OutputManager(
3265
+ cfg.output_dir or os.getcwd(), self.model_nme)
3266
+
3267
+ preprocessor = DatasetPreprocessor(train_data, test_data, cfg).run()
3268
+ self.train_data = preprocessor.train_data
3269
+ self.test_data = preprocessor.test_data
3270
+ self.train_oht_data = preprocessor.train_oht_data
3271
+ self.test_oht_data = preprocessor.test_oht_data
3272
+ self.train_oht_scl_data = preprocessor.train_oht_scl_data
3273
+ self.test_oht_scl_data = preprocessor.test_oht_scl_data
3274
+ self.var_nmes = preprocessor.var_nmes
3275
+ self.num_features = preprocessor.num_features
3276
+ self.cat_categories_for_shap = preprocessor.cat_categories_for_shap
3277
+ self.geo_token_cols: List[str] = []
3278
+ self.train_geo_tokens: Optional[pd.DataFrame] = None
3279
+ self.test_geo_tokens: Optional[pd.DataFrame] = None
3280
+ self.geo_gnn_model: Optional[GraphNeuralNetSklearn] = None
3281
+ self._add_region_effect()
3282
+ self._prepare_geo_tokens()
3283
+
3284
+ self.cv = ShuffleSplit(n_splits=int(1/self.prop_test),
3285
+ test_size=self.prop_test,
3286
+ random_state=self.rand_seed)
3287
+ if self.task_type == 'classification':
3288
+ self.obj = 'binary:logistic'
3289
+ else: # 回归任务
3290
+ if 'f' in self.model_nme:
3291
+ self.obj = 'count:poisson'
3292
+ elif 's' in self.model_nme:
3293
+ self.obj = 'reg:gamma'
3294
+ elif 'bc' in self.model_nme:
3295
+ self.obj = 'reg:tweedie'
3296
+ else:
3297
+ self.obj = 'reg:tweedie'
3298
+ self.fit_params = {
3299
+ 'sample_weight': self.train_data[self.weight_nme].values
3300
+ }
3301
+ self.model_label: List[str] = []
3302
+ self.optuna_storage = cfg.optuna_storage
3303
+ self.optuna_study_prefix = cfg.optuna_study_prefix or "bayesopt"
3304
+
3305
+ # 记录各模型训练器,后续统一通过标签访问,方便扩展新模型
3306
+ self.trainers: Dict[str, TrainerBase] = {
3307
+ 'glm': GLMTrainer(self),
3308
+ 'xgb': XGBTrainer(self),
3309
+ 'resn': ResNetTrainer(self),
3310
+ 'ft': FTTrainer(self)
3311
+ }
3312
+ self.xgb_best = None
3313
+ self.resn_best = None
3314
+ self.glm_best = None
3315
+ self.ft_best = None
3316
+ self.best_xgb_params = None
3317
+ self.best_resn_params = None
3318
+ self.best_ft_params = None
3319
+ self.best_xgb_trial = None
3320
+ self.best_resn_trial = None
3321
+ self.best_ft_trial = None
3322
+ self.best_glm_params = None
3323
+ self.best_glm_trial = None
3324
+ self.xgb_load = None
3325
+ self.resn_load = None
3326
+ self.ft_load = None
3327
+ self.version_manager = VersionManager(self.output_manager)
3328
+
3329
+ def _build_geo_tokens(self, params_override: Optional[Dict[str, Any]] = None):
3330
+ """内部构建函数,支持传入 trial 覆盖超参,失败则返回 None。"""
3331
+ geo_cols = list(self.config.geo_feature_nmes or [])
3332
+ if not geo_cols:
3333
+ return None
3334
+
3335
+ available = [c for c in geo_cols if c in self.train_data.columns]
3336
+ if not available:
3337
+ return None
3338
+
3339
+ # 预处理文本/数值:数值填中位数,文本做标签编码,未知映射到额外索引
3340
+ proc_train = {}
3341
+ proc_test = {}
3342
+ for col in available:
3343
+ s_train = self.train_data[col]
3344
+ s_test = self.test_data[col]
3345
+ if pd.api.types.is_numeric_dtype(s_train):
3346
+ tr = pd.to_numeric(s_train, errors="coerce")
3347
+ te = pd.to_numeric(s_test, errors="coerce")
3348
+ med = np.nanmedian(tr)
3349
+ proc_train[col] = np.nan_to_num(tr, nan=med).astype(np.float32)
3350
+ proc_test[col] = np.nan_to_num(te, nan=med).astype(np.float32)
3351
+ else:
3352
+ cats = pd.Categorical(s_train.astype(str))
3353
+ tr_codes = cats.codes.astype(np.float32, copy=True)
3354
+ tr_codes[tr_codes < 0] = len(cats.categories)
3355
+ te_cats = pd.Categorical(
3356
+ s_test.astype(str), categories=cats.categories)
3357
+ te_codes = te_cats.codes.astype(np.float32, copy=True)
3358
+ te_codes[te_codes < 0] = len(cats.categories)
3359
+ proc_train[col] = tr_codes
3360
+ proc_test[col] = te_codes
3361
+
3362
+ train_geo_raw = pd.DataFrame(proc_train, index=self.train_data.index)
3363
+ test_geo_raw = pd.DataFrame(proc_test, index=self.test_data.index)
3364
+
3365
+ scaler = StandardScaler()
3366
+ train_geo = pd.DataFrame(
3367
+ scaler.fit_transform(train_geo_raw),
3368
+ columns=available,
3369
+ index=self.train_data.index
3370
+ )
3371
+ test_geo = pd.DataFrame(
3372
+ scaler.transform(test_geo_raw),
3373
+ columns=available,
3374
+ index=self.test_data.index
3375
+ )
3376
+
3377
+ tw_power = None
3378
+ if self.task_type == 'regression':
3379
+ if 'f' in self.model_nme:
3380
+ tw_power = 1.0
3381
+ elif 's' in self.model_nme:
3382
+ tw_power = 2.0
3383
+ else:
3384
+ tw_power = 1.5
3385
+
3386
+ cfg = params_override or {}
3387
+ try:
3388
+ geo_gnn = GraphNeuralNetSklearn(
3389
+ model_nme=f"{self.model_nme}_geo",
3390
+ input_dim=len(available),
3391
+ hidden_dim=cfg.get("geo_token_hidden_dim",
3392
+ self.config.geo_token_hidden_dim),
3393
+ num_layers=cfg.get("geo_token_layers",
3394
+ self.config.geo_token_layers),
3395
+ k_neighbors=cfg.get("geo_token_k_neighbors",
3396
+ self.config.geo_token_k_neighbors),
3397
+ dropout=cfg.get("geo_token_dropout",
3398
+ self.config.geo_token_dropout),
3399
+ learning_rate=cfg.get(
3400
+ "geo_token_learning_rate", self.config.geo_token_learning_rate),
3401
+ epochs=int(cfg.get("geo_token_epochs",
3402
+ self.config.geo_token_epochs)),
3403
+ patience=5,
3404
+ task_type=self.task_type,
3405
+ tweedie_power=tw_power,
3406
+ use_data_parallel=False,
3407
+ use_ddp=False,
3408
+ use_approx_knn=self.config.gnn_use_approx_knn,
3409
+ approx_knn_threshold=self.config.gnn_approx_knn_threshold,
3410
+ graph_cache_path=None,
3411
+ max_gpu_knn_nodes=self.config.gnn_max_gpu_knn_nodes,
3412
+ knn_gpu_mem_ratio=self.config.gnn_knn_gpu_mem_ratio,
3413
+ knn_gpu_mem_overhead=self.config.gnn_knn_gpu_mem_overhead
3414
+ )
3415
+ geo_gnn.fit(
3416
+ train_geo,
3417
+ self.train_data[self.resp_nme],
3418
+ self.train_data[self.weight_nme]
3419
+ )
3420
+ train_embed = geo_gnn.encode(train_geo)
3421
+ test_embed = geo_gnn.encode(test_geo)
3422
+ cols = [f"geo_token_{i}" for i in range(train_embed.shape[1])]
3423
+ train_tokens = pd.DataFrame(
3424
+ train_embed, index=self.train_data.index, columns=cols)
3425
+ test_tokens = pd.DataFrame(
3426
+ test_embed, index=self.test_data.index, columns=cols)
3427
+ return train_tokens, test_tokens, cols, geo_gnn
3428
+ except Exception as exc:
3429
+ print(f"[GeoToken] 生成失败:{exc}")
3430
+ return None
3431
+
3432
+ def _prepare_geo_tokens(self) -> None:
3433
+ """使用配置默认值构建并持久化地理 token。"""
3434
+ result = self._build_geo_tokens()
3435
+ if result is None:
3436
+ return
3437
+ train_tokens, test_tokens, cols, geo_gnn = result
3438
+ self.train_geo_tokens = train_tokens
3439
+ self.test_geo_tokens = test_tokens
3440
+ self.geo_token_cols = cols
3441
+ self.geo_gnn_model = geo_gnn
3442
+ print(f"[GeoToken] 已生成 {len(cols)} 维地理 token,将注入 FT。")
3443
+
3444
+ def _add_region_effect(self) -> None:
3445
+ """对省/市层级做部分池化,生成平滑的 region_effect 数值特征。"""
3446
+ prov_col = self.config.region_province_col
3447
+ city_col = self.config.region_city_col
3448
+ if not prov_col or not city_col:
3449
+ return
3450
+ for col in [prov_col, city_col]:
3451
+ if col not in self.train_data.columns:
3452
+ print(f"[RegionEffect] 缺少列 {col},已跳过。")
3453
+ return
3454
+
3455
+ def safe_mean(df: pd.DataFrame) -> float:
3456
+ w = df[self.weight_nme]
3457
+ y = df[self.resp_nme]
3458
+ denom = max(float(w.sum()), EPS)
3459
+ return float((y * w).sum() / denom)
3460
+
3461
+ global_mean = safe_mean(self.train_data)
3462
+ alpha = max(float(self.config.region_effect_alpha), 0.0)
3463
+
3464
+ prov_stats = self.train_data.groupby(prov_col).apply(safe_mean)
3465
+ prov_stats = prov_stats.to_dict()
3466
+
3467
+ city_group = self.train_data.groupby([prov_col, city_col])
3468
+ city_sumw = city_group[self.weight_nme].sum()
3469
+ city_sumyw = (city_group[self.resp_nme].apply(
3470
+ lambda s: (s * self.train_data.loc[s.index, self.weight_nme]).sum()))
3471
+
3472
+ city_effect: Dict[tuple, float] = {}
3473
+ for (p, c), sum_w in city_sumw.items():
3474
+ sum_yw = city_sumyw[(p, c)]
3475
+ prior = prov_stats.get(p, global_mean)
3476
+ effect = (sum_yw + alpha * prior) / max(sum_w + alpha, EPS)
3477
+ city_effect[(p, c)] = float(effect)
3478
+
3479
+ def lookup_effect(df: pd.DataFrame) -> pd.Series:
3480
+ effects = []
3481
+ for _, row in df[[prov_col, city_col]].iterrows():
3482
+ p = row[prov_col]
3483
+ c = row[city_col]
3484
+ val = city_effect.get((p, c))
3485
+ if val is None:
3486
+ val = prov_stats.get(p, global_mean)
3487
+ if not np.isfinite(val):
3488
+ val = global_mean
3489
+ effects.append(val)
3490
+ return pd.Series(effects, index=df.index, dtype=np.float32)
3491
+
3492
+ re_train = lookup_effect(self.train_data)
3493
+ re_test = lookup_effect(self.test_data)
3494
+
3495
+ col_name = "region_effect"
3496
+ self.train_data[col_name] = re_train
3497
+ self.test_data[col_name] = re_test
3498
+
3499
+ # 同步到 one-hot 与标准化版本
3500
+ for df in [self.train_oht_data, self.test_oht_data]:
3501
+ if df is not None:
3502
+ df[col_name] = re_train if df is self.train_oht_data else re_test
3503
+
3504
+ # 标准化 region_effect 并同步
3505
+ scaler = StandardScaler()
3506
+ re_train_s = scaler.fit_transform(
3507
+ re_train.values.reshape(-1, 1)).astype(np.float32).reshape(-1)
3508
+ re_test_s = scaler.transform(
3509
+ re_test.values.reshape(-1, 1)).astype(np.float32).reshape(-1)
3510
+ for df in [self.train_oht_scl_data, self.test_oht_scl_data]:
3511
+ if df is not None:
3512
+ df[col_name] = re_train_s if df is self.train_oht_scl_data else re_test_s
3513
+
3514
+ # 更新特征列表
3515
+ if col_name not in self.factor_nmes:
3516
+ self.factor_nmes.append(col_name)
3517
+ if col_name not in self.num_features:
3518
+ self.num_features.append(col_name)
3519
+ if self.train_oht_scl_data is not None:
3520
+ self.var_nmes = list(
3521
+ set(self.train_oht_scl_data.columns) -
3522
+ set([self.weight_nme, self.resp_nme])
3523
+ )
3524
+
3525
+ # 定义单因素画图函数
3526
+ def plot_oneway(self, n_bins=10):
3527
+ for c in self.factor_nmes:
3528
+ fig = plt.figure(figsize=(7, 5))
3529
+ if c in self.cate_list:
3530
+ group_col = c
3531
+ plot_source = self.train_data
3532
+ else:
3533
+ group_col = f'{c}_bins'
3534
+ bins = pd.qcut(
3535
+ self.train_data[c],
3536
+ n_bins,
3537
+ duplicates='drop' # 注意:如果分位数重复会丢 bin,避免异常终止
3538
+ )
3539
+ plot_source = self.train_data.assign(**{group_col: bins})
3540
+ plot_data = plot_source.groupby(
3541
+ [group_col], observed=True).sum(numeric_only=True)
3542
+ plot_data.reset_index(inplace=True)
3543
+ plot_data['act_v'] = plot_data['w_act'] / \
3544
+ plot_data[self.weight_nme]
3545
+ ax = fig.add_subplot(111)
3546
+ ax.plot(plot_data.index, plot_data['act_v'],
3547
+ label='Actual', color='red')
3548
+ ax.set_title(
3549
+ 'Analysis of %s : Train Data' % group_col,
3550
+ fontsize=8)
3551
+ plt.xticks(plot_data.index,
3552
+ list(plot_data[group_col].astype(str)),
3553
+ rotation=90)
3554
+ if len(list(plot_data[group_col].astype(str))) > 50:
3555
+ plt.xticks(fontsize=3)
3556
+ else:
3557
+ plt.xticks(fontsize=6)
3558
+ plt.yticks(fontsize=6)
3559
+ ax2 = ax.twinx()
3560
+ ax2.bar(plot_data.index,
3561
+ plot_data[self.weight_nme],
3562
+ alpha=0.5, color='seagreen')
3563
+ plt.yticks(fontsize=6)
3564
+ plt.margins(0.05)
3565
+ plt.subplots_adjust(wspace=0.3)
3566
+ save_path = self.output_manager.plot_path(
3567
+ f'00_{self.model_nme}_{group_col}_oneway.png')
3568
+ plt.savefig(save_path, dpi=300)
3569
+ plt.close(fig)
3570
+
3571
+ # 定义通用优化函数
3572
+ def optimize_model(self, model_key: str, max_evals: int = 100):
3573
+ if model_key not in self.trainers:
3574
+ print(f"Warning: Unknown model key: {model_key}")
3575
+ return
3576
+
3577
+ trainer = self.trainers[model_key]
3578
+ trainer.tune(max_evals)
3579
+ trainer.train()
3580
+
3581
+ # 更新上下文字段,保证历史接口的兼容性
3582
+ setattr(self, f"{model_key}_best", trainer.model)
3583
+ setattr(self, f"best_{model_key}_params", trainer.best_params)
3584
+ setattr(self, f"best_{model_key}_trial", trainer.best_trial)
3585
+ # 保存一次版本快照,方便回溯
3586
+ study_name = getattr(trainer, "study_name", None)
3587
+ if study_name is None and trainer.best_trial is not None:
3588
+ study_obj = getattr(trainer.best_trial, "study", None)
3589
+ study_name = getattr(study_obj, "study_name", None)
3590
+ snapshot = {
3591
+ "model_key": model_key,
3592
+ "timestamp": datetime.now().isoformat(),
3593
+ "best_params": trainer.best_params,
3594
+ "study_name": study_name,
3595
+ "config": asdict(self.config),
3596
+ }
3597
+ self.version_manager.save(f"{model_key}_best", snapshot)
3598
+
3599
+ # 定义GLM贝叶斯优化函数
3600
+ def bayesopt_glm(self, max_evals=50):
3601
+ self.optimize_model('glm', max_evals)
3602
+
3603
+ # 定义Xgboost贝叶斯优化函数
3604
+ def bayesopt_xgb(self, max_evals=100):
3605
+ self.optimize_model('xgb', max_evals)
3606
+
3607
+ # 定义ResNet贝叶斯优化函数
3608
+ def bayesopt_resnet(self, max_evals=100):
3609
+ self.optimize_model('resn', max_evals)
3610
+
3611
+ # 定义 GNN 贝叶斯优化函数
3612
+ def bayesopt_gnn(self, max_evals=50):
3613
+ print("[bayesopt_gnn] 独立 GNN 已停用,地理 GNN 仅用于生成 FT 的 geo token。")
3614
+
3615
+ # 定义 FT-Transformer 贝叶斯优化函数
3616
+ def bayesopt_ft(self, max_evals=50):
3617
+ self.optimize_model('ft', max_evals)
3618
+
3619
+ # 绘制提纯曲线
3620
+ def plot_lift(self, model_label, pred_nme, n_bins=10):
3621
+ model_map = {
3622
+ 'Xgboost': 'pred_xgb',
3623
+ 'ResNet': 'pred_resn',
3624
+ 'ResNetClassifier': 'pred_resn',
3625
+ 'FTTransformer': 'pred_ft',
3626
+ 'FTTransformerClassifier': 'pred_ft',
3627
+ 'GLM': 'pred_glm'
3628
+ }
3629
+ for k, v in model_map.items():
3630
+ if model_label.startswith(k):
3631
+ pred_nme = v
3632
+ break
3633
+
3634
+ fig = plt.figure(figsize=(11, 5))
3635
+ for pos, (title, data) in zip([121, 122],
3636
+ [('Lift Chart on Train Data', self.train_data),
3637
+ ('Lift Chart on Test Data', self.test_data)]):
3638
+ lift_df = pd.DataFrame({
3639
+ 'pred': data[pred_nme].values,
3640
+ 'w_pred': data[f'w_{pred_nme}'].values,
3641
+ 'act': data['w_act'].values,
3642
+ 'weight': data[self.weight_nme].values
3643
+ })
3644
+ plot_data = PlotUtils.split_data(lift_df, 'pred', 'weight', n_bins)
3645
+ denom = np.maximum(plot_data['weight'], EPS)
3646
+ plot_data['exp_v'] = plot_data['w_pred'] / denom
3647
+ plot_data['act_v'] = plot_data['act'] / denom
3648
+ plot_data = plot_data.reset_index()
3649
+
3650
+ ax = fig.add_subplot(pos)
3651
+ PlotUtils.plot_lift_ax(ax, plot_data, title)
3652
+
3653
+ plt.subplots_adjust(wspace=0.3)
3654
+ save_path = self.output_manager.plot_path(
3655
+ f'01_{self.model_nme}_{model_label}_lift.png')
3656
+ plt.savefig(save_path, dpi=300)
3657
+ plt.show()
3658
+ plt.close(fig)
3659
+
3660
+ # 绘制双提纯曲线
3661
+ def plot_dlift(self, model_comp: List[str] = ['xgb', 'resn'], n_bins: int = 10) -> None:
3662
+ # 绘制双提纯曲线,对比两个模型在不同分箱下的表现。
3663
+ # 参数说明:
3664
+ # model_comp: 需要对比的模型简称(如 ['xgb', 'resn'],支持 'xgb'/'resn'/'ft')。
3665
+ # n_bins: 分箱数量,用于控制 lift 曲线的粒度。
3666
+ if len(model_comp) != 2:
3667
+ raise ValueError("`model_comp` 必须包含两个模型进行对比。")
3668
+
3669
+ model_name_map = {
3670
+ 'xgb': 'Xgboost',
3671
+ 'resn': 'ResNet',
3672
+ 'ft': 'FTTransformer',
3673
+ 'glm': 'GLM'
3674
+ }
3675
+
3676
+ name1, name2 = model_comp
3677
+ if name1 not in model_name_map or name2 not in model_name_map:
3678
+ raise ValueError(f"不支持的模型简称。请从 {list(model_name_map.keys())} 中选择。")
3679
+
3680
+ fig, axes = plt.subplots(1, 2, figsize=(11, 5))
3681
+ datasets = {
3682
+ 'Train Data': self.train_data,
3683
+ 'Test Data': self.test_data
3684
+ }
3685
+
3686
+ for ax, (data_name, data) in zip(axes, datasets.items()):
3687
+ pred1_col = f'w_pred_{name1}'
3688
+ pred2_col = f'w_pred_{name2}'
3689
+
3690
+ if pred1_col not in data.columns or pred2_col not in data.columns:
3691
+ print(
3692
+ f"警告: 在 {data_name} 中找不到预测列 {pred1_col} 或 {pred2_col}。跳过绘图。")
3693
+ continue
3694
+
3695
+ lift_data = pd.DataFrame({
3696
+ 'pred1': data[pred1_col].values,
3697
+ 'pred2': data[pred2_col].values,
3698
+ 'diff_ly': data[pred1_col].values / np.maximum(data[pred2_col].values, EPS),
3699
+ 'act': data['w_act'].values,
3700
+ 'weight': data[self.weight_nme].values
3701
+ })
3702
+ plot_data = PlotUtils.split_data(
3703
+ lift_data, 'diff_ly', 'weight', n_bins)
3704
+ denom = np.maximum(plot_data['act'], EPS)
3705
+ plot_data['exp_v1'] = plot_data['pred1'] / denom
3706
+ plot_data['exp_v2'] = plot_data['pred2'] / denom
3707
+ plot_data['act_v'] = plot_data['act'] / denom
3708
+ plot_data.reset_index(inplace=True)
3709
+
3710
+ label1 = model_name_map[name1]
3711
+ label2 = model_name_map[name2]
3712
+
3713
+ PlotUtils.plot_dlift_ax(
3714
+ ax, plot_data, f'Double Lift Chart on {data_name}', label1, label2)
3715
+
3716
+ plt.subplots_adjust(bottom=0.25, top=0.95, right=0.8, wspace=0.3)
3717
+ save_path = self.output_manager.plot_path(
3718
+ f'02_{self.model_nme}_dlift_{name1}_vs_{name2}.png')
3719
+ plt.savefig(save_path, dpi=300)
3720
+ plt.show()
3721
+ plt.close(fig)
3722
+
3723
+ # 绘制成交率提升曲线
3724
+ def plot_conversion_lift(self, model_pred_col: str, n_bins: int = 20):
3725
+ if not self.binary_resp_nme:
3726
+ print("错误: 未在 BayesOptModel 初始化时提供 `binary_resp_nme`。无法绘制成交率曲线。")
3727
+ return
3728
+
3729
+ fig, axes = plt.subplots(1, 2, figsize=(14, 6), sharey=True)
3730
+ datasets = {
3731
+ 'Train Data': self.train_data,
3732
+ 'Test Data': self.test_data
3733
+ }
3734
+
3735
+ for ax, (data_name, data) in zip(axes, datasets.items()):
3736
+ if model_pred_col not in data.columns:
3737
+ print(f"警告: 在 {data_name} 中找不到预测列 '{model_pred_col}'。跳过绘图。")
3738
+ continue
3739
+
3740
+ # 按模型预测分排序,并计算分箱
3741
+ plot_data = data.sort_values(by=model_pred_col).copy()
3742
+ plot_data['cum_weight'] = plot_data[self.weight_nme].cumsum()
3743
+ total_weight = plot_data[self.weight_nme].sum()
3744
+
3745
+ if total_weight > EPS:
3746
+ plot_data['bin'] = pd.cut(
3747
+ plot_data['cum_weight'],
3748
+ bins=n_bins,
3749
+ labels=False,
3750
+ right=False
3751
+ )
3752
+ else:
3753
+ plot_data['bin'] = 0
3754
+
3755
+ # 按分箱聚合
3756
+ lift_agg = plot_data.groupby('bin').agg(
3757
+ total_weight=(self.weight_nme, 'sum'),
3758
+ actual_conversions=(self.binary_resp_nme, 'sum'),
3759
+ weighted_conversions=('w_binary_act', 'sum'),
3760
+ avg_pred=(model_pred_col, 'mean')
3761
+ ).reset_index()
3762
+
3763
+ # 计算成交率
3764
+ lift_agg['conversion_rate'] = lift_agg['weighted_conversions'] / \
3765
+ lift_agg['total_weight']
3766
+
3767
+ # 计算整体平均成交率
3768
+ overall_conversion_rate = data['w_binary_act'].sum(
3769
+ ) / data[self.weight_nme].sum()
3770
+ ax.axhline(y=overall_conversion_rate, color='gray', linestyle='--',
3771
+ label=f'Overall Avg Rate ({overall_conversion_rate:.2%})')
3772
+
3773
+ ax.plot(lift_agg['bin'], lift_agg['conversion_rate'],
3774
+ marker='o', linestyle='-', label='Actual Conversion Rate')
3775
+ ax.set_title(f'Conversion Rate Lift Chart on {data_name}')
3776
+ ax.set_xlabel(f'Model Score Decile (based on {model_pred_col})')
3777
+ ax.set_ylabel('Conversion Rate')
3778
+ ax.grid(True, linestyle='--', alpha=0.6)
3779
+ ax.legend()
3780
+
3781
+ plt.tight_layout()
3782
+ plt.show()
3783
+
3784
+ # 保存模型
3785
+ def save_model(self, model_name=None):
3786
+ keys = [model_name] if model_name else self.trainers.keys()
3787
+ for key in keys:
3788
+ if key in self.trainers:
3789
+ self.trainers[key].save()
3790
+ else:
3791
+ if model_name: # 仅在用户指定模型名时输出告警
3792
+ print(f"[save_model] Warning: Unknown model key {key}")
3793
+
3794
+ def load_model(self, model_name=None):
3795
+ keys = [model_name] if model_name else self.trainers.keys()
3796
+ for key in keys:
3797
+ if key in self.trainers:
3798
+ self.trainers[key].load()
3799
+ # 同步上下文字段
3800
+ trainer = self.trainers[key]
3801
+ if trainer.model is not None:
3802
+ setattr(self, f"{key}_best", trainer.model)
3803
+ # 如需兼容旧版字段,也同步更新 xxx_load
3804
+ # 旧版只维护 xgb_load/resn_load/ft_load,未包含 glm_load
3805
+ if key in ['xgb', 'resn', 'ft']:
3806
+ setattr(self, f"{key}_load", trainer.model)
3807
+ else:
3808
+ if model_name:
3809
+ print(f"[load_model] Warning: Unknown model key {key}")
3810
+
3811
+ def _sample_rows(self, data: pd.DataFrame, n: int) -> pd.DataFrame:
3812
+ if len(data) == 0:
3813
+ return data
3814
+ return data.sample(min(len(data), n), random_state=self.rand_seed)
3815
+
3816
+ @staticmethod
3817
+ def _shap_nsamples(arr: np.ndarray, max_nsamples: int = 300) -> int:
3818
+ min_needed = arr.shape[1] + 2
3819
+ return max(min_needed, min(max_nsamples, arr.shape[0] * arr.shape[1]))
3820
+
3821
+ def _build_ft_shap_matrix(self, data: pd.DataFrame) -> np.ndarray:
3822
+ matrices = []
3823
+ for col in self.factor_nmes:
3824
+ s = data[col]
3825
+ if col in self.cate_list:
3826
+ cats = pd.Categorical(
3827
+ s,
3828
+ categories=self.cat_categories_for_shap[col]
3829
+ )
3830
+ codes = np.asarray(cats.codes, dtype=np.float64).reshape(-1, 1)
3831
+ matrices.append(codes)
3832
+ else:
3833
+ vals = pd.to_numeric(s, errors="coerce")
3834
+ arr = vals.to_numpy(dtype=np.float64, copy=True).reshape(-1, 1)
3835
+ matrices.append(arr)
3836
+ X_mat = np.concatenate(matrices, axis=1) # 结果形状为 (N, F)
3837
+ return X_mat
3838
+
3839
+ def _decode_ft_shap_matrix_to_df(self, X_mat: np.ndarray) -> pd.DataFrame:
3840
+ data_dict = {}
3841
+ for j, col in enumerate(self.factor_nmes):
3842
+ col_vals = X_mat[:, j]
3843
+ if col in self.cate_list:
3844
+ cats = self.cat_categories_for_shap[col]
3845
+ codes = np.round(col_vals).astype(int)
3846
+ codes = np.clip(codes, -1, len(cats) - 1)
3847
+ cat_series = pd.Categorical.from_codes(
3848
+ codes,
3849
+ categories=cats
3850
+ )
3851
+ data_dict[col] = cat_series
3852
+ else:
3853
+ data_dict[col] = col_vals.astype(float)
3854
+
3855
+ df = pd.DataFrame(data_dict, columns=self.factor_nmes)
3856
+ for col in self.cate_list:
3857
+ if col in df.columns:
3858
+ df[col] = df[col].astype("category")
3859
+ return df
3860
+
3861
+ def _build_glm_design(self, data: pd.DataFrame) -> pd.DataFrame:
3862
+ X = data[self.var_nmes]
3863
+ return sm.add_constant(X, has_constant='add')
3864
+
3865
+ def _compute_shap_core(self,
3866
+ model_key: str,
3867
+ n_background: int,
3868
+ n_samples: int,
3869
+ on_train: bool,
3870
+ X_df: pd.DataFrame,
3871
+ prep_fn,
3872
+ predict_fn,
3873
+ cleanup_fn=None):
3874
+ if model_key not in self.trainers or self.trainers[model_key].model is None:
3875
+ raise RuntimeError(f"Model {model_key} not trained.")
3876
+ if cleanup_fn:
3877
+ cleanup_fn()
3878
+ bg_df = self._sample_rows(X_df, n_background)
3879
+ bg_mat = prep_fn(bg_df)
3880
+ explainer = shap.KernelExplainer(predict_fn, bg_mat)
3881
+ ex_df = self._sample_rows(X_df, n_samples)
3882
+ ex_mat = prep_fn(ex_df)
3883
+ nsample_eff = self._shap_nsamples(ex_mat)
3884
+ shap_values = explainer.shap_values(ex_mat, nsamples=nsample_eff)
3885
+ bg_pred = predict_fn(bg_mat)
3886
+ base_value = float(np.asarray(bg_pred).mean())
3887
+
3888
+ return {
3889
+ "explainer": explainer,
3890
+ "X_explain": ex_df,
3891
+ "shap_values": shap_values,
3892
+ "base_value": base_value
3893
+ }
3894
+
3895
+ # ========= GLM 的 SHAP 解释 =========
3896
+ def compute_shap_glm(self, n_background: int = 500,
3897
+ n_samples: int = 200,
3898
+ on_train: bool = True):
3899
+ data = self.train_oht_scl_data if on_train else self.test_oht_scl_data
3900
+ design_all = self._build_glm_design(data)
3901
+ design_cols = list(design_all.columns)
3902
+
3903
+ def predict_wrapper(x_np):
3904
+ x_df = pd.DataFrame(x_np, columns=design_cols)
3905
+ y_pred = self.glm_best.predict(x_df)
3906
+ return np.asarray(y_pred, dtype=np.float64).reshape(-1)
3907
+
3908
+ self.shap_glm = self._compute_shap_core(
3909
+ 'glm', n_background, n_samples, on_train,
3910
+ X_df=design_all,
3911
+ prep_fn=lambda df: df.to_numpy(dtype=np.float64),
3912
+ predict_fn=predict_wrapper
3913
+ )
3914
+ return self.shap_glm
3915
+
3916
+ # ========= XGBoost 的 SHAP 解释 =========
3917
+ def compute_shap_xgb(self, n_background: int = 500,
3918
+ n_samples: int = 200,
3919
+ on_train: bool = True):
3920
+ data = self.train_data if on_train else self.test_data
3921
+ X_raw = data[self.factor_nmes]
3922
+
3923
+ def predict_wrapper(x_mat):
3924
+ df_input = self._decode_ft_shap_matrix_to_df(x_mat)
3925
+ return self.xgb_best.predict(df_input)
3926
+
3927
+ self.shap_xgb = self._compute_shap_core(
3928
+ 'xgb', n_background, n_samples, on_train,
3929
+ X_df=X_raw,
3930
+ prep_fn=lambda df: self._build_ft_shap_matrix(
3931
+ df).astype(np.float64),
3932
+ predict_fn=predict_wrapper
3933
+ )
3934
+ return self.shap_xgb
3935
+
3936
+ # ========= ResNet 的 SHAP 解释 =========
3937
+ def _resn_predict_wrapper(self, X_np):
3938
+ model = self.resn_best.resnet.to("cpu")
3939
+ with torch.no_grad():
3940
+ X_tensor = torch.tensor(X_np, dtype=torch.float32)
3941
+ y_pred = model(X_tensor).cpu().numpy()
3942
+ y_pred = np.clip(y_pred, 1e-6, None)
3943
+ return y_pred.reshape(-1)
3944
+
3945
+ def compute_shap_resn(self, n_background: int = 500,
3946
+ n_samples: int = 200,
3947
+ on_train: bool = True):
3948
+ data = self.train_oht_scl_data if on_train else self.test_oht_scl_data
3949
+ X = data[self.var_nmes]
3950
+
3951
+ def cleanup():
3952
+ self.resn_best.device = torch.device("cpu")
3953
+ self.resn_best.resnet.to("cpu")
3954
+ if torch.cuda.is_available():
3955
+ torch.cuda.empty_cache()
3956
+
3957
+ self.shap_resn = self._compute_shap_core(
3958
+ 'resn', n_background, n_samples, on_train,
3959
+ X_df=X,
3960
+ prep_fn=lambda df: df.to_numpy(dtype=np.float64),
3961
+ predict_fn=lambda x: self._resn_predict_wrapper(x),
3962
+ cleanup_fn=cleanup
3963
+ )
3964
+ return self.shap_resn
3965
+
3966
+ # ========= FT-Transformer 的 SHAP 解释 =========
3967
+ def _ft_shap_predict_wrapper(self, X_mat: np.ndarray) -> np.ndarray:
3968
+ df_input = self._decode_ft_shap_matrix_to_df(X_mat)
3969
+ y_pred = self.ft_best.predict(df_input)
3970
+ return np.asarray(y_pred, dtype=np.float64).reshape(-1)
3971
+
3972
+ def compute_shap_ft(self, n_background: int = 500,
3973
+ n_samples: int = 200,
3974
+ on_train: bool = True):
3975
+ data = self.train_data if on_train else self.test_data
3976
+ X_raw = data[self.factor_nmes]
3977
+
3978
+ def cleanup():
3979
+ self.ft_best.device = torch.device("cpu")
3980
+ self.ft_best.ft.to("cpu")
3981
+ if torch.cuda.is_available():
3982
+ torch.cuda.empty_cache()
3983
+
3984
+ self.shap_ft = self._compute_shap_core(
3985
+ 'ft', n_background, n_samples, on_train,
3986
+ X_df=X_raw,
3987
+ prep_fn=lambda df: self._build_ft_shap_matrix(
3988
+ df).astype(np.float64),
3989
+ predict_fn=self._ft_shap_predict_wrapper,
3990
+ cleanup_fn=cleanup
3991
+ )
3992
+ return self.shap_ft