ins-pricing 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. ins_pricing/README.md +60 -0
  2. ins_pricing/__init__.py +102 -0
  3. ins_pricing/governance/README.md +18 -0
  4. ins_pricing/governance/__init__.py +20 -0
  5. ins_pricing/governance/approval.py +93 -0
  6. ins_pricing/governance/audit.py +37 -0
  7. ins_pricing/governance/registry.py +99 -0
  8. ins_pricing/governance/release.py +159 -0
  9. ins_pricing/modelling/BayesOpt.py +146 -0
  10. ins_pricing/modelling/BayesOpt_USAGE.md +925 -0
  11. ins_pricing/modelling/BayesOpt_entry.py +575 -0
  12. ins_pricing/modelling/BayesOpt_incremental.py +731 -0
  13. ins_pricing/modelling/Explain_Run.py +36 -0
  14. ins_pricing/modelling/Explain_entry.py +539 -0
  15. ins_pricing/modelling/Pricing_Run.py +36 -0
  16. ins_pricing/modelling/README.md +33 -0
  17. ins_pricing/modelling/__init__.py +44 -0
  18. ins_pricing/modelling/bayesopt/__init__.py +98 -0
  19. ins_pricing/modelling/bayesopt/config_preprocess.py +303 -0
  20. ins_pricing/modelling/bayesopt/core.py +1476 -0
  21. ins_pricing/modelling/bayesopt/models.py +2196 -0
  22. ins_pricing/modelling/bayesopt/trainers.py +2446 -0
  23. ins_pricing/modelling/bayesopt/utils.py +1021 -0
  24. ins_pricing/modelling/cli_common.py +136 -0
  25. ins_pricing/modelling/explain/__init__.py +55 -0
  26. ins_pricing/modelling/explain/gradients.py +334 -0
  27. ins_pricing/modelling/explain/metrics.py +176 -0
  28. ins_pricing/modelling/explain/permutation.py +155 -0
  29. ins_pricing/modelling/explain/shap_utils.py +146 -0
  30. ins_pricing/modelling/notebook_utils.py +284 -0
  31. ins_pricing/modelling/plotting/__init__.py +45 -0
  32. ins_pricing/modelling/plotting/common.py +63 -0
  33. ins_pricing/modelling/plotting/curves.py +572 -0
  34. ins_pricing/modelling/plotting/diagnostics.py +139 -0
  35. ins_pricing/modelling/plotting/geo.py +362 -0
  36. ins_pricing/modelling/plotting/importance.py +121 -0
  37. ins_pricing/modelling/run_logging.py +133 -0
  38. ins_pricing/modelling/tests/conftest.py +8 -0
  39. ins_pricing/modelling/tests/test_cross_val_generic.py +66 -0
  40. ins_pricing/modelling/tests/test_distributed_utils.py +18 -0
  41. ins_pricing/modelling/tests/test_explain.py +56 -0
  42. ins_pricing/modelling/tests/test_geo_tokens_split.py +49 -0
  43. ins_pricing/modelling/tests/test_graph_cache.py +33 -0
  44. ins_pricing/modelling/tests/test_plotting.py +63 -0
  45. ins_pricing/modelling/tests/test_plotting_library.py +150 -0
  46. ins_pricing/modelling/tests/test_preprocessor.py +48 -0
  47. ins_pricing/modelling/watchdog_run.py +211 -0
  48. ins_pricing/pricing/README.md +44 -0
  49. ins_pricing/pricing/__init__.py +27 -0
  50. ins_pricing/pricing/calibration.py +39 -0
  51. ins_pricing/pricing/data_quality.py +117 -0
  52. ins_pricing/pricing/exposure.py +85 -0
  53. ins_pricing/pricing/factors.py +91 -0
  54. ins_pricing/pricing/monitoring.py +99 -0
  55. ins_pricing/pricing/rate_table.py +78 -0
  56. ins_pricing/production/__init__.py +21 -0
  57. ins_pricing/production/drift.py +30 -0
  58. ins_pricing/production/monitoring.py +143 -0
  59. ins_pricing/production/scoring.py +40 -0
  60. ins_pricing/reporting/README.md +20 -0
  61. ins_pricing/reporting/__init__.py +11 -0
  62. ins_pricing/reporting/report_builder.py +72 -0
  63. ins_pricing/reporting/scheduler.py +45 -0
  64. ins_pricing/setup.py +41 -0
  65. ins_pricing v2/__init__.py +23 -0
  66. ins_pricing v2/governance/__init__.py +20 -0
  67. ins_pricing v2/governance/approval.py +93 -0
  68. ins_pricing v2/governance/audit.py +37 -0
  69. ins_pricing v2/governance/registry.py +99 -0
  70. ins_pricing v2/governance/release.py +159 -0
  71. ins_pricing v2/modelling/Explain_Run.py +36 -0
  72. ins_pricing v2/modelling/Pricing_Run.py +36 -0
  73. ins_pricing v2/modelling/__init__.py +151 -0
  74. ins_pricing v2/modelling/cli_common.py +141 -0
  75. ins_pricing v2/modelling/config.py +249 -0
  76. ins_pricing v2/modelling/config_preprocess.py +254 -0
  77. ins_pricing v2/modelling/core.py +741 -0
  78. ins_pricing v2/modelling/data_container.py +42 -0
  79. ins_pricing v2/modelling/explain/__init__.py +55 -0
  80. ins_pricing v2/modelling/explain/gradients.py +334 -0
  81. ins_pricing v2/modelling/explain/metrics.py +176 -0
  82. ins_pricing v2/modelling/explain/permutation.py +155 -0
  83. ins_pricing v2/modelling/explain/shap_utils.py +146 -0
  84. ins_pricing v2/modelling/features.py +215 -0
  85. ins_pricing v2/modelling/model_manager.py +148 -0
  86. ins_pricing v2/modelling/model_plotting.py +463 -0
  87. ins_pricing v2/modelling/models.py +2203 -0
  88. ins_pricing v2/modelling/notebook_utils.py +294 -0
  89. ins_pricing v2/modelling/plotting/__init__.py +45 -0
  90. ins_pricing v2/modelling/plotting/common.py +63 -0
  91. ins_pricing v2/modelling/plotting/curves.py +572 -0
  92. ins_pricing v2/modelling/plotting/diagnostics.py +139 -0
  93. ins_pricing v2/modelling/plotting/geo.py +362 -0
  94. ins_pricing v2/modelling/plotting/importance.py +121 -0
  95. ins_pricing v2/modelling/run_logging.py +133 -0
  96. ins_pricing v2/modelling/tests/conftest.py +8 -0
  97. ins_pricing v2/modelling/tests/test_cross_val_generic.py +66 -0
  98. ins_pricing v2/modelling/tests/test_distributed_utils.py +18 -0
  99. ins_pricing v2/modelling/tests/test_explain.py +56 -0
  100. ins_pricing v2/modelling/tests/test_geo_tokens_split.py +49 -0
  101. ins_pricing v2/modelling/tests/test_graph_cache.py +33 -0
  102. ins_pricing v2/modelling/tests/test_plotting.py +63 -0
  103. ins_pricing v2/modelling/tests/test_plotting_library.py +150 -0
  104. ins_pricing v2/modelling/tests/test_preprocessor.py +48 -0
  105. ins_pricing v2/modelling/trainers.py +2447 -0
  106. ins_pricing v2/modelling/utils.py +1020 -0
  107. ins_pricing v2/modelling/watchdog_run.py +211 -0
  108. ins_pricing v2/pricing/__init__.py +27 -0
  109. ins_pricing v2/pricing/calibration.py +39 -0
  110. ins_pricing v2/pricing/data_quality.py +117 -0
  111. ins_pricing v2/pricing/exposure.py +85 -0
  112. ins_pricing v2/pricing/factors.py +91 -0
  113. ins_pricing v2/pricing/monitoring.py +99 -0
  114. ins_pricing v2/pricing/rate_table.py +78 -0
  115. ins_pricing v2/production/__init__.py +21 -0
  116. ins_pricing v2/production/drift.py +30 -0
  117. ins_pricing v2/production/monitoring.py +143 -0
  118. ins_pricing v2/production/scoring.py +40 -0
  119. ins_pricing v2/reporting/__init__.py +11 -0
  120. ins_pricing v2/reporting/report_builder.py +72 -0
  121. ins_pricing v2/reporting/scheduler.py +45 -0
  122. ins_pricing v2/scripts/BayesOpt_incremental.py +722 -0
  123. ins_pricing v2/scripts/Explain_entry.py +545 -0
  124. ins_pricing v2/scripts/__init__.py +1 -0
  125. ins_pricing v2/scripts/train.py +568 -0
  126. ins_pricing v2/setup.py +55 -0
  127. ins_pricing v2/smoke_test.py +28 -0
  128. ins_pricing-0.1.6.dist-info/METADATA +78 -0
  129. ins_pricing-0.1.6.dist-info/RECORD +169 -0
  130. ins_pricing-0.1.6.dist-info/WHEEL +5 -0
  131. ins_pricing-0.1.6.dist-info/top_level.txt +4 -0
  132. user_packages/__init__.py +105 -0
  133. user_packages legacy/BayesOpt.py +5659 -0
  134. user_packages legacy/BayesOpt_entry.py +513 -0
  135. user_packages legacy/BayesOpt_incremental.py +685 -0
  136. user_packages legacy/Pricing_Run.py +36 -0
  137. user_packages legacy/Try/BayesOpt Legacy251213.py +3719 -0
  138. user_packages legacy/Try/BayesOpt Legacy251215.py +3758 -0
  139. user_packages legacy/Try/BayesOpt lagecy251201.py +3506 -0
  140. user_packages legacy/Try/BayesOpt lagecy251218.py +3992 -0
  141. user_packages legacy/Try/BayesOpt legacy.py +3280 -0
  142. user_packages legacy/Try/BayesOpt.py +838 -0
  143. user_packages legacy/Try/BayesOptAll.py +1569 -0
  144. user_packages legacy/Try/BayesOptAllPlatform.py +909 -0
  145. user_packages legacy/Try/BayesOptCPUGPU.py +1877 -0
  146. user_packages legacy/Try/BayesOptSearch.py +830 -0
  147. user_packages legacy/Try/BayesOptSearchOrigin.py +829 -0
  148. user_packages legacy/Try/BayesOptV1.py +1911 -0
  149. user_packages legacy/Try/BayesOptV10.py +2973 -0
  150. user_packages legacy/Try/BayesOptV11.py +3001 -0
  151. user_packages legacy/Try/BayesOptV12.py +3001 -0
  152. user_packages legacy/Try/BayesOptV2.py +2065 -0
  153. user_packages legacy/Try/BayesOptV3.py +2209 -0
  154. user_packages legacy/Try/BayesOptV4.py +2342 -0
  155. user_packages legacy/Try/BayesOptV5.py +2372 -0
  156. user_packages legacy/Try/BayesOptV6.py +2759 -0
  157. user_packages legacy/Try/BayesOptV7.py +2832 -0
  158. user_packages legacy/Try/BayesOptV8Codex.py +2731 -0
  159. user_packages legacy/Try/BayesOptV8Gemini.py +2614 -0
  160. user_packages legacy/Try/BayesOptV9.py +2927 -0
  161. user_packages legacy/Try/BayesOpt_entry legacy.py +313 -0
  162. user_packages legacy/Try/ModelBayesOptSearch.py +359 -0
  163. user_packages legacy/Try/ResNetBayesOptSearch.py +249 -0
  164. user_packages legacy/Try/XgbBayesOptSearch.py +121 -0
  165. user_packages legacy/Try/xgbbayesopt.py +523 -0
  166. user_packages legacy/__init__.py +19 -0
  167. user_packages legacy/cli_common.py +124 -0
  168. user_packages legacy/notebook_utils.py +228 -0
  169. user_packages legacy/watchdog_run.py +202 -0
