ins-pricing 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ins_pricing/README.md +60 -0
- ins_pricing/__init__.py +102 -0
- ins_pricing/governance/README.md +18 -0
- ins_pricing/governance/__init__.py +20 -0
- ins_pricing/governance/approval.py +93 -0
- ins_pricing/governance/audit.py +37 -0
- ins_pricing/governance/registry.py +99 -0
- ins_pricing/governance/release.py +159 -0
- ins_pricing/modelling/BayesOpt.py +146 -0
- ins_pricing/modelling/BayesOpt_USAGE.md +925 -0
- ins_pricing/modelling/BayesOpt_entry.py +575 -0
- ins_pricing/modelling/BayesOpt_incremental.py +731 -0
- ins_pricing/modelling/Explain_Run.py +36 -0
- ins_pricing/modelling/Explain_entry.py +539 -0
- ins_pricing/modelling/Pricing_Run.py +36 -0
- ins_pricing/modelling/README.md +33 -0
- ins_pricing/modelling/__init__.py +44 -0
- ins_pricing/modelling/bayesopt/__init__.py +98 -0
- ins_pricing/modelling/bayesopt/config_preprocess.py +303 -0
- ins_pricing/modelling/bayesopt/core.py +1476 -0
- ins_pricing/modelling/bayesopt/models.py +2196 -0
- ins_pricing/modelling/bayesopt/trainers.py +2446 -0
- ins_pricing/modelling/bayesopt/utils.py +1021 -0
- ins_pricing/modelling/cli_common.py +136 -0
- ins_pricing/modelling/explain/__init__.py +55 -0
- ins_pricing/modelling/explain/gradients.py +334 -0
- ins_pricing/modelling/explain/metrics.py +176 -0
- ins_pricing/modelling/explain/permutation.py +155 -0
- ins_pricing/modelling/explain/shap_utils.py +146 -0
- ins_pricing/modelling/notebook_utils.py +284 -0
- ins_pricing/modelling/plotting/__init__.py +45 -0
- ins_pricing/modelling/plotting/common.py +63 -0
- ins_pricing/modelling/plotting/curves.py +572 -0
- ins_pricing/modelling/plotting/diagnostics.py +139 -0
- ins_pricing/modelling/plotting/geo.py +362 -0
- ins_pricing/modelling/plotting/importance.py +121 -0
- ins_pricing/modelling/run_logging.py +133 -0
- ins_pricing/modelling/tests/conftest.py +8 -0
- ins_pricing/modelling/tests/test_cross_val_generic.py +66 -0
- ins_pricing/modelling/tests/test_distributed_utils.py +18 -0
- ins_pricing/modelling/tests/test_explain.py +56 -0
- ins_pricing/modelling/tests/test_geo_tokens_split.py +49 -0
- ins_pricing/modelling/tests/test_graph_cache.py +33 -0
- ins_pricing/modelling/tests/test_plotting.py +63 -0
- ins_pricing/modelling/tests/test_plotting_library.py +150 -0
- ins_pricing/modelling/tests/test_preprocessor.py +48 -0
- ins_pricing/modelling/watchdog_run.py +211 -0
- ins_pricing/pricing/README.md +44 -0
- ins_pricing/pricing/__init__.py +27 -0
- ins_pricing/pricing/calibration.py +39 -0
- ins_pricing/pricing/data_quality.py +117 -0
- ins_pricing/pricing/exposure.py +85 -0
- ins_pricing/pricing/factors.py +91 -0
- ins_pricing/pricing/monitoring.py +99 -0
- ins_pricing/pricing/rate_table.py +78 -0
- ins_pricing/production/__init__.py +21 -0
- ins_pricing/production/drift.py +30 -0
- ins_pricing/production/monitoring.py +143 -0
- ins_pricing/production/scoring.py +40 -0
- ins_pricing/reporting/README.md +20 -0
- ins_pricing/reporting/__init__.py +11 -0
- ins_pricing/reporting/report_builder.py +72 -0
- ins_pricing/reporting/scheduler.py +45 -0
- ins_pricing/setup.py +41 -0
- ins_pricing v2/__init__.py +23 -0
- ins_pricing v2/governance/__init__.py +20 -0
- ins_pricing v2/governance/approval.py +93 -0
- ins_pricing v2/governance/audit.py +37 -0
- ins_pricing v2/governance/registry.py +99 -0
- ins_pricing v2/governance/release.py +159 -0
- ins_pricing v2/modelling/Explain_Run.py +36 -0
- ins_pricing v2/modelling/Pricing_Run.py +36 -0
- ins_pricing v2/modelling/__init__.py +151 -0
- ins_pricing v2/modelling/cli_common.py +141 -0
- ins_pricing v2/modelling/config.py +249 -0
- ins_pricing v2/modelling/config_preprocess.py +254 -0
- ins_pricing v2/modelling/core.py +741 -0
- ins_pricing v2/modelling/data_container.py +42 -0
- ins_pricing v2/modelling/explain/__init__.py +55 -0
- ins_pricing v2/modelling/explain/gradients.py +334 -0
- ins_pricing v2/modelling/explain/metrics.py +176 -0
- ins_pricing v2/modelling/explain/permutation.py +155 -0
- ins_pricing v2/modelling/explain/shap_utils.py +146 -0
- ins_pricing v2/modelling/features.py +215 -0
- ins_pricing v2/modelling/model_manager.py +148 -0
- ins_pricing v2/modelling/model_plotting.py +463 -0
- ins_pricing v2/modelling/models.py +2203 -0
- ins_pricing v2/modelling/notebook_utils.py +294 -0
- ins_pricing v2/modelling/plotting/__init__.py +45 -0
- ins_pricing v2/modelling/plotting/common.py +63 -0
- ins_pricing v2/modelling/plotting/curves.py +572 -0
- ins_pricing v2/modelling/plotting/diagnostics.py +139 -0
- ins_pricing v2/modelling/plotting/geo.py +362 -0
- ins_pricing v2/modelling/plotting/importance.py +121 -0
- ins_pricing v2/modelling/run_logging.py +133 -0
- ins_pricing v2/modelling/tests/conftest.py +8 -0
- ins_pricing v2/modelling/tests/test_cross_val_generic.py +66 -0
- ins_pricing v2/modelling/tests/test_distributed_utils.py +18 -0
- ins_pricing v2/modelling/tests/test_explain.py +56 -0
- ins_pricing v2/modelling/tests/test_geo_tokens_split.py +49 -0
- ins_pricing v2/modelling/tests/test_graph_cache.py +33 -0
- ins_pricing v2/modelling/tests/test_plotting.py +63 -0
- ins_pricing v2/modelling/tests/test_plotting_library.py +150 -0
- ins_pricing v2/modelling/tests/test_preprocessor.py +48 -0
- ins_pricing v2/modelling/trainers.py +2447 -0
- ins_pricing v2/modelling/utils.py +1020 -0
- ins_pricing v2/modelling/watchdog_run.py +211 -0
- ins_pricing v2/pricing/__init__.py +27 -0
- ins_pricing v2/pricing/calibration.py +39 -0
- ins_pricing v2/pricing/data_quality.py +117 -0
- ins_pricing v2/pricing/exposure.py +85 -0
- ins_pricing v2/pricing/factors.py +91 -0
- ins_pricing v2/pricing/monitoring.py +99 -0
- ins_pricing v2/pricing/rate_table.py +78 -0
- ins_pricing v2/production/__init__.py +21 -0
- ins_pricing v2/production/drift.py +30 -0
- ins_pricing v2/production/monitoring.py +143 -0
- ins_pricing v2/production/scoring.py +40 -0
- ins_pricing v2/reporting/__init__.py +11 -0
- ins_pricing v2/reporting/report_builder.py +72 -0
- ins_pricing v2/reporting/scheduler.py +45 -0
- ins_pricing v2/scripts/BayesOpt_incremental.py +722 -0
- ins_pricing v2/scripts/Explain_entry.py +545 -0
- ins_pricing v2/scripts/__init__.py +1 -0
- ins_pricing v2/scripts/train.py +568 -0
- ins_pricing v2/setup.py +55 -0
- ins_pricing v2/smoke_test.py +28 -0
- ins_pricing-0.1.6.dist-info/METADATA +78 -0
- ins_pricing-0.1.6.dist-info/RECORD +169 -0
- ins_pricing-0.1.6.dist-info/WHEEL +5 -0
- ins_pricing-0.1.6.dist-info/top_level.txt +4 -0
- user_packages/__init__.py +105 -0
- user_packages legacy/BayesOpt.py +5659 -0
- user_packages legacy/BayesOpt_entry.py +513 -0
- user_packages legacy/BayesOpt_incremental.py +685 -0
- user_packages legacy/Pricing_Run.py +36 -0
- user_packages legacy/Try/BayesOpt Legacy251213.py +3719 -0
- user_packages legacy/Try/BayesOpt Legacy251215.py +3758 -0
- user_packages legacy/Try/BayesOpt lagecy251201.py +3506 -0
- user_packages legacy/Try/BayesOpt lagecy251218.py +3992 -0
- user_packages legacy/Try/BayesOpt legacy.py +3280 -0
- user_packages legacy/Try/BayesOpt.py +838 -0
- user_packages legacy/Try/BayesOptAll.py +1569 -0
- user_packages legacy/Try/BayesOptAllPlatform.py +909 -0
- user_packages legacy/Try/BayesOptCPUGPU.py +1877 -0
- user_packages legacy/Try/BayesOptSearch.py +830 -0
- user_packages legacy/Try/BayesOptSearchOrigin.py +829 -0
- user_packages legacy/Try/BayesOptV1.py +1911 -0
- user_packages legacy/Try/BayesOptV10.py +2973 -0
- user_packages legacy/Try/BayesOptV11.py +3001 -0
- user_packages legacy/Try/BayesOptV12.py +3001 -0
- user_packages legacy/Try/BayesOptV2.py +2065 -0
- user_packages legacy/Try/BayesOptV3.py +2209 -0
- user_packages legacy/Try/BayesOptV4.py +2342 -0
- user_packages legacy/Try/BayesOptV5.py +2372 -0
- user_packages legacy/Try/BayesOptV6.py +2759 -0
- user_packages legacy/Try/BayesOptV7.py +2832 -0
- user_packages legacy/Try/BayesOptV8Codex.py +2731 -0
- user_packages legacy/Try/BayesOptV8Gemini.py +2614 -0
- user_packages legacy/Try/BayesOptV9.py +2927 -0
- user_packages legacy/Try/BayesOpt_entry legacy.py +313 -0
- user_packages legacy/Try/ModelBayesOptSearch.py +359 -0
- user_packages legacy/Try/ResNetBayesOptSearch.py +249 -0
- user_packages legacy/Try/XgbBayesOptSearch.py +121 -0
- user_packages legacy/Try/xgbbayesopt.py +523 -0
- user_packages legacy/__init__.py +19 -0
- user_packages legacy/cli_common.py +124 -0
- user_packages legacy/notebook_utils.py +228 -0
- user_packages legacy/watchdog_run.py +202 -0
|
@@ -0,0 +1,3506 @@
|
|
|
1
|
+
from sklearn.metrics import log_loss, make_scorer, mean_tweedie_deviance
|
|
2
|
+
from sklearn.preprocessing import StandardScaler
|
|
3
|
+
from sklearn.neighbors import NearestNeighbors
|
|
4
|
+
from sklearn.model_selection import ShuffleSplit, cross_val_score # 1.2.2
|
|
5
|
+
import torch.distributed as dist
|
|
6
|
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
7
|
+
from torch.nn.utils import clip_grad_norm_
|
|
8
|
+
from torch.cuda.amp import autocast, GradScaler
|
|
9
|
+
from torch.utils.data import Dataset, DataLoader, TensorDataset, DistributedSampler
|
|
10
|
+
import xgboost as xgb # 1.7.0
|
|
11
|
+
import torch.nn.functional as F
|
|
12
|
+
import torch.nn as nn
|
|
13
|
+
import torch # 版本: 1.10.1+cu111
|
|
14
|
+
import statsmodels.api as sm
|
|
15
|
+
import shap
|
|
16
|
+
import pandas as pd # 2.2.3
|
|
17
|
+
import optuna # 4.3.0
|
|
18
|
+
import numpy as np # 1.26.2
|
|
19
|
+
try:
|
|
20
|
+
from torch_geometric.nn import knn_graph
|
|
21
|
+
from torch_geometric.utils import add_self_loops, to_undirected
|
|
22
|
+
_PYG_AVAILABLE = True
|
|
23
|
+
except Exception:
|
|
24
|
+
knn_graph = None # type: ignore
|
|
25
|
+
add_self_loops = None # type: ignore
|
|
26
|
+
to_undirected = None # type: ignore
|
|
27
|
+
_PYG_AVAILABLE = False
|
|
28
|
+
try:
|
|
29
|
+
import pynndescent
|
|
30
|
+
_PYNN_AVAILABLE = True
|
|
31
|
+
except Exception:
|
|
32
|
+
pynndescent = None # type: ignore
|
|
33
|
+
_PYNN_AVAILABLE = False
|
|
34
|
+
import matplotlib.pyplot as plt
|
|
35
|
+
import joblib
|
|
36
|
+
import csv
|
|
37
|
+
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
|
|
38
|
+
from pathlib import Path
|
|
39
|
+
from dataclasses import dataclass
|
|
40
|
+
from contextlib import nullcontext
|
|
41
|
+
import os
|
|
42
|
+
import math
|
|
43
|
+
import gc
|
|
44
|
+
import copy
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
# 常量与工具模块
|
|
48
|
+
# =============================================================================
|
|
49
|
+
torch.backends.cudnn.benchmark = True
|
|
50
|
+
EPS = 1e-8
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class IOUtils:
|
|
54
|
+
# 文件与路径处理的小工具集合。
|
|
55
|
+
|
|
56
|
+
@staticmethod
|
|
57
|
+
def csv_to_dict(file_path: str) -> List[Dict[str, Any]]:
|
|
58
|
+
with open(file_path, mode='r', encoding='utf-8') as file:
|
|
59
|
+
reader = csv.DictReader(file)
|
|
60
|
+
return [
|
|
61
|
+
dict(filter(lambda item: item[0] != '', row.items()))
|
|
62
|
+
for row in reader
|
|
63
|
+
]
|
|
64
|
+
|
|
65
|
+
@staticmethod
|
|
66
|
+
def ensure_parent_dir(file_path: str) -> None:
|
|
67
|
+
# 若目标文件所在目录不存在则自动创建
|
|
68
|
+
directory = os.path.dirname(file_path)
|
|
69
|
+
if directory:
|
|
70
|
+
os.makedirs(directory, exist_ok=True)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class TrainingUtils:
|
|
74
|
+
# 训练阶段常用的小型辅助函数集合。
|
|
75
|
+
|
|
76
|
+
@staticmethod
|
|
77
|
+
def compute_batch_size(data_size: int, learning_rate: float, batch_num: int, minimum: int) -> int:
|
|
78
|
+
estimated = int((learning_rate / 1e-4) ** 0.5 *
|
|
79
|
+
(data_size / max(batch_num, 1)))
|
|
80
|
+
return max(1, min(data_size, max(minimum, estimated)))
|
|
81
|
+
|
|
82
|
+
@staticmethod
|
|
83
|
+
def tweedie_loss(pred, target, p=1.5, eps=1e-6, max_clip=1e6):
|
|
84
|
+
# 为确保稳定性先将预测值裁剪为正数
|
|
85
|
+
pred_clamped = torch.clamp(pred, min=eps)
|
|
86
|
+
if p == 1:
|
|
87
|
+
term1 = target * torch.log(target / pred_clamped + eps) # 泊松
|
|
88
|
+
term2 = -target + pred_clamped
|
|
89
|
+
term3 = 0
|
|
90
|
+
elif p == 0:
|
|
91
|
+
term1 = 0.5 * torch.pow(target - pred_clamped, 2) # 高斯
|
|
92
|
+
term2 = 0
|
|
93
|
+
term3 = 0
|
|
94
|
+
elif p == 2:
|
|
95
|
+
term1 = torch.log(pred_clamped / target + eps) # 伽马
|
|
96
|
+
term2 = -target / pred_clamped + 1
|
|
97
|
+
term3 = 0
|
|
98
|
+
else:
|
|
99
|
+
term1 = torch.pow(target, 2 - p) / ((1 - p) * (2 - p))
|
|
100
|
+
term2 = target * torch.pow(pred_clamped, 1 - p) / (1 - p)
|
|
101
|
+
term3 = torch.pow(pred_clamped, 2 - p) / (2 - p)
|
|
102
|
+
return torch.nan_to_num( # Tweedie 负对数似然(忽略常数项)
|
|
103
|
+
2 * (term1 - term2 + term3),
|
|
104
|
+
nan=eps,
|
|
105
|
+
posinf=max_clip,
|
|
106
|
+
neginf=-max_clip
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
@staticmethod
|
|
110
|
+
def free_cuda() -> None:
|
|
111
|
+
print(">>> Moving all models to CPU...")
|
|
112
|
+
for obj in gc.get_objects():
|
|
113
|
+
try:
|
|
114
|
+
if hasattr(obj, "to") and callable(obj.to):
|
|
115
|
+
obj.to("cpu")
|
|
116
|
+
except Exception:
|
|
117
|
+
pass
|
|
118
|
+
|
|
119
|
+
print(">>> Deleting tensors, optimizers, dataloaders...")
|
|
120
|
+
gc.collect()
|
|
121
|
+
|
|
122
|
+
print(">>> Emptying CUDA cache...")
|
|
123
|
+
torch.cuda.empty_cache()
|
|
124
|
+
torch.cuda.synchronize()
|
|
125
|
+
|
|
126
|
+
print(">>> CUDA memory freed.")
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
class DistributedUtils:
|
|
130
|
+
_cached_state: Optional[tuple] = None
|
|
131
|
+
|
|
132
|
+
@staticmethod
|
|
133
|
+
def setup_ddp():
|
|
134
|
+
"""Initialize DDP process group."""
|
|
135
|
+
if dist.is_initialized():
|
|
136
|
+
if DistributedUtils._cached_state is None:
|
|
137
|
+
rank = dist.get_rank()
|
|
138
|
+
world_size = dist.get_world_size()
|
|
139
|
+
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
|
140
|
+
DistributedUtils._cached_state = (
|
|
141
|
+
True,
|
|
142
|
+
local_rank,
|
|
143
|
+
rank,
|
|
144
|
+
world_size,
|
|
145
|
+
)
|
|
146
|
+
return DistributedUtils._cached_state
|
|
147
|
+
|
|
148
|
+
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
|
149
|
+
rank = int(os.environ["RANK"])
|
|
150
|
+
world_size = int(os.environ["WORLD_SIZE"])
|
|
151
|
+
local_rank = int(os.environ["LOCAL_RANK"])
|
|
152
|
+
|
|
153
|
+
if torch.cuda.is_available():
|
|
154
|
+
torch.cuda.set_device(local_rank)
|
|
155
|
+
|
|
156
|
+
dist.init_process_group(backend="nccl", init_method="env://")
|
|
157
|
+
print(
|
|
158
|
+
f">>> DDP Initialized: Rank {rank}/{world_size}, Local Rank {local_rank}")
|
|
159
|
+
DistributedUtils._cached_state = (
|
|
160
|
+
True,
|
|
161
|
+
local_rank,
|
|
162
|
+
rank,
|
|
163
|
+
world_size,
|
|
164
|
+
)
|
|
165
|
+
return DistributedUtils._cached_state
|
|
166
|
+
else:
|
|
167
|
+
print(
|
|
168
|
+
f">>> DDP Setup Failed: RANK or WORLD_SIZE not found in env. Keys found: {list(os.environ.keys())}")
|
|
169
|
+
return False, 0, 0, 1
|
|
170
|
+
|
|
171
|
+
@staticmethod
|
|
172
|
+
def cleanup_ddp():
|
|
173
|
+
"""Destroy DDP process group."""
|
|
174
|
+
if dist.is_initialized():
|
|
175
|
+
dist.destroy_process_group()
|
|
176
|
+
DistributedUtils._cached_state = None
|
|
177
|
+
|
|
178
|
+
@staticmethod
|
|
179
|
+
def is_main_process():
|
|
180
|
+
return not dist.is_initialized() or dist.get_rank() == 0
|
|
181
|
+
|
|
182
|
+
@staticmethod
|
|
183
|
+
def world_size() -> int:
|
|
184
|
+
return dist.get_world_size() if dist.is_initialized() else 1
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
class PlotUtils:
|
|
188
|
+
# 多种模型共享的绘图辅助工具。
|
|
189
|
+
|
|
190
|
+
@staticmethod
|
|
191
|
+
def split_data(data: pd.DataFrame, col_nme: str, wgt_nme: str, n_bins: int = 10) -> pd.DataFrame:
|
|
192
|
+
data_sorted = data.sort_values(by=col_nme, ascending=True).copy()
|
|
193
|
+
data_sorted['cum_weight'] = data_sorted[wgt_nme].cumsum()
|
|
194
|
+
w_sum = data_sorted[wgt_nme].sum()
|
|
195
|
+
if w_sum <= EPS:
|
|
196
|
+
data_sorted.loc[:, 'bins'] = 0
|
|
197
|
+
else:
|
|
198
|
+
data_sorted.loc[:, 'bins'] = np.floor(
|
|
199
|
+
data_sorted['cum_weight'] * float(n_bins) / w_sum
|
|
200
|
+
)
|
|
201
|
+
data_sorted.loc[(data_sorted['bins'] == n_bins),
|
|
202
|
+
'bins'] = n_bins - 1
|
|
203
|
+
return data_sorted.groupby(['bins'], observed=True).sum(numeric_only=True)
|
|
204
|
+
|
|
205
|
+
@staticmethod
|
|
206
|
+
def plot_lift_ax(ax, plot_data, title, pred_label='Predicted', act_label='Actual', weight_label='Earned Exposure'):
|
|
207
|
+
ax.plot(plot_data.index, plot_data['act_v'],
|
|
208
|
+
label=act_label, color='red')
|
|
209
|
+
ax.plot(plot_data.index, plot_data['exp_v'],
|
|
210
|
+
label=pred_label, color='blue')
|
|
211
|
+
ax.set_title(title, fontsize=8)
|
|
212
|
+
ax.set_xticks(plot_data.index)
|
|
213
|
+
ax.set_xticklabels(plot_data.index, rotation=90, fontsize=6)
|
|
214
|
+
ax.tick_params(axis='y', labelsize=6)
|
|
215
|
+
ax.legend(loc='upper left', fontsize=5, frameon=False)
|
|
216
|
+
ax.margins(0.05)
|
|
217
|
+
ax2 = ax.twinx()
|
|
218
|
+
ax2.bar(plot_data.index, plot_data['weight'],
|
|
219
|
+
alpha=0.5, color='seagreen',
|
|
220
|
+
label=weight_label)
|
|
221
|
+
ax2.tick_params(axis='y', labelsize=6)
|
|
222
|
+
ax2.legend(loc='upper right', fontsize=5, frameon=False)
|
|
223
|
+
|
|
224
|
+
@staticmethod
|
|
225
|
+
def plot_dlift_ax(ax, plot_data, title, label1, label2, act_label='Actual', weight_label='Earned Exposure'):
|
|
226
|
+
ax.plot(plot_data.index, plot_data['act_v'],
|
|
227
|
+
label=act_label, color='red')
|
|
228
|
+
ax.plot(plot_data.index, plot_data['exp_v1'],
|
|
229
|
+
label=label1, color='blue')
|
|
230
|
+
ax.plot(plot_data.index, plot_data['exp_v2'],
|
|
231
|
+
label=label2, color='black')
|
|
232
|
+
ax.set_title(title, fontsize=8)
|
|
233
|
+
ax.set_xticks(plot_data.index)
|
|
234
|
+
ax.set_xticklabels(plot_data.index, rotation=90, fontsize=6)
|
|
235
|
+
ax.set_xlabel(f'{label1} / {label2}', fontsize=6)
|
|
236
|
+
ax.tick_params(axis='y', labelsize=6)
|
|
237
|
+
ax.legend(loc='upper left', fontsize=5, frameon=False)
|
|
238
|
+
ax.margins(0.1)
|
|
239
|
+
ax2 = ax.twinx()
|
|
240
|
+
ax2.bar(plot_data.index, plot_data['weight'],
|
|
241
|
+
alpha=0.5, color='seagreen',
|
|
242
|
+
label=weight_label)
|
|
243
|
+
ax2.tick_params(axis='y', labelsize=6)
|
|
244
|
+
ax2.legend(loc='upper right', fontsize=5, frameon=False)
|
|
245
|
+
|
|
246
|
+
@staticmethod
|
|
247
|
+
def plot_lift_list(pred_model, w_pred_list, w_act_list,
|
|
248
|
+
weight_list, tgt_nme, n_bins: int = 10,
|
|
249
|
+
fig_nme: str = 'Lift Chart'):
|
|
250
|
+
lift_data = pd.DataFrame()
|
|
251
|
+
lift_data.loc[:, 'pred'] = pred_model
|
|
252
|
+
lift_data.loc[:, 'w_pred'] = w_pred_list
|
|
253
|
+
lift_data.loc[:, 'act'] = w_act_list
|
|
254
|
+
lift_data.loc[:, 'weight'] = weight_list
|
|
255
|
+
plot_data = PlotUtils.split_data(lift_data, 'pred', 'weight', n_bins)
|
|
256
|
+
plot_data['exp_v'] = plot_data['w_pred'] / plot_data['weight']
|
|
257
|
+
plot_data['act_v'] = plot_data['act'] / plot_data['weight']
|
|
258
|
+
plot_data.reset_index(inplace=True)
|
|
259
|
+
|
|
260
|
+
fig = plt.figure(figsize=(7, 5))
|
|
261
|
+
ax = fig.add_subplot(111)
|
|
262
|
+
PlotUtils.plot_lift_ax(ax, plot_data, f'Lift Chart of {tgt_nme}')
|
|
263
|
+
plt.subplots_adjust(wspace=0.3)
|
|
264
|
+
|
|
265
|
+
save_path = os.path.join(
|
|
266
|
+
os.getcwd(), 'plot', f'05_{tgt_nme}_{fig_nme}.png')
|
|
267
|
+
IOUtils.ensure_parent_dir(save_path)
|
|
268
|
+
plt.savefig(save_path, dpi=300)
|
|
269
|
+
plt.close(fig)
|
|
270
|
+
|
|
271
|
+
@staticmethod
|
|
272
|
+
def plot_dlift_list(pred_model_1, pred_model_2,
|
|
273
|
+
model_nme_1, model_nme_2,
|
|
274
|
+
tgt_nme,
|
|
275
|
+
w_list, w_act_list, n_bins: int = 10,
|
|
276
|
+
fig_nme: str = 'Double Lift Chart'):
|
|
277
|
+
lift_data = pd.DataFrame()
|
|
278
|
+
lift_data.loc[:, 'pred1'] = pred_model_1
|
|
279
|
+
lift_data.loc[:, 'pred2'] = pred_model_2
|
|
280
|
+
lift_data.loc[:, 'diff_ly'] = lift_data['pred1'] / lift_data['pred2']
|
|
281
|
+
lift_data.loc[:, 'act'] = w_act_list
|
|
282
|
+
lift_data.loc[:, 'weight'] = w_list
|
|
283
|
+
lift_data.loc[:, 'w_pred1'] = lift_data['pred1'] * lift_data['weight']
|
|
284
|
+
lift_data.loc[:, 'w_pred2'] = lift_data['pred2'] * lift_data['weight']
|
|
285
|
+
plot_data = PlotUtils.split_data(
|
|
286
|
+
lift_data, 'diff_ly', 'weight', n_bins)
|
|
287
|
+
plot_data['exp_v1'] = plot_data['w_pred1'] / plot_data['act']
|
|
288
|
+
plot_data['exp_v2'] = plot_data['w_pred2'] / plot_data['act']
|
|
289
|
+
plot_data['act_v'] = plot_data['act']/plot_data['act']
|
|
290
|
+
plot_data.reset_index(inplace=True)
|
|
291
|
+
|
|
292
|
+
fig = plt.figure(figsize=(7, 5))
|
|
293
|
+
ax = fig.add_subplot(111)
|
|
294
|
+
PlotUtils.plot_dlift_ax(
|
|
295
|
+
ax, plot_data, f'Double Lift Chart of {tgt_nme}', model_nme_1, model_nme_2)
|
|
296
|
+
plt.subplots_adjust(bottom=0.25, top=0.95, right=0.8)
|
|
297
|
+
|
|
298
|
+
save_path = os.path.join(
|
|
299
|
+
os.getcwd(), 'plot', f'06_{tgt_nme}_{fig_nme}.png')
|
|
300
|
+
IOUtils.ensure_parent_dir(save_path)
|
|
301
|
+
plt.savefig(save_path, dpi=300)
|
|
302
|
+
plt.close(fig)
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
# 向后兼容的函数式封装
|
|
306
|
+
def csv_to_dict(file_path: str) -> List[Dict[str, Any]]:
|
|
307
|
+
return IOUtils.csv_to_dict(file_path)
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def ensure_parent_dir(file_path: str) -> None:
|
|
311
|
+
IOUtils.ensure_parent_dir(file_path)
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def compute_batch_size(data_size: int, learning_rate: float, batch_num: int, minimum: int) -> int:
|
|
315
|
+
return TrainingUtils.compute_batch_size(data_size, learning_rate, batch_num, minimum)
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
# 定义在 PyTorch 环境下的 Tweedie 偏差损失函数
|
|
319
|
+
# 参考文档:https://scikit-learn.org/stable/modules/model_evaluation.html#mean-poisson-gamma-and-tweedie-deviances
|
|
320
|
+
def tweedie_loss(pred, target, p=1.5, eps=1e-6, max_clip=1e6):
|
|
321
|
+
return TrainingUtils.tweedie_loss(pred, target, p=p, eps=eps, max_clip=max_clip)
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
# 定义释放CUDA内存函数
|
|
325
|
+
def free_cuda():
|
|
326
|
+
TrainingUtils.free_cuda()
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
class TorchTrainerMixin:
|
|
330
|
+
# 面向 Torch 表格训练器的共享工具方法。
|
|
331
|
+
|
|
332
|
+
def _device_type(self) -> str:
|
|
333
|
+
return getattr(self, "device", torch.device("cpu")).type
|
|
334
|
+
|
|
335
|
+
def _build_dataloader(self,
|
|
336
|
+
dataset,
|
|
337
|
+
N: int,
|
|
338
|
+
base_bs_gpu: tuple,
|
|
339
|
+
base_bs_cpu: tuple,
|
|
340
|
+
min_bs: int = 64,
|
|
341
|
+
target_effective_cuda: int = 8192,
|
|
342
|
+
target_effective_cpu: int = 4096,
|
|
343
|
+
large_threshold: int = 200_000,
|
|
344
|
+
mid_threshold: int = 50_000):
|
|
345
|
+
batch_size = TrainingUtils.compute_batch_size(
|
|
346
|
+
data_size=len(dataset),
|
|
347
|
+
learning_rate=self.learning_rate,
|
|
348
|
+
batch_num=self.batch_num,
|
|
349
|
+
minimum=min_bs
|
|
350
|
+
)
|
|
351
|
+
gpu_large, gpu_mid, gpu_small = base_bs_gpu
|
|
352
|
+
cpu_mid, cpu_small = base_bs_cpu
|
|
353
|
+
|
|
354
|
+
if self._device_type() == 'cuda':
|
|
355
|
+
device_count = torch.cuda.device_count()
|
|
356
|
+
if getattr(self, "is_ddp_enabled", False):
|
|
357
|
+
device_count = 1
|
|
358
|
+
# 多卡环境下,适当增大最小批量,确保每张卡都能分到足够数据
|
|
359
|
+
if device_count > 1:
|
|
360
|
+
min_bs = min_bs * device_count
|
|
361
|
+
print(
|
|
362
|
+
f">>> Multi-GPU detected: {device_count} devices. Adjusted min_bs to {min_bs}.")
|
|
363
|
+
|
|
364
|
+
if N > large_threshold:
|
|
365
|
+
base_bs = gpu_large * device_count
|
|
366
|
+
elif N > mid_threshold:
|
|
367
|
+
base_bs = gpu_mid * device_count
|
|
368
|
+
else:
|
|
369
|
+
base_bs = gpu_small * device_count
|
|
370
|
+
else:
|
|
371
|
+
base_bs = cpu_mid if N > mid_threshold else cpu_small
|
|
372
|
+
|
|
373
|
+
# 重新计算 batch_size,确保不小于调整后的 min_bs
|
|
374
|
+
batch_size = TrainingUtils.compute_batch_size(
|
|
375
|
+
data_size=len(dataset),
|
|
376
|
+
learning_rate=self.learning_rate,
|
|
377
|
+
batch_num=self.batch_num,
|
|
378
|
+
minimum=min_bs
|
|
379
|
+
)
|
|
380
|
+
batch_size = min(batch_size, base_bs, N)
|
|
381
|
+
|
|
382
|
+
target_effective_bs = target_effective_cuda if self._device_type(
|
|
383
|
+
) == 'cuda' else target_effective_cpu
|
|
384
|
+
if getattr(self, "is_ddp_enabled", False):
|
|
385
|
+
world_size = max(1, DistributedUtils.world_size())
|
|
386
|
+
target_effective_bs = max(1, target_effective_bs // world_size)
|
|
387
|
+
accum_steps = max(1, target_effective_bs // batch_size)
|
|
388
|
+
|
|
389
|
+
print(
|
|
390
|
+
f">>> DataLoader config: Batch Size={batch_size}, Accum Steps={accum_steps}, Workers={min(8, os.cpu_count() or 1)}")
|
|
391
|
+
|
|
392
|
+
# Linux (posix) 采用 fork 更高效;Windows (nt) 使用 spawn,开销更大。
|
|
393
|
+
if os.name == 'nt':
|
|
394
|
+
workers = 0
|
|
395
|
+
else:
|
|
396
|
+
workers = min(8, os.cpu_count() or 1)
|
|
397
|
+
if getattr(self, "is_ddp_enabled", False):
|
|
398
|
+
workers = 0 # DDP 下禁用多 worker,防止子进程冲突
|
|
399
|
+
|
|
400
|
+
sampler = None
|
|
401
|
+
if dist.is_initialized():
|
|
402
|
+
sampler = DistributedSampler(dataset, shuffle=True)
|
|
403
|
+
shuffle = False # Sampler handles shuffling
|
|
404
|
+
else:
|
|
405
|
+
shuffle = True
|
|
406
|
+
|
|
407
|
+
persistent = workers > 0
|
|
408
|
+
if getattr(self, "is_ddp_enabled", False):
|
|
409
|
+
persistent = False # 防止 DDP/pruning 提前退出导致 worker 状态异常
|
|
410
|
+
dataloader = DataLoader(
|
|
411
|
+
dataset,
|
|
412
|
+
batch_size=batch_size,
|
|
413
|
+
shuffle=shuffle,
|
|
414
|
+
sampler=sampler,
|
|
415
|
+
num_workers=workers,
|
|
416
|
+
pin_memory=(self._device_type() == 'cuda'),
|
|
417
|
+
persistent_workers=persistent,
|
|
418
|
+
)
|
|
419
|
+
return dataloader, accum_steps
|
|
420
|
+
|
|
421
|
+
def _compute_weighted_loss(self, y_pred, y_true, weights, apply_softplus: bool = False):
|
|
422
|
+
task = getattr(self, "task_type", "regression")
|
|
423
|
+
if task == 'classification':
|
|
424
|
+
loss_fn = nn.BCEWithLogitsLoss(reduction='none')
|
|
425
|
+
losses = loss_fn(y_pred, y_true).view(-1)
|
|
426
|
+
else:
|
|
427
|
+
if apply_softplus:
|
|
428
|
+
y_pred = F.softplus(y_pred)
|
|
429
|
+
y_pred = torch.clamp(y_pred, min=1e-6)
|
|
430
|
+
power = getattr(self, "tw_power", 1.5)
|
|
431
|
+
losses = tweedie_loss(y_pred, y_true, p=power).view(-1)
|
|
432
|
+
weighted_loss = (losses * weights.view(-1)).sum() / \
|
|
433
|
+
torch.clamp(weights.sum(), min=EPS)
|
|
434
|
+
return weighted_loss
|
|
435
|
+
|
|
436
|
+
def _early_stop_update(self, val_loss, best_loss, best_state, patience_counter, model,
|
|
437
|
+
ignore_keys: Optional[List[str]] = None):
|
|
438
|
+
if val_loss < best_loss:
|
|
439
|
+
ignore_keys = ignore_keys or []
|
|
440
|
+
state_dict = {
|
|
441
|
+
k: (v.clone() if isinstance(v, torch.Tensor) else copy.deepcopy(v))
|
|
442
|
+
for k, v in model.state_dict().items()
|
|
443
|
+
if not any(k.startswith(ignore_key) for ignore_key in ignore_keys)
|
|
444
|
+
}
|
|
445
|
+
return val_loss, state_dict, 0, False
|
|
446
|
+
patience_counter += 1
|
|
447
|
+
should_stop = best_state is not None and patience_counter >= getattr(
|
|
448
|
+
self, "patience", 0)
|
|
449
|
+
return best_loss, best_state, patience_counter, should_stop
|
|
450
|
+
|
|
451
|
+
def _train_model(self,
|
|
452
|
+
model,
|
|
453
|
+
dataloader,
|
|
454
|
+
accum_steps,
|
|
455
|
+
optimizer,
|
|
456
|
+
scaler,
|
|
457
|
+
forward_fn,
|
|
458
|
+
val_forward_fn=None,
|
|
459
|
+
apply_softplus: bool = False,
|
|
460
|
+
clip_fn=None,
|
|
461
|
+
trial: Optional[optuna.trial.Trial] = None):
|
|
462
|
+
device_type = self._device_type()
|
|
463
|
+
best_loss = float('inf')
|
|
464
|
+
best_state = None
|
|
465
|
+
patience_counter = 0
|
|
466
|
+
stop_training = False
|
|
467
|
+
|
|
468
|
+
is_ddp_model = isinstance(model, DDP)
|
|
469
|
+
|
|
470
|
+
for epoch in range(1, getattr(self, "epochs", 1) + 1):
|
|
471
|
+
if hasattr(self, 'dataloader_sampler') and self.dataloader_sampler is not None:
|
|
472
|
+
self.dataloader_sampler.set_epoch(epoch)
|
|
473
|
+
|
|
474
|
+
model.train()
|
|
475
|
+
optimizer.zero_grad()
|
|
476
|
+
|
|
477
|
+
for step, batch in enumerate(dataloader):
|
|
478
|
+
is_update_step = ((step + 1) % accum_steps == 0) or \
|
|
479
|
+
((step + 1) == len(dataloader))
|
|
480
|
+
sync_cm = model.no_sync if (
|
|
481
|
+
is_ddp_model and not is_update_step) else nullcontext
|
|
482
|
+
|
|
483
|
+
with sync_cm():
|
|
484
|
+
with autocast(enabled=(device_type == 'cuda')):
|
|
485
|
+
y_pred, y_true, w = forward_fn(batch)
|
|
486
|
+
weighted_loss = self._compute_weighted_loss(
|
|
487
|
+
y_pred, y_true, w, apply_softplus=apply_softplus)
|
|
488
|
+
loss_for_backward = weighted_loss / accum_steps
|
|
489
|
+
|
|
490
|
+
scaler.scale(loss_for_backward).backward()
|
|
491
|
+
|
|
492
|
+
if is_update_step:
|
|
493
|
+
if clip_fn is not None:
|
|
494
|
+
clip_fn()
|
|
495
|
+
scaler.step(optimizer)
|
|
496
|
+
scaler.update()
|
|
497
|
+
optimizer.zero_grad()
|
|
498
|
+
|
|
499
|
+
if val_forward_fn is not None:
|
|
500
|
+
should_compute_val = (not dist.is_initialized()
|
|
501
|
+
or DistributedUtils.is_main_process())
|
|
502
|
+
val_device = getattr(self, "device", torch.device("cpu"))
|
|
503
|
+
if not isinstance(val_device, torch.device):
|
|
504
|
+
val_device = torch.device(val_device)
|
|
505
|
+
loss_tensor_device = val_device if device_type == 'cuda' else torch.device(
|
|
506
|
+
"cpu")
|
|
507
|
+
val_loss_tensor = torch.zeros(1, device=loss_tensor_device)
|
|
508
|
+
|
|
509
|
+
if should_compute_val:
|
|
510
|
+
model.eval()
|
|
511
|
+
with torch.no_grad(), autocast(enabled=(device_type == 'cuda')):
|
|
512
|
+
val_result = val_forward_fn()
|
|
513
|
+
if isinstance(val_result, tuple) and len(val_result) == 3:
|
|
514
|
+
y_val_pred, y_val_true, w_val = val_result
|
|
515
|
+
val_weighted_loss = self._compute_weighted_loss(
|
|
516
|
+
y_val_pred, y_val_true, w_val, apply_softplus=apply_softplus)
|
|
517
|
+
else:
|
|
518
|
+
val_weighted_loss = val_result
|
|
519
|
+
val_loss_tensor[0] = float(val_weighted_loss)
|
|
520
|
+
|
|
521
|
+
if dist.is_initialized():
|
|
522
|
+
dist.broadcast(val_loss_tensor, src=0)
|
|
523
|
+
val_weighted_loss = float(val_loss_tensor.item())
|
|
524
|
+
|
|
525
|
+
best_loss, best_state, patience_counter, stop_training = self._early_stop_update(
|
|
526
|
+
val_weighted_loss, best_loss, best_state, patience_counter, model)
|
|
527
|
+
|
|
528
|
+
# Optuna ????????????????????????????????? trial
|
|
529
|
+
if trial is not None:
|
|
530
|
+
trial.report(val_weighted_loss, epoch)
|
|
531
|
+
if trial.should_prune():
|
|
532
|
+
raise optuna.TrialPruned()
|
|
533
|
+
|
|
534
|
+
if stop_training:
|
|
535
|
+
break
|
|
536
|
+
|
|
537
|
+
return best_state
|
|
538
|
+
|
|
539
|
+
|
|
540
|
+
# =============================================================================
|
|
541
|
+
# 绘图辅助模块
|
|
542
|
+
# =============================================================================
|
|
543
|
+
|
|
544
|
+
def split_data(data, col_nme, wgt_nme, n_bins=10):
|
|
545
|
+
return PlotUtils.split_data(data, col_nme, wgt_nme, n_bins)
|
|
546
|
+
|
|
547
|
+
# 定义提纯曲线(Lift)绘制函数
|
|
548
|
+
|
|
549
|
+
|
|
550
|
+
def plot_lift_list(pred_model, w_pred_list, w_act_list,
|
|
551
|
+
weight_list, tgt_nme, n_bins=10,
|
|
552
|
+
fig_nme='Lift Chart'):
|
|
553
|
+
return PlotUtils.plot_lift_list(pred_model, w_pred_list, w_act_list,
|
|
554
|
+
weight_list, tgt_nme, n_bins, fig_nme)
|
|
555
|
+
|
|
556
|
+
# 定义双提纯曲线绘制函数
|
|
557
|
+
|
|
558
|
+
|
|
559
|
+
def plot_dlift_list(pred_model_1, pred_model_2,
|
|
560
|
+
model_nme_1, model_nme_2,
|
|
561
|
+
tgt_nme,
|
|
562
|
+
w_list, w_act_list, n_bins=10,
|
|
563
|
+
fig_nme='Double Lift Chart'):
|
|
564
|
+
return PlotUtils.plot_dlift_list(pred_model_1, pred_model_2,
|
|
565
|
+
model_nme_1, model_nme_2,
|
|
566
|
+
tgt_nme, w_list, w_act_list,
|
|
567
|
+
n_bins, fig_nme)
|
|
568
|
+
|
|
569
|
+
|
|
570
|
+
# =============================================================================
|
|
571
|
+
# ResNet 模型与 sklearn 风格封装
|
|
572
|
+
# =============================================================================
|
|
573
|
+
|
|
574
|
+
# 开始定义ResNet模型结构
|
|
575
|
+
# 残差块:两层线性 + ReLU + 残差连接
|
|
576
|
+
# ResBlock 继承 nn.Module
|
|
577
|
+
class ResBlock(nn.Module):
|
|
578
|
+
def __init__(self, dim: int, dropout: float = 0.1,
|
|
579
|
+
use_layernorm: bool = False, residual_scale: float = 0.1
|
|
580
|
+
):
|
|
581
|
+
super().__init__()
|
|
582
|
+
self.use_layernorm = use_layernorm
|
|
583
|
+
|
|
584
|
+
if use_layernorm:
|
|
585
|
+
Norm = nn.LayerNorm # 对最后一维做归一化
|
|
586
|
+
else:
|
|
587
|
+
def Norm(d): return nn.BatchNorm1d(d) # 保留一个开关,想试 BN 时也能用
|
|
588
|
+
|
|
589
|
+
self.norm1 = Norm(dim)
|
|
590
|
+
self.fc1 = nn.Linear(dim, dim, bias=True)
|
|
591
|
+
self.act = nn.ReLU(inplace=True)
|
|
592
|
+
self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()
|
|
593
|
+
# self.norm2 = Norm(dim)
|
|
594
|
+
self.fc2 = nn.Linear(dim, dim, bias=True)
|
|
595
|
+
|
|
596
|
+
# 残差缩放,防止一开始就把主干搞炸
|
|
597
|
+
self.res_scale = nn.Parameter(
|
|
598
|
+
torch.tensor(residual_scale, dtype=torch.float32)
|
|
599
|
+
)
|
|
600
|
+
|
|
601
|
+
def forward(self, x):
|
|
602
|
+
# 前置激活结构
|
|
603
|
+
out = self.norm1(x)
|
|
604
|
+
out = self.fc1(out)
|
|
605
|
+
out = self.act(out)
|
|
606
|
+
out = self.dropout(out)
|
|
607
|
+
# out = self.norm2(out)
|
|
608
|
+
out = self.fc2(out)
|
|
609
|
+
# 残差缩放再相加
|
|
610
|
+
return x + self.res_scale * out
|
|
611
|
+
|
|
612
|
+
# ResNetSequential 继承 nn.Module,定义整个网络结构
|
|
613
|
+
|
|
614
|
+
|
|
615
|
+
class ResNetSequential(nn.Module):
|
|
616
|
+
# 输入张量形状:(batch, input_dim)
|
|
617
|
+
# 网络结构:全连接 + 归一化 + ReLU,再堆叠若干残差块,最后输出 Softplus
|
|
618
|
+
|
|
619
|
+
def __init__(self, input_dim: int, hidden_dim: int = 64, block_num: int = 2,
|
|
620
|
+
use_layernorm: bool = True, dropout: float = 0.1,
|
|
621
|
+
residual_scale: float = 0.1, task_type: str = 'regression'):
|
|
622
|
+
super(ResNetSequential, self).__init__()
|
|
623
|
+
|
|
624
|
+
self.net = nn.Sequential()
|
|
625
|
+
self.net.add_module('fc1', nn.Linear(input_dim, hidden_dim))
|
|
626
|
+
|
|
627
|
+
# if use_layernorm:
|
|
628
|
+
# self.net.add_module('norm1', nn.LayerNorm(hidden_dim))
|
|
629
|
+
# else:
|
|
630
|
+
# self.net.add_module('norm1', nn.BatchNorm1d(hidden_dim))
|
|
631
|
+
|
|
632
|
+
# self.net.add_module('relu1', nn.ReLU(inplace=True))
|
|
633
|
+
|
|
634
|
+
# 多个残差块
|
|
635
|
+
for i in range(block_num):
|
|
636
|
+
self.net.add_module(
|
|
637
|
+
f'ResBlk_{i+1}',
|
|
638
|
+
ResBlock(
|
|
639
|
+
hidden_dim,
|
|
640
|
+
dropout=dropout,
|
|
641
|
+
use_layernorm=use_layernorm,
|
|
642
|
+
residual_scale=residual_scale)
|
|
643
|
+
)
|
|
644
|
+
|
|
645
|
+
self.net.add_module('fc_out', nn.Linear(hidden_dim, 1))
|
|
646
|
+
|
|
647
|
+
if task_type == 'classification':
|
|
648
|
+
self.net.add_module('softplus', nn.Identity())
|
|
649
|
+
else:
|
|
650
|
+
self.net.add_module('softplus', nn.Softplus())
|
|
651
|
+
|
|
652
|
+
def forward(self, x):
|
|
653
|
+
if self.training and not hasattr(self, '_printed_device'):
|
|
654
|
+
print(f">>> ResNetSequential executing on device: {x.device}")
|
|
655
|
+
self._printed_device = True
|
|
656
|
+
return self.net(x)
|
|
657
|
+
|
|
658
|
+
# 定义ResNet模型的Scikit-Learn接口类
|
|
659
|
+
|
|
660
|
+
|
|
661
|
+
class ResNetSklearn(TorchTrainerMixin, nn.Module):
|
|
662
|
+
def __init__(self, model_nme: str, input_dim: int, hidden_dim: int = 64,
|
|
663
|
+
block_num: int = 2, batch_num: int = 100, epochs: int = 100,
|
|
664
|
+
task_type: str = 'regression',
|
|
665
|
+
tweedie_power: float = 1.5, learning_rate: float = 0.01, patience: int = 10,
|
|
666
|
+
use_layernorm: bool = True, dropout: float = 0.1,
|
|
667
|
+
residual_scale: float = 0.1,
|
|
668
|
+
use_data_parallel: bool = True,
|
|
669
|
+
use_ddp: bool = False):
|
|
670
|
+
super(ResNetSklearn, self).__init__()
|
|
671
|
+
|
|
672
|
+
self.use_ddp = use_ddp
|
|
673
|
+
self.is_ddp_enabled, self.local_rank, self.rank, self.world_size = (
|
|
674
|
+
False, 0, 0, 1)
|
|
675
|
+
|
|
676
|
+
if self.use_ddp:
|
|
677
|
+
self.is_ddp_enabled, self.local_rank, self.rank, self.world_size = DistributedUtils.setup_ddp()
|
|
678
|
+
|
|
679
|
+
self.input_dim = input_dim
|
|
680
|
+
self.hidden_dim = hidden_dim
|
|
681
|
+
self.block_num = block_num
|
|
682
|
+
self.batch_num = batch_num
|
|
683
|
+
self.epochs = epochs
|
|
684
|
+
self.task_type = task_type
|
|
685
|
+
self.model_nme = model_nme
|
|
686
|
+
self.learning_rate = learning_rate
|
|
687
|
+
self.patience = patience
|
|
688
|
+
self.use_layernorm = use_layernorm
|
|
689
|
+
self.dropout = dropout
|
|
690
|
+
self.residual_scale = residual_scale
|
|
691
|
+
|
|
692
|
+
# 设备选择:cuda > mps > cpu
|
|
693
|
+
if self.is_ddp_enabled:
|
|
694
|
+
self.device = torch.device(f'cuda:{self.local_rank}')
|
|
695
|
+
elif torch.cuda.is_available():
|
|
696
|
+
self.device = torch.device('cuda')
|
|
697
|
+
elif torch.backends.mps.is_available():
|
|
698
|
+
self.device = torch.device('mps')
|
|
699
|
+
else:
|
|
700
|
+
self.device = torch.device('cpu')
|
|
701
|
+
|
|
702
|
+
# Tweedie 幂指数设定(分类时不使用)
|
|
703
|
+
if self.task_type == 'classification':
|
|
704
|
+
self.tw_power = None
|
|
705
|
+
elif 'f' in self.model_nme:
|
|
706
|
+
self.tw_power = 1
|
|
707
|
+
elif 's' in self.model_nme:
|
|
708
|
+
self.tw_power = 2
|
|
709
|
+
else:
|
|
710
|
+
self.tw_power = tweedie_power
|
|
711
|
+
|
|
712
|
+
# 搭建网络(先在 CPU 上建好)
|
|
713
|
+
core = ResNetSequential(
|
|
714
|
+
self.input_dim,
|
|
715
|
+
self.hidden_dim,
|
|
716
|
+
self.block_num,
|
|
717
|
+
use_layernorm=self.use_layernorm,
|
|
718
|
+
dropout=self.dropout,
|
|
719
|
+
residual_scale=self.residual_scale,
|
|
720
|
+
task_type=self.task_type
|
|
721
|
+
)
|
|
722
|
+
|
|
723
|
+
# ===== 多卡支持:DataParallel vs DistributedDataParallel =====
|
|
724
|
+
if self.is_ddp_enabled:
|
|
725
|
+
core = core.to(self.device)
|
|
726
|
+
core = DDP(core, device_ids=[
|
|
727
|
+
self.local_rank], output_device=self.local_rank)
|
|
728
|
+
elif use_data_parallel and (self.device.type == 'cuda') and (torch.cuda.device_count() > 1):
|
|
729
|
+
core = nn.DataParallel(core, device_ids=list(
|
|
730
|
+
range(torch.cuda.device_count())))
|
|
731
|
+
# DataParallel 会把输入 scatter 到多卡上,但“主设备”仍然是 cuda:0
|
|
732
|
+
self.device = torch.device('cuda')
|
|
733
|
+
|
|
734
|
+
self.resnet = core.to(self.device)
|
|
735
|
+
|
|
736
|
+
# ================ 内部工具 ================
|
|
737
|
+
def _build_train_val_tensors(self, X_train, y_train, w_train, X_val, y_val, w_val):
|
|
738
|
+
X_tensor = torch.tensor(X_train.values, dtype=torch.float32)
|
|
739
|
+
y_tensor = torch.tensor(
|
|
740
|
+
y_train.values, dtype=torch.float32).view(-1, 1)
|
|
741
|
+
w_tensor = torch.tensor(w_train.values, dtype=torch.float32).view(
|
|
742
|
+
-1, 1) if w_train is not None else torch.ones_like(y_tensor)
|
|
743
|
+
|
|
744
|
+
has_val = X_val is not None and y_val is not None
|
|
745
|
+
if has_val:
|
|
746
|
+
X_val_tensor = torch.tensor(X_val.values, dtype=torch.float32)
|
|
747
|
+
y_val_tensor = torch.tensor(
|
|
748
|
+
y_val.values, dtype=torch.float32).view(-1, 1)
|
|
749
|
+
w_val_tensor = torch.tensor(w_val.values, dtype=torch.float32).view(
|
|
750
|
+
-1, 1) if w_val is not None else torch.ones_like(y_val_tensor)
|
|
751
|
+
else:
|
|
752
|
+
X_val_tensor = y_val_tensor = w_val_tensor = None
|
|
753
|
+
return X_tensor, y_tensor, w_tensor, X_val_tensor, y_val_tensor, w_val_tensor, has_val
|
|
754
|
+
|
|
755
|
+
def forward(self, x):
|
|
756
|
+
# 处理 SHAP 的 NumPy 输入
|
|
757
|
+
if isinstance(x, np.ndarray):
|
|
758
|
+
x_tensor = torch.tensor(x, dtype=torch.float32)
|
|
759
|
+
else:
|
|
760
|
+
x_tensor = x
|
|
761
|
+
|
|
762
|
+
x_tensor = x_tensor.to(self.device)
|
|
763
|
+
y_pred = self.resnet(x_tensor)
|
|
764
|
+
return y_pred
|
|
765
|
+
|
|
766
|
+
# ---------------- 训练 ----------------
|
|
767
|
+
|
|
768
|
+
def fit(self, X_train, y_train, w_train=None,
|
|
769
|
+
X_val=None, y_val=None, w_val=None, trial=None):
|
|
770
|
+
|
|
771
|
+
X_tensor, y_tensor, w_tensor, X_val_tensor, y_val_tensor, w_val_tensor, has_val = \
|
|
772
|
+
self._build_train_val_tensors(
|
|
773
|
+
X_train, y_train, w_train, X_val, y_val, w_val)
|
|
774
|
+
|
|
775
|
+
dataset = TensorDataset(X_tensor, y_tensor, w_tensor)
|
|
776
|
+
dataloader, accum_steps = self._build_dataloader(
|
|
777
|
+
dataset,
|
|
778
|
+
N=X_tensor.shape[0],
|
|
779
|
+
base_bs_gpu=(16384, 8192, 4096),
|
|
780
|
+
base_bs_cpu=(1024, 512),
|
|
781
|
+
min_bs=64,
|
|
782
|
+
target_effective_cuda=8192,
|
|
783
|
+
target_effective_cpu=4096
|
|
784
|
+
)
|
|
785
|
+
|
|
786
|
+
# 在每个 epoch 开始前设置 sampler 的 epoch,以保证 shuffle 的随机性
|
|
787
|
+
if self.is_ddp_enabled and hasattr(dataloader.sampler, 'set_epoch'):
|
|
788
|
+
self.dataloader_sampler = dataloader.sampler
|
|
789
|
+
else:
|
|
790
|
+
self.dataloader_sampler = None
|
|
791
|
+
|
|
792
|
+
# === 4. 优化器与 AMP ===
|
|
793
|
+
self.optimizer = torch.optim.Adam(
|
|
794
|
+
self.resnet.parameters(), lr=self.learning_rate)
|
|
795
|
+
self.scaler = GradScaler(enabled=(self.device.type == 'cuda'))
|
|
796
|
+
|
|
797
|
+
X_val_dev = y_val_dev = w_val_dev = None
|
|
798
|
+
val_dataloader = None
|
|
799
|
+
if has_val:
|
|
800
|
+
# 构建验证集 DataLoader
|
|
801
|
+
val_dataset = TensorDataset(
|
|
802
|
+
X_val_tensor, y_val_tensor, w_val_tensor)
|
|
803
|
+
# 验证阶段无需反向传播,可适当放大批量以提高吞吐
|
|
804
|
+
val_bs = accum_steps * dataloader.batch_size
|
|
805
|
+
|
|
806
|
+
# 验证集的 worker 数沿用相同的分配逻辑
|
|
807
|
+
if os.name == 'nt':
|
|
808
|
+
val_workers = 0
|
|
809
|
+
else:
|
|
810
|
+
val_workers = min(4, os.cpu_count() or 1)
|
|
811
|
+
if getattr(self, "is_ddp_enabled", False):
|
|
812
|
+
val_workers = 0 # DDP 下禁用多 worker,防止子进程冲突
|
|
813
|
+
|
|
814
|
+
val_dataloader = DataLoader(
|
|
815
|
+
val_dataset,
|
|
816
|
+
batch_size=val_bs,
|
|
817
|
+
shuffle=False,
|
|
818
|
+
num_workers=val_workers,
|
|
819
|
+
pin_memory=(self.device.type == 'cuda'),
|
|
820
|
+
persistent_workers=val_workers > 0,
|
|
821
|
+
)
|
|
822
|
+
# 验证集通常不需要 DDP Sampler,因为我们只在主进程验证或汇总验证结果
|
|
823
|
+
# 但为了简单起见,这里保持单卡验证或主进程验证
|
|
824
|
+
|
|
825
|
+
is_data_parallel = isinstance(self.resnet, nn.DataParallel)
|
|
826
|
+
|
|
827
|
+
def forward_fn(batch):
|
|
828
|
+
X_batch, y_batch, w_batch = batch
|
|
829
|
+
|
|
830
|
+
if not is_data_parallel:
|
|
831
|
+
X_batch = X_batch.to(self.device, non_blocking=True)
|
|
832
|
+
# 目标值与权重始终与主设备保持一致,便于后续损失计算
|
|
833
|
+
y_batch = y_batch.to(self.device, non_blocking=True)
|
|
834
|
+
w_batch = w_batch.to(self.device, non_blocking=True)
|
|
835
|
+
|
|
836
|
+
y_pred = self.resnet(X_batch)
|
|
837
|
+
return y_pred, y_batch, w_batch
|
|
838
|
+
|
|
839
|
+
def val_forward_fn():
|
|
840
|
+
total_loss = 0.0
|
|
841
|
+
total_weight = 0.0
|
|
842
|
+
for batch in val_dataloader:
|
|
843
|
+
X_b, y_b, w_b = batch
|
|
844
|
+
if not is_data_parallel:
|
|
845
|
+
X_b = X_b.to(self.device, non_blocking=True)
|
|
846
|
+
y_b = y_b.to(self.device, non_blocking=True)
|
|
847
|
+
w_b = w_b.to(self.device, non_blocking=True)
|
|
848
|
+
|
|
849
|
+
y_pred = self.resnet(X_b)
|
|
850
|
+
|
|
851
|
+
# 手动计算当前批次的加权损失,以便后续精确加总
|
|
852
|
+
task = getattr(self, "task_type", "regression")
|
|
853
|
+
if task == 'classification':
|
|
854
|
+
loss_fn = nn.BCEWithLogitsLoss(reduction='none')
|
|
855
|
+
losses = loss_fn(y_pred, y_b).view(-1)
|
|
856
|
+
else:
|
|
857
|
+
# 此处无需再做 softplus:训练时 apply_softplus=False,模型前向结果本身已为正
|
|
858
|
+
y_pred_clamped = torch.clamp(y_pred, min=1e-6)
|
|
859
|
+
power = getattr(self, "tw_power", 1.5)
|
|
860
|
+
losses = tweedie_loss(
|
|
861
|
+
y_pred_clamped, y_b, p=power).view(-1)
|
|
862
|
+
|
|
863
|
+
batch_weight_sum = torch.clamp(w_b.sum(), min=EPS)
|
|
864
|
+
batch_weighted_loss_sum = (losses * w_b.view(-1)).sum()
|
|
865
|
+
|
|
866
|
+
total_loss += batch_weighted_loss_sum.item()
|
|
867
|
+
total_weight += batch_weight_sum.item()
|
|
868
|
+
|
|
869
|
+
return total_loss / max(total_weight, EPS)
|
|
870
|
+
|
|
871
|
+
clip_fn = None
|
|
872
|
+
if self.device.type == 'cuda':
|
|
873
|
+
def clip_fn(): return (self.scaler.unscale_(self.optimizer),
|
|
874
|
+
clip_grad_norm_(self.resnet.parameters(), max_norm=1.0))
|
|
875
|
+
|
|
876
|
+
# DDP 模式下,只在主进程打印日志和保存模型
|
|
877
|
+
if self.is_ddp_enabled and not DistributedUtils.is_main_process():
|
|
878
|
+
# 非主进程不进行验证回调中的打印操作(需在 _train_model 内部控制,这里暂略)
|
|
879
|
+
pass
|
|
880
|
+
|
|
881
|
+
best_state = self._train_model(
|
|
882
|
+
self.resnet,
|
|
883
|
+
dataloader,
|
|
884
|
+
accum_steps,
|
|
885
|
+
self.optimizer,
|
|
886
|
+
self.scaler,
|
|
887
|
+
forward_fn,
|
|
888
|
+
val_forward_fn if has_val else None,
|
|
889
|
+
apply_softplus=False,
|
|
890
|
+
clip_fn=clip_fn,
|
|
891
|
+
trial=trial
|
|
892
|
+
)
|
|
893
|
+
|
|
894
|
+
if has_val and best_state is not None:
|
|
895
|
+
self.resnet.load_state_dict(best_state)
|
|
896
|
+
|
|
897
|
+
# ---------------- 预测 ----------------
|
|
898
|
+
|
|
899
|
+
def predict(self, X_test):
|
|
900
|
+
self.resnet.eval()
|
|
901
|
+
if isinstance(X_test, pd.DataFrame):
|
|
902
|
+
X_np = X_test.values.astype(np.float32)
|
|
903
|
+
else:
|
|
904
|
+
X_np = X_test
|
|
905
|
+
|
|
906
|
+
with torch.no_grad():
|
|
907
|
+
y_pred = self(X_np).cpu().numpy()
|
|
908
|
+
|
|
909
|
+
if self.task_type == 'classification':
|
|
910
|
+
y_pred = 1 / (1 + np.exp(-y_pred)) # Sigmoid 函数将 logit 转换为概率
|
|
911
|
+
else:
|
|
912
|
+
y_pred = np.clip(y_pred, 1e-6, None)
|
|
913
|
+
return y_pred.flatten()
|
|
914
|
+
|
|
915
|
+
# ---------------- 设置参数 ----------------
|
|
916
|
+
|
|
917
|
+
def set_params(self, params):
|
|
918
|
+
for key, value in params.items():
|
|
919
|
+
if hasattr(self, key):
|
|
920
|
+
setattr(self, key, value)
|
|
921
|
+
else:
|
|
922
|
+
raise ValueError(f"Parameter {key} not found in model.")
|
|
923
|
+
return self
|
|
924
|
+
|
|
925
|
+
|
|
926
|
+
# =============================================================================
|
|
927
|
+
# FT-Transformer 模型与 sklearn 风格封装
|
|
928
|
+
# =============================================================================
|
|
929
|
+
# 开始定义FT Transformer模型结构
|
|
930
|
+
|
|
931
|
+
|
|
932
|
+
class FeatureTokenizer(nn.Module):
|
|
933
|
+
# 将数值特征与类别特征统一映射为 token,输出形状为 (batch, token_num, d_model)
|
|
934
|
+
# 约定:
|
|
935
|
+
# - X_num:表示数值特征,shape=(batch, num_numeric)
|
|
936
|
+
# - X_cat:表示类别特征,shape=(batch, num_categorical),每列是编码后的整数标签 [0, card-1]
|
|
937
|
+
|
|
938
|
+
def __init__(self, num_numeric: int, cat_cardinalities, d_model: int):
|
|
939
|
+
super().__init__()
|
|
940
|
+
|
|
941
|
+
self.num_numeric = num_numeric
|
|
942
|
+
self.has_numeric = num_numeric > 0
|
|
943
|
+
|
|
944
|
+
if self.has_numeric:
|
|
945
|
+
self.num_linear = nn.Linear(num_numeric, d_model)
|
|
946
|
+
|
|
947
|
+
self.embeddings = nn.ModuleList([
|
|
948
|
+
nn.Embedding(card, d_model) for card in cat_cardinalities
|
|
949
|
+
])
|
|
950
|
+
|
|
951
|
+
def forward(self, X_num, X_cat):
|
|
952
|
+
tokens = []
|
|
953
|
+
|
|
954
|
+
if self.has_numeric:
|
|
955
|
+
# 数值特征整体映射为一个 token
|
|
956
|
+
# shape = (batch, d_model)
|
|
957
|
+
num_token = self.num_linear(X_num)
|
|
958
|
+
tokens.append(num_token)
|
|
959
|
+
|
|
960
|
+
# 每个类别特征各生成一个嵌入 token
|
|
961
|
+
for i, emb in enumerate(self.embeddings):
|
|
962
|
+
# shape = (batch, d_model)
|
|
963
|
+
tok = emb(X_cat[:, i])
|
|
964
|
+
tokens.append(tok)
|
|
965
|
+
|
|
966
|
+
# 拼接后得到 (batch, token_num, d_model)
|
|
967
|
+
x = torch.stack(tokens, dim=1)
|
|
968
|
+
return x
|
|
969
|
+
|
|
970
|
+
# 定义具有残差缩放的Encoder层
|
|
971
|
+
|
|
972
|
+
|
|
973
|
+
class ScaledTransformerEncoderLayer(nn.Module):
|
|
974
|
+
def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048,
|
|
975
|
+
dropout: float = 0.1, residual_scale_attn: float = 1.0,
|
|
976
|
+
residual_scale_ffn: float = 1.0, norm_first: bool = True,
|
|
977
|
+
):
|
|
978
|
+
super().__init__()
|
|
979
|
+
self.self_attn = nn.MultiheadAttention(
|
|
980
|
+
embed_dim=d_model,
|
|
981
|
+
num_heads=nhead,
|
|
982
|
+
dropout=dropout,
|
|
983
|
+
batch_first=True
|
|
984
|
+
)
|
|
985
|
+
|
|
986
|
+
# 前馈网络部分
|
|
987
|
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
|
988
|
+
self.dropout = nn.Dropout(dropout)
|
|
989
|
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
|
990
|
+
|
|
991
|
+
# 归一化与 Dropout
|
|
992
|
+
self.norm1 = nn.LayerNorm(d_model)
|
|
993
|
+
self.norm2 = nn.LayerNorm(d_model)
|
|
994
|
+
self.dropout1 = nn.Dropout(dropout)
|
|
995
|
+
self.dropout2 = nn.Dropout(dropout)
|
|
996
|
+
|
|
997
|
+
self.activation = nn.GELU()
|
|
998
|
+
# self.activation = nn.ReLU()
|
|
999
|
+
self.norm_first = norm_first
|
|
1000
|
+
|
|
1001
|
+
# 残差缩放系数
|
|
1002
|
+
self.res_scale_attn = residual_scale_attn
|
|
1003
|
+
self.res_scale_ffn = residual_scale_ffn
|
|
1004
|
+
|
|
1005
|
+
def forward(self, src, src_mask=None, src_key_padding_mask=None):
|
|
1006
|
+
# 输入张量形状:(batch, 序列长度, d_model)
|
|
1007
|
+
x = src
|
|
1008
|
+
|
|
1009
|
+
if self.norm_first:
|
|
1010
|
+
# 先归一化再做注意力
|
|
1011
|
+
x = x + self._sa_block(self.norm1(x), src_mask,
|
|
1012
|
+
src_key_padding_mask)
|
|
1013
|
+
x = x + self._ff_block(self.norm2(x))
|
|
1014
|
+
else:
|
|
1015
|
+
# 后归一化(一般不启用)
|
|
1016
|
+
x = self.norm1(
|
|
1017
|
+
x + self._sa_block(x, src_mask, src_key_padding_mask))
|
|
1018
|
+
x = self.norm2(x + self._ff_block(x))
|
|
1019
|
+
|
|
1020
|
+
return x
|
|
1021
|
+
|
|
1022
|
+
def _sa_block(self, x, attn_mask, key_padding_mask):
|
|
1023
|
+
# 自注意力并附带残差缩放
|
|
1024
|
+
attn_out, _ = self.self_attn(
|
|
1025
|
+
x, x, x,
|
|
1026
|
+
attn_mask=attn_mask,
|
|
1027
|
+
key_padding_mask=key_padding_mask,
|
|
1028
|
+
need_weights=False
|
|
1029
|
+
)
|
|
1030
|
+
return self.res_scale_attn * self.dropout1(attn_out)
|
|
1031
|
+
|
|
1032
|
+
def _ff_block(self, x):
|
|
1033
|
+
# 前馈网络并附带残差缩放
|
|
1034
|
+
x2 = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
|
1035
|
+
return self.res_scale_ffn * self.dropout2(x2)
|
|
1036
|
+
|
|
1037
|
+
# 定义FT-Transformer核心模型
|
|
1038
|
+
|
|
1039
|
+
|
|
1040
|
+
class FTTransformerCore(nn.Module):
|
|
1041
|
+
# 最小可用版本的 FT-Transformer,由三部分组成:
|
|
1042
|
+
# 1) FeatureTokenizer:将数值/类别特征转换成 token;
|
|
1043
|
+
# 2) TransformerEncoder:建模特征之间的交互;
|
|
1044
|
+
# 3) 池化 + MLP + Softplus:输出正值,方便 Tweedie/Gamma 等任务。
|
|
1045
|
+
|
|
1046
|
+
def __init__(self, num_numeric: int, cat_cardinalities, d_model: int = 64,
|
|
1047
|
+
n_heads: int = 8, n_layers: int = 4, dropout: float = 0.1,
|
|
1048
|
+
task_type: str = 'regression'
|
|
1049
|
+
):
|
|
1050
|
+
super().__init__()
|
|
1051
|
+
|
|
1052
|
+
self.tokenizer = FeatureTokenizer(
|
|
1053
|
+
num_numeric=num_numeric,
|
|
1054
|
+
cat_cardinalities=cat_cardinalities,
|
|
1055
|
+
d_model=d_model
|
|
1056
|
+
)
|
|
1057
|
+
scale = 1.0 / math.sqrt(n_layers) # 推荐一个默认值
|
|
1058
|
+
encoder_layer = ScaledTransformerEncoderLayer(
|
|
1059
|
+
d_model=d_model,
|
|
1060
|
+
nhead=n_heads,
|
|
1061
|
+
dim_feedforward=d_model * 4,
|
|
1062
|
+
dropout=dropout,
|
|
1063
|
+
residual_scale_attn=scale,
|
|
1064
|
+
residual_scale_ffn=scale,
|
|
1065
|
+
norm_first=True,
|
|
1066
|
+
)
|
|
1067
|
+
self.encoder = nn.TransformerEncoder(
|
|
1068
|
+
encoder_layer,
|
|
1069
|
+
num_layers=n_layers
|
|
1070
|
+
)
|
|
1071
|
+
self.n_layers = n_layers
|
|
1072
|
+
|
|
1073
|
+
layers = [
|
|
1074
|
+
# nn.LayerNorm(d_model),
|
|
1075
|
+
# nn.Linear(d_model, d_model),
|
|
1076
|
+
# nn.GELU(),
|
|
1077
|
+
nn.Linear(d_model, 1),
|
|
1078
|
+
]
|
|
1079
|
+
|
|
1080
|
+
if task_type == 'classification':
|
|
1081
|
+
# 分类任务输出 logits,与 BCEWithLogitsLoss 更匹配
|
|
1082
|
+
layers.append(nn.Identity())
|
|
1083
|
+
else:
|
|
1084
|
+
# 回归任务需保持正值,适配 Tweedie/Gamma
|
|
1085
|
+
layers.append(nn.Softplus())
|
|
1086
|
+
|
|
1087
|
+
self.head = nn.Sequential(*layers)
|
|
1088
|
+
|
|
1089
|
+
def forward(self, X_num, X_cat):
|
|
1090
|
+
|
|
1091
|
+
# 输入:
|
|
1092
|
+
# X_num -> (batch, 数值特征数) 的 float32 张量
|
|
1093
|
+
# X_cat -> (batch, 类别特征数) 的 long 张量
|
|
1094
|
+
|
|
1095
|
+
if self.training and not hasattr(self, '_printed_device'):
|
|
1096
|
+
print(f">>> FTTransformerCore executing on device: {X_num.device}")
|
|
1097
|
+
self._printed_device = True
|
|
1098
|
+
|
|
1099
|
+
tokens = self.tokenizer(X_num, X_cat) # => (batch, token_num, d_model)
|
|
1100
|
+
x = self.encoder(tokens) # => (batch, token_num, d_model)
|
|
1101
|
+
|
|
1102
|
+
# 对 token 做平均池化,再送入回归头
|
|
1103
|
+
x = x.mean(dim=1) # => (batch, d_model)
|
|
1104
|
+
|
|
1105
|
+
out = self.head(x) # => (batch, 1),Softplus 约束为正
|
|
1106
|
+
return out
|
|
1107
|
+
|
|
1108
|
+
# 定义TabularDataset类
|
|
1109
|
+
|
|
1110
|
+
|
|
1111
|
+
class TabularDataset(Dataset):
|
|
1112
|
+
def __init__(self, X_num, X_cat, y, w):
|
|
1113
|
+
|
|
1114
|
+
# 输入张量说明:
|
|
1115
|
+
# X_num: torch.float32,shape=(N, 数值特征数)
|
|
1116
|
+
# X_cat: torch.long, shape=(N, 类别特征数)
|
|
1117
|
+
# y: torch.float32,shape=(N, 1)
|
|
1118
|
+
# w: torch.float32,shape=(N, 1)
|
|
1119
|
+
|
|
1120
|
+
self.X_num = X_num
|
|
1121
|
+
self.X_cat = X_cat
|
|
1122
|
+
self.y = y
|
|
1123
|
+
self.w = w
|
|
1124
|
+
|
|
1125
|
+
def __len__(self):
|
|
1126
|
+
return self.y.shape[0]
|
|
1127
|
+
|
|
1128
|
+
def __getitem__(self, idx):
|
|
1129
|
+
return (
|
|
1130
|
+
self.X_num[idx],
|
|
1131
|
+
self.X_cat[idx],
|
|
1132
|
+
self.y[idx],
|
|
1133
|
+
self.w[idx],
|
|
1134
|
+
)
|
|
1135
|
+
|
|
1136
|
+
# 定义FTTransformer的Scikit-Learn接口类
|
|
1137
|
+
|
|
1138
|
+
|
|
1139
|
+
class FTTransformerSklearn(TorchTrainerMixin, nn.Module):
|
|
1140
|
+
|
|
1141
|
+
# sklearn 风格包装:
|
|
1142
|
+
# - num_cols:数值特征列名列表
|
|
1143
|
+
# - cat_cols:类别特征列名列表(需事先做标签编码,取值 ∈ [0, n_classes-1])
|
|
1144
|
+
|
|
1145
|
+
def __init__(self, model_nme: str, num_cols, cat_cols, d_model: int = 64, n_heads: int = 8,
|
|
1146
|
+
n_layers: int = 4, dropout: float = 0.1, batch_num: int = 100, epochs: int = 100,
|
|
1147
|
+
task_type: str = 'regression',
|
|
1148
|
+
tweedie_power: float = 1.5, learning_rate: float = 1e-3, patience: int = 10,
|
|
1149
|
+
use_data_parallel: bool = True,
|
|
1150
|
+
use_ddp: bool = False
|
|
1151
|
+
):
|
|
1152
|
+
super().__init__()
|
|
1153
|
+
|
|
1154
|
+
self.use_ddp = use_ddp
|
|
1155
|
+
self.is_ddp_enabled, self.local_rank, self.rank, self.world_size = (
|
|
1156
|
+
False, 0, 0, 1)
|
|
1157
|
+
if self.use_ddp:
|
|
1158
|
+
self.is_ddp_enabled, self.local_rank, self.rank, self.world_size = DistributedUtils.setup_ddp()
|
|
1159
|
+
|
|
1160
|
+
self.model_nme = model_nme
|
|
1161
|
+
self.num_cols = list(num_cols)
|
|
1162
|
+
self.cat_cols = list(cat_cols)
|
|
1163
|
+
self.d_model = d_model
|
|
1164
|
+
self.n_heads = n_heads
|
|
1165
|
+
self.n_layers = n_layers
|
|
1166
|
+
self.dropout = dropout
|
|
1167
|
+
self.batch_num = batch_num
|
|
1168
|
+
self.epochs = epochs
|
|
1169
|
+
self.learning_rate = learning_rate
|
|
1170
|
+
self.task_type = task_type
|
|
1171
|
+
self.patience = patience
|
|
1172
|
+
if self.task_type == 'classification':
|
|
1173
|
+
self.tw_power = None # 分类时不使用 Tweedie 幂
|
|
1174
|
+
elif 'f' in self.model_nme:
|
|
1175
|
+
self.tw_power = 1.0
|
|
1176
|
+
elif 's' in self.model_nme:
|
|
1177
|
+
self.tw_power = 2.0
|
|
1178
|
+
else:
|
|
1179
|
+
self.tw_power = tweedie_power
|
|
1180
|
+
|
|
1181
|
+
if self.is_ddp_enabled:
|
|
1182
|
+
self.device = torch.device(f"cuda:{self.local_rank}")
|
|
1183
|
+
elif torch.cuda.is_available():
|
|
1184
|
+
self.device = torch.device("cuda")
|
|
1185
|
+
elif torch.backends.mps.is_available():
|
|
1186
|
+
self.device = torch.device("mps")
|
|
1187
|
+
else:
|
|
1188
|
+
self.device = torch.device("cpu")
|
|
1189
|
+
self.cat_cardinalities = None
|
|
1190
|
+
self.cat_categories = {}
|
|
1191
|
+
self.ft = None
|
|
1192
|
+
self.use_data_parallel = torch.cuda.device_count() > 1 and use_data_parallel
|
|
1193
|
+
|
|
1194
|
+
def _build_model(self, X_train):
|
|
1195
|
+
num_numeric = len(self.num_cols)
|
|
1196
|
+
cat_cardinalities = []
|
|
1197
|
+
|
|
1198
|
+
for col in self.cat_cols:
|
|
1199
|
+
cats = X_train[col].astype('category')
|
|
1200
|
+
categories = cats.cat.categories
|
|
1201
|
+
self.cat_categories[col] = categories # 保存训练集类别全集
|
|
1202
|
+
|
|
1203
|
+
card = len(categories) + 1 # 多预留 1 类给“未知/缺失”
|
|
1204
|
+
cat_cardinalities.append(card)
|
|
1205
|
+
|
|
1206
|
+
self.cat_cardinalities = cat_cardinalities
|
|
1207
|
+
|
|
1208
|
+
core = FTTransformerCore(
|
|
1209
|
+
num_numeric=num_numeric,
|
|
1210
|
+
cat_cardinalities=cat_cardinalities,
|
|
1211
|
+
d_model=self.d_model,
|
|
1212
|
+
n_heads=self.n_heads,
|
|
1213
|
+
n_layers=self.n_layers,
|
|
1214
|
+
dropout=self.dropout,
|
|
1215
|
+
task_type=self.task_type
|
|
1216
|
+
)
|
|
1217
|
+
if self.is_ddp_enabled:
|
|
1218
|
+
core = core.to(self.device)
|
|
1219
|
+
core = DDP(core, device_ids=[
|
|
1220
|
+
self.local_rank], output_device=self.local_rank)
|
|
1221
|
+
elif self.use_data_parallel:
|
|
1222
|
+
core = nn.DataParallel(core, device_ids=list(
|
|
1223
|
+
range(torch.cuda.device_count())))
|
|
1224
|
+
self.device = torch.device("cuda")
|
|
1225
|
+
self.ft = core.to(self.device)
|
|
1226
|
+
|
|
1227
|
+
def _encode_cats(self, X):
|
|
1228
|
+
# 输入 DataFrame 至少需要包含所有类别特征列
|
|
1229
|
+
# 返回形状 (N, 类别特征数) 的 int64 数组
|
|
1230
|
+
|
|
1231
|
+
if not self.cat_cols:
|
|
1232
|
+
return np.zeros((len(X), 0), dtype='int64')
|
|
1233
|
+
|
|
1234
|
+
X_cat_list = []
|
|
1235
|
+
for col in self.cat_cols:
|
|
1236
|
+
# 使用训练阶段记录的类别全集
|
|
1237
|
+
categories = self.cat_categories[col]
|
|
1238
|
+
# 按固定类别构造 Categorical
|
|
1239
|
+
cats = pd.Categorical(X[col], categories=categories)
|
|
1240
|
+
codes = cats.codes.astype('int64', copy=True) # -1 表示未知或缺失
|
|
1241
|
+
# 未知或缺失映射到额外的“未知”索引 len(categories)
|
|
1242
|
+
codes[codes < 0] = len(categories)
|
|
1243
|
+
X_cat_list.append(codes)
|
|
1244
|
+
|
|
1245
|
+
X_cat_np = np.stack(X_cat_list, axis=1) # 形状 (N, 类别特征数)
|
|
1246
|
+
return X_cat_np
|
|
1247
|
+
|
|
1248
|
+
def _build_train_tensors(self, X_train, y_train, w_train):
|
|
1249
|
+
return self._tensorize_split(X_train, y_train, w_train)
|
|
1250
|
+
|
|
1251
|
+
def _build_val_tensors(self, X_val, y_val, w_val):
|
|
1252
|
+
return self._tensorize_split(X_val, y_val, w_val, allow_none=True)
|
|
1253
|
+
|
|
1254
|
+
def _tensorize_split(self, X, y, w, allow_none: bool = False):
|
|
1255
|
+
if X is None:
|
|
1256
|
+
if allow_none:
|
|
1257
|
+
return None, None, None, None, False
|
|
1258
|
+
raise ValueError("输入特征 X 不能为空。")
|
|
1259
|
+
|
|
1260
|
+
X_num = torch.tensor(
|
|
1261
|
+
X[self.num_cols].to_numpy(dtype=np.float32, copy=True),
|
|
1262
|
+
dtype=torch.float32
|
|
1263
|
+
)
|
|
1264
|
+
if self.cat_cols:
|
|
1265
|
+
X_cat = torch.tensor(self._encode_cats(X), dtype=torch.long)
|
|
1266
|
+
else:
|
|
1267
|
+
X_cat = torch.zeros((X_num.shape[0], 0), dtype=torch.long)
|
|
1268
|
+
|
|
1269
|
+
y_tensor = torch.tensor(
|
|
1270
|
+
y.values, dtype=torch.float32).view(-1, 1) if y is not None else None
|
|
1271
|
+
if y_tensor is None:
|
|
1272
|
+
w_tensor = None
|
|
1273
|
+
elif w is not None:
|
|
1274
|
+
w_tensor = torch.tensor(
|
|
1275
|
+
w.values, dtype=torch.float32).view(-1, 1)
|
|
1276
|
+
else:
|
|
1277
|
+
w_tensor = torch.ones_like(y_tensor)
|
|
1278
|
+
return X_num, X_cat, y_tensor, w_tensor, y is not None
|
|
1279
|
+
|
|
1280
|
+
def fit(self, X_train, y_train, w_train=None,
|
|
1281
|
+
X_val=None, y_val=None, w_val=None, trial=None):
|
|
1282
|
+
|
|
1283
|
+
# 首次拟合时需要构建底层模型结构
|
|
1284
|
+
if self.ft is None:
|
|
1285
|
+
self._build_model(X_train)
|
|
1286
|
+
|
|
1287
|
+
X_num_train, X_cat_train, y_tensor, w_tensor, _ = self._build_train_tensors(
|
|
1288
|
+
X_train, y_train, w_train)
|
|
1289
|
+
X_num_val, X_cat_val, y_val_tensor, w_val_tensor, has_val = self._build_val_tensors(
|
|
1290
|
+
X_val, y_val, w_val)
|
|
1291
|
+
|
|
1292
|
+
# --- 构建 DataLoader ---
|
|
1293
|
+
dataset = TabularDataset(
|
|
1294
|
+
X_num_train, X_cat_train, y_tensor, w_tensor
|
|
1295
|
+
)
|
|
1296
|
+
|
|
1297
|
+
dataloader, accum_steps = self._build_dataloader(
|
|
1298
|
+
dataset,
|
|
1299
|
+
N=X_num_train.shape[0],
|
|
1300
|
+
base_bs_gpu=(16384, 8192, 4096),
|
|
1301
|
+
base_bs_cpu=(256, 128),
|
|
1302
|
+
min_bs=64,
|
|
1303
|
+
target_effective_cuda=4096,
|
|
1304
|
+
target_effective_cpu=2048
|
|
1305
|
+
)
|
|
1306
|
+
|
|
1307
|
+
if self.is_ddp_enabled and hasattr(dataloader.sampler, 'set_epoch'):
|
|
1308
|
+
self.dataloader_sampler = dataloader.sampler
|
|
1309
|
+
else:
|
|
1310
|
+
self.dataloader_sampler = None
|
|
1311
|
+
|
|
1312
|
+
optimizer = torch.optim.Adam(
|
|
1313
|
+
self.ft.parameters(), lr=self.learning_rate)
|
|
1314
|
+
scaler = GradScaler(enabled=(self.device.type == 'cuda'))
|
|
1315
|
+
|
|
1316
|
+
X_num_val_dev = X_cat_val_dev = y_val_dev = w_val_dev = None
|
|
1317
|
+
val_dataloader = None
|
|
1318
|
+
if has_val:
|
|
1319
|
+
val_dataset = TabularDataset(
|
|
1320
|
+
X_num_val, X_cat_val, y_val_tensor, w_val_tensor
|
|
1321
|
+
)
|
|
1322
|
+
val_bs = accum_steps * dataloader.batch_size
|
|
1323
|
+
|
|
1324
|
+
if os.name == 'nt':
|
|
1325
|
+
val_workers = 0
|
|
1326
|
+
else:
|
|
1327
|
+
val_workers = min(4, os.cpu_count() or 1)
|
|
1328
|
+
if getattr(self, "is_ddp_enabled", False):
|
|
1329
|
+
val_workers = 0 # DDP 下禁用多 worker,防止子进程冲突
|
|
1330
|
+
|
|
1331
|
+
val_dataloader = DataLoader(
|
|
1332
|
+
val_dataset,
|
|
1333
|
+
batch_size=val_bs,
|
|
1334
|
+
shuffle=False,
|
|
1335
|
+
num_workers=val_workers,
|
|
1336
|
+
pin_memory=(self.device.type == 'cuda'),
|
|
1337
|
+
persistent_workers=val_workers > 0,
|
|
1338
|
+
)
|
|
1339
|
+
|
|
1340
|
+
is_data_parallel = isinstance(self.ft, nn.DataParallel)
|
|
1341
|
+
|
|
1342
|
+
def forward_fn(batch):
|
|
1343
|
+
X_num_b, X_cat_b, y_b, w_b = batch
|
|
1344
|
+
|
|
1345
|
+
if not is_data_parallel:
|
|
1346
|
+
X_num_b = X_num_b.to(self.device, non_blocking=True)
|
|
1347
|
+
X_cat_b = X_cat_b.to(self.device, non_blocking=True)
|
|
1348
|
+
y_b = y_b.to(self.device, non_blocking=True)
|
|
1349
|
+
w_b = w_b.to(self.device, non_blocking=True)
|
|
1350
|
+
|
|
1351
|
+
y_pred = self.ft(X_num_b, X_cat_b)
|
|
1352
|
+
return y_pred, y_b, w_b
|
|
1353
|
+
|
|
1354
|
+
def val_forward_fn():
|
|
1355
|
+
total_loss = 0.0
|
|
1356
|
+
total_weight = 0.0
|
|
1357
|
+
for batch in val_dataloader:
|
|
1358
|
+
X_num_b, X_cat_b, y_b, w_b = batch
|
|
1359
|
+
if not is_data_parallel:
|
|
1360
|
+
X_num_b = X_num_b.to(self.device, non_blocking=True)
|
|
1361
|
+
X_cat_b = X_cat_b.to(self.device, non_blocking=True)
|
|
1362
|
+
y_b = y_b.to(self.device, non_blocking=True)
|
|
1363
|
+
w_b = w_b.to(self.device, non_blocking=True)
|
|
1364
|
+
|
|
1365
|
+
y_pred = self.ft(X_num_b, X_cat_b)
|
|
1366
|
+
|
|
1367
|
+
# 手动计算验证损失
|
|
1368
|
+
task = getattr(self, "task_type", "regression")
|
|
1369
|
+
if task == 'classification':
|
|
1370
|
+
loss_fn = nn.BCEWithLogitsLoss(reduction='none')
|
|
1371
|
+
losses = loss_fn(y_pred, y_b).view(-1)
|
|
1372
|
+
else:
|
|
1373
|
+
# 模型输出已通过 Softplus,无需再次应用
|
|
1374
|
+
y_pred_clamped = torch.clamp(y_pred, min=1e-6)
|
|
1375
|
+
power = getattr(self, "tw_power", 1.5)
|
|
1376
|
+
losses = tweedie_loss(
|
|
1377
|
+
y_pred_clamped, y_b, p=power).view(-1)
|
|
1378
|
+
|
|
1379
|
+
batch_weight_sum = torch.clamp(w_b.sum(), min=EPS)
|
|
1380
|
+
batch_weighted_loss_sum = (losses * w_b.view(-1)).sum()
|
|
1381
|
+
|
|
1382
|
+
total_loss += batch_weighted_loss_sum.item()
|
|
1383
|
+
total_weight += batch_weight_sum.item()
|
|
1384
|
+
|
|
1385
|
+
return total_loss / max(total_weight, EPS)
|
|
1386
|
+
|
|
1387
|
+
clip_fn = None
|
|
1388
|
+
if self.device.type == 'cuda':
|
|
1389
|
+
def clip_fn(): return (scaler.unscale_(optimizer),
|
|
1390
|
+
clip_grad_norm_(self.ft.parameters(), max_norm=1.0))
|
|
1391
|
+
|
|
1392
|
+
best_state = self._train_model(
|
|
1393
|
+
self.ft,
|
|
1394
|
+
dataloader,
|
|
1395
|
+
accum_steps,
|
|
1396
|
+
optimizer,
|
|
1397
|
+
scaler,
|
|
1398
|
+
forward_fn,
|
|
1399
|
+
val_forward_fn if has_val else None,
|
|
1400
|
+
apply_softplus=False,
|
|
1401
|
+
clip_fn=clip_fn,
|
|
1402
|
+
trial=trial
|
|
1403
|
+
)
|
|
1404
|
+
|
|
1405
|
+
if has_val and best_state is not None:
|
|
1406
|
+
self.ft.load_state_dict(best_state)
|
|
1407
|
+
|
|
1408
|
+
def predict(self, X_test):
|
|
1409
|
+
# X_test 需要包含所有数值列与类别列
|
|
1410
|
+
|
|
1411
|
+
self.ft.eval()
|
|
1412
|
+
X_num, X_cat, _, _, _ = self._tensorize_split(
|
|
1413
|
+
X_test, None, None, allow_none=True)
|
|
1414
|
+
|
|
1415
|
+
with torch.no_grad():
|
|
1416
|
+
X_num = X_num.to(self.device, non_blocking=True)
|
|
1417
|
+
X_cat = X_cat.to(self.device, non_blocking=True)
|
|
1418
|
+
y_pred = self.ft(X_num, X_cat).cpu().numpy()
|
|
1419
|
+
|
|
1420
|
+
if self.task_type == 'classification':
|
|
1421
|
+
# 从 logits 转换为概率
|
|
1422
|
+
y_pred = 1 / (1 + np.exp(-y_pred))
|
|
1423
|
+
else:
|
|
1424
|
+
# 模型已含 softplus,若需要可按需做 log-exp 平滑:y_pred = log(1 + exp(y_pred))
|
|
1425
|
+
y_pred = np.clip(y_pred, 1e-6, None)
|
|
1426
|
+
return y_pred.ravel()
|
|
1427
|
+
|
|
1428
|
+
def set_params(self, params: dict):
|
|
1429
|
+
|
|
1430
|
+
# 和 sklearn 风格保持一致。
|
|
1431
|
+
# 注意:对结构性参数(如 d_model/n_heads)修改后,需要重新 fit 才会生效。
|
|
1432
|
+
|
|
1433
|
+
for key, value in params.items():
|
|
1434
|
+
if hasattr(self, key):
|
|
1435
|
+
setattr(self, key, value)
|
|
1436
|
+
else:
|
|
1437
|
+
raise ValueError(f"Parameter {key} not found in model.")
|
|
1438
|
+
return self
|
|
1439
|
+
|
|
1440
|
+
|
|
1441
|
+
# =============================================================================
|
|
1442
|
+
# 图神经网络 (GNN) 简化实现
|
|
1443
|
+
# =============================================================================
|
|
1444
|
+
|
|
1445
|
+
|
|
1446
|
+
class SimpleGraphLayer(nn.Module):
|
|
1447
|
+
def __init__(self, in_dim: int, out_dim: int, dropout: float = 0.1):
|
|
1448
|
+
super().__init__()
|
|
1449
|
+
self.linear = nn.Linear(in_dim, out_dim)
|
|
1450
|
+
self.activation = nn.ReLU(inplace=True)
|
|
1451
|
+
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
|
1452
|
+
|
|
1453
|
+
def forward(self, x: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
|
|
1454
|
+
# 基于归一化稀疏邻接矩阵的消息传递:A_hat * X * W
|
|
1455
|
+
h = torch.sparse.mm(adj, x)
|
|
1456
|
+
h = self.linear(h)
|
|
1457
|
+
h = self.activation(h)
|
|
1458
|
+
return self.dropout(h)
|
|
1459
|
+
|
|
1460
|
+
|
|
1461
|
+
class SimpleGNN(nn.Module):
|
|
1462
|
+
def __init__(self, input_dim: int, hidden_dim: int = 64, num_layers: int = 2,
|
|
1463
|
+
dropout: float = 0.1, task_type: str = 'regression'):
|
|
1464
|
+
super().__init__()
|
|
1465
|
+
layers = []
|
|
1466
|
+
dim_in = input_dim
|
|
1467
|
+
for _ in range(max(1, num_layers)):
|
|
1468
|
+
layers.append(SimpleGraphLayer(
|
|
1469
|
+
dim_in, hidden_dim, dropout=dropout))
|
|
1470
|
+
dim_in = hidden_dim
|
|
1471
|
+
self.layers = nn.ModuleList(layers)
|
|
1472
|
+
self.output = nn.Linear(hidden_dim, 1)
|
|
1473
|
+
if task_type == 'classification':
|
|
1474
|
+
self.output_act = nn.Identity()
|
|
1475
|
+
else:
|
|
1476
|
+
self.output_act = nn.Softplus()
|
|
1477
|
+
self.task_type = task_type
|
|
1478
|
+
# 用 buffer 保持邻接矩阵,便于 DataParallel 复制
|
|
1479
|
+
self.register_buffer("adj_buffer", torch.empty(0))
|
|
1480
|
+
|
|
1481
|
+
def forward(self, x: torch.Tensor, adj: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
1482
|
+
adj_used = adj if adj is not None else getattr(
|
|
1483
|
+
self, "adj_buffer", None)
|
|
1484
|
+
if adj_used is None or adj_used.numel() == 0:
|
|
1485
|
+
raise RuntimeError("Adjacency is not set for GNN forward.")
|
|
1486
|
+
h = x
|
|
1487
|
+
for layer in self.layers:
|
|
1488
|
+
h = layer(h, adj_used)
|
|
1489
|
+
h = torch.sparse.mm(adj_used, h)
|
|
1490
|
+
out = self.output(h)
|
|
1491
|
+
return self.output_act(out)
|
|
1492
|
+
|
|
1493
|
+
|
|
1494
|
+
class GraphNeuralNetSklearn(TorchTrainerMixin, nn.Module):
|
|
1495
|
+
def __init__(self, model_nme: str, input_dim: int, hidden_dim: int = 64,
|
|
1496
|
+
num_layers: int = 2, k_neighbors: int = 10, dropout: float = 0.1,
|
|
1497
|
+
learning_rate: float = 1e-3, epochs: int = 100, patience: int = 10,
|
|
1498
|
+
task_type: str = 'regression', tweedie_power: float = 1.5,
|
|
1499
|
+
use_data_parallel: bool = False, use_ddp: bool = False,
|
|
1500
|
+
use_approx_knn: bool = True, approx_knn_threshold: int = 50000,
|
|
1501
|
+
graph_cache_path: Optional[str] = None) -> None:
|
|
1502
|
+
super().__init__()
|
|
1503
|
+
self.model_nme = model_nme
|
|
1504
|
+
self.input_dim = input_dim
|
|
1505
|
+
self.hidden_dim = hidden_dim
|
|
1506
|
+
self.num_layers = num_layers
|
|
1507
|
+
self.k_neighbors = max(1, k_neighbors)
|
|
1508
|
+
self.dropout = dropout
|
|
1509
|
+
self.learning_rate = learning_rate
|
|
1510
|
+
self.epochs = epochs
|
|
1511
|
+
self.patience = patience
|
|
1512
|
+
self.task_type = task_type
|
|
1513
|
+
self.use_approx_knn = use_approx_knn
|
|
1514
|
+
self.approx_knn_threshold = approx_knn_threshold
|
|
1515
|
+
self.graph_cache_path = Path(
|
|
1516
|
+
graph_cache_path) if graph_cache_path else None
|
|
1517
|
+
|
|
1518
|
+
if self.task_type == 'classification':
|
|
1519
|
+
self.tw_power = None
|
|
1520
|
+
elif 'f' in self.model_nme:
|
|
1521
|
+
self.tw_power = 1.0
|
|
1522
|
+
elif 's' in self.model_nme:
|
|
1523
|
+
self.tw_power = 2.0
|
|
1524
|
+
else:
|
|
1525
|
+
self.tw_power = tweedie_power
|
|
1526
|
+
|
|
1527
|
+
self.ddp_enabled = False
|
|
1528
|
+
self.local_rank = 0
|
|
1529
|
+
self.data_parallel_enabled = False
|
|
1530
|
+
|
|
1531
|
+
# DDP 仅在 CUDA 下有效;若未初始化成功则自动回退单卡
|
|
1532
|
+
if use_ddp and torch.cuda.is_available():
|
|
1533
|
+
ddp_ok, local_rank, _, _ = DistributedUtils.setup_ddp()
|
|
1534
|
+
if ddp_ok:
|
|
1535
|
+
self.ddp_enabled = True
|
|
1536
|
+
self.local_rank = local_rank
|
|
1537
|
+
self.device = torch.device(f'cuda:{local_rank}')
|
|
1538
|
+
else:
|
|
1539
|
+
self.device = torch.device('cuda')
|
|
1540
|
+
elif torch.cuda.is_available():
|
|
1541
|
+
self.device = torch.device('cuda')
|
|
1542
|
+
elif torch.backends.mps.is_available():
|
|
1543
|
+
self.device = torch.device('mps')
|
|
1544
|
+
else:
|
|
1545
|
+
self.device = torch.device('cpu')
|
|
1546
|
+
self.use_pyg_knn = self.device.type == 'cuda' and _PYG_AVAILABLE
|
|
1547
|
+
|
|
1548
|
+
self.gnn = SimpleGNN(
|
|
1549
|
+
input_dim=self.input_dim,
|
|
1550
|
+
hidden_dim=self.hidden_dim,
|
|
1551
|
+
num_layers=self.num_layers,
|
|
1552
|
+
dropout=self.dropout,
|
|
1553
|
+
task_type=self.task_type
|
|
1554
|
+
).to(self.device)
|
|
1555
|
+
|
|
1556
|
+
# DataParallel: 复制完整图到每张卡,分割特征,适合中等规模图
|
|
1557
|
+
if (not self.ddp_enabled) and use_data_parallel and (self.device.type == 'cuda') and (torch.cuda.device_count() > 1):
|
|
1558
|
+
self.data_parallel_enabled = True
|
|
1559
|
+
self.gnn = nn.DataParallel(
|
|
1560
|
+
self.gnn, device_ids=list(range(torch.cuda.device_count())))
|
|
1561
|
+
self.device = torch.device('cuda')
|
|
1562
|
+
|
|
1563
|
+
if self.ddp_enabled:
|
|
1564
|
+
self.gnn = DDP(
|
|
1565
|
+
self.gnn,
|
|
1566
|
+
device_ids=[self.local_rank],
|
|
1567
|
+
output_device=self.local_rank,
|
|
1568
|
+
find_unused_parameters=False
|
|
1569
|
+
)
|
|
1570
|
+
|
|
1571
|
+
def _unwrap_gnn(self) -> nn.Module:
|
|
1572
|
+
return self.gnn.module if isinstance(self.gnn, DDP) else self.gnn
|
|
1573
|
+
|
|
1574
|
+
def _set_adj_buffer(self, adj: torch.Tensor) -> None:
|
|
1575
|
+
base = self._unwrap_gnn()
|
|
1576
|
+
if hasattr(base, "adj_buffer"):
|
|
1577
|
+
base.adj_buffer = adj
|
|
1578
|
+
else:
|
|
1579
|
+
base.register_buffer("adj_buffer", adj)
|
|
1580
|
+
|
|
1581
|
+
def _load_cached_adj(self) -> Optional[torch.Tensor]:
|
|
1582
|
+
if self.graph_cache_path and self.graph_cache_path.exists():
|
|
1583
|
+
try:
|
|
1584
|
+
adj = torch.load(self.graph_cache_path,
|
|
1585
|
+
map_location=self.device)
|
|
1586
|
+
return adj.to(self.device)
|
|
1587
|
+
except Exception as exc:
|
|
1588
|
+
print(
|
|
1589
|
+
f"[GNN] Failed to load cached graph from {self.graph_cache_path}: {exc}")
|
|
1590
|
+
return None
|
|
1591
|
+
|
|
1592
|
+
def _build_edge_index_cpu(self, X_np: np.ndarray) -> torch.Tensor:
|
|
1593
|
+
n_samples = X_np.shape[0]
|
|
1594
|
+
k = min(self.k_neighbors, max(1, n_samples - 1))
|
|
1595
|
+
n_neighbors = min(k + 1, n_samples)
|
|
1596
|
+
use_approx = (self.use_approx_knn or n_samples >=
|
|
1597
|
+
self.approx_knn_threshold) and _PYNN_AVAILABLE
|
|
1598
|
+
indices = None
|
|
1599
|
+
if use_approx:
|
|
1600
|
+
try:
|
|
1601
|
+
nn_index = pynndescent.NNDescent(
|
|
1602
|
+
X_np,
|
|
1603
|
+
n_neighbors=n_neighbors,
|
|
1604
|
+
random_state=0
|
|
1605
|
+
)
|
|
1606
|
+
indices, _ = nn_index.neighbor_graph
|
|
1607
|
+
except Exception as exc:
|
|
1608
|
+
print(
|
|
1609
|
+
f"[GNN] Approximate kNN failed ({exc}); falling back to exact search.")
|
|
1610
|
+
use_approx = False
|
|
1611
|
+
|
|
1612
|
+
if indices is None:
|
|
1613
|
+
nbrs = NearestNeighbors(n_neighbors=n_neighbors, algorithm="auto")
|
|
1614
|
+
nbrs.fit(X_np)
|
|
1615
|
+
_, indices = nbrs.kneighbors(X_np)
|
|
1616
|
+
|
|
1617
|
+
indices = np.asarray(indices)
|
|
1618
|
+
|
|
1619
|
+
rows = []
|
|
1620
|
+
cols = []
|
|
1621
|
+
for i in range(n_samples):
|
|
1622
|
+
for j in indices[i]:
|
|
1623
|
+
if i == j:
|
|
1624
|
+
continue
|
|
1625
|
+
rows.append(i)
|
|
1626
|
+
cols.append(j)
|
|
1627
|
+
rows.append(j)
|
|
1628
|
+
cols.append(i)
|
|
1629
|
+
|
|
1630
|
+
# 添加自环,避免度为 0 的节点
|
|
1631
|
+
rows.extend(range(n_samples))
|
|
1632
|
+
cols.extend(range(n_samples))
|
|
1633
|
+
|
|
1634
|
+
edge_index = torch.tensor(
|
|
1635
|
+
[rows, cols], dtype=torch.long, device=self.device)
|
|
1636
|
+
return edge_index
|
|
1637
|
+
|
|
1638
|
+
def _build_edge_index_gpu(self, X_tensor: torch.Tensor) -> torch.Tensor:
|
|
1639
|
+
if not self.use_pyg_knn or knn_graph is None or add_self_loops is None or to_undirected is None:
|
|
1640
|
+
# 防御式编程:调用前应检查 use_pyg_knn
|
|
1641
|
+
raise RuntimeError(
|
|
1642
|
+
"GPU graph builder requested but PyG is unavailable.")
|
|
1643
|
+
|
|
1644
|
+
n_samples = X_tensor.size(0)
|
|
1645
|
+
k = min(self.k_neighbors, max(1, n_samples - 1))
|
|
1646
|
+
|
|
1647
|
+
# knn_graph 运行在 GPU 上,避免 CPU 构图成为瓶颈
|
|
1648
|
+
edge_index = knn_graph(
|
|
1649
|
+
X_tensor,
|
|
1650
|
+
k=k,
|
|
1651
|
+
loop=False
|
|
1652
|
+
)
|
|
1653
|
+
edge_index = to_undirected(edge_index, num_nodes=n_samples)
|
|
1654
|
+
edge_index, _ = add_self_loops(edge_index, num_nodes=n_samples)
|
|
1655
|
+
return edge_index
|
|
1656
|
+
|
|
1657
|
+
def _normalized_adj(self, edge_index: torch.Tensor, num_nodes: int) -> torch.Tensor:
|
|
1658
|
+
values = torch.ones(edge_index.shape[1], device=self.device)
|
|
1659
|
+
adj = torch.sparse_coo_tensor(
|
|
1660
|
+
edge_index.to(self.device), values, (num_nodes, num_nodes))
|
|
1661
|
+
adj = adj.coalesce()
|
|
1662
|
+
|
|
1663
|
+
deg = torch.sparse.sum(adj, dim=1).to_dense()
|
|
1664
|
+
deg_inv_sqrt = torch.pow(deg + 1e-8, -0.5)
|
|
1665
|
+
row, col = adj.indices()
|
|
1666
|
+
norm_values = deg_inv_sqrt[row] * adj.values() * deg_inv_sqrt[col]
|
|
1667
|
+
adj_norm = torch.sparse_coo_tensor(
|
|
1668
|
+
adj.indices(), norm_values, size=adj.shape)
|
|
1669
|
+
return adj_norm
|
|
1670
|
+
|
|
1671
|
+
def _tensorize_split(self, X, y, w, allow_none: bool = False):
|
|
1672
|
+
if X is None and allow_none:
|
|
1673
|
+
return None, None, None
|
|
1674
|
+
X_np = X.values.astype(np.float32)
|
|
1675
|
+
X_tensor = torch.tensor(X_np, dtype=torch.float32, device=self.device)
|
|
1676
|
+
if y is None:
|
|
1677
|
+
y_tensor = None
|
|
1678
|
+
else:
|
|
1679
|
+
y_tensor = torch.tensor(
|
|
1680
|
+
y.values, dtype=torch.float32, device=self.device).view(-1, 1)
|
|
1681
|
+
if w is None:
|
|
1682
|
+
w_tensor = torch.ones(
|
|
1683
|
+
(len(X), 1), dtype=torch.float32, device=self.device)
|
|
1684
|
+
else:
|
|
1685
|
+
w_tensor = torch.tensor(
|
|
1686
|
+
w.values, dtype=torch.float32, device=self.device).view(-1, 1)
|
|
1687
|
+
return X_tensor, y_tensor, w_tensor
|
|
1688
|
+
|
|
1689
|
+
def _build_graph_from_df(self, X_df: pd.DataFrame, X_tensor: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
1690
|
+
if X_tensor is None:
|
|
1691
|
+
X_tensor = torch.tensor(
|
|
1692
|
+
X_df.values.astype(np.float32),
|
|
1693
|
+
dtype=torch.float32,
|
|
1694
|
+
device=self.device
|
|
1695
|
+
)
|
|
1696
|
+
if self.graph_cache_path:
|
|
1697
|
+
cached = self._load_cached_adj()
|
|
1698
|
+
if cached is not None:
|
|
1699
|
+
return cached
|
|
1700
|
+
if self.use_pyg_knn:
|
|
1701
|
+
edge_index = self._build_edge_index_gpu(X_tensor)
|
|
1702
|
+
else:
|
|
1703
|
+
edge_index = self._build_edge_index_cpu(
|
|
1704
|
+
X_df.values.astype(np.float32))
|
|
1705
|
+
adj_norm = self._normalized_adj(edge_index, X_df.shape[0])
|
|
1706
|
+
if self.graph_cache_path:
|
|
1707
|
+
try:
|
|
1708
|
+
IOUtils.ensure_parent_dir(str(self.graph_cache_path))
|
|
1709
|
+
torch.save(adj_norm.cpu(), self.graph_cache_path)
|
|
1710
|
+
except Exception as exc:
|
|
1711
|
+
print(
|
|
1712
|
+
f"[GNN] Failed to cache graph to {self.graph_cache_path}: {exc}")
|
|
1713
|
+
return adj_norm
|
|
1714
|
+
|
|
1715
|
+
def fit(self, X_train, y_train, w_train=None,
|
|
1716
|
+
X_val=None, y_val=None, w_val=None,
|
|
1717
|
+
trial: Optional[optuna.trial.Trial] = None):
|
|
1718
|
+
|
|
1719
|
+
X_train_tensor, y_train_tensor, w_train_tensor = self._tensorize_split(
|
|
1720
|
+
X_train, y_train, w_train, allow_none=False)
|
|
1721
|
+
has_val = X_val is not None and y_val is not None
|
|
1722
|
+
if has_val:
|
|
1723
|
+
X_val_tensor, y_val_tensor, w_val_tensor = self._tensorize_split(
|
|
1724
|
+
X_val, y_val, w_val, allow_none=False)
|
|
1725
|
+
else:
|
|
1726
|
+
X_val_tensor = y_val_tensor = w_val_tensor = None
|
|
1727
|
+
|
|
1728
|
+
adj_train = self._build_graph_from_df(X_train, X_train_tensor)
|
|
1729
|
+
adj_val = self._build_graph_from_df(
|
|
1730
|
+
X_val, X_val_tensor) if has_val else None
|
|
1731
|
+
# DataParallel 需要将邻接矩阵缓存在模型上,避免被 scatter
|
|
1732
|
+
self._set_adj_buffer(adj_train)
|
|
1733
|
+
|
|
1734
|
+
base_gnn = self._unwrap_gnn()
|
|
1735
|
+
optimizer = torch.optim.Adam(
|
|
1736
|
+
base_gnn.parameters(), lr=self.learning_rate)
|
|
1737
|
+
scaler = GradScaler(enabled=(self.device.type == 'cuda'))
|
|
1738
|
+
|
|
1739
|
+
best_loss = float('inf')
|
|
1740
|
+
best_state = None
|
|
1741
|
+
patience_counter = 0
|
|
1742
|
+
|
|
1743
|
+
for epoch in range(1, self.epochs + 1):
|
|
1744
|
+
self.gnn.train()
|
|
1745
|
+
optimizer.zero_grad()
|
|
1746
|
+
with autocast(enabled=(self.device.type == 'cuda')):
|
|
1747
|
+
if self.data_parallel_enabled:
|
|
1748
|
+
y_pred = self.gnn(X_train_tensor)
|
|
1749
|
+
else:
|
|
1750
|
+
y_pred = self.gnn(X_train_tensor, adj_train)
|
|
1751
|
+
loss = self._compute_weighted_loss(
|
|
1752
|
+
y_pred, y_train_tensor, w_train_tensor, apply_softplus=False)
|
|
1753
|
+
scaler.scale(loss).backward()
|
|
1754
|
+
scaler.unscale_(optimizer)
|
|
1755
|
+
clip_grad_norm_(self.gnn.parameters(), max_norm=1.0)
|
|
1756
|
+
scaler.step(optimizer)
|
|
1757
|
+
scaler.update()
|
|
1758
|
+
|
|
1759
|
+
if has_val:
|
|
1760
|
+
self.gnn.eval()
|
|
1761
|
+
if self.data_parallel_enabled and adj_val is not None:
|
|
1762
|
+
self._set_adj_buffer(adj_val)
|
|
1763
|
+
with torch.no_grad(), autocast(enabled=(self.device.type == 'cuda')):
|
|
1764
|
+
if self.data_parallel_enabled:
|
|
1765
|
+
y_val_pred = self.gnn(X_val_tensor)
|
|
1766
|
+
else:
|
|
1767
|
+
y_val_pred = self.gnn(X_val_tensor, adj_val)
|
|
1768
|
+
val_loss = self._compute_weighted_loss(
|
|
1769
|
+
y_val_pred, y_val_tensor, w_val_tensor, apply_softplus=False)
|
|
1770
|
+
if self.data_parallel_enabled:
|
|
1771
|
+
# 恢复训练邻接矩阵
|
|
1772
|
+
self._set_adj_buffer(adj_train)
|
|
1773
|
+
|
|
1774
|
+
best_loss, best_state, patience_counter, stop_training = self._early_stop_update(
|
|
1775
|
+
val_loss, best_loss, best_state, patience_counter, base_gnn,
|
|
1776
|
+
ignore_keys=["adj_buffer"])
|
|
1777
|
+
|
|
1778
|
+
if trial is not None:
|
|
1779
|
+
trial.report(val_loss, epoch)
|
|
1780
|
+
if trial.should_prune():
|
|
1781
|
+
raise optuna.TrialPruned()
|
|
1782
|
+
if stop_training:
|
|
1783
|
+
break
|
|
1784
|
+
|
|
1785
|
+
if best_state is not None:
|
|
1786
|
+
base_gnn.load_state_dict(best_state, strict=False)
|
|
1787
|
+
|
|
1788
|
+
def predict(self, X: pd.DataFrame) -> np.ndarray:
|
|
1789
|
+
self.gnn.eval()
|
|
1790
|
+
X_tensor, _, _ = self._tensorize_split(
|
|
1791
|
+
X, None, None, allow_none=False)
|
|
1792
|
+
adj = self._build_graph_from_df(X, X_tensor)
|
|
1793
|
+
if self.data_parallel_enabled:
|
|
1794
|
+
self._set_adj_buffer(adj)
|
|
1795
|
+
with torch.no_grad():
|
|
1796
|
+
if self.data_parallel_enabled:
|
|
1797
|
+
y_pred = self.gnn(X_tensor).cpu().numpy()
|
|
1798
|
+
else:
|
|
1799
|
+
y_pred = self.gnn(X_tensor, adj).cpu().numpy()
|
|
1800
|
+
if self.task_type == 'classification':
|
|
1801
|
+
y_pred = 1 / (1 + np.exp(-y_pred))
|
|
1802
|
+
else:
|
|
1803
|
+
y_pred = np.clip(y_pred, 1e-6, None)
|
|
1804
|
+
return y_pred.ravel()
|
|
1805
|
+
|
|
1806
|
+
def set_params(self, params: Dict[str, Any]):
|
|
1807
|
+
for key, value in params.items():
|
|
1808
|
+
if hasattr(self, key):
|
|
1809
|
+
setattr(self, key, value)
|
|
1810
|
+
else:
|
|
1811
|
+
raise ValueError(f"Parameter {key} not found in GNN model.")
|
|
1812
|
+
# 结构参数变化后需要重建骨架
|
|
1813
|
+
self.gnn = SimpleGNN(
|
|
1814
|
+
input_dim=self.input_dim,
|
|
1815
|
+
hidden_dim=self.hidden_dim,
|
|
1816
|
+
num_layers=self.num_layers,
|
|
1817
|
+
dropout=self.dropout,
|
|
1818
|
+
task_type=self.task_type
|
|
1819
|
+
).to(self.device)
|
|
1820
|
+
return self
|
|
1821
|
+
|
|
1822
|
+
|
|
1823
|
+
# ===== 基础组件与训练封装 =====================================================
|
|
1824
|
+
|
|
1825
|
+
# =============================================================================
|
|
1826
|
+
# 配置、预处理与训练器基类
|
|
1827
|
+
# =============================================================================
|
|
1828
|
+
@dataclass
|
|
1829
|
+
class BayesOptConfig:
|
|
1830
|
+
model_nme: str
|
|
1831
|
+
resp_nme: str
|
|
1832
|
+
weight_nme: str
|
|
1833
|
+
factor_nmes: List[str]
|
|
1834
|
+
task_type: str = 'regression'
|
|
1835
|
+
binary_resp_nme: Optional[str] = None
|
|
1836
|
+
cate_list: Optional[List[str]] = None
|
|
1837
|
+
prop_test: float = 0.25
|
|
1838
|
+
rand_seed: Optional[int] = None
|
|
1839
|
+
epochs: int = 100
|
|
1840
|
+
use_gpu: bool = True
|
|
1841
|
+
use_resn_data_parallel: bool = False
|
|
1842
|
+
use_ft_data_parallel: bool = False
|
|
1843
|
+
use_resn_ddp: bool = False
|
|
1844
|
+
use_ft_ddp: bool = False
|
|
1845
|
+
use_gnn_data_parallel: bool = False
|
|
1846
|
+
use_gnn_ddp: bool = False
|
|
1847
|
+
gnn_use_approx_knn: bool = True
|
|
1848
|
+
gnn_approx_knn_threshold: int = 50000
|
|
1849
|
+
gnn_graph_cache: Optional[str] = None
|
|
1850
|
+
output_dir: Optional[str] = None
|
|
1851
|
+
optuna_storage: Optional[str] = None
|
|
1852
|
+
optuna_study_prefix: Optional[str] = None
|
|
1853
|
+
|
|
1854
|
+
|
|
1855
|
+
class OutputManager:
|
|
1856
|
+
# 统一管理结果、图表与模型的输出路径
|
|
1857
|
+
|
|
1858
|
+
def __init__(self, root: Optional[str] = None, model_name: str = "model") -> None:
|
|
1859
|
+
self.root = Path(root or os.getcwd())
|
|
1860
|
+
self.model_name = model_name
|
|
1861
|
+
self.plot_dir = self.root / 'plot'
|
|
1862
|
+
self.result_dir = self.root / 'Results'
|
|
1863
|
+
self.model_dir = self.root / 'model'
|
|
1864
|
+
|
|
1865
|
+
def _prepare(self, path: Path) -> str:
|
|
1866
|
+
ensure_parent_dir(str(path))
|
|
1867
|
+
return str(path)
|
|
1868
|
+
|
|
1869
|
+
def plot_path(self, filename: str) -> str:
|
|
1870
|
+
return self._prepare(self.plot_dir / filename)
|
|
1871
|
+
|
|
1872
|
+
def result_path(self, filename: str) -> str:
|
|
1873
|
+
return self._prepare(self.result_dir / filename)
|
|
1874
|
+
|
|
1875
|
+
def model_path(self, filename: str) -> str:
|
|
1876
|
+
return self._prepare(self.model_dir / filename)
|
|
1877
|
+
|
|
1878
|
+
|
|
1879
|
+
class DatasetPreprocessor:
|
|
1880
|
+
# 为各训练器准备通用的训练/测试数据视图
|
|
1881
|
+
|
|
1882
|
+
def __init__(self, train_df: pd.DataFrame, test_df: pd.DataFrame,
|
|
1883
|
+
config: BayesOptConfig) -> None:
|
|
1884
|
+
self.config = config
|
|
1885
|
+
self.train_data = train_df.copy(deep=True)
|
|
1886
|
+
self.test_data = test_df.copy(deep=True)
|
|
1887
|
+
self.num_features: List[str] = []
|
|
1888
|
+
self.train_oht_data: Optional[pd.DataFrame] = None
|
|
1889
|
+
self.test_oht_data: Optional[pd.DataFrame] = None
|
|
1890
|
+
self.train_oht_scl_data: Optional[pd.DataFrame] = None
|
|
1891
|
+
self.test_oht_scl_data: Optional[pd.DataFrame] = None
|
|
1892
|
+
self.var_nmes: List[str] = []
|
|
1893
|
+
self.cat_categories_for_shap: Dict[str, List[Any]] = {}
|
|
1894
|
+
|
|
1895
|
+
def run(self) -> "DatasetPreprocessor":
|
|
1896
|
+
"""Execute preprocessing: binarize categories, clip targets, and scale numerics."""
|
|
1897
|
+
cfg = self.config
|
|
1898
|
+
# 预先计算加权实际值,后续画图、校验都依赖该字段
|
|
1899
|
+
self.train_data.loc[:, 'w_act'] = self.train_data[cfg.resp_nme] * \
|
|
1900
|
+
self.train_data[cfg.weight_nme]
|
|
1901
|
+
self.test_data.loc[:, 'w_act'] = self.test_data[cfg.resp_nme] * \
|
|
1902
|
+
self.test_data[cfg.weight_nme]
|
|
1903
|
+
if cfg.binary_resp_nme:
|
|
1904
|
+
self.train_data.loc[:, 'w_binary_act'] = self.train_data[cfg.binary_resp_nme] * \
|
|
1905
|
+
self.train_data[cfg.weight_nme]
|
|
1906
|
+
self.test_data.loc[:, 'w_binary_act'] = self.test_data[cfg.binary_resp_nme] * \
|
|
1907
|
+
self.test_data[cfg.weight_nme]
|
|
1908
|
+
# 高分位裁剪用来吸收离群值;若删除会导致极端点主导损失
|
|
1909
|
+
q99 = self.train_data[cfg.resp_nme].quantile(0.999)
|
|
1910
|
+
self.train_data[cfg.resp_nme] = self.train_data[cfg.resp_nme].clip(
|
|
1911
|
+
upper=q99)
|
|
1912
|
+
cate_list = list(cfg.cate_list or [])
|
|
1913
|
+
if cate_list:
|
|
1914
|
+
for cate in cate_list:
|
|
1915
|
+
self.train_data[cate] = self.train_data[cate].astype(
|
|
1916
|
+
'category')
|
|
1917
|
+
self.test_data[cate] = self.test_data[cate].astype('category')
|
|
1918
|
+
cats = self.train_data[cate].cat.categories
|
|
1919
|
+
self.cat_categories_for_shap[cate] = list(cats)
|
|
1920
|
+
self.num_features = [
|
|
1921
|
+
nme for nme in cfg.factor_nmes if nme not in cate_list]
|
|
1922
|
+
train_oht = self.train_data[cfg.factor_nmes +
|
|
1923
|
+
[cfg.weight_nme] + [cfg.resp_nme]].copy()
|
|
1924
|
+
test_oht = self.test_data[cfg.factor_nmes +
|
|
1925
|
+
[cfg.weight_nme] + [cfg.resp_nme]].copy()
|
|
1926
|
+
train_oht = pd.get_dummies(
|
|
1927
|
+
train_oht,
|
|
1928
|
+
columns=cate_list,
|
|
1929
|
+
drop_first=True,
|
|
1930
|
+
dtype=np.int8
|
|
1931
|
+
)
|
|
1932
|
+
test_oht = pd.get_dummies(
|
|
1933
|
+
test_oht,
|
|
1934
|
+
columns=cate_list,
|
|
1935
|
+
drop_first=True,
|
|
1936
|
+
dtype=np.int8
|
|
1937
|
+
)
|
|
1938
|
+
|
|
1939
|
+
# reindex 时将缺失的哑变量列补零,避免测试集列数与训练集不一致
|
|
1940
|
+
test_oht = test_oht.reindex(columns=train_oht.columns, fill_value=0)
|
|
1941
|
+
|
|
1942
|
+
# 保留未缩放的 one-hot 数据,供交叉验证时按折内标准化避免泄露
|
|
1943
|
+
self.train_oht_data = train_oht.copy(deep=True)
|
|
1944
|
+
self.test_oht_data = test_oht.copy(deep=True)
|
|
1945
|
+
|
|
1946
|
+
train_oht_scaled = train_oht.copy(deep=True)
|
|
1947
|
+
test_oht_scaled = test_oht.copy(deep=True)
|
|
1948
|
+
for num_chr in self.num_features:
|
|
1949
|
+
# 逐列标准化保障每个特征在同一量级,否则神经网络会难以收敛
|
|
1950
|
+
scaler = StandardScaler()
|
|
1951
|
+
train_oht_scaled[num_chr] = scaler.fit_transform(
|
|
1952
|
+
train_oht_scaled[num_chr].values.reshape(-1, 1))
|
|
1953
|
+
test_oht_scaled[num_chr] = scaler.transform(
|
|
1954
|
+
test_oht_scaled[num_chr].values.reshape(-1, 1))
|
|
1955
|
+
# reindex 时将缺失的哑变量列补零,避免测试集列数与训练集不一致
|
|
1956
|
+
test_oht_scaled = test_oht_scaled.reindex(
|
|
1957
|
+
columns=train_oht_scaled.columns, fill_value=0)
|
|
1958
|
+
self.train_oht_scl_data = train_oht_scaled
|
|
1959
|
+
self.test_oht_scl_data = test_oht_scaled
|
|
1960
|
+
self.var_nmes = list(
|
|
1961
|
+
set(list(train_oht_scaled.columns)) -
|
|
1962
|
+
set([cfg.weight_nme, cfg.resp_nme])
|
|
1963
|
+
)
|
|
1964
|
+
return self
|
|
1965
|
+
|
|
1966
|
+
# =============================================================================
|
|
1967
|
+
# 训练器体系
|
|
1968
|
+
# =============================================================================
|
|
1969
|
+
|
|
1970
|
+
|
|
1971
|
+
class TrainerBase:
|
|
1972
|
+
def __init__(self, context: "BayesOptModel", label: str, model_name_prefix: str) -> None:
|
|
1973
|
+
self.ctx = context
|
|
1974
|
+
self.label = label
|
|
1975
|
+
self.model_name_prefix = model_name_prefix
|
|
1976
|
+
self.model = None
|
|
1977
|
+
self.best_params: Optional[Dict[str, Any]] = None
|
|
1978
|
+
self.best_trial = None
|
|
1979
|
+
|
|
1980
|
+
@property
|
|
1981
|
+
def config(self) -> BayesOptConfig:
|
|
1982
|
+
return self.ctx.config
|
|
1983
|
+
|
|
1984
|
+
@property
|
|
1985
|
+
def output(self) -> OutputManager:
|
|
1986
|
+
return self.ctx.output_manager
|
|
1987
|
+
|
|
1988
|
+
def _get_model_filename(self) -> str:
|
|
1989
|
+
ext = 'pkl' if self.label in ['Xgboost', 'GLM'] else 'pth'
|
|
1990
|
+
return f'01_{self.ctx.model_nme}_{self.model_name_prefix}.{ext}'
|
|
1991
|
+
|
|
1992
|
+
def tune(self, max_evals: int, objective_fn=None) -> None:
|
|
1993
|
+
# 通用的 Optuna 调参循环流程。
|
|
1994
|
+
if objective_fn is None:
|
|
1995
|
+
# 若子类未显式提供 objective_fn,则默认使用 cross_val 作为优化目标
|
|
1996
|
+
objective_fn = self.cross_val
|
|
1997
|
+
|
|
1998
|
+
total_trials = max(1, int(max_evals))
|
|
1999
|
+
progress_counter = {"count": 0}
|
|
2000
|
+
|
|
2001
|
+
def objective_wrapper(trial: optuna.trial.Trial) -> float:
|
|
2002
|
+
should_log = DistributedUtils.is_main_process()
|
|
2003
|
+
if should_log:
|
|
2004
|
+
current_idx = progress_counter["count"] + 1
|
|
2005
|
+
print(
|
|
2006
|
+
f"[Optuna][{self.label}] Trial {current_idx}/{total_trials} started "
|
|
2007
|
+
f"(trial_id={trial.number})."
|
|
2008
|
+
)
|
|
2009
|
+
try:
|
|
2010
|
+
result = objective_fn(trial)
|
|
2011
|
+
except RuntimeError as exc:
|
|
2012
|
+
if "out of memory" in str(exc).lower():
|
|
2013
|
+
print(
|
|
2014
|
+
f"[Optuna][{self.label}] OOM detected. Pruning trial and clearing CUDA cache."
|
|
2015
|
+
)
|
|
2016
|
+
self._clean_gpu()
|
|
2017
|
+
raise optuna.TrialPruned() from exc
|
|
2018
|
+
raise
|
|
2019
|
+
finally:
|
|
2020
|
+
self._clean_gpu()
|
|
2021
|
+
if should_log:
|
|
2022
|
+
progress_counter["count"] = progress_counter["count"] + 1
|
|
2023
|
+
trial_state = getattr(trial, "state", None)
|
|
2024
|
+
state_repr = getattr(trial_state, "name", "OK")
|
|
2025
|
+
print(
|
|
2026
|
+
f"[Optuna][{self.label}] Trial {progress_counter['count']}/{total_trials} finished "
|
|
2027
|
+
f"(status={state_repr})."
|
|
2028
|
+
)
|
|
2029
|
+
return result
|
|
2030
|
+
|
|
2031
|
+
study = optuna.create_study(
|
|
2032
|
+
direction='minimize',
|
|
2033
|
+
sampler=optuna.samplers.TPESampler(seed=self.ctx.rand_seed)
|
|
2034
|
+
)
|
|
2035
|
+
study.optimize(objective_wrapper, n_trials=max_evals)
|
|
2036
|
+
self.best_params = study.best_params
|
|
2037
|
+
self.best_trial = study.best_trial
|
|
2038
|
+
|
|
2039
|
+
# 将最优参数保存为 CSV,方便复现
|
|
2040
|
+
params_path = self.output.result_path(
|
|
2041
|
+
f'{self.ctx.model_nme}_bestparams_{self.label.lower()}.csv'
|
|
2042
|
+
)
|
|
2043
|
+
pd.DataFrame(self.best_params, index=[0]).to_csv(params_path)
|
|
2044
|
+
|
|
2045
|
+
def train(self) -> None:
|
|
2046
|
+
raise NotImplementedError
|
|
2047
|
+
|
|
2048
|
+
def save(self) -> None:
|
|
2049
|
+
if self.model is None:
|
|
2050
|
+
print(f"[save] Warning: No model to save for {self.label}")
|
|
2051
|
+
return
|
|
2052
|
+
|
|
2053
|
+
path = self.output.model_path(self._get_model_filename())
|
|
2054
|
+
if self.label in ['Xgboost', 'GLM']:
|
|
2055
|
+
joblib.dump(self.model, path)
|
|
2056
|
+
else:
|
|
2057
|
+
# Torch 模型既可以只存 state_dict,也可以整个对象一起序列化
|
|
2058
|
+
# 兼容历史行为:ResNetTrainer 保存 state_dict,FTTrainer 保存完整对象
|
|
2059
|
+
if hasattr(self.model, 'resnet'): # ResNetSklearn
|
|
2060
|
+
torch.save(self.model.resnet.state_dict(), path)
|
|
2061
|
+
else: # FTTransformerSklearn or others
|
|
2062
|
+
torch.save(self.model, path)
|
|
2063
|
+
|
|
2064
|
+
def load(self) -> None:
|
|
2065
|
+
path = self.output.model_path(self._get_model_filename())
|
|
2066
|
+
if not os.path.exists(path):
|
|
2067
|
+
print(f"[load] Warning: Model file not found: {path}")
|
|
2068
|
+
return
|
|
2069
|
+
|
|
2070
|
+
if self.label in ['Xgboost', 'GLM']:
|
|
2071
|
+
self.model = joblib.load(path)
|
|
2072
|
+
else:
|
|
2073
|
+
# Torch 模型的加载需要根据结构区别处理
|
|
2074
|
+
if self.label == 'ResNet' or self.label == 'ResNetClassifier':
|
|
2075
|
+
# ResNet 需要重新构建骨架,结构参数依赖 ctx,因此交由子类处理
|
|
2076
|
+
pass
|
|
2077
|
+
else:
|
|
2078
|
+
# FT-Transformer 序列化了整个对象,可直接加载后迁移到目标设备
|
|
2079
|
+
loaded = torch.load(path, map_location='cpu')
|
|
2080
|
+
self._move_to_device(loaded)
|
|
2081
|
+
self.model = loaded
|
|
2082
|
+
|
|
2083
|
+
def _move_to_device(self, model_obj):
|
|
2084
|
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
2085
|
+
if hasattr(model_obj, 'device'):
|
|
2086
|
+
model_obj.device = device
|
|
2087
|
+
if hasattr(model_obj, 'to'):
|
|
2088
|
+
model_obj.to(device)
|
|
2089
|
+
# 若对象内部还包含 ft/resnet 子模块,也要同时迁移设备
|
|
2090
|
+
if hasattr(model_obj, 'ft'):
|
|
2091
|
+
model_obj.ft.to(device)
|
|
2092
|
+
if hasattr(model_obj, 'resnet'):
|
|
2093
|
+
model_obj.resnet.to(device)
|
|
2094
|
+
if hasattr(model_obj, 'gnn'):
|
|
2095
|
+
model_obj.gnn.to(device)
|
|
2096
|
+
|
|
2097
|
+
def _clean_gpu(self):
|
|
2098
|
+
gc.collect()
|
|
2099
|
+
if torch.cuda.is_available():
|
|
2100
|
+
device = None
|
|
2101
|
+
try:
|
|
2102
|
+
device = getattr(self, "device", None)
|
|
2103
|
+
except Exception:
|
|
2104
|
+
device = None
|
|
2105
|
+
if isinstance(device, torch.device):
|
|
2106
|
+
try:
|
|
2107
|
+
torch.cuda.set_device(device)
|
|
2108
|
+
except Exception:
|
|
2109
|
+
pass
|
|
2110
|
+
torch.cuda.empty_cache()
|
|
2111
|
+
torch.cuda.ipc_collect()
|
|
2112
|
+
torch.cuda.synchronize()
|
|
2113
|
+
|
|
2114
|
+
def _standardize_fold(self,
|
|
2115
|
+
X_train: pd.DataFrame,
|
|
2116
|
+
X_val: pd.DataFrame,
|
|
2117
|
+
columns: Optional[List[str]] = None
|
|
2118
|
+
) -> Tuple[pd.DataFrame, pd.DataFrame, StandardScaler]:
|
|
2119
|
+
"""Fit a StandardScaler on the training fold and transform both splits.
|
|
2120
|
+
|
|
2121
|
+
Args:
|
|
2122
|
+
X_train: Training features.
|
|
2123
|
+
X_val: Validation features.
|
|
2124
|
+
columns: Columns to scale. Defaults to all columns.
|
|
2125
|
+
|
|
2126
|
+
Returns:
|
|
2127
|
+
Scaled training/validation features and the fitted scaler.
|
|
2128
|
+
"""
|
|
2129
|
+
scaler = StandardScaler()
|
|
2130
|
+
cols = list(columns) if columns else list(X_train.columns)
|
|
2131
|
+
X_train_scaled = X_train.copy(deep=True)
|
|
2132
|
+
X_val_scaled = X_val.copy(deep=True)
|
|
2133
|
+
if cols:
|
|
2134
|
+
scaler.fit(X_train_scaled[cols])
|
|
2135
|
+
X_train_scaled[cols] = scaler.transform(X_train_scaled[cols])
|
|
2136
|
+
X_val_scaled[cols] = scaler.transform(X_val_scaled[cols])
|
|
2137
|
+
return X_train_scaled, X_val_scaled, scaler
|
|
2138
|
+
|
|
2139
|
+
def cross_val_generic(
|
|
2140
|
+
self,
|
|
2141
|
+
trial: optuna.trial.Trial,
|
|
2142
|
+
hyperparameter_space: Dict[str, Callable[[optuna.trial.Trial], Any]],
|
|
2143
|
+
data_provider: Callable[[], Tuple[pd.DataFrame, pd.Series, Optional[pd.Series]]],
|
|
2144
|
+
model_builder: Callable[[Dict[str, Any]], Any],
|
|
2145
|
+
metric_fn: Callable[[pd.Series, np.ndarray, Optional[pd.Series]], float],
|
|
2146
|
+
sample_limit: Optional[int] = None,
|
|
2147
|
+
preprocess_fn: Optional[Callable[[
|
|
2148
|
+
pd.DataFrame, pd.DataFrame], Tuple[pd.DataFrame, pd.DataFrame]]] = None,
|
|
2149
|
+
fit_predict_fn: Optional[
|
|
2150
|
+
Callable[[Any, pd.DataFrame, pd.Series, Optional[pd.Series],
|
|
2151
|
+
pd.DataFrame, pd.Series, Optional[pd.Series],
|
|
2152
|
+
optuna.trial.Trial], np.ndarray]
|
|
2153
|
+
] = None,
|
|
2154
|
+
cleanup_fn: Optional[Callable[[Any], None]] = None,
|
|
2155
|
+
splitter: Optional[Iterable[Tuple[np.ndarray, np.ndarray]]] = None) -> float:
|
|
2156
|
+
"""Generic hold-out/K-fold cross-validation helper to reduce duplication.
|
|
2157
|
+
|
|
2158
|
+
Args:
|
|
2159
|
+
trial: Active Optuna trial.
|
|
2160
|
+
hyperparameter_space: Dict of parameter samplers keyed by param name.
|
|
2161
|
+
data_provider: Callable returning X, y, sample_weight.
|
|
2162
|
+
model_builder: Callable constructing a fresh model per fold.
|
|
2163
|
+
metric_fn: Callable computing loss/score given y_true, y_pred, weight.
|
|
2164
|
+
sample_limit: Optional cap on rows; samples randomly if exceeded.
|
|
2165
|
+
preprocess_fn: Optional callable to transform (X_train, X_val) per fold.
|
|
2166
|
+
fit_predict_fn: Optional custom trainer returning validation predictions.
|
|
2167
|
+
cleanup_fn: Optional callable invoked with the trained model per fold.
|
|
2168
|
+
splitter: Optional iterable of (train_idx, val_idx); defaults to a single ShuffleSplit.
|
|
2169
|
+
|
|
2170
|
+
Returns:
|
|
2171
|
+
Average validation metric across folds.
|
|
2172
|
+
"""
|
|
2173
|
+
params = {name: sampler(trial)
|
|
2174
|
+
for name, sampler in hyperparameter_space.items()}
|
|
2175
|
+
X_all, y_all, w_all = data_provider()
|
|
2176
|
+
if sample_limit is not None and len(X_all) > sample_limit:
|
|
2177
|
+
sampled_idx = X_all.sample(
|
|
2178
|
+
n=sample_limit,
|
|
2179
|
+
random_state=self.ctx.rand_seed
|
|
2180
|
+
).index
|
|
2181
|
+
X_all = X_all.loc[sampled_idx]
|
|
2182
|
+
y_all = y_all.loc[sampled_idx]
|
|
2183
|
+
w_all = w_all.loc[sampled_idx] if w_all is not None else None
|
|
2184
|
+
|
|
2185
|
+
splits = splitter or [next(ShuffleSplit(
|
|
2186
|
+
n_splits=int(1 / self.ctx.prop_test),
|
|
2187
|
+
test_size=self.ctx.prop_test,
|
|
2188
|
+
random_state=self.ctx.rand_seed
|
|
2189
|
+
).split(X_all))]
|
|
2190
|
+
|
|
2191
|
+
losses: List[float] = []
|
|
2192
|
+
for train_idx, val_idx in splits:
|
|
2193
|
+
X_train = X_all.iloc[train_idx]
|
|
2194
|
+
y_train = y_all.iloc[train_idx]
|
|
2195
|
+
X_val = X_all.iloc[val_idx]
|
|
2196
|
+
y_val = y_all.iloc[val_idx]
|
|
2197
|
+
w_train = w_all.iloc[train_idx] if w_all is not None else None
|
|
2198
|
+
w_val = w_all.iloc[val_idx] if w_all is not None else None
|
|
2199
|
+
|
|
2200
|
+
if preprocess_fn:
|
|
2201
|
+
X_train, X_val = preprocess_fn(X_train, X_val)
|
|
2202
|
+
|
|
2203
|
+
model = model_builder(params)
|
|
2204
|
+
try:
|
|
2205
|
+
if fit_predict_fn:
|
|
2206
|
+
y_pred = fit_predict_fn(
|
|
2207
|
+
model, X_train, y_train, w_train,
|
|
2208
|
+
X_val, y_val, w_val, trial
|
|
2209
|
+
)
|
|
2210
|
+
else:
|
|
2211
|
+
fit_kwargs = {}
|
|
2212
|
+
if w_train is not None:
|
|
2213
|
+
fit_kwargs["sample_weight"] = w_train
|
|
2214
|
+
model.fit(X_train, y_train, **fit_kwargs)
|
|
2215
|
+
y_pred = model.predict(X_val)
|
|
2216
|
+
losses.append(metric_fn(y_val, y_pred, w_val))
|
|
2217
|
+
finally:
|
|
2218
|
+
if cleanup_fn:
|
|
2219
|
+
cleanup_fn(model)
|
|
2220
|
+
self._clean_gpu()
|
|
2221
|
+
|
|
2222
|
+
return float(np.mean(losses))
|
|
2223
|
+
|
|
2224
|
+
# 预测 + 缓存逻辑
|
|
2225
|
+
def _predict_and_cache(self,
|
|
2226
|
+
model,
|
|
2227
|
+
pred_prefix: str,
|
|
2228
|
+
use_oht: bool = False,
|
|
2229
|
+
design_fn=None) -> None:
|
|
2230
|
+
if design_fn:
|
|
2231
|
+
X_train = design_fn(train=True)
|
|
2232
|
+
X_test = design_fn(train=False)
|
|
2233
|
+
elif use_oht:
|
|
2234
|
+
X_train = self.ctx.train_oht_scl_data[self.ctx.var_nmes]
|
|
2235
|
+
X_test = self.ctx.test_oht_scl_data[self.ctx.var_nmes]
|
|
2236
|
+
else:
|
|
2237
|
+
X_train = self.ctx.train_data[self.ctx.factor_nmes]
|
|
2238
|
+
X_test = self.ctx.test_data[self.ctx.factor_nmes]
|
|
2239
|
+
|
|
2240
|
+
preds_train = model.predict(X_train)
|
|
2241
|
+
preds_test = model.predict(X_test)
|
|
2242
|
+
|
|
2243
|
+
self.ctx.train_data[f'pred_{pred_prefix}'] = preds_train
|
|
2244
|
+
self.ctx.test_data[f'pred_{pred_prefix}'] = preds_test
|
|
2245
|
+
self.ctx.train_data[f'w_pred_{pred_prefix}'] = (
|
|
2246
|
+
self.ctx.train_data[f'pred_{pred_prefix}'] *
|
|
2247
|
+
self.ctx.train_data[self.ctx.weight_nme]
|
|
2248
|
+
)
|
|
2249
|
+
self.ctx.test_data[f'w_pred_{pred_prefix}'] = (
|
|
2250
|
+
self.ctx.test_data[f'pred_{pred_prefix}'] *
|
|
2251
|
+
self.ctx.test_data[self.ctx.weight_nme]
|
|
2252
|
+
)
|
|
2253
|
+
|
|
2254
|
+
def _fit_predict_cache(self,
|
|
2255
|
+
model,
|
|
2256
|
+
X_train,
|
|
2257
|
+
y_train,
|
|
2258
|
+
sample_weight,
|
|
2259
|
+
pred_prefix: str,
|
|
2260
|
+
use_oht: bool = False,
|
|
2261
|
+
design_fn=None,
|
|
2262
|
+
fit_kwargs: Optional[Dict[str, Any]] = None,
|
|
2263
|
+
sample_weight_arg: Optional[str] = 'sample_weight') -> None:
|
|
2264
|
+
fit_kwargs = fit_kwargs.copy() if fit_kwargs else {}
|
|
2265
|
+
if sample_weight is not None and sample_weight_arg:
|
|
2266
|
+
fit_kwargs.setdefault(sample_weight_arg, sample_weight)
|
|
2267
|
+
model.fit(X_train, y_train, **fit_kwargs)
|
|
2268
|
+
self.ctx.model_label.append(self.label)
|
|
2269
|
+
self._predict_and_cache(
|
|
2270
|
+
model, pred_prefix, use_oht=use_oht, design_fn=design_fn)
|
|
2271
|
+
|
|
2272
|
+
|
|
2273
|
+
class XGBTrainer(TrainerBase):
|
|
2274
|
+
def __init__(self, context: "BayesOptModel") -> None:
|
|
2275
|
+
super().__init__(context, 'Xgboost', 'Xgboost')
|
|
2276
|
+
self.model: Optional[xgb.XGBRegressor] = None
|
|
2277
|
+
|
|
2278
|
+
def _build_estimator(self) -> xgb.XGBRegressor:
|
|
2279
|
+
params = dict(
|
|
2280
|
+
objective=self.ctx.obj,
|
|
2281
|
+
random_state=self.ctx.rand_seed,
|
|
2282
|
+
subsample=0.9,
|
|
2283
|
+
tree_method='gpu_hist' if self.ctx.use_gpu else 'hist',
|
|
2284
|
+
enable_categorical=True,
|
|
2285
|
+
predictor='gpu_predictor' if self.ctx.use_gpu else 'cpu_predictor'
|
|
2286
|
+
)
|
|
2287
|
+
if self.ctx.use_gpu:
|
|
2288
|
+
params['gpu_id'] = 0
|
|
2289
|
+
print(f">>> XGBoost using GPU ID: 0 (Single GPU Mode)")
|
|
2290
|
+
return xgb.XGBRegressor(**params)
|
|
2291
|
+
|
|
2292
|
+
def cross_val(self, trial: optuna.trial.Trial) -> float:
|
|
2293
|
+
learning_rate = trial.suggest_float(
|
|
2294
|
+
'learning_rate', 1e-5, 1e-1, log=True)
|
|
2295
|
+
gamma = trial.suggest_float('gamma', 0, 10000)
|
|
2296
|
+
max_depth = trial.suggest_int('max_depth', 3, 25)
|
|
2297
|
+
n_estimators = trial.suggest_int('n_estimators', 10, 500, step=10)
|
|
2298
|
+
min_child_weight = trial.suggest_int(
|
|
2299
|
+
'min_child_weight', 100, 10000, step=100)
|
|
2300
|
+
reg_alpha = trial.suggest_float('reg_alpha', 1e-10, 1, log=True)
|
|
2301
|
+
reg_lambda = trial.suggest_float('reg_lambda', 1e-10, 1, log=True)
|
|
2302
|
+
if self.ctx.obj == 'reg:tweedie':
|
|
2303
|
+
tweedie_variance_power = trial.suggest_float(
|
|
2304
|
+
'tweedie_variance_power', 1, 2)
|
|
2305
|
+
elif self.ctx.obj == 'count:poisson':
|
|
2306
|
+
tweedie_variance_power = 1
|
|
2307
|
+
elif self.ctx.obj == 'reg:gamma':
|
|
2308
|
+
tweedie_variance_power = 2
|
|
2309
|
+
else:
|
|
2310
|
+
tweedie_variance_power = 1.5
|
|
2311
|
+
clf = self._build_estimator()
|
|
2312
|
+
params = {
|
|
2313
|
+
'learning_rate': learning_rate,
|
|
2314
|
+
'gamma': gamma,
|
|
2315
|
+
'max_depth': max_depth,
|
|
2316
|
+
'n_estimators': n_estimators,
|
|
2317
|
+
'min_child_weight': min_child_weight,
|
|
2318
|
+
'reg_alpha': reg_alpha,
|
|
2319
|
+
'reg_lambda': reg_lambda
|
|
2320
|
+
}
|
|
2321
|
+
if self.ctx.obj == 'reg:tweedie':
|
|
2322
|
+
params['tweedie_variance_power'] = tweedie_variance_power
|
|
2323
|
+
clf.set_params(**params)
|
|
2324
|
+
n_jobs = 1 if self.ctx.use_gpu else int(1 / self.ctx.prop_test)
|
|
2325
|
+
acc = cross_val_score(
|
|
2326
|
+
clf,
|
|
2327
|
+
self.ctx.train_data[self.ctx.factor_nmes],
|
|
2328
|
+
self.ctx.train_data[self.ctx.resp_nme].values,
|
|
2329
|
+
fit_params=self.ctx.fit_params,
|
|
2330
|
+
cv=self.ctx.cv,
|
|
2331
|
+
scoring=make_scorer(
|
|
2332
|
+
mean_tweedie_deviance,
|
|
2333
|
+
power=tweedie_variance_power,
|
|
2334
|
+
greater_is_better=False),
|
|
2335
|
+
error_score='raise',
|
|
2336
|
+
n_jobs=n_jobs
|
|
2337
|
+
).mean()
|
|
2338
|
+
return -acc
|
|
2339
|
+
|
|
2340
|
+
def train(self) -> None:
|
|
2341
|
+
if not self.best_params:
|
|
2342
|
+
raise RuntimeError('请先运行 tune() 以获得 XGB 最优参数。')
|
|
2343
|
+
self.model = self._build_estimator()
|
|
2344
|
+
self.model.set_params(**self.best_params)
|
|
2345
|
+
self._fit_predict_cache(
|
|
2346
|
+
self.model,
|
|
2347
|
+
self.ctx.train_data[self.ctx.factor_nmes],
|
|
2348
|
+
self.ctx.train_data[self.ctx.resp_nme].values,
|
|
2349
|
+
sample_weight=None,
|
|
2350
|
+
pred_prefix='xgb',
|
|
2351
|
+
fit_kwargs=self.ctx.fit_params,
|
|
2352
|
+
sample_weight_arg=None # 样本权重已通过 fit_kwargs 传入
|
|
2353
|
+
)
|
|
2354
|
+
self.ctx.xgb_best = self.model
|
|
2355
|
+
|
|
2356
|
+
|
|
2357
|
+
class GLMTrainer(TrainerBase):
|
|
2358
|
+
def __init__(self, context: "BayesOptModel") -> None:
|
|
2359
|
+
super().__init__(context, 'GLM', 'GLM')
|
|
2360
|
+
self.model = None
|
|
2361
|
+
|
|
2362
|
+
def _select_family(self, tweedie_power: Optional[float] = None):
|
|
2363
|
+
if self.ctx.task_type == 'classification':
|
|
2364
|
+
return sm.families.Binomial()
|
|
2365
|
+
if self.ctx.obj == 'count:poisson':
|
|
2366
|
+
return sm.families.Poisson()
|
|
2367
|
+
if self.ctx.obj == 'reg:gamma':
|
|
2368
|
+
return sm.families.Gamma()
|
|
2369
|
+
power = tweedie_power if tweedie_power is not None else 1.5
|
|
2370
|
+
return sm.families.Tweedie(var_power=power, link=sm.families.links.log())
|
|
2371
|
+
|
|
2372
|
+
def _prepare_design(self, data: pd.DataFrame) -> pd.DataFrame:
|
|
2373
|
+
# 为 statsmodels 设计矩阵添加截距项
|
|
2374
|
+
X = data[self.ctx.var_nmes]
|
|
2375
|
+
return sm.add_constant(X, has_constant='add')
|
|
2376
|
+
|
|
2377
|
+
def _metric_power(self, family, tweedie_power: Optional[float]) -> float:
|
|
2378
|
+
if isinstance(family, sm.families.Poisson):
|
|
2379
|
+
return 1.0
|
|
2380
|
+
if isinstance(family, sm.families.Gamma):
|
|
2381
|
+
return 2.0
|
|
2382
|
+
if isinstance(family, sm.families.Tweedie):
|
|
2383
|
+
return tweedie_power if tweedie_power is not None else getattr(family, 'var_power', 1.5)
|
|
2384
|
+
return 1.5
|
|
2385
|
+
|
|
2386
|
+
def cross_val(self, trial: optuna.trial.Trial) -> float:
|
|
2387
|
+
param_space = {
|
|
2388
|
+
"alpha": lambda t: t.suggest_float('alpha', 1e-6, 1e2, log=True),
|
|
2389
|
+
"l1_ratio": lambda t: t.suggest_float('l1_ratio', 0.0, 1.0)
|
|
2390
|
+
}
|
|
2391
|
+
if self.ctx.task_type == 'regression' and self.ctx.obj == 'reg:tweedie':
|
|
2392
|
+
param_space["tweedie_power"] = lambda t: t.suggest_float(
|
|
2393
|
+
'tweedie_power', 1.0, 2.0)
|
|
2394
|
+
|
|
2395
|
+
def data_provider():
|
|
2396
|
+
data = self.ctx.train_oht_data if self.ctx.train_oht_data is not None else self.ctx.train_oht_scl_data
|
|
2397
|
+
assert data is not None, "Preprocessed training data is missing."
|
|
2398
|
+
return data[self.ctx.var_nmes], data[self.ctx.resp_nme], data[self.ctx.weight_nme]
|
|
2399
|
+
|
|
2400
|
+
def preprocess_fn(X_train, X_val):
|
|
2401
|
+
X_train_s, X_val_s, _ = self._standardize_fold(
|
|
2402
|
+
X_train, X_val, self.ctx.num_features)
|
|
2403
|
+
return self._prepare_design(X_train_s), self._prepare_design(X_val_s)
|
|
2404
|
+
|
|
2405
|
+
metric_ctx: Dict[str, Any] = {}
|
|
2406
|
+
|
|
2407
|
+
def model_builder(params):
|
|
2408
|
+
family = self._select_family(params.get("tweedie_power"))
|
|
2409
|
+
metric_ctx["family"] = family
|
|
2410
|
+
metric_ctx["tweedie_power"] = params.get("tweedie_power")
|
|
2411
|
+
return {
|
|
2412
|
+
"family": family,
|
|
2413
|
+
"alpha": params["alpha"],
|
|
2414
|
+
"l1_ratio": params["l1_ratio"],
|
|
2415
|
+
"tweedie_power": params.get("tweedie_power")
|
|
2416
|
+
}
|
|
2417
|
+
|
|
2418
|
+
def fit_predict(model_cfg, X_train, y_train, w_train, X_val, y_val, w_val, _trial):
|
|
2419
|
+
glm = sm.GLM(y_train, X_train,
|
|
2420
|
+
family=model_cfg["family"],
|
|
2421
|
+
freq_weights=w_train)
|
|
2422
|
+
result = glm.fit_regularized(
|
|
2423
|
+
alpha=model_cfg["alpha"],
|
|
2424
|
+
L1_wt=model_cfg["l1_ratio"],
|
|
2425
|
+
maxiter=200
|
|
2426
|
+
)
|
|
2427
|
+
return result.predict(X_val)
|
|
2428
|
+
|
|
2429
|
+
def metric_fn(y_true, y_pred, weight):
|
|
2430
|
+
if self.ctx.task_type == 'classification':
|
|
2431
|
+
y_pred_clipped = np.clip(y_pred, EPS, 1 - EPS)
|
|
2432
|
+
return log_loss(y_true, y_pred_clipped, sample_weight=weight)
|
|
2433
|
+
y_pred_safe = np.maximum(y_pred, EPS)
|
|
2434
|
+
return mean_tweedie_deviance(
|
|
2435
|
+
y_true,
|
|
2436
|
+
y_pred_safe,
|
|
2437
|
+
sample_weight=weight,
|
|
2438
|
+
power=self._metric_power(
|
|
2439
|
+
metric_ctx.get("family"), metric_ctx.get("tweedie_power"))
|
|
2440
|
+
)
|
|
2441
|
+
|
|
2442
|
+
return self.cross_val_generic(
|
|
2443
|
+
trial=trial,
|
|
2444
|
+
hyperparameter_space=param_space,
|
|
2445
|
+
data_provider=data_provider,
|
|
2446
|
+
model_builder=model_builder,
|
|
2447
|
+
metric_fn=metric_fn,
|
|
2448
|
+
preprocess_fn=preprocess_fn,
|
|
2449
|
+
fit_predict_fn=fit_predict,
|
|
2450
|
+
splitter=self.ctx.cv.split(self.ctx.train_oht_data[self.ctx.var_nmes]
|
|
2451
|
+
if self.ctx.train_oht_data is not None else self.ctx.train_oht_scl_data[self.ctx.var_nmes])
|
|
2452
|
+
)
|
|
2453
|
+
|
|
2454
|
+
def train(self) -> None:
|
|
2455
|
+
if not self.best_params:
|
|
2456
|
+
raise RuntimeError('请先运行 tune() 以获得 GLM 最优参数。')
|
|
2457
|
+
tweedie_power = self.best_params.get('tweedie_power')
|
|
2458
|
+
family = self._select_family(tweedie_power)
|
|
2459
|
+
|
|
2460
|
+
X_train = self._prepare_design(self.ctx.train_oht_scl_data)
|
|
2461
|
+
y_train = self.ctx.train_oht_scl_data[self.ctx.resp_nme]
|
|
2462
|
+
w_train = self.ctx.train_oht_scl_data[self.ctx.weight_nme]
|
|
2463
|
+
|
|
2464
|
+
glm = sm.GLM(y_train, X_train, family=family,
|
|
2465
|
+
freq_weights=w_train)
|
|
2466
|
+
self.model = glm.fit_regularized(
|
|
2467
|
+
alpha=self.best_params['alpha'],
|
|
2468
|
+
L1_wt=self.best_params['l1_ratio'],
|
|
2469
|
+
maxiter=300
|
|
2470
|
+
)
|
|
2471
|
+
|
|
2472
|
+
self.ctx.glm_best = self.model
|
|
2473
|
+
self.ctx.model_label += [self.label]
|
|
2474
|
+
self._predict_and_cache(
|
|
2475
|
+
self.model,
|
|
2476
|
+
'glm',
|
|
2477
|
+
design_fn=lambda train: self._prepare_design(
|
|
2478
|
+
self.ctx.train_oht_scl_data if train else self.ctx.test_oht_scl_data
|
|
2479
|
+
)
|
|
2480
|
+
)
|
|
2481
|
+
|
|
2482
|
+
|
|
2483
|
+
class ResNetTrainer(TrainerBase):
|
|
2484
|
+
def __init__(self, context: "BayesOptModel") -> None:
|
|
2485
|
+
if context.task_type == 'classification':
|
|
2486
|
+
super().__init__(context, 'ResNetClassifier', 'ResNet')
|
|
2487
|
+
else:
|
|
2488
|
+
super().__init__(context, 'ResNet', 'ResNet')
|
|
2489
|
+
self.model: Optional[ResNetSklearn] = None
|
|
2490
|
+
|
|
2491
|
+
# ========= 交叉验证(BayesOpt 用) =========
|
|
2492
|
+
def cross_val(self, trial: optuna.trial.Trial) -> float:
|
|
2493
|
+
# 针对 ResNet 的交叉验证流程,重点控制显存:
|
|
2494
|
+
# - 每个 fold 单独创建 ResNetSklearn,结束立刻释放资源;
|
|
2495
|
+
# - fold 完成后迁移模型到 CPU,删除对象并调用 gc/empty_cache;
|
|
2496
|
+
# - 可选:BayesOpt 期间只抽样部分训练集以减少显存压力。
|
|
2497
|
+
|
|
2498
|
+
base_tw_power = None
|
|
2499
|
+
if self.ctx.task_type == 'regression':
|
|
2500
|
+
if self.ctx.obj == 'count:poisson':
|
|
2501
|
+
base_tw_power = 1.0
|
|
2502
|
+
elif self.ctx.obj == 'reg:gamma':
|
|
2503
|
+
base_tw_power = 2.0
|
|
2504
|
+
else:
|
|
2505
|
+
base_tw_power = 1.5
|
|
2506
|
+
|
|
2507
|
+
def data_provider():
|
|
2508
|
+
data = self.ctx.train_oht_data if self.ctx.train_oht_data is not None else self.ctx.train_oht_scl_data
|
|
2509
|
+
assert data is not None, "Preprocessed training data is missing."
|
|
2510
|
+
return data[self.ctx.var_nmes], data[self.ctx.resp_nme], data[self.ctx.weight_nme]
|
|
2511
|
+
|
|
2512
|
+
metric_ctx: Dict[str, Any] = {}
|
|
2513
|
+
|
|
2514
|
+
def model_builder(params):
|
|
2515
|
+
power = params.get("tw_power", base_tw_power)
|
|
2516
|
+
metric_ctx["tw_power"] = power
|
|
2517
|
+
return ResNetSklearn(
|
|
2518
|
+
model_nme=self.ctx.model_nme,
|
|
2519
|
+
input_dim=len(self.ctx.var_nmes),
|
|
2520
|
+
hidden_dim=params["hidden_dim"],
|
|
2521
|
+
block_num=params["block_num"],
|
|
2522
|
+
task_type=self.ctx.task_type,
|
|
2523
|
+
epochs=self.ctx.epochs,
|
|
2524
|
+
tweedie_power=power,
|
|
2525
|
+
learning_rate=params["learning_rate"],
|
|
2526
|
+
patience=5,
|
|
2527
|
+
use_layernorm=True,
|
|
2528
|
+
dropout=0.1,
|
|
2529
|
+
residual_scale=0.1,
|
|
2530
|
+
use_data_parallel=self.ctx.config.use_resn_data_parallel,
|
|
2531
|
+
use_ddp=self.ctx.config.use_resn_ddp
|
|
2532
|
+
)
|
|
2533
|
+
|
|
2534
|
+
def preprocess_fn(X_train, X_val):
|
|
2535
|
+
X_train_s, X_val_s, _ = self._standardize_fold(
|
|
2536
|
+
X_train, X_val, self.ctx.num_features)
|
|
2537
|
+
return X_train_s, X_val_s
|
|
2538
|
+
|
|
2539
|
+
def fit_predict(model, X_train, y_train, w_train, X_val, y_val, w_val, trial_obj):
|
|
2540
|
+
model.fit(
|
|
2541
|
+
X_train, y_train, w_train,
|
|
2542
|
+
X_val, y_val, w_val,
|
|
2543
|
+
trial=trial_obj
|
|
2544
|
+
)
|
|
2545
|
+
return model.predict(X_val)
|
|
2546
|
+
|
|
2547
|
+
def metric_fn(y_true, y_pred, weight):
|
|
2548
|
+
if self.ctx.task_type == 'regression':
|
|
2549
|
+
return mean_tweedie_deviance(
|
|
2550
|
+
y_true,
|
|
2551
|
+
y_pred,
|
|
2552
|
+
sample_weight=weight,
|
|
2553
|
+
power=metric_ctx.get("tw_power", base_tw_power)
|
|
2554
|
+
)
|
|
2555
|
+
return log_loss(y_true, y_pred, sample_weight=weight)
|
|
2556
|
+
|
|
2557
|
+
sample_cap = data_provider()[0]
|
|
2558
|
+
max_rows_for_resnet_bo = min(100000, int(len(sample_cap)/5))
|
|
2559
|
+
|
|
2560
|
+
return self.cross_val_generic(
|
|
2561
|
+
trial=trial,
|
|
2562
|
+
hyperparameter_space={
|
|
2563
|
+
"learning_rate": lambda t: t.suggest_float('learning_rate', 1e-6, 1e-2, log=True),
|
|
2564
|
+
"hidden_dim": lambda t: t.suggest_int('hidden_dim', 8, 32, step=2),
|
|
2565
|
+
"block_num": lambda t: t.suggest_int('block_num', 2, 10),
|
|
2566
|
+
**({"tw_power": lambda t: t.suggest_float('tw_power', 1.0, 2.0)} if self.ctx.task_type == 'regression' and self.ctx.obj == 'reg:tweedie' else {})
|
|
2567
|
+
},
|
|
2568
|
+
data_provider=data_provider,
|
|
2569
|
+
model_builder=model_builder,
|
|
2570
|
+
metric_fn=metric_fn,
|
|
2571
|
+
sample_limit=max_rows_for_resnet_bo if len(
|
|
2572
|
+
sample_cap) > max_rows_for_resnet_bo > 0 else None,
|
|
2573
|
+
preprocess_fn=preprocess_fn,
|
|
2574
|
+
fit_predict_fn=fit_predict,
|
|
2575
|
+
cleanup_fn=lambda m: getattr(
|
|
2576
|
+
getattr(m, "resnet", None), "to", lambda *_args, **_kwargs: None)("cpu")
|
|
2577
|
+
)
|
|
2578
|
+
|
|
2579
|
+
# ========= 用最优超参训练最终 ResNet =========
|
|
2580
|
+
def train(self) -> None:
|
|
2581
|
+
if not self.best_params:
|
|
2582
|
+
raise RuntimeError('请先运行 tune() 以获得 ResNet 最优参数。')
|
|
2583
|
+
|
|
2584
|
+
self.model = ResNetSklearn(
|
|
2585
|
+
model_nme=self.ctx.model_nme,
|
|
2586
|
+
input_dim=self.ctx.train_oht_scl_data[self.ctx.var_nmes].shape[1],
|
|
2587
|
+
task_type=self.ctx.task_type,
|
|
2588
|
+
use_data_parallel=self.ctx.config.use_resn_data_parallel,
|
|
2589
|
+
use_ddp=self.ctx.config.use_resn_ddp
|
|
2590
|
+
)
|
|
2591
|
+
self.model.set_params(self.best_params)
|
|
2592
|
+
|
|
2593
|
+
self._fit_predict_cache(
|
|
2594
|
+
self.model,
|
|
2595
|
+
self.ctx.train_oht_scl_data[self.ctx.var_nmes],
|
|
2596
|
+
self.ctx.train_oht_scl_data[self.ctx.resp_nme],
|
|
2597
|
+
sample_weight=self.ctx.train_oht_scl_data[self.ctx.weight_nme],
|
|
2598
|
+
pred_prefix='resn',
|
|
2599
|
+
use_oht=True,
|
|
2600
|
+
sample_weight_arg='w_train'
|
|
2601
|
+
)
|
|
2602
|
+
|
|
2603
|
+
# 方便外部调用
|
|
2604
|
+
self.ctx.resn_best = self.model
|
|
2605
|
+
|
|
2606
|
+
# ========= 保存 / 加载 =========
|
|
2607
|
+
# ResNet 使用 state_dict 保存,需要特殊的 load 逻辑,所以保留 load
|
|
2608
|
+
# save 逻辑已经在 TrainerBase 中处理了 (check for .resnet attribute)
|
|
2609
|
+
|
|
2610
|
+
def load(self) -> None:
|
|
2611
|
+
# 将磁盘中的 ResNet 权重加载到当前设备,保持与上下文一致。
|
|
2612
|
+
path = self.output.model_path(self._get_model_filename())
|
|
2613
|
+
if os.path.exists(path):
|
|
2614
|
+
resn_loaded = ResNetSklearn(
|
|
2615
|
+
model_nme=self.ctx.model_nme,
|
|
2616
|
+
input_dim=self.ctx.train_oht_scl_data[self.ctx.var_nmes].shape[1],
|
|
2617
|
+
task_type=self.ctx.task_type,
|
|
2618
|
+
use_data_parallel=self.ctx.config.use_resn_data_parallel,
|
|
2619
|
+
use_ddp=self.ctx.config.use_resn_ddp
|
|
2620
|
+
)
|
|
2621
|
+
state_dict = torch.load(path, map_location='cpu')
|
|
2622
|
+
resn_loaded.resnet.load_state_dict(state_dict)
|
|
2623
|
+
|
|
2624
|
+
self._move_to_device(resn_loaded)
|
|
2625
|
+
self.model = resn_loaded
|
|
2626
|
+
self.ctx.resn_best = self.model
|
|
2627
|
+
else:
|
|
2628
|
+
print(f"[ResNetTrainer.load] 未找到模型文件:{path}")
|
|
2629
|
+
|
|
2630
|
+
|
|
2631
|
+
class FTTrainer(TrainerBase):
|
|
2632
|
+
def __init__(self, context: "BayesOptModel") -> None:
|
|
2633
|
+
if context.task_type == 'classification':
|
|
2634
|
+
super().__init__(context, 'FTTransformerClassifier', 'FTTransformer')
|
|
2635
|
+
else:
|
|
2636
|
+
super().__init__(context, 'FTTransformer', 'FTTransformer')
|
|
2637
|
+
self.model: Optional[FTTransformerSklearn] = None
|
|
2638
|
+
|
|
2639
|
+
def cross_val(self, trial: optuna.trial.Trial) -> float:
|
|
2640
|
+
# 针对 FT-Transformer 的交叉验证,重点同样在显存控制:
|
|
2641
|
+
# - 收缩超参搜索空间,防止不必要的超大模型;
|
|
2642
|
+
# - 每个 fold 结束后立即释放 GPU 显存,确保下一个 trial 顺利进行。
|
|
2643
|
+
# 超参空间适当缩小一点,避免特别大的模型
|
|
2644
|
+
param_space: Dict[str, Callable[[optuna.trial.Trial], Any]] = {
|
|
2645
|
+
"learning_rate": lambda t: t.suggest_float('learning_rate', 1e-5, 5e-4, log=True),
|
|
2646
|
+
"d_model": lambda t: t.suggest_int('d_model', 32, 256, step=32),
|
|
2647
|
+
"n_heads": lambda t: t.suggest_categorical('n_heads', [2, 4, 8]),
|
|
2648
|
+
"n_layers": lambda t: t.suggest_int('n_layers', 2, 8),
|
|
2649
|
+
"dropout": lambda t: t.suggest_float('dropout', 0.0, 0.2)
|
|
2650
|
+
}
|
|
2651
|
+
if self.ctx.task_type == 'regression' and self.ctx.obj == 'reg:tweedie':
|
|
2652
|
+
param_space["tw_power"] = lambda t: t.suggest_float(
|
|
2653
|
+
'tw_power', 1.0, 2.0)
|
|
2654
|
+
|
|
2655
|
+
metric_ctx: Dict[str, Any] = {}
|
|
2656
|
+
|
|
2657
|
+
def data_provider():
|
|
2658
|
+
data = self.ctx.train_data
|
|
2659
|
+
return data[self.ctx.factor_nmes], data[self.ctx.resp_nme], data[self.ctx.weight_nme]
|
|
2660
|
+
|
|
2661
|
+
def model_builder(params):
|
|
2662
|
+
d_model = params["d_model"]
|
|
2663
|
+
n_layers = params["n_layers"]
|
|
2664
|
+
approx_units = d_model * n_layers * \
|
|
2665
|
+
max(1, len(self.ctx.factor_nmes))
|
|
2666
|
+
if approx_units > 1_200_000:
|
|
2667
|
+
print(
|
|
2668
|
+
f"[FTTrainer] Trial pruned early: d_model={d_model}, n_layers={n_layers} -> approx_units={approx_units}")
|
|
2669
|
+
raise optuna.TrialPruned(
|
|
2670
|
+
"config exceeds safe memory budget; prune before training")
|
|
2671
|
+
|
|
2672
|
+
tw_power = params.get("tw_power")
|
|
2673
|
+
if self.ctx.task_type == 'regression':
|
|
2674
|
+
if self.ctx.obj == 'count:poisson':
|
|
2675
|
+
tw_power = 1.0
|
|
2676
|
+
elif self.ctx.obj == 'reg:gamma':
|
|
2677
|
+
tw_power = 2.0
|
|
2678
|
+
elif tw_power is None:
|
|
2679
|
+
tw_power = 1.5
|
|
2680
|
+
metric_ctx["tw_power"] = tw_power
|
|
2681
|
+
|
|
2682
|
+
return FTTransformerSklearn(
|
|
2683
|
+
model_nme=self.ctx.model_nme,
|
|
2684
|
+
num_cols=self.ctx.num_features,
|
|
2685
|
+
cat_cols=self.ctx.cate_list,
|
|
2686
|
+
d_model=d_model,
|
|
2687
|
+
n_heads=params["n_heads"],
|
|
2688
|
+
n_layers=n_layers,
|
|
2689
|
+
dropout=params["dropout"],
|
|
2690
|
+
task_type=self.ctx.task_type,
|
|
2691
|
+
epochs=self.ctx.epochs,
|
|
2692
|
+
tweedie_power=tw_power,
|
|
2693
|
+
learning_rate=params["learning_rate"],
|
|
2694
|
+
patience=5,
|
|
2695
|
+
use_data_parallel=self.ctx.config.use_ft_data_parallel,
|
|
2696
|
+
use_ddp=self.ctx.config.use_ft_ddp
|
|
2697
|
+
)
|
|
2698
|
+
|
|
2699
|
+
def fit_predict(model, X_train, y_train, w_train, X_val, y_val, w_val, trial_obj):
|
|
2700
|
+
model.fit(
|
|
2701
|
+
X_train, y_train, w_train,
|
|
2702
|
+
X_val, y_val, w_val,
|
|
2703
|
+
trial=trial_obj
|
|
2704
|
+
)
|
|
2705
|
+
return model.predict(X_val)
|
|
2706
|
+
|
|
2707
|
+
def metric_fn(y_true, y_pred, weight):
|
|
2708
|
+
if self.ctx.task_type == 'regression':
|
|
2709
|
+
return mean_tweedie_deviance(
|
|
2710
|
+
y_true,
|
|
2711
|
+
y_pred,
|
|
2712
|
+
sample_weight=weight,
|
|
2713
|
+
power=metric_ctx.get("tw_power", 1.5)
|
|
2714
|
+
)
|
|
2715
|
+
return log_loss(y_true, y_pred, sample_weight=weight)
|
|
2716
|
+
|
|
2717
|
+
data_for_cap = data_provider()[0]
|
|
2718
|
+
max_rows_for_ft_bo = min(1000000, int(len(data_for_cap)/2))
|
|
2719
|
+
|
|
2720
|
+
return self.cross_val_generic(
|
|
2721
|
+
trial=trial,
|
|
2722
|
+
hyperparameter_space=param_space,
|
|
2723
|
+
data_provider=data_provider,
|
|
2724
|
+
model_builder=model_builder,
|
|
2725
|
+
metric_fn=metric_fn,
|
|
2726
|
+
sample_limit=max_rows_for_ft_bo if len(
|
|
2727
|
+
data_for_cap) > max_rows_for_ft_bo > 0 else None,
|
|
2728
|
+
fit_predict_fn=fit_predict,
|
|
2729
|
+
cleanup_fn=lambda m: getattr(
|
|
2730
|
+
getattr(m, "ft", None), "to", lambda *_args, **_kwargs: None)("cpu")
|
|
2731
|
+
)
|
|
2732
|
+
|
|
2733
|
+
def train(self) -> None:
|
|
2734
|
+
if not self.best_params:
|
|
2735
|
+
raise RuntimeError('请先运行 tune() 以获得 FT-Transformer 最优参数。')
|
|
2736
|
+
self.model = FTTransformerSklearn(
|
|
2737
|
+
model_nme=self.ctx.model_nme,
|
|
2738
|
+
num_cols=self.ctx.num_features,
|
|
2739
|
+
cat_cols=self.ctx.cate_list,
|
|
2740
|
+
task_type=self.ctx.task_type,
|
|
2741
|
+
use_data_parallel=self.ctx.config.use_ft_data_parallel,
|
|
2742
|
+
use_ddp=self.ctx.config.use_ft_ddp
|
|
2743
|
+
)
|
|
2744
|
+
self.model.set_params(self.best_params)
|
|
2745
|
+
self._fit_predict_cache(
|
|
2746
|
+
self.model,
|
|
2747
|
+
self.ctx.train_data[self.ctx.factor_nmes],
|
|
2748
|
+
self.ctx.train_data[self.ctx.resp_nme],
|
|
2749
|
+
sample_weight=self.ctx.train_data[self.ctx.weight_nme],
|
|
2750
|
+
pred_prefix='ft',
|
|
2751
|
+
sample_weight_arg='w_train'
|
|
2752
|
+
)
|
|
2753
|
+
self.ctx.ft_best = self.model
|
|
2754
|
+
|
|
2755
|
+
|
|
2756
|
+
class GNNTrainer(TrainerBase):
|
|
2757
|
+
def __init__(self, context: "BayesOptModel") -> None:
|
|
2758
|
+
if context.task_type == 'classification':
|
|
2759
|
+
super().__init__(context, 'GNNClassifier', 'GNN')
|
|
2760
|
+
else:
|
|
2761
|
+
super().__init__(context, 'GNN', 'GNN')
|
|
2762
|
+
self.model: Optional[GraphNeuralNetSklearn] = None
|
|
2763
|
+
|
|
2764
|
+
def cross_val(self, trial: optuna.trial.Trial) -> float:
|
|
2765
|
+
base_tw_power = None
|
|
2766
|
+
if self.ctx.task_type == 'regression':
|
|
2767
|
+
if self.ctx.obj == 'count:poisson':
|
|
2768
|
+
base_tw_power = 1.0
|
|
2769
|
+
elif self.ctx.obj == 'reg:gamma':
|
|
2770
|
+
base_tw_power = 2.0
|
|
2771
|
+
else:
|
|
2772
|
+
base_tw_power = 1.5
|
|
2773
|
+
|
|
2774
|
+
metric_ctx: Dict[str, Any] = {}
|
|
2775
|
+
|
|
2776
|
+
def data_provider():
|
|
2777
|
+
data = self.ctx.train_oht_data if self.ctx.train_oht_data is not None else self.ctx.train_oht_scl_data
|
|
2778
|
+
assert data is not None, "Preprocessed training data is missing."
|
|
2779
|
+
return data[self.ctx.var_nmes], data[self.ctx.resp_nme], data[self.ctx.weight_nme]
|
|
2780
|
+
|
|
2781
|
+
def preprocess_fn(X_train, X_val):
|
|
2782
|
+
X_train_s, X_val_s, _ = self._standardize_fold(
|
|
2783
|
+
X_train, X_val, self.ctx.num_features)
|
|
2784
|
+
return X_train_s, X_val_s
|
|
2785
|
+
|
|
2786
|
+
def model_builder(params):
|
|
2787
|
+
power = params.get("tw_power", base_tw_power)
|
|
2788
|
+
metric_ctx["tw_power"] = power
|
|
2789
|
+
return GraphNeuralNetSklearn(
|
|
2790
|
+
model_nme=self.ctx.model_nme,
|
|
2791
|
+
input_dim=len(self.ctx.var_nmes),
|
|
2792
|
+
hidden_dim=params["hidden_dim"],
|
|
2793
|
+
num_layers=params["num_layers"],
|
|
2794
|
+
k_neighbors=params["k_neighbors"],
|
|
2795
|
+
dropout=params["dropout"],
|
|
2796
|
+
learning_rate=params["learning_rate"],
|
|
2797
|
+
epochs=self.ctx.epochs,
|
|
2798
|
+
patience=5,
|
|
2799
|
+
task_type=self.ctx.task_type,
|
|
2800
|
+
tweedie_power=power if power is not None else 1.5,
|
|
2801
|
+
use_ddp=False, # BO 阶段默认单卡,避免多进程与 Optuna 冲突
|
|
2802
|
+
use_approx_knn=self.ctx.config.gnn_use_approx_knn,
|
|
2803
|
+
approx_knn_threshold=self.ctx.config.gnn_approx_knn_threshold
|
|
2804
|
+
)
|
|
2805
|
+
|
|
2806
|
+
def fit_predict(model, X_train, y_train, w_train, X_val, y_val, w_val, trial_obj):
|
|
2807
|
+
model.fit(
|
|
2808
|
+
X_train, y_train, w_train,
|
|
2809
|
+
X_val, y_val, w_val,
|
|
2810
|
+
trial=trial_obj
|
|
2811
|
+
)
|
|
2812
|
+
return model.predict(X_val)
|
|
2813
|
+
|
|
2814
|
+
def metric_fn(y_true, y_pred, weight):
|
|
2815
|
+
if self.ctx.task_type == 'regression':
|
|
2816
|
+
return mean_tweedie_deviance(
|
|
2817
|
+
y_true,
|
|
2818
|
+
y_pred,
|
|
2819
|
+
sample_weight=weight,
|
|
2820
|
+
power=metric_ctx.get("tw_power", 1.5)
|
|
2821
|
+
)
|
|
2822
|
+
return log_loss(y_true, y_pred, sample_weight=weight)
|
|
2823
|
+
|
|
2824
|
+
base_X = data_provider()[0]
|
|
2825
|
+
max_rows_for_gnn_bo = min(50_000, max(1, int(len(base_X) / 5)))
|
|
2826
|
+
|
|
2827
|
+
return self.cross_val_generic(
|
|
2828
|
+
trial=trial,
|
|
2829
|
+
hyperparameter_space={
|
|
2830
|
+
"learning_rate": lambda t: t.suggest_float('learning_rate', 1e-5, 5e-3, log=True),
|
|
2831
|
+
"hidden_dim": lambda t: t.suggest_int('hidden_dim', 16, 128, step=16),
|
|
2832
|
+
"num_layers": lambda t: t.suggest_int('num_layers', 1, 4),
|
|
2833
|
+
"k_neighbors": lambda t: t.suggest_int('k_neighbors', 5, 20),
|
|
2834
|
+
"dropout": lambda t: t.suggest_float('dropout', 0.0, 0.3),
|
|
2835
|
+
**({"tw_power": lambda t: t.suggest_float('tw_power', 1.0, 2.0)} if self.ctx.task_type == 'regression' and self.ctx.obj == 'reg:tweedie' else {})
|
|
2836
|
+
},
|
|
2837
|
+
data_provider=data_provider,
|
|
2838
|
+
model_builder=model_builder,
|
|
2839
|
+
metric_fn=metric_fn,
|
|
2840
|
+
sample_limit=max_rows_for_gnn_bo if len(
|
|
2841
|
+
base_X) > max_rows_for_gnn_bo > 0 else None,
|
|
2842
|
+
preprocess_fn=preprocess_fn,
|
|
2843
|
+
fit_predict_fn=fit_predict,
|
|
2844
|
+
cleanup_fn=lambda m: getattr(
|
|
2845
|
+
getattr(m, "gnn", None), "to", lambda *_args, **_kwargs: None)("cpu")
|
|
2846
|
+
)
|
|
2847
|
+
|
|
2848
|
+
def train(self) -> None:
|
|
2849
|
+
if not self.best_params:
|
|
2850
|
+
raise RuntimeError('请先运行 tune() 以获得 GNN 最优参数。')
|
|
2851
|
+
|
|
2852
|
+
self.model = GraphNeuralNetSklearn(
|
|
2853
|
+
model_nme=self.ctx.model_nme,
|
|
2854
|
+
input_dim=self.ctx.train_oht_scl_data[self.ctx.var_nmes].shape[1],
|
|
2855
|
+
task_type=self.ctx.task_type,
|
|
2856
|
+
use_data_parallel=self.ctx.config.use_gnn_data_parallel,
|
|
2857
|
+
use_ddp=self.ctx.config.use_gnn_ddp,
|
|
2858
|
+
use_approx_knn=self.ctx.config.gnn_use_approx_knn,
|
|
2859
|
+
approx_knn_threshold=self.ctx.config.gnn_approx_knn_threshold,
|
|
2860
|
+
graph_cache_path=self.ctx.config.gnn_graph_cache
|
|
2861
|
+
)
|
|
2862
|
+
self.model.set_params(self.best_params)
|
|
2863
|
+
|
|
2864
|
+
self._fit_predict_cache(
|
|
2865
|
+
self.model,
|
|
2866
|
+
self.ctx.train_oht_scl_data[self.ctx.var_nmes],
|
|
2867
|
+
self.ctx.train_oht_scl_data[self.ctx.resp_nme],
|
|
2868
|
+
sample_weight=self.ctx.train_oht_scl_data[self.ctx.weight_nme],
|
|
2869
|
+
pred_prefix='gnn',
|
|
2870
|
+
use_oht=True,
|
|
2871
|
+
sample_weight_arg='w_train'
|
|
2872
|
+
)
|
|
2873
|
+
self.ctx.gnn_best = self.model
|
|
2874
|
+
|
|
2875
|
+
|
|
2876
|
+
# =============================================================================
|
|
2877
|
+
# BayesOpt orchestration & SHAP utilities
|
|
2878
|
+
# =============================================================================
|
|
2879
|
+
class BayesOptModel:
|
|
2880
|
+
def __init__(self, train_data, test_data,
|
|
2881
|
+
model_nme, resp_nme, weight_nme, factor_nmes, task_type='regression',
|
|
2882
|
+
binary_resp_nme=None,
|
|
2883
|
+
cate_list=None, prop_test=0.25, rand_seed=None,
|
|
2884
|
+
epochs=100, use_gpu=True,
|
|
2885
|
+
use_resn_data_parallel: bool = False, use_ft_data_parallel: bool = False,
|
|
2886
|
+
use_gnn_data_parallel: bool = False,
|
|
2887
|
+
use_resn_ddp: bool = False, use_ft_ddp: bool = False,
|
|
2888
|
+
use_gnn_ddp: bool = False,
|
|
2889
|
+
output_dir: Optional[str] = None,
|
|
2890
|
+
gnn_use_approx_knn: bool = True,
|
|
2891
|
+
gnn_approx_knn_threshold: int = 50000,
|
|
2892
|
+
gnn_graph_cache: Optional[str] = None):
|
|
2893
|
+
"""BayesOpt orchestration layer for all trainers.
|
|
2894
|
+
|
|
2895
|
+
Args:
|
|
2896
|
+
train_data: Training dataframe.
|
|
2897
|
+
test_data: Testing dataframe.
|
|
2898
|
+
model_nme: Model name prefix for artifacts.
|
|
2899
|
+
resp_nme: Target column name.
|
|
2900
|
+
weight_nme: Sample-weight column name.
|
|
2901
|
+
factor_nmes: Feature column names.
|
|
2902
|
+
task_type: 'regression' or 'classification'.
|
|
2903
|
+
binary_resp_nme: Optional binary response for conversion plots.
|
|
2904
|
+
cate_list: Categorical feature names.
|
|
2905
|
+
prop_test: Validation proportion for CV splits.
|
|
2906
|
+
rand_seed: Optional random seed.
|
|
2907
|
+
epochs: Epochs for neural models.
|
|
2908
|
+
use_gpu: Whether to prefer GPU for supported models.
|
|
2909
|
+
use_resn_data_parallel: Enable DataParallel for ResNet.
|
|
2910
|
+
use_ft_data_parallel: Enable DataParallel for FTTransformer.
|
|
2911
|
+
use_gnn_data_parallel: Enable DataParallel for GNN.
|
|
2912
|
+
use_resn_ddp: Enable DDP for ResNet.
|
|
2913
|
+
use_ft_ddp: Enable DDP for FTTransformer.
|
|
2914
|
+
use_gnn_ddp: Enable DDP for GNN.
|
|
2915
|
+
output_dir: Root directory for models/results/plots.
|
|
2916
|
+
gnn_use_approx_knn: Use approximate k-NN when available.
|
|
2917
|
+
gnn_approx_knn_threshold: Row threshold to trigger approximate k-NN.
|
|
2918
|
+
gnn_graph_cache: Optional path to persist/load cached adjacency.
|
|
2919
|
+
"""
|
|
2920
|
+
cfg = BayesOptConfig(
|
|
2921
|
+
model_nme=model_nme,
|
|
2922
|
+
task_type=task_type,
|
|
2923
|
+
resp_nme=resp_nme,
|
|
2924
|
+
weight_nme=weight_nme,
|
|
2925
|
+
factor_nmes=list(factor_nmes),
|
|
2926
|
+
binary_resp_nme=binary_resp_nme,
|
|
2927
|
+
cate_list=list(cate_list) if cate_list else None,
|
|
2928
|
+
prop_test=prop_test,
|
|
2929
|
+
rand_seed=rand_seed,
|
|
2930
|
+
epochs=epochs,
|
|
2931
|
+
use_gpu=use_gpu,
|
|
2932
|
+
use_resn_data_parallel=use_resn_data_parallel,
|
|
2933
|
+
use_ft_data_parallel=use_ft_data_parallel,
|
|
2934
|
+
use_resn_ddp=use_resn_ddp,
|
|
2935
|
+
use_gnn_data_parallel=use_gnn_data_parallel,
|
|
2936
|
+
use_ft_ddp=use_ft_ddp,
|
|
2937
|
+
use_gnn_ddp=use_gnn_ddp,
|
|
2938
|
+
gnn_use_approx_knn=gnn_use_approx_knn,
|
|
2939
|
+
gnn_approx_knn_threshold=gnn_approx_knn_threshold,
|
|
2940
|
+
gnn_graph_cache=gnn_graph_cache,
|
|
2941
|
+
output_dir=output_dir
|
|
2942
|
+
)
|
|
2943
|
+
self.config = cfg
|
|
2944
|
+
self.model_nme = cfg.model_nme
|
|
2945
|
+
self.task_type = cfg.task_type
|
|
2946
|
+
self.resp_nme = cfg.resp_nme
|
|
2947
|
+
self.weight_nme = cfg.weight_nme
|
|
2948
|
+
self.factor_nmes = cfg.factor_nmes
|
|
2949
|
+
self.binary_resp_nme = cfg.binary_resp_nme
|
|
2950
|
+
self.cate_list = list(cfg.cate_list or [])
|
|
2951
|
+
self.prop_test = cfg.prop_test
|
|
2952
|
+
self.epochs = cfg.epochs
|
|
2953
|
+
self.rand_seed = cfg.rand_seed if cfg.rand_seed is not None else np.random.randint(
|
|
2954
|
+
1, 10000)
|
|
2955
|
+
self.use_gpu = bool(cfg.use_gpu and torch.cuda.is_available())
|
|
2956
|
+
self.output_manager = OutputManager(
|
|
2957
|
+
cfg.output_dir or os.getcwd(), self.model_nme)
|
|
2958
|
+
|
|
2959
|
+
preprocessor = DatasetPreprocessor(train_data, test_data, cfg).run()
|
|
2960
|
+
self.train_data = preprocessor.train_data
|
|
2961
|
+
self.test_data = preprocessor.test_data
|
|
2962
|
+
self.train_oht_data = preprocessor.train_oht_data
|
|
2963
|
+
self.test_oht_data = preprocessor.test_oht_data
|
|
2964
|
+
self.train_oht_scl_data = preprocessor.train_oht_scl_data
|
|
2965
|
+
self.test_oht_scl_data = preprocessor.test_oht_scl_data
|
|
2966
|
+
self.var_nmes = preprocessor.var_nmes
|
|
2967
|
+
self.num_features = preprocessor.num_features
|
|
2968
|
+
self.cat_categories_for_shap = preprocessor.cat_categories_for_shap
|
|
2969
|
+
|
|
2970
|
+
self.cv = ShuffleSplit(n_splits=int(1/self.prop_test),
|
|
2971
|
+
test_size=self.prop_test,
|
|
2972
|
+
random_state=self.rand_seed)
|
|
2973
|
+
if self.task_type == 'classification':
|
|
2974
|
+
self.obj = 'binary:logistic'
|
|
2975
|
+
else: # regression
|
|
2976
|
+
if 'f' in self.model_nme:
|
|
2977
|
+
self.obj = 'count:poisson'
|
|
2978
|
+
elif 's' in self.model_nme:
|
|
2979
|
+
self.obj = 'reg:gamma'
|
|
2980
|
+
elif 'bc' in self.model_nme:
|
|
2981
|
+
self.obj = 'reg:tweedie'
|
|
2982
|
+
else:
|
|
2983
|
+
self.obj = 'reg:tweedie'
|
|
2984
|
+
self.fit_params = {
|
|
2985
|
+
'sample_weight': self.train_data[self.weight_nme].values
|
|
2986
|
+
}
|
|
2987
|
+
self.model_label: List[str] = []
|
|
2988
|
+
self.optuna_storage = cfg.optuna_storage
|
|
2989
|
+
self.optuna_study_prefix = cfg.optuna_study_prefix or "bayesopt"
|
|
2990
|
+
|
|
2991
|
+
# 记录各模型训练器,后续统一通过标签访问,方便扩展新模型
|
|
2992
|
+
self.trainers: Dict[str, TrainerBase] = {
|
|
2993
|
+
'glm': GLMTrainer(self),
|
|
2994
|
+
'xgb': XGBTrainer(self),
|
|
2995
|
+
'resn': ResNetTrainer(self),
|
|
2996
|
+
'gnn': GNNTrainer(self),
|
|
2997
|
+
'ft': FTTrainer(self)
|
|
2998
|
+
}
|
|
2999
|
+
self.xgb_best = None
|
|
3000
|
+
self.resn_best = None
|
|
3001
|
+
self.glm_best = None
|
|
3002
|
+
self.ft_best = None
|
|
3003
|
+
self.gnn_best = None
|
|
3004
|
+
self.best_xgb_params = None
|
|
3005
|
+
self.best_resn_params = None
|
|
3006
|
+
self.best_ft_params = None
|
|
3007
|
+
self.best_gnn_params = None
|
|
3008
|
+
self.best_xgb_trial = None
|
|
3009
|
+
self.best_resn_trial = None
|
|
3010
|
+
self.best_ft_trial = None
|
|
3011
|
+
self.best_gnn_trial = None
|
|
3012
|
+
self.best_glm_params = None
|
|
3013
|
+
self.best_glm_trial = None
|
|
3014
|
+
self.xgb_load = None
|
|
3015
|
+
self.resn_load = None
|
|
3016
|
+
self.ft_load = None
|
|
3017
|
+
self.gnn_load = None
|
|
3018
|
+
|
|
3019
|
+
# 定义单因素画图函数
|
|
3020
|
+
def plot_oneway(self, n_bins=10):
|
|
3021
|
+
for c in self.factor_nmes:
|
|
3022
|
+
fig = plt.figure(figsize=(7, 5))
|
|
3023
|
+
if c in self.cate_list:
|
|
3024
|
+
group_col = c
|
|
3025
|
+
plot_source = self.train_data
|
|
3026
|
+
else:
|
|
3027
|
+
group_col = f'{c}_bins'
|
|
3028
|
+
bins = pd.qcut(
|
|
3029
|
+
self.train_data[c],
|
|
3030
|
+
n_bins,
|
|
3031
|
+
duplicates='drop' # 注意:如果分位数重复会丢 bin,避免异常终止
|
|
3032
|
+
)
|
|
3033
|
+
plot_source = self.train_data.assign(**{group_col: bins})
|
|
3034
|
+
plot_data = plot_source.groupby(
|
|
3035
|
+
[group_col], observed=True).sum(numeric_only=True)
|
|
3036
|
+
plot_data.reset_index(inplace=True)
|
|
3037
|
+
plot_data['act_v'] = plot_data['w_act'] / \
|
|
3038
|
+
plot_data[self.weight_nme]
|
|
3039
|
+
plot_data.head()
|
|
3040
|
+
ax = fig.add_subplot(111)
|
|
3041
|
+
ax.plot(plot_data.index, plot_data['act_v'],
|
|
3042
|
+
label='Actual', color='red')
|
|
3043
|
+
ax.set_title(
|
|
3044
|
+
'Analysis of %s : Train Data' % group_col,
|
|
3045
|
+
fontsize=8)
|
|
3046
|
+
plt.xticks(plot_data.index,
|
|
3047
|
+
list(plot_data[group_col].astype(str)),
|
|
3048
|
+
rotation=90)
|
|
3049
|
+
if len(list(plot_data[group_col].astype(str))) > 50:
|
|
3050
|
+
plt.xticks(fontsize=3)
|
|
3051
|
+
else:
|
|
3052
|
+
plt.xticks(fontsize=6)
|
|
3053
|
+
plt.yticks(fontsize=6)
|
|
3054
|
+
ax2 = ax.twinx()
|
|
3055
|
+
ax2.bar(plot_data.index,
|
|
3056
|
+
plot_data[self.weight_nme],
|
|
3057
|
+
alpha=0.5, color='seagreen')
|
|
3058
|
+
plt.yticks(fontsize=6)
|
|
3059
|
+
plt.margins(0.05)
|
|
3060
|
+
plt.subplots_adjust(wspace=0.3)
|
|
3061
|
+
save_path = self.output_manager.plot_path(
|
|
3062
|
+
f'00_{self.model_nme}_{group_col}_oneway.png')
|
|
3063
|
+
plt.savefig(save_path, dpi=300)
|
|
3064
|
+
plt.close(fig)
|
|
3065
|
+
|
|
3066
|
+
# 定义通用优化函数
|
|
3067
|
+
def optimize_model(self, model_key: str, max_evals: int = 100):
|
|
3068
|
+
if model_key not in self.trainers:
|
|
3069
|
+
print(f"Warning: Unknown model key: {model_key}")
|
|
3070
|
+
return
|
|
3071
|
+
|
|
3072
|
+
trainer = self.trainers[model_key]
|
|
3073
|
+
trainer.tune(max_evals)
|
|
3074
|
+
trainer.train()
|
|
3075
|
+
|
|
3076
|
+
# Update context attributes for backward compatibility
|
|
3077
|
+
setattr(self, f"{model_key}_best", trainer.model)
|
|
3078
|
+
setattr(self, f"best_{model_key}_params", trainer.best_params)
|
|
3079
|
+
setattr(self, f"best_{model_key}_trial", trainer.best_trial)
|
|
3080
|
+
|
|
3081
|
+
# 定义GLM贝叶斯优化函数
|
|
3082
|
+
def bayesopt_glm(self, max_evals=50):
|
|
3083
|
+
self.optimize_model('glm', max_evals)
|
|
3084
|
+
|
|
3085
|
+
# 定义Xgboost贝叶斯优化函数
|
|
3086
|
+
def bayesopt_xgb(self, max_evals=100):
|
|
3087
|
+
self.optimize_model('xgb', max_evals)
|
|
3088
|
+
|
|
3089
|
+
# 定义ResNet贝叶斯优化函数
|
|
3090
|
+
def bayesopt_resnet(self, max_evals=100):
|
|
3091
|
+
self.optimize_model('resn', max_evals)
|
|
3092
|
+
|
|
3093
|
+
# 定义 GNN 贝叶斯优化函数
|
|
3094
|
+
def bayesopt_gnn(self, max_evals=50):
|
|
3095
|
+
self.optimize_model('gnn', max_evals)
|
|
3096
|
+
|
|
3097
|
+
# 定义 FT-Transformer 贝叶斯优化函数
|
|
3098
|
+
def bayesopt_ft(self, max_evals=50):
|
|
3099
|
+
self.optimize_model('ft', max_evals)
|
|
3100
|
+
|
|
3101
|
+
# 绘制提纯曲线
|
|
3102
|
+
def plot_lift(self, model_label, pred_nme, n_bins=10):
|
|
3103
|
+
model_map = {
|
|
3104
|
+
'Xgboost': 'pred_xgb',
|
|
3105
|
+
'ResNet': 'pred_resn',
|
|
3106
|
+
'ResNetClassifier': 'pred_resn',
|
|
3107
|
+
'FTTransformer': 'pred_ft',
|
|
3108
|
+
'FTTransformerClassifier': 'pred_ft',
|
|
3109
|
+
'GLM': 'pred_glm',
|
|
3110
|
+
'GNN': 'pred_gnn',
|
|
3111
|
+
'GNNClassifier': 'pred_gnn'
|
|
3112
|
+
}
|
|
3113
|
+
for k, v in model_map.items():
|
|
3114
|
+
if model_label.startswith(k):
|
|
3115
|
+
pred_nme = v
|
|
3116
|
+
break
|
|
3117
|
+
|
|
3118
|
+
fig = plt.figure(figsize=(11, 5))
|
|
3119
|
+
for pos, (title, data) in zip([121, 122],
|
|
3120
|
+
[('Lift Chart on Train Data', self.train_data),
|
|
3121
|
+
('Lift Chart on Test Data', self.test_data)]):
|
|
3122
|
+
lift_df = pd.DataFrame({
|
|
3123
|
+
'pred': data[pred_nme].values,
|
|
3124
|
+
'w_pred': data[f'w_{pred_nme}'].values,
|
|
3125
|
+
'act': data['w_act'].values,
|
|
3126
|
+
'weight': data[self.weight_nme].values
|
|
3127
|
+
})
|
|
3128
|
+
plot_data = PlotUtils.split_data(lift_df, 'pred', 'weight', n_bins)
|
|
3129
|
+
denom = np.maximum(plot_data['weight'], EPS)
|
|
3130
|
+
plot_data['exp_v'] = plot_data['w_pred'] / denom
|
|
3131
|
+
plot_data['act_v'] = plot_data['act'] / denom
|
|
3132
|
+
plot_data = plot_data.reset_index()
|
|
3133
|
+
|
|
3134
|
+
ax = fig.add_subplot(pos)
|
|
3135
|
+
PlotUtils.plot_lift_ax(ax, plot_data, title)
|
|
3136
|
+
|
|
3137
|
+
plt.subplots_adjust(wspace=0.3)
|
|
3138
|
+
save_path = self.output_manager.plot_path(
|
|
3139
|
+
f'01_{self.model_nme}_{model_label}_lift.png')
|
|
3140
|
+
plt.savefig(save_path, dpi=300)
|
|
3141
|
+
plt.show()
|
|
3142
|
+
plt.close(fig)
|
|
3143
|
+
|
|
3144
|
+
# 绘制双提纯曲线
|
|
3145
|
+
def plot_dlift(self, model_comp: List[str] = ['xgb', 'resn'], n_bins: int = 10) -> None:
|
|
3146
|
+
# 绘制双提纯曲线,对比两个模型在不同分箱下的表现。
|
|
3147
|
+
# Args:
|
|
3148
|
+
# model_comp: 需要对比的模型简称(如 ['xgb', 'resn'],支持 'xgb'/'resn'/'ft')。
|
|
3149
|
+
# n_bins: 分箱数量,用于控制 lift 曲线的粒度。
|
|
3150
|
+
if len(model_comp) != 2:
|
|
3151
|
+
raise ValueError("`model_comp` 必须包含两个模型进行对比。")
|
|
3152
|
+
|
|
3153
|
+
model_name_map = {
|
|
3154
|
+
'xgb': 'Xgboost',
|
|
3155
|
+
'resn': 'ResNet',
|
|
3156
|
+
'ft': 'FTTransformer',
|
|
3157
|
+
'glm': 'GLM',
|
|
3158
|
+
'gnn': 'GNN'
|
|
3159
|
+
}
|
|
3160
|
+
|
|
3161
|
+
name1, name2 = model_comp
|
|
3162
|
+
if name1 not in model_name_map or name2 not in model_name_map:
|
|
3163
|
+
raise ValueError(f"不支持的模型简称。请从 {list(model_name_map.keys())} 中选择。")
|
|
3164
|
+
|
|
3165
|
+
fig, axes = plt.subplots(1, 2, figsize=(11, 5))
|
|
3166
|
+
datasets = {
|
|
3167
|
+
'Train Data': self.train_data,
|
|
3168
|
+
'Test Data': self.test_data
|
|
3169
|
+
}
|
|
3170
|
+
|
|
3171
|
+
for ax, (data_name, data) in zip(axes, datasets.items()):
|
|
3172
|
+
pred1_col = f'w_pred_{name1}'
|
|
3173
|
+
pred2_col = f'w_pred_{name2}'
|
|
3174
|
+
|
|
3175
|
+
if pred1_col not in data.columns or pred2_col not in data.columns:
|
|
3176
|
+
print(
|
|
3177
|
+
f"警告: 在 {data_name} 中找不到预测列 {pred1_col} 或 {pred2_col}。跳过绘图。")
|
|
3178
|
+
continue
|
|
3179
|
+
|
|
3180
|
+
lift_data = pd.DataFrame({
|
|
3181
|
+
'pred1': data[pred1_col].values,
|
|
3182
|
+
'pred2': data[pred2_col].values,
|
|
3183
|
+
'diff_ly': data[pred1_col].values / np.maximum(data[pred2_col].values, EPS),
|
|
3184
|
+
'act': data['w_act'].values,
|
|
3185
|
+
'weight': data[self.weight_nme].values
|
|
3186
|
+
})
|
|
3187
|
+
plot_data = PlotUtils.split_data(
|
|
3188
|
+
lift_data, 'diff_ly', 'weight', n_bins)
|
|
3189
|
+
denom = np.maximum(plot_data['act'], EPS)
|
|
3190
|
+
plot_data['exp_v1'] = plot_data['pred1'] / denom
|
|
3191
|
+
plot_data['exp_v2'] = plot_data['pred2'] / denom
|
|
3192
|
+
plot_data['act_v'] = plot_data['act'] / denom
|
|
3193
|
+
plot_data.reset_index(inplace=True)
|
|
3194
|
+
|
|
3195
|
+
label1 = model_name_map[name1]
|
|
3196
|
+
label2 = model_name_map[name2]
|
|
3197
|
+
|
|
3198
|
+
PlotUtils.plot_dlift_ax(
|
|
3199
|
+
ax, plot_data, f'Double Lift Chart on {data_name}', label1, label2)
|
|
3200
|
+
|
|
3201
|
+
plt.subplots_adjust(bottom=0.25, top=0.95, right=0.8, wspace=0.3)
|
|
3202
|
+
save_path = self.output_manager.plot_path(
|
|
3203
|
+
f'02_{self.model_nme}_dlift_{name1}_vs_{name2}.png')
|
|
3204
|
+
plt.savefig(save_path, dpi=300)
|
|
3205
|
+
plt.show()
|
|
3206
|
+
plt.close(fig)
|
|
3207
|
+
|
|
3208
|
+
# 绘制成交率提升曲线
|
|
3209
|
+
def plot_conversion_lift(self, model_pred_col: str, n_bins: int = 20):
|
|
3210
|
+
if not self.binary_resp_nme:
|
|
3211
|
+
print("错误: 未在 BayesOptModel 初始化时提供 `binary_resp_nme`。无法绘制成交率曲线。")
|
|
3212
|
+
return
|
|
3213
|
+
|
|
3214
|
+
fig, axes = plt.subplots(1, 2, figsize=(14, 6), sharey=True)
|
|
3215
|
+
datasets = {
|
|
3216
|
+
'Train Data': self.train_data,
|
|
3217
|
+
'Test Data': self.test_data
|
|
3218
|
+
}
|
|
3219
|
+
|
|
3220
|
+
for ax, (data_name, data) in zip(axes, datasets.items()):
|
|
3221
|
+
if model_pred_col not in data.columns:
|
|
3222
|
+
print(f"警告: 在 {data_name} 中找不到预测列 '{model_pred_col}'。跳过绘图。")
|
|
3223
|
+
continue
|
|
3224
|
+
|
|
3225
|
+
# 按模型预测分排序,并计算分箱
|
|
3226
|
+
plot_data = data.sort_values(by=model_pred_col).copy()
|
|
3227
|
+
plot_data['cum_weight'] = plot_data[self.weight_nme].cumsum()
|
|
3228
|
+
total_weight = plot_data[self.weight_nme].sum()
|
|
3229
|
+
|
|
3230
|
+
if total_weight > EPS:
|
|
3231
|
+
plot_data['bin'] = pd.cut(
|
|
3232
|
+
plot_data['cum_weight'],
|
|
3233
|
+
bins=n_bins,
|
|
3234
|
+
labels=False,
|
|
3235
|
+
right=False
|
|
3236
|
+
)
|
|
3237
|
+
else:
|
|
3238
|
+
plot_data['bin'] = 0
|
|
3239
|
+
|
|
3240
|
+
# 按分箱聚合
|
|
3241
|
+
lift_agg = plot_data.groupby('bin').agg(
|
|
3242
|
+
total_weight=(self.weight_nme, 'sum'),
|
|
3243
|
+
actual_conversions=(self.binary_resp_nme, 'sum'),
|
|
3244
|
+
weighted_conversions=('w_binary_act', 'sum'),
|
|
3245
|
+
avg_pred=(model_pred_col, 'mean')
|
|
3246
|
+
).reset_index()
|
|
3247
|
+
|
|
3248
|
+
# 计算成交率
|
|
3249
|
+
lift_agg['conversion_rate'] = lift_agg['weighted_conversions'] / \
|
|
3250
|
+
lift_agg['total_weight']
|
|
3251
|
+
|
|
3252
|
+
# 计算整体平均成交率
|
|
3253
|
+
overall_conversion_rate = data['w_binary_act'].sum(
|
|
3254
|
+
) / data[self.weight_nme].sum()
|
|
3255
|
+
ax.axhline(y=overall_conversion_rate, color='gray', linestyle='--',
|
|
3256
|
+
label=f'Overall Avg Rate ({overall_conversion_rate:.2%})')
|
|
3257
|
+
|
|
3258
|
+
ax.plot(lift_agg['bin'], lift_agg['conversion_rate'],
|
|
3259
|
+
marker='o', linestyle='-', label='Actual Conversion Rate')
|
|
3260
|
+
ax.set_title(f'Conversion Rate Lift Chart on {data_name}')
|
|
3261
|
+
ax.set_xlabel(f'Model Score Decile (based on {model_pred_col})')
|
|
3262
|
+
ax.set_ylabel('Conversion Rate')
|
|
3263
|
+
ax.grid(True, linestyle='--', alpha=0.6)
|
|
3264
|
+
ax.legend()
|
|
3265
|
+
|
|
3266
|
+
plt.tight_layout()
|
|
3267
|
+
plt.show()
|
|
3268
|
+
|
|
3269
|
+
# 保存模型
|
|
3270
|
+
def save_model(self, model_name=None):
|
|
3271
|
+
keys = [model_name] if model_name else self.trainers.keys()
|
|
3272
|
+
for key in keys:
|
|
3273
|
+
if key in self.trainers:
|
|
3274
|
+
self.trainers[key].save()
|
|
3275
|
+
else:
|
|
3276
|
+
if model_name: # Only warn if specific model requested
|
|
3277
|
+
print(f"[save_model] Warning: Unknown model key {key}")
|
|
3278
|
+
|
|
3279
|
+
def load_model(self, model_name=None):
|
|
3280
|
+
keys = [model_name] if model_name else self.trainers.keys()
|
|
3281
|
+
for key in keys:
|
|
3282
|
+
if key in self.trainers:
|
|
3283
|
+
self.trainers[key].load()
|
|
3284
|
+
# Update context attributes
|
|
3285
|
+
trainer = self.trainers[key]
|
|
3286
|
+
if trainer.model is not None:
|
|
3287
|
+
setattr(self, f"{key}_best", trainer.model)
|
|
3288
|
+
# Also update xxx_load for backward compatibility if needed
|
|
3289
|
+
# Original code had xgb_load, resn_load, ft_load but not glm_load
|
|
3290
|
+
if key in ['xgb', 'resn', 'ft']:
|
|
3291
|
+
setattr(self, f"{key}_load", trainer.model)
|
|
3292
|
+
else:
|
|
3293
|
+
if model_name:
|
|
3294
|
+
print(f"[load_model] Warning: Unknown model key {key}")
|
|
3295
|
+
|
|
3296
|
+
def _sample_rows(self, data: pd.DataFrame, n: int) -> pd.DataFrame:
|
|
3297
|
+
if len(data) == 0:
|
|
3298
|
+
return data
|
|
3299
|
+
return data.sample(min(len(data), n), random_state=self.rand_seed)
|
|
3300
|
+
|
|
3301
|
+
@staticmethod
|
|
3302
|
+
def _shap_nsamples(arr: np.ndarray, max_nsamples: int = 300) -> int:
|
|
3303
|
+
min_needed = arr.shape[1] + 2
|
|
3304
|
+
return max(min_needed, min(max_nsamples, arr.shape[0] * arr.shape[1]))
|
|
3305
|
+
|
|
3306
|
+
def _build_ft_shap_matrix(self, data: pd.DataFrame) -> np.ndarray:
|
|
3307
|
+
matrices = []
|
|
3308
|
+
for col in self.factor_nmes:
|
|
3309
|
+
s = data[col]
|
|
3310
|
+
if col in self.cate_list:
|
|
3311
|
+
cats = pd.Categorical(
|
|
3312
|
+
s,
|
|
3313
|
+
categories=self.cat_categories_for_shap[col]
|
|
3314
|
+
)
|
|
3315
|
+
codes = np.asarray(cats.codes, dtype=np.float64).reshape(-1, 1)
|
|
3316
|
+
matrices.append(codes)
|
|
3317
|
+
else:
|
|
3318
|
+
vals = pd.to_numeric(s, errors="coerce")
|
|
3319
|
+
arr = vals.to_numpy(dtype=np.float64, copy=True).reshape(-1, 1)
|
|
3320
|
+
matrices.append(arr)
|
|
3321
|
+
X_mat = np.concatenate(matrices, axis=1) # (N, F)
|
|
3322
|
+
return X_mat
|
|
3323
|
+
|
|
3324
|
+
def _decode_ft_shap_matrix_to_df(self, X_mat: np.ndarray) -> pd.DataFrame:
|
|
3325
|
+
data_dict = {}
|
|
3326
|
+
for j, col in enumerate(self.factor_nmes):
|
|
3327
|
+
col_vals = X_mat[:, j]
|
|
3328
|
+
if col in self.cate_list:
|
|
3329
|
+
cats = self.cat_categories_for_shap[col]
|
|
3330
|
+
codes = np.round(col_vals).astype(int)
|
|
3331
|
+
codes = np.clip(codes, -1, len(cats) - 1)
|
|
3332
|
+
cat_series = pd.Categorical.from_codes(
|
|
3333
|
+
codes,
|
|
3334
|
+
categories=cats
|
|
3335
|
+
)
|
|
3336
|
+
data_dict[col] = cat_series
|
|
3337
|
+
else:
|
|
3338
|
+
data_dict[col] = col_vals.astype(float)
|
|
3339
|
+
|
|
3340
|
+
df = pd.DataFrame(data_dict, columns=self.factor_nmes)
|
|
3341
|
+
for col in self.cate_list:
|
|
3342
|
+
if col in df.columns:
|
|
3343
|
+
df[col] = df[col].astype("category")
|
|
3344
|
+
return df
|
|
3345
|
+
|
|
3346
|
+
def _build_glm_design(self, data: pd.DataFrame) -> pd.DataFrame:
|
|
3347
|
+
X = data[self.var_nmes]
|
|
3348
|
+
return sm.add_constant(X, has_constant='add')
|
|
3349
|
+
|
|
3350
|
+
def _compute_shap_core(self,
|
|
3351
|
+
model_key: str,
|
|
3352
|
+
n_background: int,
|
|
3353
|
+
n_samples: int,
|
|
3354
|
+
on_train: bool,
|
|
3355
|
+
X_df: pd.DataFrame,
|
|
3356
|
+
prep_fn,
|
|
3357
|
+
predict_fn,
|
|
3358
|
+
cleanup_fn=None):
|
|
3359
|
+
if model_key not in self.trainers or self.trainers[model_key].model is None:
|
|
3360
|
+
raise RuntimeError(f"Model {model_key} not trained.")
|
|
3361
|
+
if cleanup_fn:
|
|
3362
|
+
cleanup_fn()
|
|
3363
|
+
bg_df = self._sample_rows(X_df, n_background)
|
|
3364
|
+
bg_mat = prep_fn(bg_df)
|
|
3365
|
+
explainer = shap.KernelExplainer(predict_fn, bg_mat)
|
|
3366
|
+
ex_df = self._sample_rows(X_df, n_samples)
|
|
3367
|
+
ex_mat = prep_fn(ex_df)
|
|
3368
|
+
nsample_eff = self._shap_nsamples(ex_mat)
|
|
3369
|
+
shap_values = explainer.shap_values(ex_mat, nsamples=nsample_eff)
|
|
3370
|
+
bg_pred = predict_fn(bg_mat)
|
|
3371
|
+
base_value = float(np.asarray(bg_pred).mean())
|
|
3372
|
+
|
|
3373
|
+
return {
|
|
3374
|
+
"explainer": explainer,
|
|
3375
|
+
"X_explain": ex_df,
|
|
3376
|
+
"shap_values": shap_values,
|
|
3377
|
+
"base_value": base_value
|
|
3378
|
+
}
|
|
3379
|
+
|
|
3380
|
+
# ========= GLM SHAP =========
|
|
3381
|
+
def compute_shap_glm(self, n_background: int = 500,
|
|
3382
|
+
n_samples: int = 200,
|
|
3383
|
+
on_train: bool = True):
|
|
3384
|
+
data = self.train_oht_scl_data if on_train else self.test_oht_scl_data
|
|
3385
|
+
design_all = self._build_glm_design(data)
|
|
3386
|
+
design_cols = list(design_all.columns)
|
|
3387
|
+
|
|
3388
|
+
def predict_wrapper(x_np):
|
|
3389
|
+
x_df = pd.DataFrame(x_np, columns=design_cols)
|
|
3390
|
+
y_pred = self.glm_best.predict(x_df)
|
|
3391
|
+
return np.asarray(y_pred, dtype=np.float64).reshape(-1)
|
|
3392
|
+
|
|
3393
|
+
self.shap_glm = self._compute_shap_core(
|
|
3394
|
+
'glm', n_background, n_samples, on_train,
|
|
3395
|
+
X_df=design_all,
|
|
3396
|
+
prep_fn=lambda df: df.to_numpy(dtype=np.float64),
|
|
3397
|
+
predict_fn=predict_wrapper
|
|
3398
|
+
)
|
|
3399
|
+
return self.shap_glm
|
|
3400
|
+
|
|
3401
|
+
# ========= XGBoost SHAP =========
|
|
3402
|
+
def compute_shap_xgb(self, n_background: int = 500,
|
|
3403
|
+
n_samples: int = 200,
|
|
3404
|
+
on_train: bool = True):
|
|
3405
|
+
data = self.train_data if on_train else self.test_data
|
|
3406
|
+
X_raw = data[self.factor_nmes]
|
|
3407
|
+
|
|
3408
|
+
def predict_wrapper(x_mat):
|
|
3409
|
+
df_input = self._decode_ft_shap_matrix_to_df(x_mat)
|
|
3410
|
+
return self.xgb_best.predict(df_input)
|
|
3411
|
+
|
|
3412
|
+
self.shap_xgb = self._compute_shap_core(
|
|
3413
|
+
'xgb', n_background, n_samples, on_train,
|
|
3414
|
+
X_df=X_raw,
|
|
3415
|
+
prep_fn=lambda df: self._build_ft_shap_matrix(
|
|
3416
|
+
df).astype(np.float64),
|
|
3417
|
+
predict_fn=predict_wrapper
|
|
3418
|
+
)
|
|
3419
|
+
return self.shap_xgb
|
|
3420
|
+
|
|
3421
|
+
# ========= ResNet SHAP =========
|
|
3422
|
+
def _resn_predict_wrapper(self, X_np):
|
|
3423
|
+
model = self.resn_best.resnet.to("cpu")
|
|
3424
|
+
with torch.no_grad():
|
|
3425
|
+
X_tensor = torch.tensor(X_np, dtype=torch.float32)
|
|
3426
|
+
y_pred = model(X_tensor).cpu().numpy()
|
|
3427
|
+
y_pred = np.clip(y_pred, 1e-6, None)
|
|
3428
|
+
return y_pred.reshape(-1)
|
|
3429
|
+
|
|
3430
|
+
def compute_shap_resn(self, n_background: int = 500,
|
|
3431
|
+
n_samples: int = 200,
|
|
3432
|
+
on_train: bool = True):
|
|
3433
|
+
data = self.train_oht_scl_data if on_train else self.test_oht_scl_data
|
|
3434
|
+
X = data[self.var_nmes]
|
|
3435
|
+
|
|
3436
|
+
def cleanup():
|
|
3437
|
+
self.resn_best.device = torch.device("cpu")
|
|
3438
|
+
self.resn_best.resnet.to("cpu")
|
|
3439
|
+
if torch.cuda.is_available():
|
|
3440
|
+
torch.cuda.empty_cache()
|
|
3441
|
+
|
|
3442
|
+
self.shap_resn = self._compute_shap_core(
|
|
3443
|
+
'resn', n_background, n_samples, on_train,
|
|
3444
|
+
X_df=X,
|
|
3445
|
+
prep_fn=lambda df: df.to_numpy(dtype=np.float64),
|
|
3446
|
+
predict_fn=lambda x: self._resn_predict_wrapper(x),
|
|
3447
|
+
cleanup_fn=cleanup
|
|
3448
|
+
)
|
|
3449
|
+
return self.shap_resn
|
|
3450
|
+
|
|
3451
|
+
# ========= GNN SHAP =========
|
|
3452
|
+
def _gnn_predict_wrapper(self, X_np: np.ndarray) -> np.ndarray:
|
|
3453
|
+
X_df = pd.DataFrame(X_np, columns=self.var_nmes)
|
|
3454
|
+
y_pred = self.gnn_best.predict(X_df)
|
|
3455
|
+
return np.asarray(y_pred, dtype=np.float64).reshape(-1)
|
|
3456
|
+
|
|
3457
|
+
def compute_shap_gnn(self, n_background: int = 300,
|
|
3458
|
+
n_samples: int = 150,
|
|
3459
|
+
on_train: bool = True):
|
|
3460
|
+
data = self.train_oht_scl_data if on_train else self.test_oht_scl_data
|
|
3461
|
+
if data is None:
|
|
3462
|
+
raise RuntimeError("One-hot 标准化数据未准备好,无法计算 GNN SHAP。")
|
|
3463
|
+
X = data[self.var_nmes]
|
|
3464
|
+
|
|
3465
|
+
def cleanup():
|
|
3466
|
+
self.gnn_best.device = torch.device("cpu")
|
|
3467
|
+
self.gnn_best.gnn.to("cpu")
|
|
3468
|
+
if torch.cuda.is_available():
|
|
3469
|
+
torch.cuda.empty_cache()
|
|
3470
|
+
|
|
3471
|
+
self.shap_gnn = self._compute_shap_core(
|
|
3472
|
+
'gnn', n_background, n_samples, on_train,
|
|
3473
|
+
X_df=X,
|
|
3474
|
+
prep_fn=lambda df: df.to_numpy(dtype=np.float64),
|
|
3475
|
+
predict_fn=lambda x: self._gnn_predict_wrapper(x),
|
|
3476
|
+
cleanup_fn=cleanup
|
|
3477
|
+
)
|
|
3478
|
+
return self.shap_gnn
|
|
3479
|
+
|
|
3480
|
+
# ========= FT-Transformer SHAP =========
|
|
3481
|
+
def _ft_shap_predict_wrapper(self, X_mat: np.ndarray) -> np.ndarray:
|
|
3482
|
+
df_input = self._decode_ft_shap_matrix_to_df(X_mat)
|
|
3483
|
+
y_pred = self.ft_best.predict(df_input)
|
|
3484
|
+
return np.asarray(y_pred, dtype=np.float64).reshape(-1)
|
|
3485
|
+
|
|
3486
|
+
def compute_shap_ft(self, n_background: int = 500,
|
|
3487
|
+
n_samples: int = 200,
|
|
3488
|
+
on_train: bool = True):
|
|
3489
|
+
data = self.train_data if on_train else self.test_data
|
|
3490
|
+
X_raw = data[self.factor_nmes]
|
|
3491
|
+
|
|
3492
|
+
def cleanup():
|
|
3493
|
+
self.ft_best.device = torch.device("cpu")
|
|
3494
|
+
self.ft_best.ft.to("cpu")
|
|
3495
|
+
if torch.cuda.is_available():
|
|
3496
|
+
torch.cuda.empty_cache()
|
|
3497
|
+
|
|
3498
|
+
self.shap_ft = self._compute_shap_core(
|
|
3499
|
+
'ft', n_background, n_samples, on_train,
|
|
3500
|
+
X_df=X_raw,
|
|
3501
|
+
prep_fn=lambda df: self._build_ft_shap_matrix(
|
|
3502
|
+
df).astype(np.float64),
|
|
3503
|
+
predict_fn=self._ft_shap_predict_wrapper,
|
|
3504
|
+
cleanup_fn=cleanup
|
|
3505
|
+
)
|
|
3506
|
+
return self.shap_ft
|