ins-pricing 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. ins_pricing/README.md +60 -0
  2. ins_pricing/__init__.py +102 -0
  3. ins_pricing/governance/README.md +18 -0
  4. ins_pricing/governance/__init__.py +20 -0
  5. ins_pricing/governance/approval.py +93 -0
  6. ins_pricing/governance/audit.py +37 -0
  7. ins_pricing/governance/registry.py +99 -0
  8. ins_pricing/governance/release.py +159 -0
  9. ins_pricing/modelling/BayesOpt.py +146 -0
  10. ins_pricing/modelling/BayesOpt_USAGE.md +925 -0
  11. ins_pricing/modelling/BayesOpt_entry.py +575 -0
  12. ins_pricing/modelling/BayesOpt_incremental.py +731 -0
  13. ins_pricing/modelling/Explain_Run.py +36 -0
  14. ins_pricing/modelling/Explain_entry.py +539 -0
  15. ins_pricing/modelling/Pricing_Run.py +36 -0
  16. ins_pricing/modelling/README.md +33 -0
  17. ins_pricing/modelling/__init__.py +44 -0
  18. ins_pricing/modelling/bayesopt/__init__.py +98 -0
  19. ins_pricing/modelling/bayesopt/config_preprocess.py +303 -0
  20. ins_pricing/modelling/bayesopt/core.py +1476 -0
  21. ins_pricing/modelling/bayesopt/models.py +2196 -0
  22. ins_pricing/modelling/bayesopt/trainers.py +2446 -0
  23. ins_pricing/modelling/bayesopt/utils.py +1021 -0
  24. ins_pricing/modelling/cli_common.py +136 -0
  25. ins_pricing/modelling/explain/__init__.py +55 -0
  26. ins_pricing/modelling/explain/gradients.py +334 -0
  27. ins_pricing/modelling/explain/metrics.py +176 -0
  28. ins_pricing/modelling/explain/permutation.py +155 -0
  29. ins_pricing/modelling/explain/shap_utils.py +146 -0
  30. ins_pricing/modelling/notebook_utils.py +284 -0
  31. ins_pricing/modelling/plotting/__init__.py +45 -0
  32. ins_pricing/modelling/plotting/common.py +63 -0
  33. ins_pricing/modelling/plotting/curves.py +572 -0
  34. ins_pricing/modelling/plotting/diagnostics.py +139 -0
  35. ins_pricing/modelling/plotting/geo.py +362 -0
  36. ins_pricing/modelling/plotting/importance.py +121 -0
  37. ins_pricing/modelling/run_logging.py +133 -0
  38. ins_pricing/modelling/tests/conftest.py +8 -0
  39. ins_pricing/modelling/tests/test_cross_val_generic.py +66 -0
  40. ins_pricing/modelling/tests/test_distributed_utils.py +18 -0
  41. ins_pricing/modelling/tests/test_explain.py +56 -0
  42. ins_pricing/modelling/tests/test_geo_tokens_split.py +49 -0
  43. ins_pricing/modelling/tests/test_graph_cache.py +33 -0
  44. ins_pricing/modelling/tests/test_plotting.py +63 -0
  45. ins_pricing/modelling/tests/test_plotting_library.py +150 -0
  46. ins_pricing/modelling/tests/test_preprocessor.py +48 -0
  47. ins_pricing/modelling/watchdog_run.py +211 -0
  48. ins_pricing/pricing/README.md +44 -0
  49. ins_pricing/pricing/__init__.py +27 -0
  50. ins_pricing/pricing/calibration.py +39 -0
  51. ins_pricing/pricing/data_quality.py +117 -0
  52. ins_pricing/pricing/exposure.py +85 -0
  53. ins_pricing/pricing/factors.py +91 -0
  54. ins_pricing/pricing/monitoring.py +99 -0
  55. ins_pricing/pricing/rate_table.py +78 -0
  56. ins_pricing/production/__init__.py +21 -0
  57. ins_pricing/production/drift.py +30 -0
  58. ins_pricing/production/monitoring.py +143 -0
  59. ins_pricing/production/scoring.py +40 -0
  60. ins_pricing/reporting/README.md +20 -0
  61. ins_pricing/reporting/__init__.py +11 -0
  62. ins_pricing/reporting/report_builder.py +72 -0
  63. ins_pricing/reporting/scheduler.py +45 -0
  64. ins_pricing/setup.py +41 -0
  65. ins_pricing v2/__init__.py +23 -0
  66. ins_pricing v2/governance/__init__.py +20 -0
  67. ins_pricing v2/governance/approval.py +93 -0
  68. ins_pricing v2/governance/audit.py +37 -0
  69. ins_pricing v2/governance/registry.py +99 -0
  70. ins_pricing v2/governance/release.py +159 -0
  71. ins_pricing v2/modelling/Explain_Run.py +36 -0
  72. ins_pricing v2/modelling/Pricing_Run.py +36 -0
  73. ins_pricing v2/modelling/__init__.py +151 -0
  74. ins_pricing v2/modelling/cli_common.py +141 -0
  75. ins_pricing v2/modelling/config.py +249 -0
  76. ins_pricing v2/modelling/config_preprocess.py +254 -0
  77. ins_pricing v2/modelling/core.py +741 -0
  78. ins_pricing v2/modelling/data_container.py +42 -0
  79. ins_pricing v2/modelling/explain/__init__.py +55 -0
  80. ins_pricing v2/modelling/explain/gradients.py +334 -0
  81. ins_pricing v2/modelling/explain/metrics.py +176 -0
  82. ins_pricing v2/modelling/explain/permutation.py +155 -0
  83. ins_pricing v2/modelling/explain/shap_utils.py +146 -0
  84. ins_pricing v2/modelling/features.py +215 -0
  85. ins_pricing v2/modelling/model_manager.py +148 -0
  86. ins_pricing v2/modelling/model_plotting.py +463 -0
  87. ins_pricing v2/modelling/models.py +2203 -0
  88. ins_pricing v2/modelling/notebook_utils.py +294 -0
  89. ins_pricing v2/modelling/plotting/__init__.py +45 -0
  90. ins_pricing v2/modelling/plotting/common.py +63 -0
  91. ins_pricing v2/modelling/plotting/curves.py +572 -0
  92. ins_pricing v2/modelling/plotting/diagnostics.py +139 -0
  93. ins_pricing v2/modelling/plotting/geo.py +362 -0
  94. ins_pricing v2/modelling/plotting/importance.py +121 -0
  95. ins_pricing v2/modelling/run_logging.py +133 -0
  96. ins_pricing v2/modelling/tests/conftest.py +8 -0
  97. ins_pricing v2/modelling/tests/test_cross_val_generic.py +66 -0
  98. ins_pricing v2/modelling/tests/test_distributed_utils.py +18 -0
  99. ins_pricing v2/modelling/tests/test_explain.py +56 -0
  100. ins_pricing v2/modelling/tests/test_geo_tokens_split.py +49 -0
  101. ins_pricing v2/modelling/tests/test_graph_cache.py +33 -0
  102. ins_pricing v2/modelling/tests/test_plotting.py +63 -0
  103. ins_pricing v2/modelling/tests/test_plotting_library.py +150 -0
  104. ins_pricing v2/modelling/tests/test_preprocessor.py +48 -0
  105. ins_pricing v2/modelling/trainers.py +2447 -0
  106. ins_pricing v2/modelling/utils.py +1020 -0
  107. ins_pricing v2/modelling/watchdog_run.py +211 -0
  108. ins_pricing v2/pricing/__init__.py +27 -0
  109. ins_pricing v2/pricing/calibration.py +39 -0
  110. ins_pricing v2/pricing/data_quality.py +117 -0
  111. ins_pricing v2/pricing/exposure.py +85 -0
  112. ins_pricing v2/pricing/factors.py +91 -0
  113. ins_pricing v2/pricing/monitoring.py +99 -0
  114. ins_pricing v2/pricing/rate_table.py +78 -0
  115. ins_pricing v2/production/__init__.py +21 -0
  116. ins_pricing v2/production/drift.py +30 -0
  117. ins_pricing v2/production/monitoring.py +143 -0
  118. ins_pricing v2/production/scoring.py +40 -0
  119. ins_pricing v2/reporting/__init__.py +11 -0
  120. ins_pricing v2/reporting/report_builder.py +72 -0
  121. ins_pricing v2/reporting/scheduler.py +45 -0
  122. ins_pricing v2/scripts/BayesOpt_incremental.py +722 -0
  123. ins_pricing v2/scripts/Explain_entry.py +545 -0
  124. ins_pricing v2/scripts/__init__.py +1 -0
  125. ins_pricing v2/scripts/train.py +568 -0
  126. ins_pricing v2/setup.py +55 -0
  127. ins_pricing v2/smoke_test.py +28 -0
  128. ins_pricing-0.1.6.dist-info/METADATA +78 -0
  129. ins_pricing-0.1.6.dist-info/RECORD +169 -0
  130. ins_pricing-0.1.6.dist-info/WHEEL +5 -0
  131. ins_pricing-0.1.6.dist-info/top_level.txt +4 -0
  132. user_packages/__init__.py +105 -0
  133. user_packages legacy/BayesOpt.py +5659 -0
  134. user_packages legacy/BayesOpt_entry.py +513 -0
  135. user_packages legacy/BayesOpt_incremental.py +685 -0
  136. user_packages legacy/Pricing_Run.py +36 -0
  137. user_packages legacy/Try/BayesOpt Legacy251213.py +3719 -0
  138. user_packages legacy/Try/BayesOpt Legacy251215.py +3758 -0
  139. user_packages legacy/Try/BayesOpt lagecy251201.py +3506 -0
  140. user_packages legacy/Try/BayesOpt lagecy251218.py +3992 -0
  141. user_packages legacy/Try/BayesOpt legacy.py +3280 -0
  142. user_packages legacy/Try/BayesOpt.py +838 -0
  143. user_packages legacy/Try/BayesOptAll.py +1569 -0
  144. user_packages legacy/Try/BayesOptAllPlatform.py +909 -0
  145. user_packages legacy/Try/BayesOptCPUGPU.py +1877 -0
  146. user_packages legacy/Try/BayesOptSearch.py +830 -0
  147. user_packages legacy/Try/BayesOptSearchOrigin.py +829 -0
  148. user_packages legacy/Try/BayesOptV1.py +1911 -0
  149. user_packages legacy/Try/BayesOptV10.py +2973 -0
  150. user_packages legacy/Try/BayesOptV11.py +3001 -0
  151. user_packages legacy/Try/BayesOptV12.py +3001 -0
  152. user_packages legacy/Try/BayesOptV2.py +2065 -0
  153. user_packages legacy/Try/BayesOptV3.py +2209 -0
  154. user_packages legacy/Try/BayesOptV4.py +2342 -0
  155. user_packages legacy/Try/BayesOptV5.py +2372 -0
  156. user_packages legacy/Try/BayesOptV6.py +2759 -0
  157. user_packages legacy/Try/BayesOptV7.py +2832 -0
  158. user_packages legacy/Try/BayesOptV8Codex.py +2731 -0
  159. user_packages legacy/Try/BayesOptV8Gemini.py +2614 -0
  160. user_packages legacy/Try/BayesOptV9.py +2927 -0
  161. user_packages legacy/Try/BayesOpt_entry legacy.py +313 -0
  162. user_packages legacy/Try/ModelBayesOptSearch.py +359 -0
  163. user_packages legacy/Try/ResNetBayesOptSearch.py +249 -0
  164. user_packages legacy/Try/XgbBayesOptSearch.py +121 -0
  165. user_packages legacy/Try/xgbbayesopt.py +523 -0
  166. user_packages legacy/__init__.py +19 -0
  167. user_packages legacy/cli_common.py +124 -0
  168. user_packages legacy/notebook_utils.py +228 -0
  169. user_packages legacy/watchdog_run.py +202 -0