@@ -0,0 +1,830 @@
1
+ from random import sample
2
+ from turtle import st
3
+ import numpy as np # 1.26.2
4
+ import pandas as pd # 2.2.3
5
+ import torch # 1.10.1+cu111
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import optuna # 4.3.0
9
+ import xgboost as xgb # 1.7.0
10
+ import matplotlib.pyplot as plt
11
+ import os
12
+ import joblib
13
+ import torch.utils.checkpoint as cp
14
+
15
+ from torch.utils.data import DataLoader, TensorDataset
16
+ from torch.cuda.amp import autocast, GradScaler
17
+ from torch.nn.utils import clip_grad_norm_
18
+ from sklearn.model_selection import KFold, ShuffleSplit, cross_val_score # 1.2.2
19
+ from sklearn.preprocessing import StandardScaler
20
+ from sklearn.metrics import make_scorer, mean_tweedie_deviance
21
+
22
+ # 定义torch下tweedie deviance损失函数
23
+ # 参考:https://scikit-learn.org/stable/modules/model_evaluation.html#mean-poisson-gamma-and-tweedie-deviances
24
+
25
+
26
+ def tweedie_loss(pred, target, p=1.5):
27
+ # Ensure predictions are positive for stability
28
+ eps = 1e-6
29
+ pred_clamped = torch.clamp(pred, min=eps)
30
+ # Compute Tweedie deviance components
31
+ if p == 1:
32
+ # Poisson case
33
+ term1 = target * torch.log(target / pred_clamped + eps)
34
+ term2 = -target + pred_clamped
35
+ term3 = 0
36
+ elif p == 0:
37
+ # Gaussian case
38
+ term1 = 0.5 * torch.pow(target - pred_clamped, 2)
39
+ term2 = 0
40
+ term3 = 0
41
+ elif p == 2:
42
+ # Gamma case
43
+ term1 = torch.log(pred_clamped / target + eps)
44
+ term2 = -target / pred_clamped + 1
45
+ term3 = 0
46
+ else:
47
+ term1 = torch.pow(target, 2 - p) / ((1 - p) * (2 - p))
48
+ term2 = target * torch.pow(pred_clamped, 1 - p) / (1 - p)
49
+ term3 = torch.pow(pred_clamped, 2 - p) / (2 - p)
50
+ # Tweedie negative log-likelihood (up to a constant)
51
+ return 2 * (term1 - term2 + term3)
52
+
53
+ # 定义分箱函数
54
+
55
+
56
+ def split_data(data, col_nme, wgt_nme, n_bins=10):
57
+ data.sort_values(by=col_nme, ascending=True, inplace=True)
58
+ data['cum_weight'] = data[wgt_nme].cumsum()
59
+ w_sum = data[wgt_nme].sum()
60
+ data.loc[:, 'bins'] = np.floor(data['cum_weight'] * float(n_bins) / w_sum)
61
+ data.loc[(data['bins'] == n_bins), 'bins'] = n_bins - 1
62
+ return data.groupby(['bins'], observed=True).sum(numeric_only=True)
63
+
64
+ # 定义Lift Chart绘制函数
65
+
66
+
67
+ def plot_lift_list(pred_model, w_pred_list, w_act_list,
68
+ weight_list, tgt_nme, n_bins=10,
69
+ fig_nme='Lift Chart'):
70
+ lift_data = pd.DataFrame()
71
+ lift_data.loc[:, 'pred'] = pred_model
72
+ lift_data.loc[:, 'w_pred'] = w_pred_list
73
+ lift_data.loc[:, 'act'] = w_act_list
74
+ lift_data.loc[:, 'weight'] = weight_list
75
+ plot_data = split_data(lift_data, 'pred', 'weight', n_bins)
76
+ plot_data['exp_v'] = plot_data['w_pred'] / plot_data['weight']
77
+ plot_data['act_v'] = plot_data['act'] / plot_data['weight']
78
+ plot_data.reset_index(inplace=True)
79
+ fig = plt.figure(figsize=(7, 5))
80
+ ax = fig.add_subplot(111)
81
+ ax.plot(plot_data.index, plot_data['act_v'],
82
+ label='Actual', color='red')
83
+ ax.plot(plot_data.index, plot_data['exp_v'],
84
+ label='Predicted', color='blue')
85
+ ax.set_title(
86
+ 'Lift Chart of %s' % tgt_nme, fontsize=8)
87
+ plt.xticks(plot_data.index,
88
+ plot_data.index,
89
+ rotation=90, fontsize=6)
90
+ plt.yticks(fontsize=6)
91
+ plt.legend(loc='upper left',
92
+ fontsize=5, frameon=False)
93
+ plt.margins(0.05)
94
+ ax2 = ax.twinx()
95
+ ax2.bar(plot_data.index, plot_data['weight'],
96
+ alpha=0.5, color='seagreen',
97
+ label='Earned Exposure')
98
+ plt.yticks(fontsize=6)
99
+ plt.legend(loc='upper right',
100
+ fontsize=5, frameon=False)
101
+ plt.subplots_adjust(wspace=0.3)
102
+ save_path = os.path.join(
103
+ os.getcwd(), 'plot', f'05_{tgt_nme}_{fig_nme}.png')
104
+ plt.savefig(save_path, dpi=300)
105
+ plt.close(fig)
106
+
107
+ # 定义Double Lift Chart绘制函数
108
+
109
+
110
+ def plot_dlift_list(pred_model_1, pred_model_2,
111
+ model_nme_1, model_nme_2,
112
+ tgt_nme,
113
+ w_list, w_act_list, n_bins=10,
114
+ fig_nme='Double Lift Chart'):
115
+ lift_data = pd.DataFrame()
116
+ lift_data.loc[:, 'pred1'] = pred_model_1
117
+ lift_data.loc[:, 'pred2'] = pred_model_2
118
+ lift_data.loc[:, 'diff_ly'] = lift_data['pred1'] / lift_data['pred2']
119
+ lift_data.loc[:, 'act'] = w_act_list
120
+ lift_data.loc[:, 'weight'] = w_list
121
+ lift_data.loc[:, 'w_pred1'] = lift_data['pred1'] * lift_data['weight']
122
+ lift_data.loc[:, 'w_pred2'] = lift_data['pred2'] * lift_data['weight']
123
+ plot_data = split_data(lift_data, 'diff_ly', 'weight', n_bins)
124
+ plot_data['exp_v1'] = plot_data['w_pred1'] / plot_data['act']
125
+ plot_data['exp_v2'] = plot_data['w_pred2'] / plot_data['act']
126
+ plot_data['act_v'] = plot_data['act']/plot_data['act']
127
+ plot_data.reset_index(inplace=True)
128
+ fig = plt.figure(figsize=(7, 5))
129
+ ax = fig.add_subplot(111)
130
+ ax.plot(plot_data.index, plot_data['act_v'],
131
+ label='Actual', color='red')
132
+ ax.plot(plot_data.index, plot_data['exp_v1'],
133
+ label=model_nme_1, color='blue')
134
+ ax.plot(plot_data.index, plot_data['exp_v2'],
135
+ label=model_nme_2, color='black')
136
+ ax.set_title(
137
+ 'Double Lift Chart of %s' % tgt_nme, fontsize=8)
138
+ plt.xticks(plot_data.index,
139
+ plot_data.index,
140
+ rotation=90, fontsize=6)
141
+ plt.xlabel('%s / %s' % (model_nme_1, model_nme_2), fontsize=6)
142
+ plt.yticks(fontsize=6)
143
+ plt.legend(loc='upper left',
144
+ fontsize=5, frameon=False)
145
+ plt.margins(0.1)
146
+ plt.subplots_adjust(bottom=0.25, top=0.95, right=0.8)
147
+ ax2 = ax.twinx()
148
+ ax2.bar(plot_data.index, plot_data['weight'],
149
+ alpha=0.5, color='seagreen',
150
+ label='Earned Exposure')
151
+ plt.yticks(fontsize=6)
152
+ plt.legend(loc='upper right',
153
+ fontsize=5, frameon=False)
154
+ plt.subplots_adjust(wspace=0.3)
155
+ save_path = os.path.join(
156
+ os.getcwd(), 'plot', f'06_{tgt_nme}_{fig_nme}.png')
157
+ plt.savefig(save_path, dpi=300)
158
+ plt.close(fig)
159
+
160
+ # 残差块:两层线性 + ReLU + 残差连接
161
+ # ResBlock 继承 nn.Module
162
+
163
+
164
+ class ResBlock(nn.Module):
165
+ def __init__(self, dim):
166
+ super(ResBlock, self).__init__()
167
+ self.block = nn.Sequential(
168
+ nn.Linear(dim, dim),
169
+ nn.BatchNorm1d(dim),
170
+ nn.ReLU(),
171
+ nn.Linear(dim, dim),
172
+ nn.BatchNorm1d(dim)
173
+ )
174
+
175
+ def forward(self, x):
176
+ # 原始输入 + 两层变换,然后再过 ReLU
177
+ # 用 checkpoint来节省内存
178
+ # out = cp.checkpoint(self.block, x)
179
+ return F.relu(self.block(x) + x)
180
+ # return F.relu(out + x)
181
+
182
+ # ResNetSequential 继承 nn.Module,定义整个网络结构
183
+
184
+
185
+ class ResNetSequential(nn.Module):
186
+ # 整个网络用 nn.Sequential 串联:输入 -> ResBlock*block_num -> 输出
187
+ def __init__(self, input_dim, hidden_dim=64, block_num=2):
188
+ super(ResNetSequential, self).__init__()
189
+ self.net = nn.Sequential()
190
+ self.net.add_module('fc1', nn.Linear(input_dim, hidden_dim))
191
+ self.net.add_module('bn1', nn.BatchNorm1d(hidden_dim))
192
+ self.net.add_module('ReLU1', nn.ReLU())
193
+ for i in range(block_num):
194
+ self.net.add_module('ResBlk_'+str(i+1), ResBlock(hidden_dim))
195
+ self.net.add_module('fc2', nn.Linear(hidden_dim, 1))
196
+ self.net.add_module('softplus', nn.Softplus())
197
+
198
+ def forward(self, x):
199
+ return self.net(x)
200
+
201
+ # 贝叶斯优化类,使用高斯过程进行超参数优化
202
+
203
+
204
+ class ResNetScikitLearn(nn.Module):
205
+ def __init__(self, model_nme, input_dim, hidden_dim=64,
206
+ block_num=2, batch_num=100, epochs=100,
207
+ tweedie_power=1.5, learning_rate=0.01,
208
+ patience=10, accumulation_steps=2):
209
+ super(ResNetScikitLearn, self).__init__()
210
+ self.input_dim = input_dim
211
+ self.hidden_dim = hidden_dim
212
+ self.block_num = block_num
213
+ if torch.cuda.is_available():
214
+ self.device = torch.device('cuda')
215
+ elif torch.backends.mps.is_available():
216
+ self.device = torch.device('mps')
217
+ else:
218
+ self.device = torch.device('cpu')
219
+ self.resnet = ResNetSequential(
220
+ self.input_dim,
221
+ self.hidden_dim,
222
+ self.block_num
223
+ ).to(self.device)
224
+ if torch.cuda.device_count() > 1:
225
+ self.resnet = nn.DataParallel(
226
+ self.resnet,
227
+ device_ids=list(range(torch.cuda.device_count()))
228
+ )
229
+ self.batch_num = batch_num
230
+ self.epochs = epochs
231
+ self.model_nme = model_nme
232
+ if self.model_nme.find('f') != -1:
233
+ self.tw_power = 1
234
+ elif self.model_nme.find('s') != -1:
235
+ self.tw_power = 2
236
+ else:
237
+ self.tw_power = tweedie_power
238
+ self.learning_rate = learning_rate
239
+ self.patience = patience # Early stopping patience
240
+ self.accumulation_steps = accumulation_steps # Gradient accumulation steps
241
+
242
+ def fit(self, X_train, y_train, w_train=None, X_val=None, y_val=None, w_val=None):
243
+ # 将数据转换为 PyTorch 张量
244
+ X_tensor = torch.tensor(
245
+ X_train.values, dtype=torch.float32).to(self.device)
246
+ y_tensor = torch.tensor(
247
+ y_train.values, dtype=torch.float32).view(-1, 1).to(self.device)
248
+ w_tensor = torch.tensor(
249
+ w_train.values, dtype=torch.float32).view(-1, 1).to(self.device) if w_train is not None else torch.ones_like(y_tensor)
250
+
251
+ # 验证集张量
252
+ if X_val is not None:
253
+ X_val_tensor = torch.tensor(
254
+ X_val.values, dtype=torch.float32).to(self.device)
255
+ y_val_tensor = torch.tensor(
256
+ y_val.values, dtype=torch.float32).view(-1, 1).to(self.device)
257
+ w_val_tensor = torch.tensor(
258
+ w_val.values, dtype=torch.float32).view(-1, 1).to(self.device) if w_val is not None else torch.ones_like(y_val_tensor)
259
+
260
+ # 创建数据集和数据加载器
261
+ dataset = TensorDataset(
262
+ X_tensor, y_tensor, w_tensor
263
+ )
264
+ dataloader = DataLoader(
265
+ dataset,
266
+ batch_size=max(1, int((self.learning_rate/(1e-4))**0.5 *
267
+ (X_train.shape[0]/self.batch_num))),
268
+ shuffle=True
269
+ # num_workers=4
270
+ # pin_memory=(self.device.type == 'cuda')
271
+ )
272
+ # 定义损失函数和优化器
273
+ optimizer = torch.optim.Adam(
274
+ self.resnet.parameters(), lr=self.learning_rate)
275
+ scaler = GradScaler(enabled=(self.device.type == 'cuda'))
276
+
277
+ # Early stopping 参数
278
+ best_loss, patience_counter = float('inf'), 0
279
+ best_model_state = None # Initialize best_model_state
280
+
281
+ # 训练模型
282
+ for epoch in range(1, self.epochs + 1):
283
+ self.resnet.train()
284
+ for X_batch, y_batch, w_batch in dataloader:
285
+ optimizer.zero_grad()
286
+ # 如果运行设备为 CUDA,则启用混合精度。
287
+ with autocast(enabled=(self.device.type == 'cuda')):
288
+ X_batch, y_batch, w_batch = X_batch.to(self.device), y_batch.to(
289
+ self.device), w_batch.to(self.device)
290
+ y_pred = self.resnet(X_batch)
291
+ y_pred = torch.clamp(y_pred, min=1e-6)
292
+ losses = tweedie_loss(
293
+ y_pred, y_batch, p=self.tw_power).view(-1)
294
+ weighted_loss = (losses * w_batch.view(-1)
295
+ ).sum() / w_batch.sum()
296
+ scaler.scale(weighted_loss).backward()
297
+ # gradient clipping
298
+ # 如进行gradient clipping,需要在反向传播之前取消缩放
299
+ if self.device.type == 'cuda':
300
+ scaler.unscale_(optimizer)
301
+ clip_grad_norm_(
302
+ self.resnet.parameters(),
303
+ max_norm=1.0
304
+ )
305
+ scaler.step(optimizer)
306
+ scaler.update()
307
+ optimizer.zero_grad()
308
+
309
+ # 验证集损失计算
310
+ if X_val is not None and y_val is not None:
311
+ self.resnet.eval()
312
+ with torch.no_grad(), autocast(enabled=(self.device.type == 'cuda')):
313
+ y_val_pred = self.resnet(X_val_tensor)
314
+ val_loss_values = tweedie_loss(
315
+ y_val_pred, y_val_tensor, p=self.tw_power).view(-1)
316
+ val_weighted_loss = (
317
+ val_loss_values * w_val_tensor.view(-1)).sum() / w_val_tensor.sum()
318
+
319
+ # Early stopping 检查
320
+ if val_weighted_loss < best_loss:
321
+ best_loss, patience_counter = val_weighted_loss, 0
322
+ # 保存当前最佳模型
323
+ best_model_state = self.resnet.state_dict()
324
+ else:
325
+ patience_counter += 1
326
+ if patience_counter >= self.patience:
327
+ self.resnet.load_state_dict(best_model_state) # 恢复最佳模型
328
+ break
329
+
330
+ def predict(self, X_test):
331
+ self.resnet.eval()
332
+ with torch.no_grad():
333
+ X_tensor = torch.tensor(
334
+ X_test.values, dtype=torch.float32).to(self.device)
335
+ y_pred = self.resnet(X_tensor).cpu().numpy()
336
+ y_pred = np.clip(y_pred, 1e-6, None)
337
+ return y_pred.flatten()
338
+
339
+ def set_params(self, params):
340
+ # 设置模型参数
341
+ for key, value in params.items():
342
+ if hasattr(self, key):
343
+ setattr(self, key, value)
344
+ else:
345
+ raise ValueError(f"Parameter {key} not found in model.")
346
+
347
+ # 定义贝叶斯优化模型类,包含XGBoost和ResNet模型
348
+
349
+
350
+ class BayesOptModel:
351
+ def __init__(self, train_data, test_data,
352
+ model_nme, resp_nme, weight_nme, factor_nmes,
353
+ cate_list=[], prop_test=0.25, rand_seed=None, epochs=100):
354
+ # 初始化数据
355
+ # train_data: 训练数据, test_data: 测试数据 格式需为DataFrame
356
+ # model_nme: 模型名称
357
+ # resp_nme: 因变量名称, weight_nme: 权重名称
358
+ # factor_nmes: 因子名称列表, space_params: 参数空间
359
+ # cate_list: 类别变量列表
360
+ # prop_test: 测试集比例, rand_seed
361
+ self.train_data = train_data
362
+ self.test_data = test_data
363
+ self.resp_nme = resp_nme
364
+ self.weight_nme = weight_nme
365
+ self.train_data.loc[:, 'w_act'] = self.train_data[self.resp_nme] * \
366
+ self.train_data[self.weight_nme]
367
+ self.test_data.loc[:, 'w_act'] = self.test_data[self.resp_nme] * \
368
+ self.test_data[self.weight_nme]
369
+ self.factor_nmes = factor_nmes
370
+ self.cate_list = cate_list
371
+ self.rand_seed = rand_seed if rand_seed is not None else np.random.randint(
372
+ 1, 10000)
373
+ if self.cate_list != []:
374
+ for cate in self.cate_list:
375
+ self.train_data[cate] = self.train_data[cate].astype(
376
+ 'category')
377
+ self.test_data[cate] = self.test_data[cate].astype('category')
378
+ self.prop_test = prop_test
379
+ self.cv = ShuffleSplit(n_splits=int(1/self.prop_test),
380
+ test_size=self.prop_test,
381
+ random_state=self.rand_seed)
382
+ self.model_nme = model_nme
383
+ if self.model_nme.find('f') != -1:
384
+ self.obj = 'count:poisson'
385
+ elif self.model_nme.find('s') != -1:
386
+ self.obj = 'reg:gamma'
387
+ elif self.model_nme.find('bc') != -1:
388
+ self.obj = 'reg:tweedie'
389
+ self.fit_params = {
390
+ 'sample_weight': self.train_data[self.weight_nme].values
391
+ }
392
+ self.num_features = [
393
+ nme for nme in self.factor_nmes if nme not in self.cate_list]
394
+ self.train_oht_scl_data = self.train_data[self.factor_nmes +
395
+ [self.weight_nme]+[self.resp_nme]].copy()
396
+ self.test_oht_scl_data = self.test_data[self.factor_nmes +
397
+ [self.weight_nme]+[self.resp_nme]].copy()
398
+ self.train_oht_scl_data = pd.get_dummies(
399
+ self.train_oht_scl_data,
400
+ columns=self.cate_list,
401
+ drop_first=True,
402
+ dtype=np.int8
403
+ )
404
+ self.test_oht_scl_data = pd.get_dummies(
405
+ self.test_oht_scl_data,
406
+ columns=self.cate_list,
407
+ drop_first=True,
408
+ dtype=np.int8
409
+ )
410
+ for num_chr in self.num_features:
411
+ scaler = StandardScaler()
412
+ self.train_oht_scl_data[num_chr] = scaler.fit_transform(
413
+ self.train_oht_scl_data[num_chr].values.reshape(-1, 1))
414
+ self.test_oht_scl_data[num_chr] = scaler.transform(
415
+ self.test_oht_scl_data[num_chr].values.reshape(-1, 1))
416
+ # 对测试集进行列对齐
417
+ self.test_oht_scl_data = self.test_oht_scl_data.reindex(
418
+ columns=self.train_oht_scl_data.columns,
419
+ fill_value=0
420
+ )
421
+ self.var_nmes = list(
422
+ set(list(self.train_oht_scl_data.columns)) -
423
+ set([self.weight_nme, self.resp_nme])
424
+ )
425
+ self.epochs = epochs
426
+ self.model_label = []
427
+
428
+ # 定义单因素画图函数
429
+
430
+ def plot_oneway(self, n_bins=10):
431
+ for c in self.factor_nmes:
432
+ fig = plt.figure(figsize=(7, 5))
433
+ if c in self.cate_list:
434
+ strs = c
435
+ else:
436
+ strs = c+'_bins'
437
+ self.train_data.loc[:, strs] = pd.qcut(
438
+ self.train_data[c],
439
+ n_bins,
440
+ duplicates='drop'
441
+ )
442
+ plot_data = self.train_data.groupby(
443
+ [strs], observed=True).sum(numeric_only=True)
444
+ plot_data.reset_index(inplace=True)
445
+ plot_data['act_v'] = plot_data['w_act'] / \
446
+ plot_data[self.weight_nme]
447
+ plot_data.head()
448
+ ax = fig.add_subplot(111)
449
+ ax.plot(plot_data.index, plot_data['act_v'],
450
+ label='Actual', color='red')
451
+ ax.set_title(
452
+ 'Analysis of %s : Train Data' % strs,
453
+ fontsize=8)
454
+ plt.xticks(plot_data.index,
455
+ list(plot_data[strs].astype(str)),
456
+ rotation=90)
457
+ if len(list(plot_data[strs].astype(str))) > 50:
458
+ plt.xticks(fontsize=3)
459
+ else:
460
+ plt.xticks(fontsize=6)
461
+ plt.yticks(fontsize=6)
462
+ ax2 = ax.twinx()
463
+ ax2.bar(plot_data.index,
464
+ plot_data[self.weight_nme],
465
+ alpha=0.5, color='seagreen')
466
+ plt.yticks(fontsize=6)
467
+ plt.margins(0.05)
468
+ plt.subplots_adjust(wspace=0.3)
469
+ save_path = os.path.join(
470
+ os.getcwd(), 'plot',
471
+ f'00_{self.model_nme}_{strs}_oneway.png')
472
+ plt.savefig(save_path, dpi=300)
473
+ plt.close(fig)
474
+
475
+ # Xgboost交叉验证函数
476
+ def cross_val_xgb(self, trial):
477
+ learning_rate = trial.suggest_float(
478
+ 'learning_rate', 1e-5, 1e-1, log=True)
479
+ gamma = trial.suggest_float(
480
+ 'gamma', 0, 10000)
481
+ max_depth = trial.suggest_int(
482
+ 'max_depth', 3, 25)
483
+ n_estimators = trial.suggest_int(
484
+ 'n_estimators', 10, 500, step=10)
485
+ min_child_weight = trial.suggest_int(
486
+ 'min_child_weight', 100, 10000, step=100)
487
+ reg_alpha = trial.suggest_float(
488
+ 'reg_alpha', 1e-10, 1, log=True)
489
+ reg_lambda = trial.suggest_float(
490
+ 'reg_lambda', 1e-10, 1, log=True)
491
+ if self.obj == 'reg:tweedie':
492
+ tweedie_variance_power = trial.suggest_float(
493
+ 'tweedie_variance_power', 1, 2)
494
+ elif self.obj == 'count:poisson':
495
+ tweedie_variance_power = 1
496
+ elif self.obj == 'reg:gamma':
497
+ tweedie_variance_power = 2
498
+ clf = xgb.XGBRegressor(
499
+ objective=self.obj,
500
+ random_state=self.rand_seed,
501
+ subsample=0.9,
502
+ tree_method='gpu_hist',
503
+ gpu_id=0,
504
+ enable_categorical=True,
505
+ predictor='gpu_predictor'
506
+ )
507
+ params = {
508
+ 'learning_rate': learning_rate,
509
+ 'gamma': gamma,
510
+ 'max_depth': max_depth,
511
+ 'n_estimators': n_estimators,
512
+ 'min_child_weight': min_child_weight,
513
+ 'reg_alpha': reg_alpha,
514
+ 'reg_lambda': reg_lambda
515
+ }
516
+ if self.obj == 'reg:tweedie':
517
+ params['tweedie_variance_power'] = tweedie_variance_power
518
+ clf.set_params(**params)
519
+ acc = cross_val_score(
520
+ clf,
521
+ self.train_data[self.factor_nmes],
522
+ self.train_data[self.resp_nme].values,
523
+ fit_params=self.fit_params,
524
+ cv=self.cv,
525
+ scoring=make_scorer(
526
+ mean_tweedie_deviance,
527
+ power=tweedie_variance_power,
528
+ greater_is_better=False),
529
+ error_score='raise',
530
+ n_jobs=int(1/self.prop_test)).mean()
531
+ return -acc
532
+
533
+ # 定义Xgboost贝叶斯优化函数
534
+ def bayesopt_xgb(self, max_evals=100):
535
+ study = optuna.create_study(
536
+ direction='minimize',
537
+ sampler=optuna.samplers.TPESampler(seed=self.rand_seed))
538
+ study.optimize(self.cross_val_xgb, n_trials=max_evals)
539
+ self.best_xgb_params = study.best_params
540
+ pd.DataFrame(self.best_xgb_params, index=[0]).to_csv(
541
+ os.getcwd() + '/Results/' + self.model_nme + '_bestparams_xgb.csv')
542
+ self.best_xgb_trial = study.best_trial
543
+ self.xgb_best = xgb.XGBRegressor(
544
+ objective=self.obj,
545
+ random_state=self.rand_seed,
546
+ subsample=0.9,
547
+ tree_method='gpu_hist' if torch.cuda.is_available() else 'hist',
548
+ gpu_id=0,
549
+ enable_categorical=True,
550
+ predictor='gpu_predictor'
551
+ )
552
+ self.xgb_best.set_params(**self.best_xgb_params)
553
+ self.xgb_best.fit(self.train_data[self.factor_nmes],
554
+ self.train_data[self.resp_nme].values,
555
+ **self.fit_params)
556
+ self.model_label += ['Xgboost']
557
+ self.train_data['pred_xgb'] = self.xgb_best.predict(
558
+ self.train_data[self.factor_nmes])
559
+ self.test_data['pred_xgb'] = self.xgb_best.predict(
560
+ self.test_data[self.factor_nmes])
561
+ self.train_data.loc[:, 'w_pred_xgb'] = self.train_data['pred_xgb'] * \
562
+ self.train_data[self.weight_nme]
563
+ self.test_data.loc[:, 'w_pred_xgb'] = self.test_data['pred_xgb'] * \
564
+ self.test_data[self.weight_nme]
565
+
566
+ # ResNet交叉验证函数
567
+ def cross_val_resn(self, trial):
568
+ learning_rate = trial.suggest_float(
569
+ 'learning_rate', 1e-6, 1e-2, log=True)
570
+ hidden_dim = trial.suggest_int(
571
+ 'hidden_dim', 32, 256, step=16)
572
+ block_num = trial.suggest_int(
573
+ 'block_num', 3, 10)
574
+ batch_num = trial.suggest_int(
575
+ 'batch_num',
576
+ 10 if self.obj == 'reg:gamma' else 100,
577
+ 100 if self.obj == 'reg:gamma' else 1000,
578
+ step=10)
579
+ if self.obj == 'reg:tweedie':
580
+ tw_power = trial.suggest_float(
581
+ 'tw_power', 1, 2)
582
+ elif self.obj == 'count:poisson':
583
+ tw_power = 1
584
+ elif self.obj == 'reg:gamma':
585
+ tw_power = 2
586
+ loss = 0
587
+ for fold, (train_idx, test_idx) in enumerate(self.cv.split(self.train_oht_scl_data[self.var_nmes])):
588
+ # 创建模型
589
+ cv_net = ResNetScikitLearn(
590
+ model_nme=self.model_nme,
591
+ input_dim=self.train_oht_scl_data[self.var_nmes].shape[1],
592
+ epochs=self.epochs,
593
+ learning_rate=learning_rate,
594
+ hidden_dim=hidden_dim,
595
+ block_num=block_num,
596
+ # 保证权重方差不变
597
+ batch_num=batch_num,
598
+ tweedie_power=tw_power if self.obj == 'reg:tweedie' and tw_power != 1 else tw_power+1e-6
599
+ )
600
+ # 训练模型
601
+ cv_net.fit(
602
+ self.train_oht_scl_data[self.var_nmes].iloc[train_idx],
603
+ self.train_oht_scl_data[self.resp_nme].iloc[train_idx],
604
+ self.train_oht_scl_data[self.weight_nme].iloc[train_idx],
605
+ self.train_oht_scl_data[self.var_nmes].iloc[test_idx],
606
+ self.train_oht_scl_data[self.resp_nme].iloc[test_idx],
607
+ self.train_oht_scl_data[self.weight_nme].iloc[test_idx]
608
+ )
609
+ # 预测
610
+ y_pred_fold = cv_net.predict(
611
+ self.train_oht_scl_data[self.var_nmes].iloc[test_idx]
612
+ )
613
+ # 计算损失
614
+
615
+ loss += mean_tweedie_deviance(
616
+ self.train_oht_scl_data[self.resp_nme].iloc[test_idx],
617
+ y_pred_fold,
618
+ sample_weight=self.train_oht_scl_data[self.weight_nme].iloc[test_idx],
619
+ power=tw_power
620
+ )
621
+ return loss / int(1/self.prop_test)
622
+
623
+ # 定义ResNet贝叶斯优化函数
624
+ def bayesopt_resnet(self, max_evals=100):
625
+ study = optuna.create_study(
626
+ direction='minimize',
627
+ sampler=optuna.samplers.TPESampler(seed=self.rand_seed))
628
+ study.optimize(self.cross_val_resn, n_trials=max_evals)
629
+ self.best_resn_params = study.best_params
630
+ pd.DataFrame(self.best_resn_params, index=[0]).to_csv(
631
+ os.getcwd() + '/Results/' + self.model_nme + '_bestparams_resn.csv')
632
+ self.best_resn_trial = study.best_trial
633
+ self.resn_best = ResNetScikitLearn(
634
+ model_nme=self.model_nme,
635
+ input_dim=self.train_oht_scl_data[self.var_nmes].shape[1]
636
+ )
637
+ self.resn_best.set_params(self.best_resn_params)
638
+ self.resn_best.fit(self.train_oht_scl_data[self.var_nmes],
639
+ self.train_oht_scl_data[self.resp_nme],
640
+ self.train_oht_scl_data[self.weight_nme])
641
+ self.model_label += ['ResNet']
642
+ self.train_data['pred_resn'] = self.resn_best.predict(
643
+ self.train_oht_scl_data[self.var_nmes])
644
+ self.test_data['pred_resn'] = self.resn_best.predict(
645
+ self.test_oht_scl_data[self.var_nmes])
646
+ self.train_data.loc[:, 'w_pred_resn'] = self.train_data['pred_resn'] * \
647
+ self.train_data[self.weight_nme]
648
+ self.test_data.loc[:, 'w_pred_resn'] = self.test_data['pred_resn'] * \
649
+ self.test_data[self.weight_nme]
650
+
651
+ # 定义分箱函数
652
+ def _split_data(self, data, col_nme, wgt_nme, n_bins=10):
653
+ data.sort_values(by=col_nme, ascending=True, inplace=True)
654
+ data['cum_weight'] = data[wgt_nme].cumsum()
655
+ w_sum = data[wgt_nme].sum()
656
+ data.loc[:, 'bins'] = np.floor(
657
+ data['cum_weight']*float(n_bins)/w_sum)
658
+ data.loc[(data['bins'] == n_bins), 'bins'] = n_bins-1
659
+ return data.groupby(['bins'], observed=True).sum(numeric_only=True)
660
+
661
+ # 定义Lift Chart绘制数据集函数
662
+ def _plot_data_lift(self,
663
+ pred_list, w_pred_list,
664
+ w_act_list, weight_list, n_bins=10):
665
+ lift_data = pd.DataFrame()
666
+ lift_data.loc[:, 'pred'] = pred_list
667
+ lift_data.loc[:, 'w_pred'] = w_pred_list
668
+ lift_data.loc[:, 'act'] = w_act_list
669
+ lift_data.loc[:, 'weight'] = weight_list
670
+ plot_data = self._split_data(
671
+ lift_data, 'pred', 'weight', n_bins)
672
+ plot_data['exp_v'] = plot_data['w_pred'] / plot_data['weight']
673
+ plot_data['act_v'] = plot_data['act'] / plot_data['weight']
674
+ plot_data.reset_index(inplace=True)
675
+ return plot_data
676
+
677
+ # 定义lift曲线绘制函数
678
+ def plot_lift(self, model_label, n_bins=10):
679
+ # 绘制建模集上结果
680
+ figpos_list = [121, 122]
681
+ plot_dict = {
682
+ 121: self.train_data,
683
+ 122: self.test_data
684
+ }
685
+ name_list = {
686
+ 121: 'Train Data',
687
+ 122: 'Test Data'
688
+ }
689
+ fig = plt.figure(figsize=(11, 5))
690
+ if model_label == 'Xgboost':
691
+ pred_nme = 'pred_xgb'
692
+ elif model_label == 'ResNet':
693
+ pred_nme = 'pred_resn'
694
+
695
+ for figpos in figpos_list:
696
+ plot_data = self._plot_data_lift(
697
+ plot_dict[figpos][pred_nme].values,
698
+ plot_dict[figpos]['w_'+pred_nme].values,
699
+ plot_dict[figpos]['w_act'].values,
700
+ plot_dict[figpos][self.weight_nme].values,
701
+ n_bins)
702
+ ax = fig.add_subplot(figpos)
703
+ ax.plot(plot_data.index, plot_data['act_v'],
704
+ label='Actual', color='red')
705
+ ax.plot(plot_data.index, plot_data['exp_v'],
706
+ label='Predicted', color='blue')
707
+ ax.set_title(
708
+ 'Lift Chart on %s' % name_list[figpos], fontsize=8)
709
+ plt.xticks(plot_data.index,
710
+ plot_data.index,
711
+ rotation=90, fontsize=6)
712
+ plt.yticks(fontsize=6)
713
+ plt.legend(loc='upper left',
714
+ fontsize=5, frameon=False)
715
+ plt.margins(0.05)
716
+ ax2 = ax.twinx()
717
+ ax2.bar(plot_data.index, plot_data['weight'],
718
+ alpha=0.5, color='seagreen',
719
+ label='Earned Exposure')
720
+ plt.yticks(fontsize=6)
721
+ plt.legend(loc='upper right',
722
+ fontsize=5, frameon=False)
723
+ plt.subplots_adjust(wspace=0.3)
724
+ save_path = os.path.join(
725
+ os.getcwd(), 'plot', f'01_{self.model_nme}_{model_label}_lift.png')
726
+ plt.savefig(save_path, dpi=300)
727
+ plt.show()
728
+ plt.close(fig)
729
+
730
+ # 定义Double Lift Chart绘制数据集函数
731
+ def _plot_data_dlift(self,
732
+ pred_list_model1, pred_list_model2,
733
+ w_list, w_act_list, n_bins=10):
734
+ lift_data = pd.DataFrame()
735
+ lift_data.loc[:, 'pred1'] = pred_list_model1
736
+ lift_data.loc[:, 'pred2'] = pred_list_model2
737
+ lift_data.loc[:, 'diff_ly'] = lift_data['pred1'] / lift_data['pred2']
738
+ lift_data.loc[:, 'act'] = w_act_list
739
+ lift_data.loc[:, 'weight'] = w_list
740
+ plot_data = self._split_data(lift_data, 'diff_ly', 'weight', n_bins)
741
+ plot_data['exp_v1'] = plot_data['pred1'] / plot_data['act']
742
+ plot_data['exp_v2'] = plot_data['pred2'] / plot_data['act']
743
+ plot_data['act_v'] = plot_data['act'] / plot_data['act']
744
+ plot_data.reset_index(inplace=True)
745
+ return plot_data
746
+
747
+ # 定义绘制Double Lift Chart函数
748
+ def plot_dlift(self, n_bins=10):
749
+ # 绘制建模集上结果
750
+ figpos_list = [121, 122]
751
+ plot_dict = {
752
+ 121: self.train_data,
753
+ 122: self.test_data
754
+ }
755
+ name_list = {
756
+ 121: 'Train Data',
757
+ 122: 'Test Data'
758
+ }
759
+ fig = plt.figure(figsize=(11, 5))
760
+ for figpos in figpos_list:
761
+ plot_data = self._plot_data_dlift(
762
+ plot_dict[figpos]['w_pred_xgb'].values,
763
+ plot_dict[figpos]['w_pred_resn'].values,
764
+ plot_dict[figpos][self.weight_nme].values,
765
+ plot_dict[figpos]['w_act'].values,
766
+ n_bins)
767
+ ax = fig.add_subplot(figpos)
768
+ tt1 = 'Xgboost'
769
+ tt2 = 'ResNet'
770
+ ax.plot(plot_data.index, plot_data['act_v'],
771
+ label='Actual', color='red')
772
+ ax.plot(plot_data.index, plot_data['exp_v1'],
773
+ label=tt1, color='blue')
774
+ ax.plot(plot_data.index, plot_data['exp_v2'],
775
+ label=tt2, color='black')
776
+ ax.set_title(
777
+ 'Double Lift Chart on %s' % name_list[figpos], fontsize=8)
778
+ plt.xticks(plot_data.index,
779
+ plot_data.index,
780
+ rotation=90, fontsize=6)
781
+ plt.xlabel('%s / %s' % (tt1, tt2), fontsize=6)
782
+ plt.yticks(fontsize=6)
783
+ plt.legend(loc='upper left',
784
+ fontsize=5, frameon=False)
785
+ plt.margins(0.1)
786
+ plt.subplots_adjust(bottom=0.25, top=0.95, right=0.8)
787
+ ax2 = ax.twinx()
788
+ ax2.bar(plot_data.index, plot_data['weight'],
789
+ alpha=0.5, color='seagreen',
790
+ label='Earned Exposure')
791
+ plt.yticks(fontsize=6)
792
+ plt.legend(loc='upper right',
793
+ fontsize=5, frameon=False)
794
+ plt.subplots_adjust(wspace=0.3)
795
+ save_path = os.path.join(
796
+ os.getcwd(), 'plot', f'02_{self.model_nme}_dlift.png')
797
+ plt.savefig(save_path, dpi=300)
798
+ plt.show()
799
+ plt.close(fig)
800
+
801
+ # 保存模型
802
+ def save_model(self, model_name=None):
803
+ # model_name 可以是 'xgb', 'resn' 或 None
804
+ save_path_xgb = os.path.join(
805
+ os.getcwd(), 'model', f'01_{self.model_nme}_Xgboost.pkl')
806
+ save_path_resn = os.path.join(
807
+ os.getcwd(), 'model', f'01_{self.model_nme}_ResNet.pth')
808
+ if not os.path.exists(os.path.dirname(save_path_xgb)):
809
+ os.makedirs(os.path.dirname(save_path_xgb))
810
+ # self.xgb_best.save_model(save_path_xgb)
811
+ if model_name != 'resn':
812
+ joblib.dump(self.xgb_best, save_path_xgb)
813
+ if model_name != 'xgb':
814
+ torch.save(self.resn_best.resnet.state_dict(), save_path_resn)
815
+
816
+ def load_model(self, model_name=None):
817
+ # model_name 可以是 'xgb', 'resn' 或 None
818
+ save_path_xgb = os.path.join(
819
+ os.getcwd(), 'model', f'01_{self.model_nme}_Xgboost.pkl')
820
+ save_path_resn = os.path.join(
821
+ os.getcwd(), 'model', f'01_{self.model_nme}_ResNet.pth')
822
+ if model_name != 'resn':
823
+ self.xgb_load = joblib.load(save_path_xgb)
824
+ if model_name != 'xgb':
825
+ self.resn_load = ResNetScikitLearn(
826
+ model_nme=self.model_nme,
827
+ input_dim=self.train_oht_scl_data[self.var_nmes].shape[1]
828
+ )
829
+ self.resn_load.resnet.load_state_dict(
830
+ torch.load(save_path_resn, map_location=self.resn_load.device))