ins-pricing 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ins_pricing/README.md +60 -0
- ins_pricing/__init__.py +102 -0
- ins_pricing/governance/README.md +18 -0
- ins_pricing/governance/__init__.py +20 -0
- ins_pricing/governance/approval.py +93 -0
- ins_pricing/governance/audit.py +37 -0
- ins_pricing/governance/registry.py +99 -0
- ins_pricing/governance/release.py +159 -0
- ins_pricing/modelling/BayesOpt.py +146 -0
- ins_pricing/modelling/BayesOpt_USAGE.md +925 -0
- ins_pricing/modelling/BayesOpt_entry.py +575 -0
- ins_pricing/modelling/BayesOpt_incremental.py +731 -0
- ins_pricing/modelling/Explain_Run.py +36 -0
- ins_pricing/modelling/Explain_entry.py +539 -0
- ins_pricing/modelling/Pricing_Run.py +36 -0
- ins_pricing/modelling/README.md +33 -0
- ins_pricing/modelling/__init__.py +44 -0
- ins_pricing/modelling/bayesopt/__init__.py +98 -0
- ins_pricing/modelling/bayesopt/config_preprocess.py +303 -0
- ins_pricing/modelling/bayesopt/core.py +1476 -0
- ins_pricing/modelling/bayesopt/models.py +2196 -0
- ins_pricing/modelling/bayesopt/trainers.py +2446 -0
- ins_pricing/modelling/bayesopt/utils.py +1021 -0
- ins_pricing/modelling/cli_common.py +136 -0
- ins_pricing/modelling/explain/__init__.py +55 -0
- ins_pricing/modelling/explain/gradients.py +334 -0
- ins_pricing/modelling/explain/metrics.py +176 -0
- ins_pricing/modelling/explain/permutation.py +155 -0
- ins_pricing/modelling/explain/shap_utils.py +146 -0
- ins_pricing/modelling/notebook_utils.py +284 -0
- ins_pricing/modelling/plotting/__init__.py +45 -0
- ins_pricing/modelling/plotting/common.py +63 -0
- ins_pricing/modelling/plotting/curves.py +572 -0
- ins_pricing/modelling/plotting/diagnostics.py +139 -0
- ins_pricing/modelling/plotting/geo.py +362 -0
- ins_pricing/modelling/plotting/importance.py +121 -0
- ins_pricing/modelling/run_logging.py +133 -0
- ins_pricing/modelling/tests/conftest.py +8 -0
- ins_pricing/modelling/tests/test_cross_val_generic.py +66 -0
- ins_pricing/modelling/tests/test_distributed_utils.py +18 -0
- ins_pricing/modelling/tests/test_explain.py +56 -0
- ins_pricing/modelling/tests/test_geo_tokens_split.py +49 -0
- ins_pricing/modelling/tests/test_graph_cache.py +33 -0
- ins_pricing/modelling/tests/test_plotting.py +63 -0
- ins_pricing/modelling/tests/test_plotting_library.py +150 -0
- ins_pricing/modelling/tests/test_preprocessor.py +48 -0
- ins_pricing/modelling/watchdog_run.py +211 -0
- ins_pricing/pricing/README.md +44 -0
- ins_pricing/pricing/__init__.py +27 -0
- ins_pricing/pricing/calibration.py +39 -0
- ins_pricing/pricing/data_quality.py +117 -0
- ins_pricing/pricing/exposure.py +85 -0
- ins_pricing/pricing/factors.py +91 -0
- ins_pricing/pricing/monitoring.py +99 -0
- ins_pricing/pricing/rate_table.py +78 -0
- ins_pricing/production/__init__.py +21 -0
- ins_pricing/production/drift.py +30 -0
- ins_pricing/production/monitoring.py +143 -0
- ins_pricing/production/scoring.py +40 -0
- ins_pricing/reporting/README.md +20 -0
- ins_pricing/reporting/__init__.py +11 -0
- ins_pricing/reporting/report_builder.py +72 -0
- ins_pricing/reporting/scheduler.py +45 -0
- ins_pricing/setup.py +41 -0
- ins_pricing v2/__init__.py +23 -0
- ins_pricing v2/governance/__init__.py +20 -0
- ins_pricing v2/governance/approval.py +93 -0
- ins_pricing v2/governance/audit.py +37 -0
- ins_pricing v2/governance/registry.py +99 -0
- ins_pricing v2/governance/release.py +159 -0
- ins_pricing v2/modelling/Explain_Run.py +36 -0
- ins_pricing v2/modelling/Pricing_Run.py +36 -0
- ins_pricing v2/modelling/__init__.py +151 -0
- ins_pricing v2/modelling/cli_common.py +141 -0
- ins_pricing v2/modelling/config.py +249 -0
- ins_pricing v2/modelling/config_preprocess.py +254 -0
- ins_pricing v2/modelling/core.py +741 -0
- ins_pricing v2/modelling/data_container.py +42 -0
- ins_pricing v2/modelling/explain/__init__.py +55 -0
- ins_pricing v2/modelling/explain/gradients.py +334 -0
- ins_pricing v2/modelling/explain/metrics.py +176 -0
- ins_pricing v2/modelling/explain/permutation.py +155 -0
- ins_pricing v2/modelling/explain/shap_utils.py +146 -0
- ins_pricing v2/modelling/features.py +215 -0
- ins_pricing v2/modelling/model_manager.py +148 -0
- ins_pricing v2/modelling/model_plotting.py +463 -0
- ins_pricing v2/modelling/models.py +2203 -0
- ins_pricing v2/modelling/notebook_utils.py +294 -0
- ins_pricing v2/modelling/plotting/__init__.py +45 -0
- ins_pricing v2/modelling/plotting/common.py +63 -0
- ins_pricing v2/modelling/plotting/curves.py +572 -0
- ins_pricing v2/modelling/plotting/diagnostics.py +139 -0
- ins_pricing v2/modelling/plotting/geo.py +362 -0
- ins_pricing v2/modelling/plotting/importance.py +121 -0
- ins_pricing v2/modelling/run_logging.py +133 -0
- ins_pricing v2/modelling/tests/conftest.py +8 -0
- ins_pricing v2/modelling/tests/test_cross_val_generic.py +66 -0
- ins_pricing v2/modelling/tests/test_distributed_utils.py +18 -0
- ins_pricing v2/modelling/tests/test_explain.py +56 -0
- ins_pricing v2/modelling/tests/test_geo_tokens_split.py +49 -0
- ins_pricing v2/modelling/tests/test_graph_cache.py +33 -0
- ins_pricing v2/modelling/tests/test_plotting.py +63 -0
- ins_pricing v2/modelling/tests/test_plotting_library.py +150 -0
- ins_pricing v2/modelling/tests/test_preprocessor.py +48 -0
- ins_pricing v2/modelling/trainers.py +2447 -0
- ins_pricing v2/modelling/utils.py +1020 -0
- ins_pricing v2/modelling/watchdog_run.py +211 -0
- ins_pricing v2/pricing/__init__.py +27 -0
- ins_pricing v2/pricing/calibration.py +39 -0
- ins_pricing v2/pricing/data_quality.py +117 -0
- ins_pricing v2/pricing/exposure.py +85 -0
- ins_pricing v2/pricing/factors.py +91 -0
- ins_pricing v2/pricing/monitoring.py +99 -0
- ins_pricing v2/pricing/rate_table.py +78 -0
- ins_pricing v2/production/__init__.py +21 -0
- ins_pricing v2/production/drift.py +30 -0
- ins_pricing v2/production/monitoring.py +143 -0
- ins_pricing v2/production/scoring.py +40 -0
- ins_pricing v2/reporting/__init__.py +11 -0
- ins_pricing v2/reporting/report_builder.py +72 -0
- ins_pricing v2/reporting/scheduler.py +45 -0
- ins_pricing v2/scripts/BayesOpt_incremental.py +722 -0
- ins_pricing v2/scripts/Explain_entry.py +545 -0
- ins_pricing v2/scripts/__init__.py +1 -0
- ins_pricing v2/scripts/train.py +568 -0
- ins_pricing v2/setup.py +55 -0
- ins_pricing v2/smoke_test.py +28 -0
- ins_pricing-0.1.6.dist-info/METADATA +78 -0
- ins_pricing-0.1.6.dist-info/RECORD +169 -0
- ins_pricing-0.1.6.dist-info/WHEEL +5 -0
- ins_pricing-0.1.6.dist-info/top_level.txt +4 -0
- user_packages/__init__.py +105 -0
- user_packages legacy/BayesOpt.py +5659 -0
- user_packages legacy/BayesOpt_entry.py +513 -0
- user_packages legacy/BayesOpt_incremental.py +685 -0
- user_packages legacy/Pricing_Run.py +36 -0
- user_packages legacy/Try/BayesOpt Legacy251213.py +3719 -0
- user_packages legacy/Try/BayesOpt Legacy251215.py +3758 -0
- user_packages legacy/Try/BayesOpt lagecy251201.py +3506 -0
- user_packages legacy/Try/BayesOpt lagecy251218.py +3992 -0
- user_packages legacy/Try/BayesOpt legacy.py +3280 -0
- user_packages legacy/Try/BayesOpt.py +838 -0
- user_packages legacy/Try/BayesOptAll.py +1569 -0
- user_packages legacy/Try/BayesOptAllPlatform.py +909 -0
- user_packages legacy/Try/BayesOptCPUGPU.py +1877 -0
- user_packages legacy/Try/BayesOptSearch.py +830 -0
- user_packages legacy/Try/BayesOptSearchOrigin.py +829 -0
- user_packages legacy/Try/BayesOptV1.py +1911 -0
- user_packages legacy/Try/BayesOptV10.py +2973 -0
- user_packages legacy/Try/BayesOptV11.py +3001 -0
- user_packages legacy/Try/BayesOptV12.py +3001 -0
- user_packages legacy/Try/BayesOptV2.py +2065 -0
- user_packages legacy/Try/BayesOptV3.py +2209 -0
- user_packages legacy/Try/BayesOptV4.py +2342 -0
- user_packages legacy/Try/BayesOptV5.py +2372 -0
- user_packages legacy/Try/BayesOptV6.py +2759 -0
- user_packages legacy/Try/BayesOptV7.py +2832 -0
- user_packages legacy/Try/BayesOptV8Codex.py +2731 -0
- user_packages legacy/Try/BayesOptV8Gemini.py +2614 -0
- user_packages legacy/Try/BayesOptV9.py +2927 -0
- user_packages legacy/Try/BayesOpt_entry legacy.py +313 -0
- user_packages legacy/Try/ModelBayesOptSearch.py +359 -0
- user_packages legacy/Try/ResNetBayesOptSearch.py +249 -0
- user_packages legacy/Try/XgbBayesOptSearch.py +121 -0
- user_packages legacy/Try/xgbbayesopt.py +523 -0
- user_packages legacy/__init__.py +19 -0
- user_packages legacy/cli_common.py +124 -0
- user_packages legacy/notebook_utils.py +228 -0
- user_packages legacy/watchdog_run.py +202 -0
|
@@ -0,0 +1,523 @@
|
|
|
1
|
+
from sklearn.model_selection import ShuffleSplit, cross_val_score # 1.2.2
|
|
2
|
+
from hyperopt import plotting, fmin, hp, tpe, Trials, STATUS_OK # 0.2.7
|
|
3
|
+
from sklearn.metrics import make_scorer, mean_tweedie_deviance # 1.2.2
|
|
4
|
+
|
|
5
|
+
import shap # 0.44.1
|
|
6
|
+
import xgboost as xgb # 1.7.0
|
|
7
|
+
import joblib
|
|
8
|
+
import matplotlib.pyplot as plt
|
|
9
|
+
import numpy as np # 1.26.2
|
|
10
|
+
import pandas as pd # 2.2.3
|
|
11
|
+
import os
|
|
12
|
+
import re
|
|
13
|
+
|
|
14
|
+
class xgb_bayesopt:
|
|
15
|
+
def __init__(self, train_data, test_data,
|
|
16
|
+
model_nme, resp_nme, weight_nme,
|
|
17
|
+
factor_nmes, space_params,
|
|
18
|
+
int_p_list=['n_estimators', 'max_depth'],
|
|
19
|
+
cate_list=[], prop_test=0.25, rand_seed=None):
|
|
20
|
+
# 初始化数据
|
|
21
|
+
# train_data: 训练数据, test_data: 测试数据 格式需为DataFrame
|
|
22
|
+
# model_nme: 模型名称
|
|
23
|
+
# resp_nme: 因变量名称, weight_nme: 权重名称
|
|
24
|
+
# factor_nmes: 因子名称列表, space_params: 参数空间
|
|
25
|
+
# int_p_list: 整数参数列表, cate_list: 类别变量列表
|
|
26
|
+
# prop_test: 测试集比例, rand_seed
|
|
27
|
+
|
|
28
|
+
self.train_data = train_data
|
|
29
|
+
self.test_data = test_data
|
|
30
|
+
self.resp_nme = resp_nme
|
|
31
|
+
self.weight_nme = weight_nme
|
|
32
|
+
self.factor_nmes = factor_nmes
|
|
33
|
+
self.train_data.loc[:, 'w_act'] = self.train_data[self.resp_nme] * \
|
|
34
|
+
self.train_data[self.weight_nme]
|
|
35
|
+
self.test_data.loc[:, 'w_act'] = self.test_data[self.resp_nme] * \
|
|
36
|
+
self.test_data[self.weight_nme]
|
|
37
|
+
self.cate_list = cate_list
|
|
38
|
+
self.space_params = space_params
|
|
39
|
+
self.rand_seed = rand_seed if rand_seed is not None else np.random.randint(
|
|
40
|
+
1, 10000)
|
|
41
|
+
if self.cate_list != []:
|
|
42
|
+
for cate in self.cate_list:
|
|
43
|
+
self.train_data[cate] = self.train_data[cate].astype('category')
|
|
44
|
+
self.test_data[cate] = self.test_data[cate].astype('category')
|
|
45
|
+
self.prop_test = prop_test
|
|
46
|
+
self.cv = ShuffleSplit(n_splits=int(1/self.prop_test),
|
|
47
|
+
test_size=self.prop_test,
|
|
48
|
+
random_state=self.rand_seed)
|
|
49
|
+
self.model_nme = model_nme
|
|
50
|
+
if self.model_nme.find('f') != -1:
|
|
51
|
+
self.obj = 'count:poisson'
|
|
52
|
+
elif self.model_nme.find('s') != -1:
|
|
53
|
+
self.obj = 'reg:gamma'
|
|
54
|
+
elif self.model_nme.find('bc') != -1:
|
|
55
|
+
self.obj = 'reg:tweedie'
|
|
56
|
+
|
|
57
|
+
if self.obj != 'reg:tweedie':
|
|
58
|
+
del self.space_params['tweedie_variance_power']
|
|
59
|
+
self.int_p_list = int_p_list
|
|
60
|
+
self.clf_init = xgb.XGBRegressor(objective=self.obj,
|
|
61
|
+
random_state=self.rand_seed,
|
|
62
|
+
subsample=0.9,
|
|
63
|
+
tree_method='gpu_hist',
|
|
64
|
+
gpu_id=0,
|
|
65
|
+
enable_categorical=True,
|
|
66
|
+
predictor='gpu_predictor')
|
|
67
|
+
self.clf = xgb.XGBRegressor(objective=self.obj,
|
|
68
|
+
random_state=self.rand_seed,
|
|
69
|
+
subsample=0.9,
|
|
70
|
+
tree_method='gpu_hist',
|
|
71
|
+
gpu_id=0,
|
|
72
|
+
enable_categorical=True,
|
|
73
|
+
predictor='gpu_predictor')
|
|
74
|
+
self.fit_params = {
|
|
75
|
+
'sample_weight': self.train_data[self.weight_nme].values
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
# 定义交叉验证函数
|
|
79
|
+
def cross_val_xgb(self, params):
|
|
80
|
+
# 将部分float参数调整为整数型
|
|
81
|
+
for param_name in self.int_p_list: # , 'max_leaves'
|
|
82
|
+
params[param_name] = int(params[param_name])
|
|
83
|
+
self.clf.set_params(**params)
|
|
84
|
+
if self.obj == 'reg:tweedie':
|
|
85
|
+
tw_power = params['tweedie_variance_power']
|
|
86
|
+
elif self.obj == 'count:poisson':
|
|
87
|
+
tw_power = 1
|
|
88
|
+
elif self.obj == 'reg:gamma':
|
|
89
|
+
tw_power = 2
|
|
90
|
+
acc = cross_val_score(self.clf,
|
|
91
|
+
self.train_data[self.factor_nmes],
|
|
92
|
+
self.train_data[self.resp_nme].values,
|
|
93
|
+
fit_params=self.fit_params,
|
|
94
|
+
cv=self.cv,
|
|
95
|
+
# scoring='neg_root_mean_squared_error',
|
|
96
|
+
scoring=make_scorer(mean_tweedie_deviance,
|
|
97
|
+
power=tw_power,
|
|
98
|
+
greater_is_better=False),
|
|
99
|
+
error_score='raise',
|
|
100
|
+
n_jobs=int(1/self.prop_test)).mean()
|
|
101
|
+
return {'loss': -acc, 'params': params, 'status': STATUS_OK}
|
|
102
|
+
|
|
103
|
+
# 定义贝叶斯优化函数
|
|
104
|
+
def bayesopt(self, max_evals=100):
|
|
105
|
+
self.trials = Trials()
|
|
106
|
+
self.best = fmin(self.cross_val_xgb, self.space_params,
|
|
107
|
+
algo=tpe.suggest,
|
|
108
|
+
max_evals=max_evals, trials=self.trials)
|
|
109
|
+
for param_name in self.int_p_list: # , 'max_leaves'
|
|
110
|
+
self.best[param_name] = int(self.best[param_name])
|
|
111
|
+
pd.DataFrame(self.best, index=[0]).to_csv(
|
|
112
|
+
os.getcwd() + '/Results/' + self.model_nme + '_bestparams_xgb.csv')
|
|
113
|
+
self.clf.set_params(**self.best)
|
|
114
|
+
self.clf.fit(self.train_data[self.factor_nmes],
|
|
115
|
+
self.train_data[self.resp_nme],
|
|
116
|
+
**self.fit_params)
|
|
117
|
+
self.clf_init.fit(self.train_data[self.factor_nmes],
|
|
118
|
+
self.train_data[self.resp_nme],
|
|
119
|
+
**self.fit_params)
|
|
120
|
+
self.train_data.loc[:, 'pred'] = self.clf.predict(
|
|
121
|
+
self.train_data[self.factor_nmes])
|
|
122
|
+
self.test_data.loc[:, 'pred'] = self.clf.predict(
|
|
123
|
+
self.test_data[self.factor_nmes])
|
|
124
|
+
self.train_data.loc[:, 'pred_init'] = self.clf_init.predict(
|
|
125
|
+
self.train_data[self.factor_nmes])
|
|
126
|
+
self.test_data.loc[:, 'pred_init'] = self.clf_init.predict(
|
|
127
|
+
self.test_data[self.factor_nmes])
|
|
128
|
+
self.train_data.loc[:, 'w_pred'] = self.train_data['pred'] * \
|
|
129
|
+
self.train_data[self.weight_nme]
|
|
130
|
+
self.test_data.loc[:, 'w_pred'] = self.test_data['pred'] * \
|
|
131
|
+
self.test_data[self.weight_nme]
|
|
132
|
+
self.train_data.loc[:, 'w_pred_init'] = self.clf_init.predict(
|
|
133
|
+
self.train_data[self.factor_nmes]) * self.train_data[self.weight_nme]
|
|
134
|
+
self.test_data.loc[:, 'w_pred_init'] = self.clf_init.predict(
|
|
135
|
+
self.test_data[self.factor_nmes]) * self.test_data[self.weight_nme]
|
|
136
|
+
|
|
137
|
+
# 定义输出模型函数
|
|
138
|
+
def output_model(self, model_nme='Optimization'):
|
|
139
|
+
''' 模型可在Optimization和Initial两种模式下保存 '''
|
|
140
|
+
if model_nme == 'Optimization':
|
|
141
|
+
joblib.dump(self.clf, os.getcwd() + '/Results/' +
|
|
142
|
+
self.model_nme + '_xgb.pkl')
|
|
143
|
+
elif model_nme == 'Initial':
|
|
144
|
+
joblib.dump(self.clf_init, os.getcwd() +
|
|
145
|
+
'/Results/' + self.model_nme + '_xgb.pkl')
|
|
146
|
+
|
|
147
|
+
def pred(self, data, model_nme='Optimization'):
|
|
148
|
+
# 模型可在Optimization和Initial两种模式下预测
|
|
149
|
+
if model_nme == 'Optimization':
|
|
150
|
+
return self.clf.predict(data[self.factor_nmes])
|
|
151
|
+
elif model_nme == 'Initial':
|
|
152
|
+
return self.clf_init.predict(data[self.factor_nmes])
|
|
153
|
+
|
|
154
|
+
# 定义绘制单因素结果
|
|
155
|
+
def plot_oneway(self, n_bins=10):
|
|
156
|
+
for c in self.factor_nmes:
|
|
157
|
+
fig = plt.figure(figsize=(7, 5))
|
|
158
|
+
if c in self.cate_list:
|
|
159
|
+
strs = c
|
|
160
|
+
else:
|
|
161
|
+
strs = c+'_bins'
|
|
162
|
+
self.train_data.loc[:, strs] = pd.qcut(self.train_data[c], n_bins,
|
|
163
|
+
duplicates='drop')
|
|
164
|
+
plot_data = self.train_data.groupby([strs], observed=True).sum(numeric_only=True)
|
|
165
|
+
plot_data.reset_index(inplace=True)
|
|
166
|
+
plot_data['act_v'] = plot_data['w_act'] / plot_data[self.weight_nme]
|
|
167
|
+
plot_data.head()
|
|
168
|
+
ax = fig.add_subplot(111)
|
|
169
|
+
ax.plot(plot_data.index, plot_data['act_v'],
|
|
170
|
+
label='Actual', color='red')
|
|
171
|
+
ax.set_title(
|
|
172
|
+
'Analysis of %s : Train Data' % strs,
|
|
173
|
+
fontsize=8)
|
|
174
|
+
plt.xticks(plot_data.index,
|
|
175
|
+
list(plot_data[strs].astype(str)),
|
|
176
|
+
rotation=90)
|
|
177
|
+
if len(list(plot_data[strs].astype(str))) > 50:
|
|
178
|
+
plt.xticks(fontsize=3)
|
|
179
|
+
else:
|
|
180
|
+
plt.xticks(fontsize=6)
|
|
181
|
+
plt.yticks(fontsize=6)
|
|
182
|
+
ax2 = ax.twinx()
|
|
183
|
+
ax2.bar(plot_data.index,
|
|
184
|
+
plot_data[self.weight_nme],
|
|
185
|
+
alpha=0.5, color='seagreen')
|
|
186
|
+
plt.yticks(fontsize=6)
|
|
187
|
+
plt.margins(0.05)
|
|
188
|
+
plt.subplots_adjust(wspace=0.3)
|
|
189
|
+
save_path = os.path.join(
|
|
190
|
+
os.getcwd(), 'plot',
|
|
191
|
+
f'00_{self.model_nme}_{strs}_oneway.png')
|
|
192
|
+
plt.savefig(save_path, dpi=300)
|
|
193
|
+
plt.close(fig)
|
|
194
|
+
|
|
195
|
+
# 定义分箱函数
|
|
196
|
+
def _split_data(self, data, col_nme, wgt_nme, n_bins=10):
|
|
197
|
+
data.sort_values(by=col_nme, ascending=True, inplace=True)
|
|
198
|
+
data['cum_weight'] = data[wgt_nme].cumsum()
|
|
199
|
+
w_sum = data[wgt_nme].sum()
|
|
200
|
+
data.loc[:, 'bins'] = np.floor(
|
|
201
|
+
data['cum_weight']*float(n_bins)/w_sum)
|
|
202
|
+
data.loc[(data['bins'] == n_bins), 'bins'] = n_bins-1
|
|
203
|
+
return data.groupby(['bins'], observed=True).sum(numeric_only=True)
|
|
204
|
+
|
|
205
|
+
# 定义Lift Chart绘制数据集函数
|
|
206
|
+
def _plot_data_lift(self,
|
|
207
|
+
pred_list, w_pred_list,
|
|
208
|
+
w_act_list, weight_list, n_bins=10):
|
|
209
|
+
lift_data = pd.DataFrame()
|
|
210
|
+
lift_data.loc[:, 'pred'] = pred_list
|
|
211
|
+
lift_data.loc[:, 'w_pred'] = w_pred_list
|
|
212
|
+
lift_data.loc[:, 'act'] = w_act_list
|
|
213
|
+
lift_data.loc[:, 'weight'] = weight_list
|
|
214
|
+
plot_data = self._split_data(
|
|
215
|
+
lift_data, 'pred', 'weight', n_bins)
|
|
216
|
+
plot_data['exp_v'] = plot_data['w_pred'] / plot_data['weight']
|
|
217
|
+
plot_data['act_v'] = plot_data['act'] / plot_data['weight']
|
|
218
|
+
plot_data.reset_index(inplace=True)
|
|
219
|
+
return plot_data
|
|
220
|
+
|
|
221
|
+
# 定义lift曲线绘制函数
|
|
222
|
+
def plot_lift(self, n_bins=10):
|
|
223
|
+
# 绘制建模集上结果
|
|
224
|
+
figpos_list = [121, 122]
|
|
225
|
+
plot_dict = {
|
|
226
|
+
121: self.train_data,
|
|
227
|
+
122: self.test_data
|
|
228
|
+
}
|
|
229
|
+
name_list = {
|
|
230
|
+
121: 'Train Data',
|
|
231
|
+
122: 'Test Data'
|
|
232
|
+
}
|
|
233
|
+
fig = plt.figure(figsize=(11, 5))
|
|
234
|
+
for figpos in figpos_list:
|
|
235
|
+
plot_data = self._plot_data_lift(
|
|
236
|
+
plot_dict[figpos]['pred'].values,
|
|
237
|
+
plot_dict[figpos]['w_pred'].values,
|
|
238
|
+
plot_dict[figpos]['w_act'].values,
|
|
239
|
+
plot_dict[figpos][self.weight_nme].values,
|
|
240
|
+
n_bins)
|
|
241
|
+
ax = fig.add_subplot(figpos)
|
|
242
|
+
ax.plot(plot_data.index, plot_data['act_v'],
|
|
243
|
+
label='Actual', color='red')
|
|
244
|
+
ax.plot(plot_data.index, plot_data['exp_v'],
|
|
245
|
+
label='Predicted', color='blue')
|
|
246
|
+
ax.set_title(
|
|
247
|
+
'Lift Chart on %s' % name_list[figpos], fontsize=8)
|
|
248
|
+
plt.xticks(plot_data.index,
|
|
249
|
+
plot_data.index,
|
|
250
|
+
rotation=90, fontsize=6)
|
|
251
|
+
plt.yticks(fontsize=6)
|
|
252
|
+
plt.legend(loc='upper left',
|
|
253
|
+
fontsize=5, frameon=False)
|
|
254
|
+
plt.margins(0.05)
|
|
255
|
+
ax2 = ax.twinx()
|
|
256
|
+
ax2.bar(plot_data.index, plot_data['weight'],
|
|
257
|
+
alpha=0.5, color='seagreen',
|
|
258
|
+
label='Earned Exposure')
|
|
259
|
+
plt.yticks(fontsize=6)
|
|
260
|
+
plt.legend(loc='upper right',
|
|
261
|
+
fontsize=5, frameon=False)
|
|
262
|
+
plt.subplots_adjust(wspace=0.3)
|
|
263
|
+
save_path = os.path.join(
|
|
264
|
+
os.getcwd(), 'plot', f'01_{self.model_nme}_lift.png')
|
|
265
|
+
plt.savefig(save_path, dpi=300)
|
|
266
|
+
plt.close(fig)
|
|
267
|
+
|
|
268
|
+
# 定义Double Lift Chart绘制数据集函数
|
|
269
|
+
def _plot_data_dlift(self,
|
|
270
|
+
pred_list_model1, pred_list_model2,
|
|
271
|
+
w_list, w_act_list, n_bins=10):
|
|
272
|
+
lift_data = pd.DataFrame()
|
|
273
|
+
lift_data.loc[:, 'pred1'] = pred_list_model1
|
|
274
|
+
lift_data.loc[:, 'pred2'] = pred_list_model2
|
|
275
|
+
lift_data.loc[:, 'diff_ly'] = lift_data['pred1'] / lift_data['pred2']
|
|
276
|
+
lift_data.loc[:, 'act'] = w_act_list
|
|
277
|
+
lift_data.loc[:, 'weight'] = w_list
|
|
278
|
+
plot_data = self._split_data(lift_data, 'diff_ly', 'weight', n_bins)
|
|
279
|
+
plot_data['exp_v1'] = plot_data['pred1'] / plot_data['act']
|
|
280
|
+
plot_data['exp_v2'] = plot_data['pred2'] / plot_data['act']
|
|
281
|
+
plot_data['act_v'] = plot_data['act'] / plot_data['act']
|
|
282
|
+
plot_data.reset_index(inplace=True)
|
|
283
|
+
return plot_data
|
|
284
|
+
|
|
285
|
+
# 定义绘制Double Lift Chart函数
|
|
286
|
+
def plot_dlift(self, n_bins=10):
|
|
287
|
+
# 绘制建模集上结果
|
|
288
|
+
figpos_list = [121, 122]
|
|
289
|
+
plot_dict = {
|
|
290
|
+
121: self.train_data,
|
|
291
|
+
122: self.test_data
|
|
292
|
+
}
|
|
293
|
+
name_list = {
|
|
294
|
+
121: 'Train Data',
|
|
295
|
+
122: 'Test Data'
|
|
296
|
+
}
|
|
297
|
+
fig = plt.figure(figsize=(11, 5))
|
|
298
|
+
for figpos in figpos_list:
|
|
299
|
+
plot_data = self._plot_data_dlift(
|
|
300
|
+
plot_dict[figpos]['w_pred'].values,
|
|
301
|
+
plot_dict[figpos]['w_pred_init'].values,
|
|
302
|
+
plot_dict[figpos][self.weight_nme].values,
|
|
303
|
+
plot_dict[figpos]['w_act'].values,
|
|
304
|
+
n_bins)
|
|
305
|
+
ax = fig.add_subplot(figpos)
|
|
306
|
+
tt1 = 'Modified Model'
|
|
307
|
+
tt2 = 'Initial Model'
|
|
308
|
+
ax.plot(plot_data.index, plot_data['act_v'],
|
|
309
|
+
label='Actual', color='red')
|
|
310
|
+
ax.plot(plot_data.index, plot_data['exp_v1'],
|
|
311
|
+
label=tt1, color='blue')
|
|
312
|
+
ax.plot(plot_data.index, plot_data['exp_v2'],
|
|
313
|
+
label=tt2, color='black')
|
|
314
|
+
ax.set_title(
|
|
315
|
+
'Double Lift Chart on %s' % name_list[figpos], fontsize=8)
|
|
316
|
+
plt.xticks(plot_data.index,
|
|
317
|
+
plot_data.index,
|
|
318
|
+
rotation=90, fontsize=6)
|
|
319
|
+
plt.xlabel('%s / %s' % (tt1, tt2), fontsize=6)
|
|
320
|
+
plt.yticks(fontsize=6)
|
|
321
|
+
plt.legend(loc='upper left',
|
|
322
|
+
fontsize=5, frameon=False)
|
|
323
|
+
plt.margins(0.1)
|
|
324
|
+
plt.subplots_adjust(bottom=0.25, top=0.95, right=0.8)
|
|
325
|
+
ax2 = ax.twinx()
|
|
326
|
+
ax2.bar(plot_data.index, plot_data['weight'],
|
|
327
|
+
alpha=0.5, color='seagreen',
|
|
328
|
+
label='Earned Exposure')
|
|
329
|
+
plt.yticks(fontsize=6)
|
|
330
|
+
plt.legend(loc='upper right',
|
|
331
|
+
fontsize=5, frameon=False)
|
|
332
|
+
plt.subplots_adjust(wspace=0.3)
|
|
333
|
+
save_path = os.path.join(
|
|
334
|
+
os.getcwd(), 'plot', f'02_{self.model_nme}_dlift.png')
|
|
335
|
+
plt.savefig(save_path, dpi=300)
|
|
336
|
+
plt.close(fig)
|
|
337
|
+
|
|
338
|
+
# 绘制单因素实际与预测值对比图
|
|
339
|
+
def plot_sim(self, n_bins=10):
|
|
340
|
+
figpos_list = [121, 122]
|
|
341
|
+
plot_dict = {
|
|
342
|
+
121: self.train_data,
|
|
343
|
+
122: self.test_data
|
|
344
|
+
}
|
|
345
|
+
name_list = {
|
|
346
|
+
121: 'Train Data',
|
|
347
|
+
122: 'Test Data'
|
|
348
|
+
}
|
|
349
|
+
for c in self.factor_nmes:
|
|
350
|
+
fig = plt.figure(figsize=(11, 5))
|
|
351
|
+
for figpos in figpos_list:
|
|
352
|
+
plot_data = plot_dict[figpos]
|
|
353
|
+
if c in self.cate_list:
|
|
354
|
+
strs = c
|
|
355
|
+
else:
|
|
356
|
+
strs = c+'_bins'
|
|
357
|
+
plot_data.loc[:, strs] = pd.qcut(
|
|
358
|
+
plot_data[c], n_bins,
|
|
359
|
+
duplicates='drop')
|
|
360
|
+
plot_data = plot_data.groupby(
|
|
361
|
+
[strs], observed=True).sum(numeric_only=True)
|
|
362
|
+
plot_data.reset_index(inplace=True)
|
|
363
|
+
plot_data['exp_v'] = plot_data['w_pred'] / \
|
|
364
|
+
plot_data[self.weight_nme]
|
|
365
|
+
plot_data['act_v'] = plot_data['w_act'] / \
|
|
366
|
+
plot_data[self.weight_nme]
|
|
367
|
+
ax = fig.add_subplot(figpos)
|
|
368
|
+
ax.plot(plot_data.index, plot_data['act_v'],
|
|
369
|
+
label='Actual', color='red')
|
|
370
|
+
ax.plot(plot_data.index, plot_data['exp_v'],
|
|
371
|
+
label='Predicted', color='blue')
|
|
372
|
+
ax.set_title(
|
|
373
|
+
'Analysis of %s : %s' % (strs, name_list[figpos]),
|
|
374
|
+
fontsize=8)
|
|
375
|
+
plt.xticks(plot_data.index,
|
|
376
|
+
list(plot_data[strs].astype(str)),
|
|
377
|
+
rotation=90, fontsize=4)
|
|
378
|
+
plt.legend(loc='upper left',
|
|
379
|
+
fontsize=5, frameon=False)
|
|
380
|
+
plt.margins(0.05)
|
|
381
|
+
plt.yticks(fontsize=6)
|
|
382
|
+
ax2 = ax.twinx()
|
|
383
|
+
ax2.bar(plot_data.index, plot_data[self.weight_nme],
|
|
384
|
+
alpha=0.5, color='seagreen',
|
|
385
|
+
label='Earned Exposure')
|
|
386
|
+
plt.legend(loc='upper right',
|
|
387
|
+
fontsize=5, frameon=False)
|
|
388
|
+
plt.yticks(fontsize=6)
|
|
389
|
+
plt.subplots_adjust(wspace=0.3)
|
|
390
|
+
save_path = os.path.join(
|
|
391
|
+
os.getcwd(), 'plot', f'03_{self.model_nme}_{strs}_sim.png')
|
|
392
|
+
plt.savefig(save_path, dpi=300)
|
|
393
|
+
plt.close(fig)
|
|
394
|
+
|
|
395
|
+
# 绘制SHAP值图
|
|
396
|
+
|
|
397
|
+
def plot_shap(self, n_bins=10):
|
|
398
|
+
figpos_list = [121, 122]
|
|
399
|
+
plot_dict = {
|
|
400
|
+
121: self.train_data,
|
|
401
|
+
122: self.test_data
|
|
402
|
+
}
|
|
403
|
+
name_list = {
|
|
404
|
+
121: 'Train Data',
|
|
405
|
+
122: 'Test Data'
|
|
406
|
+
}
|
|
407
|
+
for figpos in figpos_list:
|
|
408
|
+
plot_data = plot_dict[figpos]
|
|
409
|
+
explainer = shap.TreeExplainer(self.clf)
|
|
410
|
+
shap_values = explainer.shap_values(plot_data[self.factor_nmes])
|
|
411
|
+
shap.summary_plot(shap_values, plot_data[self.factor_nmes],
|
|
412
|
+
plot_type='bar', max_display=10)
|
|
413
|
+
plt.title('SHAP Summary Plot on %s' % name_list[figpos])
|
|
414
|
+
save_path = os.path.join(
|
|
415
|
+
os.getcwd(), 'plot', f'04_{self.model_nme}_shap.png')
|
|
416
|
+
plt.savefig(save_path, dpi=300)
|
|
417
|
+
plt.close()
|
|
418
|
+
|
|
419
|
+
|
|
420
|
+
# 定义外部函数
|
|
421
|
+
# 定义分箱函数
|
|
422
|
+
|
|
423
|
+
def split_data(data, col_nme, wgt_nme, n_bins=10):
|
|
424
|
+
data.sort_values(by=col_nme, ascending=True, inplace=True)
|
|
425
|
+
data['cum_weight'] = data[wgt_nme].cumsum()
|
|
426
|
+
w_sum = data[wgt_nme].sum()
|
|
427
|
+
data.loc[:, 'bins'] = np.floor(data['cum_weight'] * float(n_bins) / w_sum)
|
|
428
|
+
data.loc[(data['bins'] == n_bins), 'bins'] = n_bins - 1
|
|
429
|
+
return data.groupby(['bins'], observed=True).sum(numeric_only=True)
|
|
430
|
+
|
|
431
|
+
# 定义Lift Chart绘制函数
|
|
432
|
+
|
|
433
|
+
def plot_lift_list(pred_model, w_pred_list, w_act_list,
|
|
434
|
+
weight_list, tgt_nme, n_bins=10,
|
|
435
|
+
fig_nme='Lift Chart'):
|
|
436
|
+
lift_data = pd.DataFrame()
|
|
437
|
+
lift_data.loc[:, 'pred'] = pred_model
|
|
438
|
+
lift_data.loc[:, 'w_pred'] = w_pred_list
|
|
439
|
+
lift_data.loc[:, 'act'] = w_act_list
|
|
440
|
+
lift_data.loc[:, 'weight'] = weight_list
|
|
441
|
+
plot_data = split_data(lift_data, 'pred', 'weight', n_bins)
|
|
442
|
+
plot_data['exp_v'] = plot_data['w_pred'] / plot_data['weight']
|
|
443
|
+
plot_data['act_v'] = plot_data['act'] / plot_data['weight']
|
|
444
|
+
plot_data.reset_index(inplace=True)
|
|
445
|
+
fig = plt.figure(figsize=(7, 5))
|
|
446
|
+
ax = fig.add_subplot(111)
|
|
447
|
+
ax.plot(plot_data.index, plot_data['act_v'],
|
|
448
|
+
label='Actual', color='red')
|
|
449
|
+
ax.plot(plot_data.index, plot_data['exp_v'],
|
|
450
|
+
label='Predicted', color='blue')
|
|
451
|
+
ax.set_title(
|
|
452
|
+
'Lift Chart of %s' % tgt_nme, fontsize=8)
|
|
453
|
+
plt.xticks(plot_data.index,
|
|
454
|
+
plot_data.index,
|
|
455
|
+
rotation=90, fontsize=6)
|
|
456
|
+
plt.yticks(fontsize=6)
|
|
457
|
+
plt.legend(loc='upper left',
|
|
458
|
+
fontsize=5, frameon=False)
|
|
459
|
+
plt.margins(0.05)
|
|
460
|
+
ax2 = ax.twinx()
|
|
461
|
+
ax2.bar(plot_data.index, plot_data['weight'],
|
|
462
|
+
alpha=0.5, color='seagreen',
|
|
463
|
+
label='Earned Exposure')
|
|
464
|
+
plt.yticks(fontsize=6)
|
|
465
|
+
plt.legend(loc='upper right',
|
|
466
|
+
fontsize=5, frameon=False)
|
|
467
|
+
plt.subplots_adjust(wspace=0.3)
|
|
468
|
+
save_path = os.path.join(
|
|
469
|
+
os.getcwd(), 'plot', f'05_{tgt_nme}_{fig_nme}.png')
|
|
470
|
+
plt.savefig(save_path, dpi=300)
|
|
471
|
+
plt.close(fig)
|
|
472
|
+
|
|
473
|
+
# 定义Double Lift Chart绘制函数
|
|
474
|
+
|
|
475
|
+
def plot_dlift_list(pred_model_1, pred_model_2,
|
|
476
|
+
model_nme_1, model_nme_2,
|
|
477
|
+
tgt_nme,
|
|
478
|
+
w_list, w_act_list, n_bins=10,
|
|
479
|
+
fig_nme='Double Lift Chart'):
|
|
480
|
+
lift_data = pd.DataFrame()
|
|
481
|
+
lift_data.loc[:, 'pred1'] = pred_model_1
|
|
482
|
+
lift_data.loc[:, 'pred2'] = pred_model_2
|
|
483
|
+
lift_data.loc[:, 'diff_ly'] = lift_data['pred1'] / lift_data['pred2']
|
|
484
|
+
lift_data.loc[:, 'act'] = w_act_list
|
|
485
|
+
lift_data.loc[:, 'weight'] = w_list
|
|
486
|
+
lift_data.loc[:, 'w_pred1'] = lift_data['pred1'] * lift_data['weight']
|
|
487
|
+
lift_data.loc[:, 'w_pred2'] = lift_data['pred2'] * lift_data['weight']
|
|
488
|
+
plot_data = split_data(lift_data, 'diff_ly', 'weight', n_bins)
|
|
489
|
+
plot_data['exp_v1'] = plot_data['w_pred1'] / plot_data['act']
|
|
490
|
+
plot_data['exp_v2'] = plot_data['w_pred2'] / plot_data['act']
|
|
491
|
+
plot_data['act_v'] = plot_data['act']/plot_data['act']
|
|
492
|
+
plot_data.reset_index(inplace=True)
|
|
493
|
+
fig = plt.figure(figsize=(7, 5))
|
|
494
|
+
ax = fig.add_subplot(111)
|
|
495
|
+
ax.plot(plot_data.index, plot_data['act_v'],
|
|
496
|
+
label='Actual', color='red')
|
|
497
|
+
ax.plot(plot_data.index, plot_data['exp_v1'],
|
|
498
|
+
label=model_nme_1, color='blue')
|
|
499
|
+
ax.plot(plot_data.index, plot_data['exp_v2'],
|
|
500
|
+
label=model_nme_2, color='black')
|
|
501
|
+
ax.set_title(
|
|
502
|
+
'Double Lift Chart of %s' % tgt_nme, fontsize=8)
|
|
503
|
+
plt.xticks(plot_data.index,
|
|
504
|
+
plot_data.index,
|
|
505
|
+
rotation=90, fontsize=6)
|
|
506
|
+
plt.xlabel('%s / %s' % (model_nme_1, model_nme_2), fontsize=6)
|
|
507
|
+
plt.yticks(fontsize=6)
|
|
508
|
+
plt.legend(loc='upper left',
|
|
509
|
+
fontsize=5, frameon=False)
|
|
510
|
+
plt.margins(0.1)
|
|
511
|
+
plt.subplots_adjust(bottom=0.25, top=0.95, right=0.8)
|
|
512
|
+
ax2 = ax.twinx()
|
|
513
|
+
ax2.bar(plot_data.index, plot_data['weight'],
|
|
514
|
+
alpha=0.5, color='seagreen',
|
|
515
|
+
label='Earned Exposure')
|
|
516
|
+
plt.yticks(fontsize=6)
|
|
517
|
+
plt.legend(loc='upper right',
|
|
518
|
+
fontsize=5, frameon=False)
|
|
519
|
+
plt.subplots_adjust(wspace=0.3)
|
|
520
|
+
save_path = os.path.join(
|
|
521
|
+
os.getcwd(), 'plot', f'06_{tgt_nme}_{fig_nme}.png')
|
|
522
|
+
plt.savefig(save_path, dpi=300)
|
|
523
|
+
plt.close(fig)
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
# 使 user_packages 成为可导入的 Python 包,便于在 notebook/脚本中统一引用。
|
|
4
|
+
|
|
5
|
+
from .BayesOpt import ( # noqa: F401
|
|
6
|
+
BayesOptConfig,
|
|
7
|
+
BayesOptModel,
|
|
8
|
+
IOUtils,
|
|
9
|
+
TrainingUtils,
|
|
10
|
+
free_cuda,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"BayesOptConfig",
|
|
15
|
+
"BayesOptModel",
|
|
16
|
+
"IOUtils",
|
|
17
|
+
"TrainingUtils",
|
|
18
|
+
"free_cuda",
|
|
19
|
+
]
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
PLOT_MODEL_LABELS: Dict[str, Tuple[str, str]] = {
|
|
10
|
+
"glm": ("GLM", "pred_glm"),
|
|
11
|
+
"xgb": ("Xgboost", "pred_xgb"),
|
|
12
|
+
"resn": ("ResNet", "pred_resn"),
|
|
13
|
+
"ft": ("FTTransformer", "pred_ft"),
|
|
14
|
+
"gnn": ("GNN", "pred_gnn"),
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
PYTORCH_TRAINERS = {"resn", "ft", "gnn"}
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def dedupe_preserve_order(items: Iterable[str]) -> List[str]:
|
|
21
|
+
seen = set()
|
|
22
|
+
unique: List[str] = []
|
|
23
|
+
for item in items:
|
|
24
|
+
if item not in seen:
|
|
25
|
+
unique.append(item)
|
|
26
|
+
seen.add(item)
|
|
27
|
+
return unique
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def build_model_names(prefixes: Sequence[str], suffixes: Sequence[str]) -> List[str]:
|
|
31
|
+
names: List[str] = []
|
|
32
|
+
for suffix in suffixes:
|
|
33
|
+
names.extend(f"{prefix}_{suffix}" for prefix in prefixes)
|
|
34
|
+
return names
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def parse_model_pairs(raw_pairs: List) -> List[Tuple[str, str]]:
|
|
38
|
+
pairs: List[Tuple[str, str]] = []
|
|
39
|
+
for pair in raw_pairs:
|
|
40
|
+
if isinstance(pair, (list, tuple)) and len(pair) == 2:
|
|
41
|
+
pairs.append((str(pair[0]), str(pair[1])))
|
|
42
|
+
elif isinstance(pair, str):
|
|
43
|
+
parts = [p.strip() for p in pair.split(",") if p.strip()]
|
|
44
|
+
if len(parts) == 2:
|
|
45
|
+
pairs.append((parts[0], parts[1]))
|
|
46
|
+
return pairs
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def resolve_path(value: Optional[str], base_dir: Path) -> Optional[Path]:
|
|
50
|
+
if value is None:
|
|
51
|
+
return None
|
|
52
|
+
if not isinstance(value, str) or not value.strip():
|
|
53
|
+
return None
|
|
54
|
+
p = Path(value)
|
|
55
|
+
if p.is_absolute():
|
|
56
|
+
return p
|
|
57
|
+
return (base_dir / p).resolve()
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def resolve_config_path(raw: str, script_dir: Path) -> Path:
|
|
61
|
+
candidate = Path(raw)
|
|
62
|
+
if candidate.exists():
|
|
63
|
+
return candidate.resolve()
|
|
64
|
+
candidate2 = (script_dir / raw)
|
|
65
|
+
if candidate2.exists():
|
|
66
|
+
return candidate2.resolve()
|
|
67
|
+
raise FileNotFoundError(
|
|
68
|
+
f"Config file not found: {raw}. Tried: {Path(raw).resolve()} and {candidate2.resolve()}"
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def load_config_json(path: Path, required_keys: Sequence[str]) -> Dict[str, Any]:
|
|
73
|
+
cfg = json.loads(path.read_text(encoding="utf-8"))
|
|
74
|
+
missing = [key for key in required_keys if key not in cfg]
|
|
75
|
+
if missing:
|
|
76
|
+
raise ValueError(f"Missing required keys in {path}: {missing}")
|
|
77
|
+
return cfg
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def set_env(env_overrides: Dict[str, Any]) -> None:
|
|
81
|
+
for key, value in (env_overrides or {}).items():
|
|
82
|
+
os.environ.setdefault(str(key), str(value))
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def _looks_like_url(value: str) -> bool:
|
|
86
|
+
value = str(value)
|
|
87
|
+
return "://" in value
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def normalize_config_paths(cfg: Dict[str, Any], config_path: Path) -> Dict[str, Any]:
|
|
91
|
+
"""将配置中的相对路径统一解析为“相对于 config.json 所在目录”。
|
|
92
|
+
|
|
93
|
+
目前处理的字段:
|
|
94
|
+
- data_dir / output_dir / optuna_storage / gnn_graph_cache
|
|
95
|
+
- best_params_files(dict: model_key -> path)
|
|
96
|
+
"""
|
|
97
|
+
base_dir = config_path.parent
|
|
98
|
+
out = dict(cfg)
|
|
99
|
+
|
|
100
|
+
for key in ("data_dir", "output_dir", "gnn_graph_cache"):
|
|
101
|
+
if key in out and isinstance(out.get(key), str):
|
|
102
|
+
resolved = resolve_path(out.get(key), base_dir)
|
|
103
|
+
if resolved is not None:
|
|
104
|
+
out[key] = str(resolved)
|
|
105
|
+
|
|
106
|
+
storage = out.get("optuna_storage")
|
|
107
|
+
if isinstance(storage, str) and storage.strip():
|
|
108
|
+
if not _looks_like_url(storage):
|
|
109
|
+
resolved = resolve_path(storage, base_dir)
|
|
110
|
+
if resolved is not None:
|
|
111
|
+
out["optuna_storage"] = str(resolved)
|
|
112
|
+
|
|
113
|
+
best_files = out.get("best_params_files")
|
|
114
|
+
if isinstance(best_files, dict):
|
|
115
|
+
resolved_map: Dict[str, str] = {}
|
|
116
|
+
for mk, path_str in best_files.items():
|
|
117
|
+
if not isinstance(path_str, str):
|
|
118
|
+
continue
|
|
119
|
+
resolved = resolve_path(path_str, base_dir)
|
|
120
|
+
resolved_map[str(mk)] = str(resolved) if resolved is not None else str(path_str)
|
|
121
|
+
out["best_params_files"] = resolved_map
|
|
122
|
+
|
|
123
|
+
return out
|
|
124
|
+
|