@@ -0,0 +1,829 @@
1
+ from random import sample
2
+ from turtle import st
3
+ import numpy as np # 1.26.2
4
+ import pandas as pd # 2.2.3
5
+ import torch # 1.10.1+cu111
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import optuna # 4.3.0
9
+ import xgboost as xgb # 1.7.0
10
+ import matplotlib.pyplot as plt
11
+ import os
12
+ import joblib
13
+ import torch.utils.checkpoint as cp
14
+
15
+ from torch.utils.data import DataLoader, TensorDataset
16
+ from torch.cuda.amp import autocast, GradScaler
17
+ from torch.nn.utils import clip_grad_norm_
18
+ from sklearn.model_selection import KFold, ShuffleSplit, cross_val_score # 1.2.2
19
+ from sklearn.preprocessing import StandardScaler
20
+ from sklearn.metrics import make_scorer, mean_tweedie_deviance
21
+
22
+ # 定义torch下tweedie deviance损失函数
23
+ # 参考:https://scikit-learn.org/stable/modules/model_evaluation.html#mean-poisson-gamma-and-tweedie-deviances
24
+
25
+
26
+
27
+ def tweedie_loss(pred, target, p=1.5):
28
+ # Ensure predictions are positive for stability
29
+ eps = 1e-6
30
+ pred_clamped = torch.clamp(pred, min=eps)
31
+ # Compute Tweedie deviance components
32
+ if p == 1:
33
+ # Poisson case
34
+ term1 = target * torch.log(target / pred_clamped + eps)
35
+ term2 = -target + pred_clamped
36
+ term3 = 0
37
+ elif p == 0:
38
+ # Gaussian case
39
+ term1 = 0.5 * torch.pow(target - pred_clamped, 2)
40
+ term2 = 0
41
+ term3 = 0
42
+ elif p == 2:
43
+ # Gamma case
44
+ term1 = torch.log(pred_clamped / target + eps)
45
+ term2 = -target / pred_clamped + 1
46
+ term3 = 0
47
+ else:
48
+ term1 = torch.pow(target, 2 - p) / ((1 - p) * (2 - p))
49
+ term2 = target * torch.pow(pred_clamped, 1 - p) / (1 - p)
50
+ term3 = torch.pow(pred_clamped, 2 - p) / (2 - p)
51
+ # Tweedie negative log-likelihood (up to a constant)
52
+ return 2 * (term1 - term2 + term3)
53
+
54
+ # 定义分箱函数
55
+
56
+
57
+ def split_data(data, col_nme, wgt_nme, n_bins=10):
58
+ data.sort_values(by=col_nme, ascending=True, inplace=True)
59
+ data['cum_weight'] = data[wgt_nme].cumsum()
60
+ w_sum = data[wgt_nme].sum()
61
+ data.loc[:, 'bins'] = np.floor(data['cum_weight'] * float(n_bins) / w_sum)
62
+ data.loc[(data['bins'] == n_bins), 'bins'] = n_bins - 1
63
+ return data.groupby(['bins'], observed=True).sum(numeric_only=True)
64
+
65
+ # 定义Lift Chart绘制函数
66
+
67
+
68
+ def plot_lift_list(pred_model, w_pred_list, w_act_list,
69
+ weight_list, tgt_nme, n_bins=10,
70
+ fig_nme='Lift Chart'):
71
+ lift_data = pd.DataFrame()
72
+ lift_data.loc[:, 'pred'] = pred_model
73
+ lift_data.loc[:, 'w_pred'] = w_pred_list
74
+ lift_data.loc[:, 'act'] = w_act_list
75
+ lift_data.loc[:, 'weight'] = weight_list
76
+ plot_data = split_data(lift_data, 'pred', 'weight', n_bins)
77
+ plot_data['exp_v'] = plot_data['w_pred'] / plot_data['weight']
78
+ plot_data['act_v'] = plot_data['act'] / plot_data['weight']
79
+ plot_data.reset_index(inplace=True)
80
+ fig = plt.figure(figsize=(7, 5))
81
+ ax = fig.add_subplot(111)
82
+ ax.plot(plot_data.index, plot_data['act_v'],
83
+ label='Actual', color='red')
84
+ ax.plot(plot_data.index, plot_data['exp_v'],
85
+ label='Predicted', color='blue')
86
+ ax.set_title(
87
+ 'Lift Chart of %s' % tgt_nme, fontsize=8)
88
+ plt.xticks(plot_data.index,
89
+ plot_data.index,
90
+ rotation=90, fontsize=6)
91
+ plt.yticks(fontsize=6)
92
+ plt.legend(loc='upper left',
93
+ fontsize=5, frameon=False)
94
+ plt.margins(0.05)
95
+ ax2 = ax.twinx()
96
+ ax2.bar(plot_data.index, plot_data['weight'],
97
+ alpha=0.5, color='seagreen',
98
+ label='Earned Exposure')
99
+ plt.yticks(fontsize=6)
100
+ plt.legend(loc='upper right',
101
+ fontsize=5, frameon=False)
102
+ plt.subplots_adjust(wspace=0.3)
103
+ save_path = os.path.join(
104
+ os.getcwd(), 'plot', f'05_{tgt_nme}_{fig_nme}.png')
105
+ plt.savefig(save_path, dpi=300)
106
+ plt.close(fig)
107
+
108
+ # 定义Double Lift Chart绘制函数
109
+
110
+
111
+ def plot_dlift_list(pred_model_1, pred_model_2,
112
+ model_nme_1, model_nme_2,
113
+ tgt_nme,
114
+ w_list, w_act_list, n_bins=10,
115
+ fig_nme='Double Lift Chart'):
116
+ lift_data = pd.DataFrame()
117
+ lift_data.loc[:, 'pred1'] = pred_model_1
118
+ lift_data.loc[:, 'pred2'] = pred_model_2
119
+ lift_data.loc[:, 'diff_ly'] = lift_data['pred1'] / lift_data['pred2']
120
+ lift_data.loc[:, 'act'] = w_act_list
121
+ lift_data.loc[:, 'weight'] = w_list
122
+ lift_data.loc[:, 'w_pred1'] = lift_data['pred1'] * lift_data['weight']
123
+ lift_data.loc[:, 'w_pred2'] = lift_data['pred2'] * lift_data['weight']
124
+ plot_data = split_data(lift_data, 'diff_ly', 'weight', n_bins)
125
+ plot_data['exp_v1'] = plot_data['w_pred1'] / plot_data['act']
126
+ plot_data['exp_v2'] = plot_data['w_pred2'] / plot_data['act']
127
+ plot_data['act_v'] = plot_data['act']/plot_data['act']
128
+ plot_data.reset_index(inplace=True)
129
+ fig = plt.figure(figsize=(7, 5))
130
+ ax = fig.add_subplot(111)
131
+ ax.plot(plot_data.index, plot_data['act_v'],
132
+ label='Actual', color='red')
133
+ ax.plot(plot_data.index, plot_data['exp_v1'],
134
+ label=model_nme_1, color='blue')
135
+ ax.plot(plot_data.index, plot_data['exp_v2'],
136
+ label=model_nme_2, color='black')
137
+ ax.set_title(
138
+ 'Double Lift Chart of %s' % tgt_nme, fontsize=8)
139
+ plt.xticks(plot_data.index,
140
+ plot_data.index,
141
+ rotation=90, fontsize=6)
142
+ plt.xlabel('%s / %s' % (model_nme_1, model_nme_2), fontsize=6)
143
+ plt.yticks(fontsize=6)
144
+ plt.legend(loc='upper left',
145
+ fontsize=5, frameon=False)
146
+ plt.margins(0.1)
147
+ plt.subplots_adjust(bottom=0.25, top=0.95, right=0.8)
148
+ ax2 = ax.twinx()
149
+ ax2.bar(plot_data.index, plot_data['weight'],
150
+ alpha=0.5, color='seagreen',
151
+ label='Earned Exposure')
152
+ plt.yticks(fontsize=6)
153
+ plt.legend(loc='upper right',
154
+ fontsize=5, frameon=False)
155
+ plt.subplots_adjust(wspace=0.3)
156
+ save_path = os.path.join(
157
+ os.getcwd(), 'plot', f'06_{tgt_nme}_{fig_nme}.png')
158
+ plt.savefig(save_path, dpi=300)
159
+ plt.close(fig)
160
+
161
+ # 残差块:两层线性 + ReLU + 残差连接
162
+ # ResBlock 继承 nn.Module
163
+
164
+
165
+ class ResBlock(nn.Module):
166
+ def __init__(self, dim):
167
+ super(ResBlock, self).__init__()
168
+ self.block = nn.Sequential(
169
+ nn.Linear(dim, dim),
170
+ nn.BatchNorm1d(dim)
171
+ )
172
+
173
+ def forward(self, x):
174
+ # 原始输入 + 两层变换,然后再过 ReLU
175
+ # 用 checkpoint来节省内存
176
+ out = cp.checkpoint(self.block, x)
177
+ # return F.relu(self.block(x) + x)
178
+ return F.relu(out + x)
179
+
180
+ # ResNetSequential 继承 nn.Module,定义整个网络结构
181
+
182
+
183
+ class ResNetSequential(nn.Module):
184
+ # 整个网络用 nn.Sequential 串联:输入 -> ResBlock*block_num -> 输出
185
+ def __init__(self, input_dim, hidden_dim=64, block_num=2):
186
+ super(ResNetSequential, self).__init__()
187
+ self.net = nn.Sequential()
188
+ self.net.add_module('fc1', nn.Linear(input_dim, hidden_dim))
189
+ self.net.add_module('bn1', nn.BatchNorm1d(hidden_dim))
190
+ self.net.add_module('ReLU1', nn.ReLU())
191
+ for i in range(block_num):
192
+ self.net.add_module('ResBlk_'+str(i+1), ResBlock(hidden_dim))
193
+ self.net.add_module('fc2', nn.Linear(hidden_dim, 1))
194
+ self.net.add_module('softplus', nn.Softplus())
195
+
196
+ def forward(self, x):
197
+ return self.net(x)
198
+
199
+ # 贝叶斯优化类,使用高斯过程进行超参数优化
200
+
201
+
202
+ class ResNetScikitLearn(nn.Module):
203
+ def __init__(self, model_nme, input_dim, hidden_dim=64,
204
+ block_num=2, batch_num=100, epochs=100,
205
+ tweedie_power=1.5, learning_rate=0.01,
206
+ patience=10, accumulation_steps=2):
207
+ super(ResNetScikitLearn, self).__init__()
208
+ self.input_dim = input_dim
209
+ self.hidden_dim = hidden_dim
210
+ self.block_num = block_num
211
+ if torch.cuda.is_available():
212
+ self.device = torch.device('cuda')
213
+ elif torch.backends.mps.is_available():
214
+ self.device = torch.device('mps')
215
+ else:
216
+ self.device = torch.device('cpu')
217
+ self.resnet = ResNetSequential(
218
+ self.input_dim,
219
+ self.hidden_dim,
220
+ self.block_num
221
+ ).to(self.device)
222
+ if torch.cuda.device_count() > 1:
223
+ self.resnet = nn.DataParallel(
224
+ self.resnet,
225
+ device_ids=list(range(torch.cuda.device_count()))
226
+ )
227
+ self.batch_num = batch_num
228
+ self.epochs = epochs
229
+ self.model_nme = model_nme
230
+ if self.model_nme.find('f') != -1:
231
+ self.tw_power = 1
232
+ elif self.model_nme.find('s') != -1:
233
+ self.tw_power = 2
234
+ else:
235
+ self.tw_power = tweedie_power
236
+ self.learning_rate = learning_rate
237
+ self.patience = patience # Early stopping patience
238
+ self.accumulation_steps = accumulation_steps # Gradient accumulation steps
239
+
240
+ def fit(self, X_train, y_train, w_train=None, X_val=None, y_val=None, w_val=None):
241
+ # 将数据转换为 PyTorch 张量
242
+ X_tensor = torch.tensor(
243
+ X_train.values, dtype=torch.float32).to(self.device)
244
+ y_tensor = torch.tensor(
245
+ y_train.values, dtype=torch.float32).view(-1, 1).to(self.device)
246
+ w_tensor = torch.tensor(
247
+ w_train.values, dtype=torch.float32).view(-1, 1).to(self.device) if w_train is not None else torch.ones_like(y_tensor)
248
+
249
+ # 验证集张量
250
+ if X_val is not None:
251
+ X_val_tensor = torch.tensor(
252
+ X_val.values, dtype=torch.float32).to(self.device)
253
+ y_val_tensor = torch.tensor(
254
+ y_val.values, dtype=torch.float32).view(-1, 1).to(self.device)
255
+ w_val_tensor = torch.tensor(
256
+ w_val.values, dtype=torch.float32).view(-1, 1).to(self.device) if w_val is not None else torch.ones_like(y_val_tensor)
257
+
258
+ # 创建数据集和数据加载器
259
+ dataset = TensorDataset(
260
+ X_tensor, y_tensor, w_tensor
261
+ )
262
+ dataloader = DataLoader(
263
+ dataset,
264
+ batch_size=max(1, int((self.learning_rate/(1e-4))**0.5 *
265
+ (X_train.shape[0]/self.batch_num))),
266
+ shuffle=True
267
+ # num_workers=4
268
+ # pin_memory=(self.device.type == 'cuda')
269
+ )
270
+ # 定义损失函数和优化器
271
+ optimizer = torch.optim.Adam(
272
+ self.resnet.parameters(), lr=self.learning_rate)
273
+ scaler = GradScaler(enabled=(self.device.type == 'cuda'))
274
+
275
+ # Early stopping 参数
276
+ best_loss, patience_counter = float('inf'), 0
277
+ best_model_state = None # Initialize best_model_state
278
+
279
+ # 训练模型
280
+ for epoch in range(1, self.epochs + 1):
281
+ self.resnet.train()
282
+ for X_batch, y_batch, w_batch in dataloader:
283
+ optimizer.zero_grad()
284
+ # 如果运行设备为 CUDA,则启用混合精度。
285
+ with autocast(enabled=(self.device.type == 'cuda')):
286
+ X_batch, y_batch, w_batch = X_batch.to(self.device), y_batch.to(
287
+ self.device), w_batch.to(self.device)
288
+ y_pred = self.resnet(X_batch)
289
+ y_pred = torch.clamp(y_pred, min=1e-6)
290
+ losses = tweedie_loss(
291
+ y_pred, y_batch, p=self.tw_power).view(-1)
292
+ weighted_loss = (losses * w_batch.view(-1)
293
+ ).sum() / w_batch.sum()
294
+ scaler.scale(weighted_loss).backward()
295
+ # gradient clipping
296
+ # 如进行gradient clipping,需要在反向传播之前取消缩放
297
+ if self.device.type == 'cuda':
298
+ scaler.unscale_(optimizer)
299
+ clip_grad_norm_(
300
+ self.resnet.parameters(),
301
+ max_norm=1.0
302
+ )
303
+ scaler.step(optimizer)
304
+ scaler.update()
305
+ optimizer.zero_grad()
306
+
307
+ # 验证集损失计算
308
+ if X_val is not None and y_val is not None:
309
+ self.resnet.eval()
310
+ with torch.no_grad(), autocast(enabled=(self.device.type == 'cuda')):
311
+ y_val_pred = self.resnet(X_val_tensor)
312
+ val_loss_values = tweedie_loss(
313
+ y_val_pred, y_val_tensor, p=self.tw_power).view(-1)
314
+ val_weighted_loss = (
315
+ val_loss_values * w_val_tensor.view(-1)).sum() / w_val_tensor.sum()
316
+
317
+ # Early stopping 检查
318
+ if val_weighted_loss < best_loss:
319
+ best_loss, patience_counter = val_weighted_loss, 0
320
+ # 保存当前最佳模型
321
+ best_model_state = self.resnet.state_dict()
322
+ else:
323
+ patience_counter += 1
324
+ if patience_counter >= self.patience:
325
+ self.resnet.load_state_dict(best_model_state) # 恢复最佳模型
326
+ break
327
+
328
+ def predict(self, X_test):
329
+ self.resnet.eval()
330
+ with torch.no_grad():
331
+ X_tensor = torch.tensor(
332
+ X_test.values, dtype=torch.float32).to(self.device)
333
+ y_pred = self.resnet(X_tensor).cpu().numpy()
334
+ y_pred = np.clip(y_pred, 1e-6, None)
335
+ return y_pred.flatten()
336
+
337
+ def set_params(self, params):
338
+ # 设置模型参数
339
+ for key, value in params.items():
340
+ if hasattr(self, key):
341
+ setattr(self, key, value)
342
+ else:
343
+ raise ValueError(f"Parameter {key} not found in model.")
344
+
345
+ # 定义贝叶斯优化模型类,包含XGBoost和ResNet模型
346
+
347
+
348
+ class BayesOptModel:
349
+ def __init__(self, train_data, test_data,
350
+ model_nme, resp_nme, weight_nme, factor_nmes,
351
+ cate_list=[], prop_test=0.25, rand_seed=None, epochs=100):
352
+ # 初始化数据
353
+ # train_data: 训练数据, test_data: 测试数据 格式需为DataFrame
354
+ # model_nme: 模型名称
355
+ # resp_nme: 因变量名称, weight_nme: 权重名称
356
+ # factor_nmes: 因子名称列表, space_params: 参数空间
357
+ # cate_list: 类别变量列表
358
+ # prop_test: 测试集比例, rand_seed
359
+ self.train_data = train_data
360
+ self.test_data = test_data
361
+ self.resp_nme = resp_nme
362
+ self.weight_nme = weight_nme
363
+ self.train_data.loc[:, 'w_act'] = self.train_data[self.resp_nme] * \
364
+ self.train_data[self.weight_nme]
365
+ self.test_data.loc[:, 'w_act'] = self.test_data[self.resp_nme] * \
366
+ self.test_data[self.weight_nme]
367
+ self.factor_nmes = factor_nmes
368
+ self.cate_list = cate_list
369
+ self.rand_seed = rand_seed if rand_seed is not None else np.random.randint(
370
+ 1, 10000)
371
+ if self.cate_list != []:
372
+ for cate in self.cate_list:
373
+ self.train_data[cate] = self.train_data[cate].astype(
374
+ 'category')
375
+ self.test_data[cate] = self.test_data[cate].astype('category')
376
+ self.prop_test = prop_test
377
+ self.cv = ShuffleSplit(n_splits=int(1/self.prop_test),
378
+ test_size=self.prop_test,
379
+ random_state=self.rand_seed)
380
+ self.model_nme = model_nme
381
+ if self.model_nme.find('f') != -1:
382
+ self.obj = 'count:poisson'
383
+ elif self.model_nme.find('s') != -1:
384
+ self.obj = 'reg:gamma'
385
+ elif self.model_nme.find('bc') != -1:
386
+ self.obj = 'reg:tweedie'
387
+ self.fit_params = {
388
+ 'sample_weight': self.train_data[self.weight_nme].values
389
+ }
390
+ self.num_features = [
391
+ nme for nme in self.factor_nmes if nme not in self.cate_list]
392
+ self.train_oht_scl_data = self.train_data[self.factor_nmes +
393
+ [self.weight_nme]+[self.resp_nme]].copy()
394
+ self.test_oht_scl_data = self.test_data[self.factor_nmes +
395
+ [self.weight_nme]+[self.resp_nme]].copy()
396
+ self.train_oht_scl_data = pd.get_dummies(
397
+ self.train_oht_scl_data,
398
+ columns=self.cate_list,
399
+ drop_first=True,
400
+ dtype=np.int8
401
+ )
402
+ self.test_oht_scl_data = pd.get_dummies(
403
+ self.test_oht_scl_data,
404
+ columns=self.cate_list,
405
+ drop_first=True,
406
+ dtype=np.int8
407
+ )
408
+ for num_chr in self.num_features:
409
+ scaler = StandardScaler()
410
+ self.train_oht_scl_data[num_chr] = scaler.fit_transform(
411
+ self.train_oht_scl_data[num_chr].values.reshape(-1, 1))
412
+ self.test_oht_scl_data[num_chr] = scaler.transform(
413
+ self.test_oht_scl_data[num_chr].values.reshape(-1, 1))
414
+ # 对测试集进行列对齐
415
+ self.test_oht_scl_data = self.test_oht_scl_data.reindex(
416
+ columns=self.train_oht_scl_data.columns,
417
+ fill_value=0
418
+ )
419
+ self.var_nmes = list(
420
+ set(list(self.train_oht_scl_data.columns)) -
421
+ set([self.weight_nme, self.resp_nme])
422
+ )
423
+ self.epochs = epochs
424
+ self.model_label = []
425
+
426
+ # 定义单因素画图函数
427
+ def plot_oneway(self, n_bins=10):
428
+ for c in self.factor_nmes:
429
+ fig = plt.figure(figsize=(7, 5))
430
+ if c in self.cate_list:
431
+ strs = c
432
+ else:
433
+ strs = c+'_bins'
434
+ self.train_data.loc[:, strs] = pd.qcut(
435
+ self.train_data[c],
436
+ n_bins,
437
+ duplicates='drop'
438
+ )
439
+ plot_data = self.train_data.groupby(
440
+ [strs], observed=True).sum(numeric_only=True)
441
+ plot_data.reset_index(inplace=True)
442
+ plot_data['act_v'] = plot_data['w_act'] / \
443
+ plot_data[self.weight_nme]
444
+ plot_data.head()
445
+ ax = fig.add_subplot(111)
446
+ ax.plot(plot_data.index, plot_data['act_v'],
447
+ label='Actual', color='red')
448
+ ax.set_title(
449
+ 'Analysis of %s : Train Data' % strs,
450
+ fontsize=8)
451
+ plt.xticks(plot_data.index,
452
+ list(plot_data[strs].astype(str)),
453
+ rotation=90)
454
+ if len(list(plot_data[strs].astype(str))) > 50:
455
+ plt.xticks(fontsize=3)
456
+ else:
457
+ plt.xticks(fontsize=6)
458
+ plt.yticks(fontsize=6)
459
+ ax2 = ax.twinx()
460
+ ax2.bar(plot_data.index,
461
+ plot_data[self.weight_nme],
462
+ alpha=0.5, color='seagreen')
463
+ plt.yticks(fontsize=6)
464
+ plt.margins(0.05)
465
+ plt.subplots_adjust(wspace=0.3)
466
+ save_path = os.path.join(
467
+ os.getcwd(), 'plot',
468
+ f'00_{self.model_nme}_{strs}_oneway.png')
469
+ plt.savefig(save_path, dpi=300)
470
+ plt.close(fig)
471
+
472
+ # Xgboost交叉验证函数
473
+ def cross_val_xgb(self, trial):
474
+ learning_rate = trial.suggest_float(
475
+ 'learning_rate', 1e-5, 1e-1, log=True)
476
+ gamma = trial.suggest_float(
477
+ 'gamma', 0, 10000)
478
+ max_depth = trial.suggest_int(
479
+ 'max_depth', 3, 25)
480
+ n_estimators = trial.suggest_int(
481
+ 'n_estimators', 10, 500, step=10)
482
+ min_child_weight = trial.suggest_int(
483
+ 'min_child_weight', 100, 10000, step=100)
484
+ reg_alpha = trial.suggest_float(
485
+ 'reg_alpha', 1e-10, 1, log=True)
486
+ reg_lambda = trial.suggest_float(
487
+ 'reg_lambda', 1e-10, 1, log=True)
488
+ if self.obj == 'reg:tweedie':
489
+ tweedie_variance_power = trial.suggest_float(
490
+ 'tweedie_variance_power', 1, 2)
491
+ elif self.obj == 'count:poisson':
492
+ tweedie_variance_power = 1
493
+ elif self.obj == 'reg:gamma':
494
+ tweedie_variance_power = 2
495
+ clf = xgb.XGBRegressor(
496
+ objective=self.obj,
497
+ random_state=self.rand_seed,
498
+ subsample=0.9,
499
+ tree_method='gpu_hist',
500
+ gpu_id=0,
501
+ enable_categorical=True,
502
+ predictor='gpu_predictor'
503
+ )
504
+ params = {
505
+ 'learning_rate': learning_rate,
506
+ 'gamma': gamma,
507
+ 'max_depth': max_depth,
508
+ 'n_estimators': n_estimators,
509
+ 'min_child_weight': min_child_weight,
510
+ 'reg_alpha': reg_alpha,
511
+ 'reg_lambda': reg_lambda
512
+ }
513
+ if self.obj == 'reg:tweedie':
514
+ params['tweedie_variance_power'] = tweedie_variance_power
515
+ clf.set_params(**params)
516
+ acc = cross_val_score(
517
+ clf,
518
+ self.train_data[self.factor_nmes],
519
+ self.train_data[self.resp_nme].values,
520
+ fit_params=self.fit_params,
521
+ cv=self.cv,
522
+ scoring=make_scorer(
523
+ mean_tweedie_deviance,
524
+ power=tweedie_variance_power,
525
+ greater_is_better=False),
526
+ error_score='raise',
527
+ n_jobs=int(1/self.prop_test)).mean()
528
+ return -acc
529
+
530
+ # 定义Xgboost贝叶斯优化函数
531
+ def bayesopt_xgb(self, max_evals=100):
532
+ study = optuna.create_study(
533
+ direction='minimize',
534
+ sampler=optuna.samplers.TPESampler(seed=self.rand_seed))
535
+ study.optimize(self.cross_val_xgb, n_trials=max_evals)
536
+ self.best_xgb_params = study.best_params
537
+ pd.DataFrame(self.best_xgb_params, index=[0]).to_csv(
538
+ os.getcwd() + '/Results/' + self.model_nme + '_bestparams_xgb.csv')
539
+ self.best_xgb_trial = study.best_trial
540
+ self.xgb_best = xgb.XGBRegressor(
541
+ objective=self.obj,
542
+ random_state=self.rand_seed,
543
+ subsample=0.9,
544
+ tree_method='gpu_hist' if torch.cuda.is_available() else 'hist',
545
+ gpu_id=0,
546
+ enable_categorical=True,
547
+ predictor='gpu_predictor'
548
+ )
549
+ self.xgb_best.set_params(**self.best_xgb_params)
550
+ self.xgb_best.fit(self.train_data[self.factor_nmes],
551
+ self.train_data[self.resp_nme].values,
552
+ **self.fit_params)
553
+ self.model_label += ['Xgboost']
554
+ self.train_data['pred_xgb'] = self.xgb_best.predict(
555
+ self.train_data[self.factor_nmes])
556
+ self.test_data['pred_xgb'] = self.xgb_best.predict(
557
+ self.test_data[self.factor_nmes])
558
+ self.train_data.loc[:, 'w_pred_xgb'] = self.train_data['pred_xgb'] * \
559
+ self.train_data[self.weight_nme]
560
+ self.test_data.loc[:, 'w_pred_xgb'] = self.test_data['pred_xgb'] * \
561
+ self.test_data[self.weight_nme]
562
+
563
+ # ResNet交叉验证函数
564
+ def cross_val_resn(self, trial):
565
+
566
+ learning_rate = trial.suggest_float(
567
+ 'learning_rate', 1e-6, 1e-2, log=True)
568
+ hidden_dim = trial.suggest_int(
569
+ 'hidden_dim', 32, 256, step=16)
570
+ block_num = trial.suggest_int(
571
+ 'block_num', 3, 10)
572
+ batch_num = trial.suggest_int(
573
+ 'batch_num',
574
+ 10 if self.obj == 'reg:gamma' else 100,
575
+ 100 if self.obj == 'reg:gamma' else 1000,
576
+ step=10)
577
+ if self.obj == 'reg:tweedie':
578
+ tw_power = trial.suggest_flaot(
579
+ 'tw_power', 0, 2.0)
580
+ elif self.obj == 'count:poisson':
581
+ tw_power = 1
582
+ elif self.obj == 'reg:gamma':
583
+ tw_power = 2
584
+ fold_num = int(1/self.prop_test)
585
+ kf = KFold(n_splits=fold_num, shuffle=True,
586
+ random_state=self.rand_seed)
587
+ loss = 0
588
+ for fold, (train_idx, test_idx) in enumerate(kf.split(self.train_oht_scl_data[self.var_nmes])):
589
+ # 创建模型
590
+ cv_net = ResNetScikitLearn(
591
+ model_nme=self.model_nme,
592
+ input_dim=self.train_oht_scl_data[self.var_nmes].shape[1],
593
+ epochs=self.epochs,
594
+ learning_rate=learning_rate,
595
+ hidden_dim=hidden_dim,
596
+ block_num=block_num,
597
+ # 保证权重方差不变
598
+ batch_num=batch_num,
599
+ tweedie_power=tw_power if self.obj == 'reg:tweedie' and tw_power != 1 else tw_power+1e-6
600
+ )
601
+ # 训练模型
602
+ cv_net.fit(
603
+ self.train_oht_scl_data[self.var_nmes].iloc[train_idx],
604
+ self.train_oht_scl_data[self.resp_nme].iloc[train_idx],
605
+ self.train_oht_scl_data[self.weight_nme].iloc[train_idx],
606
+ self.train_oht_scl_data[self.var_nmes].iloc[test_idx],
607
+ self.train_oht_scl_data[self.resp_nme].iloc[test_idx],
608
+ self.train_oht_scl_data[self.weight_nme].iloc[test_idx]
609
+ )
610
+ # 预测
611
+ y_pred_fold = cv_net.predict(
612
+ self.train_oht_scl_data[self.var_nmes].iloc[test_idx]
613
+ )
614
+ # 计算损失
615
+ loss += mean_tweedie_deviance(
616
+ self.train_oht_scl_data[self.resp_nme].iloc[test_idx],
617
+ y_pred_fold,
618
+ sample_weight=self.train_oht_scl_data[self.weight_nme].iloc[test_idx],
619
+ power=tw_power
620
+ )
621
+ return loss / fold_num
622
+
623
+ # 定义ResNet贝叶斯优化函数
624
+ def bayesopt_resnet(self, max_evals=100):
625
+ study = optuna.create_study(
626
+ direction='minimize',
627
+ sampler=optuna.samplers.TPESampler(seed=self.rand_seed))
628
+ study.optimize(self.cross_val_resn, n_trials=max_evals)
629
+ self.best_resn_params = study.best_params
630
+ pd.DataFrame(self.best_resn_params, index=[0]).to_csv(
631
+ os.getcwd() + '/Results/' + self.model_nme + '_bestparams_resn.csv')
632
+ self.best_resn_trial = study.best_trial
633
+ self.resn_best = ResNetScikitLearn(
634
+ model_nme=self.model_nme,
635
+ input_dim=self.train_oht_scl_data[self.var_nmes].shape[1]
636
+ )
637
+ self.resn_best.set_params(self.best_resn_params)
638
+ self.resn_best.fit(self.train_oht_scl_data[self.var_nmes],
639
+ self.train_oht_scl_data[self.resp_nme],
640
+ self.train_oht_scl_data[self.weight_nme])
641
+ self.model_label += ['ResNet']
642
+ self.train_data['pred_resn'] = self.resn_best.predict(
643
+ self.train_oht_scl_data[self.var_nmes])
644
+ self.test_data['pred_resn'] = self.resn_best.predict(
645
+ self.test_oht_scl_data[self.var_nmes])
646
+ self.train_data.loc[:, 'w_pred_resn'] = self.train_data['pred_resn'] * \
647
+ self.train_data[self.weight_nme]
648
+ self.test_data.loc[:, 'w_pred_resn'] = self.test_data['pred_resn'] * \
649
+ self.test_data[self.weight_nme]
650
+
651
+ # 定义分箱函数
652
+ def _split_data(self, data, col_nme, wgt_nme, n_bins=10):
653
+ data.sort_values(by=col_nme, ascending=True, inplace=True)
654
+ data['cum_weight'] = data[wgt_nme].cumsum()
655
+ w_sum = data[wgt_nme].sum()
656
+ data.loc[:, 'bins'] = np.floor(
657
+ data['cum_weight']*float(n_bins)/w_sum)
658
+ data.loc[(data['bins'] == n_bins), 'bins'] = n_bins-1
659
+ return data.groupby(['bins'], observed=True).sum(numeric_only=True)
660
+
661
+ # 定义Lift Chart绘制数据集函数
662
+ def _plot_data_lift(self,
663
+ pred_list, w_pred_list,
664
+ w_act_list, weight_list, n_bins=10):
665
+ lift_data = pd.DataFrame()
666
+ lift_data.loc[:, 'pred'] = pred_list
667
+ lift_data.loc[:, 'w_pred'] = w_pred_list
668
+ lift_data.loc[:, 'act'] = w_act_list
669
+ lift_data.loc[:, 'weight'] = weight_list
670
+ plot_data = self._split_data(
671
+ lift_data, 'pred', 'weight', n_bins)
672
+ plot_data['exp_v'] = plot_data['w_pred'] / plot_data['weight']
673
+ plot_data['act_v'] = plot_data['act'] / plot_data['weight']
674
+ plot_data.reset_index(inplace=True)
675
+ return plot_data
676
+
677
+ # 定义lift曲线绘制函数
678
+ def plot_lift(self, model_label, n_bins=10):
679
+ # 绘制建模集上结果
680
+ figpos_list = [121, 122]
681
+ plot_dict = {
682
+ 121: self.train_data,
683
+ 122: self.test_data
684
+ }
685
+ name_list = {
686
+ 121: 'Train Data',
687
+ 122: 'Test Data'
688
+ }
689
+ fig = plt.figure(figsize=(11, 5))
690
+ if model_label == 'Xgboost':
691
+ pred_nme = 'pred_xgb'
692
+ elif model_label == 'ResNet':
693
+ pred_nme = 'pred_resn'
694
+
695
+ for figpos in figpos_list:
696
+ plot_data = self._plot_data_lift(
697
+ plot_dict[figpos][pred_nme].values,
698
+ plot_dict[figpos]['w_'+pred_nme].values,
699
+ plot_dict[figpos]['w_act'].values,
700
+ plot_dict[figpos][self.weight_nme].values,
701
+ n_bins)
702
+ ax = fig.add_subplot(figpos)
703
+ ax.plot(plot_data.index, plot_data['act_v'],
704
+ label='Actual', color='red')
705
+ ax.plot(plot_data.index, plot_data['exp_v'],
706
+ label='Predicted', color='blue')
707
+ ax.set_title(
708
+ 'Lift Chart on %s' % name_list[figpos], fontsize=8)
709
+ plt.xticks(plot_data.index,
710
+ plot_data.index,
711
+ rotation=90, fontsize=6)
712
+ plt.yticks(fontsize=6)
713
+ plt.legend(loc='upper left',
714
+ fontsize=5, frameon=False)
715
+ plt.margins(0.05)
716
+ ax2 = ax.twinx()
717
+ ax2.bar(plot_data.index, plot_data['weight'],
718
+ alpha=0.5, color='seagreen',
719
+ label='Earned Exposure')
720
+ plt.yticks(fontsize=6)
721
+ plt.legend(loc='upper right',
722
+ fontsize=5, frameon=False)
723
+ plt.subplots_adjust(wspace=0.3)
724
+ save_path = os.path.join(
725
+ os.getcwd(), 'plot', f'01_{self.model_nme}_{model_label}_lift.png')
726
+ plt.savefig(save_path, dpi=300)
727
+ plt.show()
728
+ plt.close(fig)
729
+
730
+ # 定义Double Lift Chart绘制数据集函数
731
+ def _plot_data_dlift(self,
732
+ pred_list_model1, pred_list_model2,
733
+ w_list, w_act_list, n_bins=10):
734
+ lift_data = pd.DataFrame()
735
+ lift_data.loc[:, 'pred1'] = pred_list_model1
736
+ lift_data.loc[:, 'pred2'] = pred_list_model2
737
+ lift_data.loc[:, 'diff_ly'] = lift_data['pred1'] / lift_data['pred2']
738
+ lift_data.loc[:, 'act'] = w_act_list
739
+ lift_data.loc[:, 'weight'] = w_list
740
+ plot_data = self._split_data(lift_data, 'diff_ly', 'weight', n_bins)
741
+ plot_data['exp_v1'] = plot_data['pred1'] / plot_data['act']
742
+ plot_data['exp_v2'] = plot_data['pred2'] / plot_data['act']
743
+ plot_data['act_v'] = plot_data['act'] / plot_data['act']
744
+ plot_data.reset_index(inplace=True)
745
+ return plot_data
746
+
747
+ # 定义绘制Double Lift Chart函数
748
+ def plot_dlift(self, n_bins=10):
749
+ # 绘制建模集上结果
750
+ figpos_list = [121, 122]
751
+ plot_dict = {
752
+ 121: self.train_data,
753
+ 122: self.test_data
754
+ }
755
+ name_list = {
756
+ 121: 'Train Data',
757
+ 122: 'Test Data'
758
+ }
759
+ fig = plt.figure(figsize=(11, 5))
760
+ for figpos in figpos_list:
761
+ plot_data = self._plot_data_dlift(
762
+ plot_dict[figpos]['w_pred_xgb'].values,
763
+ plot_dict[figpos]['w_pred_resn'].values,
764
+ plot_dict[figpos][self.weight_nme].values,
765
+ plot_dict[figpos]['w_act'].values,
766
+ n_bins)
767
+ ax = fig.add_subplot(figpos)
768
+ tt1 = 'Xgboost'
769
+ tt2 = 'ResNet'
770
+ ax.plot(plot_data.index, plot_data['act_v'],
771
+ label='Actual', color='red')
772
+ ax.plot(plot_data.index, plot_data['exp_v1'],
773
+ label=tt1, color='blue')
774
+ ax.plot(plot_data.index, plot_data['exp_v2'],
775
+ label=tt2, color='black')
776
+ ax.set_title(
777
+ 'Double Lift Chart on %s' % name_list[figpos], fontsize=8)
778
+ plt.xticks(plot_data.index,
779
+ plot_data.index,
780
+ rotation=90, fontsize=6)
781
+ plt.xlabel('%s / %s' % (tt1, tt2), fontsize=6)
782
+ plt.yticks(fontsize=6)
783
+ plt.legend(loc='upper left',
784
+ fontsize=5, frameon=False)
785
+ plt.margins(0.1)
786
+ plt.subplots_adjust(bottom=0.25, top=0.95, right=0.8)
787
+ ax2 = ax.twinx()
788
+ ax2.bar(plot_data.index, plot_data['weight'],
789
+ alpha=0.5, color='seagreen',
790
+ label='Earned Exposure')
791
+ plt.yticks(fontsize=6)
792
+ plt.legend(loc='upper right',
793
+ fontsize=5, frameon=False)
794
+ plt.subplots_adjust(wspace=0.3)
795
+ save_path = os.path.join(
796
+ os.getcwd(), 'plot', f'02_{self.model_nme}_dlift.png')
797
+ plt.savefig(save_path, dpi=300)
798
+ plt.show()
799
+ plt.close(fig)
800
+
801
+ # 保存模型
802
+ def save_model(self, model_name=None):
803
+ save_path_xgb = os.path.join(
804
+ os.getcwd(), 'model', f'01_{self.model_nme}_Xgboost.pkl')
805
+ save_path_resn = os.path.join(
806
+ os.getcwd(), 'model', f'01_{self.model_nme}_ResNet.pth')
807
+ if not os.path.exists(os.path.dirname(save_path_xgb)):
808
+ os.makedirs(os.path.dirname(save_path_xgb))
809
+ # self.xgb_best.save_model(save_path_xgb)
810
+ if model_name != 'resn':
811
+ joblib.dump(self.xgb_best, save_path_xgb)
812
+ if model_name != 'xgb':
813
+ torch.save(self.resn_best.resnet.state_dict(), save_path_resn)
814
+
815
+ def load_model(self, model_name=None):
816
+ # model_name 可以是 'xgb', 'resn' 或 None
817
+ save_path_xgb = os.path.join(
818
+ os.getcwd(), 'model', f'01_{self.model_nme}_Xgboost.pkl')
819
+ save_path_resn = os.path.join(
820
+ os.getcwd(), 'model', f'01_{self.model_nme}_ResNet.pth')
821
+ if model_name != 'resn':
822
+ self.xgb_load = joblib.load(save_path_xgb)
823
+ if model_name != 'xgb':
824
+ self.resn_load = ResNetScikitLearn(
825
+ model_nme=self.model_nme,
826
+ input_dim=self.train_oht_scl_data[self.var_nmes].shape[1]
827
+ )
828
+ self.resn_load.resnet.load_state_dict(
829
+ torch.load(save_path_resn, map_location=self.resn_load.device))