staran 0.6.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- staran/__init__.py +10 -0
- staran/tools/__init__.py +5 -5
- staran-1.0.0.dist-info/METADATA +301 -0
- staran-1.0.0.dist-info/RECORD +8 -0
- staran/banks/__init__.py +0 -30
- staran/banks/xinjiang_icbc/__init__.py +0 -90
- staran/engines/__init__.py +0 -65
- staran/engines/base.py +0 -255
- staran/engines/hive.py +0 -163
- staran/engines/spark.py +0 -252
- staran/engines/turing.py +0 -439
- staran/features/__init__.py +0 -59
- staran/features/engines.py +0 -284
- staran/features/generator.py +0 -603
- staran/features/manager.py +0 -155
- staran/features/schema.py +0 -193
- staran/models/__init__.py +0 -72
- staran/models/config.py +0 -271
- staran/models/daifa_models.py +0 -361
- staran/models/registry.py +0 -281
- staran/models/target.py +0 -321
- staran/schemas/__init__.py +0 -27
- staran/schemas/aum/__init__.py +0 -210
- staran/tools/document_generator.py +0 -350
- staran-0.6.1.dist-info/METADATA +0 -586
- staran-0.6.1.dist-info/RECORD +0 -28
- {staran-0.6.1.dist-info → staran-1.0.0.dist-info}/WHEEL +0 -0
- {staran-0.6.1.dist-info → staran-1.0.0.dist-info}/licenses/LICENSE +0 -0
- {staran-0.6.1.dist-info → staran-1.0.0.dist-info}/top_level.txt +0 -0
staran/models/daifa_models.py
DELETED
@@ -1,361 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
新疆工行代发长尾客户专用模型定义
|
3
|
-
|
4
|
-
包含两个核心模型:
|
5
|
-
1. 代发长尾客户提升3k预测模型
|
6
|
-
2. 代发长尾客户防流失1.5k预测模型
|
7
|
-
|
8
|
-
基于新疆工行代发长尾客户数据库和业务规则
|
9
|
-
"""
|
10
|
-
|
11
|
-
from typing import Dict, List
|
12
|
-
from .config import create_model_config
|
13
|
-
from .target import create_target_definition
|
14
|
-
from .registry import ModelRegistry, register_model
|
15
|
-
import os
|
16
|
-
import json
|
17
|
-
from datetime import datetime
|
18
|
-
|
19
|
-
|
20
|
-
def save_model_registry(output_path: str):
|
21
|
-
"""保存模型注册信息到文件"""
|
22
|
-
|
23
|
-
def convert_to_serializable(obj):
|
24
|
-
"""递归转换对象为可序列化格式"""
|
25
|
-
if isinstance(obj, datetime):
|
26
|
-
return obj.isoformat()
|
27
|
-
elif hasattr(obj, '__dict__'):
|
28
|
-
result = {}
|
29
|
-
for key, value in obj.__dict__.items():
|
30
|
-
result[key] = convert_to_serializable(value)
|
31
|
-
return result
|
32
|
-
elif hasattr(obj, 'value'): # 枚举类型
|
33
|
-
return obj.value
|
34
|
-
elif isinstance(obj, (list, tuple)):
|
35
|
-
return [convert_to_serializable(item) for item in obj]
|
36
|
-
elif isinstance(obj, dict):
|
37
|
-
return {k: convert_to_serializable(v) for k, v in obj.items()}
|
38
|
-
else:
|
39
|
-
return obj
|
40
|
-
|
41
|
-
data = {
|
42
|
-
"models": {},
|
43
|
-
"version_history": {},
|
44
|
-
"saved_at": str(datetime.now())
|
45
|
-
}
|
46
|
-
|
47
|
-
# 获取所有注册的模型
|
48
|
-
for model_id, entry in ModelRegistry._models.items():
|
49
|
-
data["models"][model_id] = {
|
50
|
-
"model_config": convert_to_serializable(entry.model_config),
|
51
|
-
"target_definition": convert_to_serializable(entry.target_definition),
|
52
|
-
"registered_at": entry.registered_at.isoformat(),
|
53
|
-
"status": entry.status,
|
54
|
-
"performance_metrics": entry.performance_metrics
|
55
|
-
}
|
56
|
-
|
57
|
-
data["version_history"] = ModelRegistry._version_history.copy()
|
58
|
-
|
59
|
-
with open(output_path, 'w', encoding='utf-8') as f:
|
60
|
-
json.dump(data, f, indent=2, ensure_ascii=False)
|
61
|
-
|
62
|
-
print(f"✅ 模型注册信息已保存到: {output_path}")
|
63
|
-
return output_path
|
64
|
-
|
65
|
-
|
66
|
-
def create_daifa_longtail_upgrade_model() -> Dict:
|
67
|
-
"""创建代发长尾客户提升3k预测模型"""
|
68
|
-
|
69
|
-
# 模型配置
|
70
|
-
model_config = create_model_config(
|
71
|
-
name="xinjiang_icbc_daifa_longtail_upgrade_3k",
|
72
|
-
model_type="classification",
|
73
|
-
algorithm="gradient_boosting",
|
74
|
-
version="1.0.0",
|
75
|
-
schema_name="daifa_longtail",
|
76
|
-
table_types=["daifa_longtail_behavior", "daifa_longtail_asset_avg",
|
77
|
-
"daifa_longtail_asset_config", "daifa_longtail_monthly_stat"],
|
78
|
-
hyperparameters={
|
79
|
-
"n_estimators": 300,
|
80
|
-
"learning_rate": 0.05,
|
81
|
-
"max_depth": 12,
|
82
|
-
"min_samples_split": 20,
|
83
|
-
"min_samples_leaf": 10,
|
84
|
-
"subsample": 0.8,
|
85
|
-
"random_state": 42
|
86
|
-
},
|
87
|
-
bank_code="xinjiang_icbc",
|
88
|
-
business_domain="代发长尾客户",
|
89
|
-
description="新疆工行代发长尾客户下个月资产提升3k预测模型",
|
90
|
-
tags=["daifa", "longtail", "upgrade", "3k", "xinjiang_icbc"]
|
91
|
-
)
|
92
|
-
|
93
|
-
# 目标定义 - 预测下个月提升3k
|
94
|
-
target_definition = create_target_definition(
|
95
|
-
name="daifa_longtail_upgrade_3k_target",
|
96
|
-
target_type="binary_classification",
|
97
|
-
description="新疆工行代发长尾客户下个月资产提升3000元预测目标",
|
98
|
-
sql_query="""
|
99
|
-
WITH customer_baseline AS (
|
100
|
-
-- 获取代发长尾客户基础信息(当月)
|
101
|
-
SELECT
|
102
|
-
b.party_id,
|
103
|
-
b.asset_total_bal as current_asset,
|
104
|
-
b.salary_amount as current_salary,
|
105
|
-
b.longtail_score,
|
106
|
-
b.upgrade_potential,
|
107
|
-
CASE
|
108
|
-
WHEN b.asset_total_bal BETWEEN 10000 AND 100000 THEN 1
|
109
|
-
ELSE 0
|
110
|
-
END as is_daifa_longtail
|
111
|
-
FROM xinjiang_icbc_daifa_hlwj_monthly_stat_wy b
|
112
|
-
WHERE b.data_dt = '{baseline_date}'
|
113
|
-
),
|
114
|
-
|
115
|
-
next_month_performance AS (
|
116
|
-
-- 计算下个月的资产变化
|
117
|
-
SELECT
|
118
|
-
party_id,
|
119
|
-
asset_total_bal as next_month_asset,
|
120
|
-
salary_amount as next_month_salary,
|
121
|
-
monthly_deposit_amount,
|
122
|
-
monthly_withdraw_amount
|
123
|
-
FROM xinjiang_icbc_daifa_hlwj_monthly_stat_wy
|
124
|
-
WHERE data_dt = '{next_month_date}'
|
125
|
-
),
|
126
|
-
|
127
|
-
asset_change AS (
|
128
|
-
-- 计算资产变化情况
|
129
|
-
SELECT
|
130
|
-
cb.party_id,
|
131
|
-
cb.current_asset,
|
132
|
-
nmp.next_month_asset,
|
133
|
-
(nmp.next_month_asset - cb.current_asset) as asset_change,
|
134
|
-
nmp.monthly_deposit_amount,
|
135
|
-
cb.upgrade_potential
|
136
|
-
FROM customer_baseline cb
|
137
|
-
INNER JOIN next_month_performance nmp ON cb.party_id = nmp.party_id
|
138
|
-
WHERE cb.is_daifa_longtail = 1 -- 只关注代发长尾客户
|
139
|
-
)
|
140
|
-
|
141
|
-
SELECT
|
142
|
-
party_id,
|
143
|
-
CASE
|
144
|
-
-- 代发长尾客户资产提升3k的判断标准
|
145
|
-
WHEN asset_change >= 3000 -- 资产增长达到3000元
|
146
|
-
AND monthly_deposit_amount > asset_change * 0.7 -- 主要通过存入实现
|
147
|
-
AND upgrade_potential >= 0.6 -- 提升潜力评分较高
|
148
|
-
THEN 1
|
149
|
-
ELSE 0
|
150
|
-
END as upgrade_3k_target,
|
151
|
-
|
152
|
-
-- 辅助分析字段
|
153
|
-
current_asset,
|
154
|
-
next_month_asset,
|
155
|
-
asset_change,
|
156
|
-
monthly_deposit_amount,
|
157
|
-
upgrade_potential
|
158
|
-
|
159
|
-
FROM asset_change
|
160
|
-
""",
|
161
|
-
target_column="upgrade_3k_target",
|
162
|
-
class_labels=["no_upgrade", "upgrade_3k"],
|
163
|
-
class_weights={"no_upgrade": 1.0, "upgrade_3k": 2.5}, # 提升类样本权重更高
|
164
|
-
time_window="1_month",
|
165
|
-
prediction_horizon="1_month",
|
166
|
-
bank_code="xinjiang_icbc",
|
167
|
-
business_rules={
|
168
|
-
"min_asset_threshold": 10000, # 代发长尾最小资产
|
169
|
-
"max_asset_threshold": 100000, # 代发长尾最大资产
|
170
|
-
"upgrade_target_amount": 3000, # 提升目标金额
|
171
|
-
"deposit_contribution_ratio": 0.7, # 存入贡献占比
|
172
|
-
"min_upgrade_potential": 0.6 # 最小提升潜力
|
173
|
-
}
|
174
|
-
)
|
175
|
-
|
176
|
-
return {
|
177
|
-
"model_config": model_config,
|
178
|
-
"target_definition": target_definition,
|
179
|
-
"model_type": "upgrade_prediction"
|
180
|
-
}
|
181
|
-
|
182
|
-
|
183
|
-
def create_daifa_longtail_churn_model() -> Dict:
|
184
|
-
"""创建代发长尾客户防流失1.5k预测模型"""
|
185
|
-
|
186
|
-
# 模型配置
|
187
|
-
model_config = create_model_config(
|
188
|
-
name="xinjiang_icbc_daifa_longtail_churn_1_5k",
|
189
|
-
model_type="classification",
|
190
|
-
algorithm="random_forest", # 防流失模型使用随机森林
|
191
|
-
version="1.0.0",
|
192
|
-
schema_name="daifa_longtail",
|
193
|
-
table_types=["daifa_longtail_behavior", "daifa_longtail_asset_avg",
|
194
|
-
"daifa_longtail_asset_config", "daifa_longtail_monthly_stat"],
|
195
|
-
hyperparameters={
|
196
|
-
"n_estimators": 200,
|
197
|
-
"max_depth": 10,
|
198
|
-
"min_samples_split": 15,
|
199
|
-
"min_samples_leaf": 8,
|
200
|
-
"max_features": "sqrt",
|
201
|
-
"random_state": 42,
|
202
|
-
"class_weight": "balanced" # 处理不平衡数据
|
203
|
-
},
|
204
|
-
bank_code="xinjiang_icbc",
|
205
|
-
business_domain="代发长尾客户",
|
206
|
-
description="新疆工行代发长尾客户下个月流失1.5k资产风险预测模型",
|
207
|
-
tags=["daifa", "longtail", "churn", "1_5k", "risk_prevention"]
|
208
|
-
)
|
209
|
-
|
210
|
-
# 目标定义 - 预测下个月流失1.5k风险
|
211
|
-
target_definition = create_target_definition(
|
212
|
-
name="daifa_longtail_churn_1_5k_target",
|
213
|
-
target_type="binary_classification",
|
214
|
-
description="新疆工行代发长尾客户下个月流失1500元资产风险预测目标",
|
215
|
-
sql_query="""
|
216
|
-
WITH customer_baseline AS (
|
217
|
-
-- 获取代发长尾客户基础信息(当月)
|
218
|
-
SELECT
|
219
|
-
b.party_id,
|
220
|
-
b.asset_total_bal as current_asset,
|
221
|
-
b.salary_amount as current_salary,
|
222
|
-
b.longtail_score,
|
223
|
-
b.churn_risk,
|
224
|
-
b.login_days,
|
225
|
-
CASE
|
226
|
-
WHEN b.asset_total_bal BETWEEN 10000 AND 100000 THEN 1
|
227
|
-
ELSE 0
|
228
|
-
END as is_daifa_longtail
|
229
|
-
FROM xinjiang_icbc_daifa_hlwj_monthly_stat_wy b
|
230
|
-
WHERE b.data_dt = '{baseline_date}'
|
231
|
-
),
|
232
|
-
|
233
|
-
next_month_performance AS (
|
234
|
-
-- 计算下个月的资产变化和行为
|
235
|
-
SELECT
|
236
|
-
party_id,
|
237
|
-
asset_total_bal as next_month_asset,
|
238
|
-
monthly_withdraw_amount,
|
239
|
-
login_days as next_month_login_days
|
240
|
-
FROM xinjiang_icbc_daifa_hlwj_monthly_stat_wy
|
241
|
-
WHERE data_dt = '{next_month_date}'
|
242
|
-
),
|
243
|
-
|
244
|
-
churn_analysis AS (
|
245
|
-
-- 分析流失风险情况
|
246
|
-
SELECT
|
247
|
-
cb.party_id,
|
248
|
-
cb.current_asset,
|
249
|
-
nmp.next_month_asset,
|
250
|
-
(cb.current_asset - nmp.next_month_asset) as asset_decrease,
|
251
|
-
nmp.monthly_withdraw_amount,
|
252
|
-
cb.churn_risk,
|
253
|
-
cb.login_days,
|
254
|
-
nmp.next_month_login_days
|
255
|
-
FROM customer_baseline cb
|
256
|
-
INNER JOIN next_month_performance nmp ON cb.party_id = nmp.party_id
|
257
|
-
WHERE cb.is_daifa_longtail = 1 -- 只关注代发长尾客户
|
258
|
-
)
|
259
|
-
|
260
|
-
SELECT
|
261
|
-
party_id,
|
262
|
-
CASE
|
263
|
-
-- 代发长尾客户流失1.5k的判断标准
|
264
|
-
WHEN asset_decrease >= 1500 -- 资产减少达到1500元
|
265
|
-
AND monthly_withdraw_amount >= 1500 -- 主要通过取出导致
|
266
|
-
AND (
|
267
|
-
churn_risk >= 0.7 -- 流失风险评分高
|
268
|
-
OR next_month_login_days <= login_days * 0.5 -- 活跃度大幅下降
|
269
|
-
)
|
270
|
-
THEN 1
|
271
|
-
ELSE 0
|
272
|
-
END as churn_1_5k_target,
|
273
|
-
|
274
|
-
-- 辅助分析字段
|
275
|
-
current_asset,
|
276
|
-
next_month_asset,
|
277
|
-
asset_decrease,
|
278
|
-
monthly_withdraw_amount,
|
279
|
-
churn_risk,
|
280
|
-
login_days,
|
281
|
-
next_month_login_days
|
282
|
-
|
283
|
-
FROM churn_analysis
|
284
|
-
""",
|
285
|
-
target_column="churn_1_5k_target",
|
286
|
-
class_labels=["no_churn", "churn_1_5k"],
|
287
|
-
class_weights={"no_churn": 1.0, "churn_1_5k": 3.0}, # 流失类样本权重更高
|
288
|
-
time_window="1_month",
|
289
|
-
prediction_horizon="1_month",
|
290
|
-
bank_code="xinjiang_icbc",
|
291
|
-
business_rules={
|
292
|
-
"min_asset_threshold": 10000, # 代发长尾最小资产
|
293
|
-
"max_asset_threshold": 100000, # 代发长尾最大资产
|
294
|
-
"churn_threshold_amount": 1500, # 流失阈值金额
|
295
|
-
"min_churn_risk": 0.7, # 最小流失风险
|
296
|
-
"activity_decline_ratio": 0.5 # 活跃度下降比例
|
297
|
-
}
|
298
|
-
)
|
299
|
-
|
300
|
-
return {
|
301
|
-
"model_config": model_config,
|
302
|
-
"target_definition": target_definition,
|
303
|
-
"model_type": "churn_prevention"
|
304
|
-
}
|
305
|
-
|
306
|
-
|
307
|
-
def create_both_daifa_models(output_dir: str = "./xinjiang_models") -> Dict:
|
308
|
-
"""创建两个代发长尾客户模型并注册"""
|
309
|
-
|
310
|
-
# 确保输出目录存在
|
311
|
-
os.makedirs(output_dir, exist_ok=True)
|
312
|
-
|
313
|
-
# 创建提升模型
|
314
|
-
upgrade_model = create_daifa_longtail_upgrade_model()
|
315
|
-
upgrade_id = register_model(
|
316
|
-
upgrade_model["model_config"],
|
317
|
-
upgrade_model["target_definition"]
|
318
|
-
)
|
319
|
-
|
320
|
-
# 创建防流失模型
|
321
|
-
churn_model = create_daifa_longtail_churn_model()
|
322
|
-
churn_id = register_model(
|
323
|
-
churn_model["model_config"],
|
324
|
-
churn_model["target_definition"]
|
325
|
-
)
|
326
|
-
|
327
|
-
# 保存注册信息到指定目录
|
328
|
-
registry_path = os.path.join(output_dir, "model_registry.json")
|
329
|
-
save_model_registry(registry_path)
|
330
|
-
|
331
|
-
return {
|
332
|
-
"upgrade_model": {
|
333
|
-
"model_id": upgrade_id,
|
334
|
-
"config": upgrade_model["model_config"],
|
335
|
-
"target": upgrade_model["target_definition"]
|
336
|
-
},
|
337
|
-
"churn_model": {
|
338
|
-
"model_id": churn_id,
|
339
|
-
"config": churn_model["model_config"],
|
340
|
-
"target": churn_model["target_definition"]
|
341
|
-
},
|
342
|
-
"registry_path": registry_path,
|
343
|
-
"output_dir": output_dir
|
344
|
-
}
|
345
|
-
|
346
|
-
|
347
|
-
def get_available_daifa_models() -> List[str]:
|
348
|
-
"""获取所有可用的代发长尾客户模型"""
|
349
|
-
return [
|
350
|
-
"daifa_longtail_upgrade_3k", # 代发长尾客户提升3k模型
|
351
|
-
"daifa_longtail_churn_1_5k" # 代发长尾客户防流失1.5k模型
|
352
|
-
]
|
353
|
-
|
354
|
-
|
355
|
-
# 导出函数
|
356
|
-
__all__ = [
|
357
|
-
'create_daifa_longtail_upgrade_model',
|
358
|
-
'create_daifa_longtail_churn_model',
|
359
|
-
'create_both_daifa_models',
|
360
|
-
'get_available_daifa_models'
|
361
|
-
]
|
staran/models/registry.py
DELETED
@@ -1,281 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
模型注册和管理模块
|
3
|
-
|
4
|
-
提供模型配置的注册、查询、版本管理等功能
|
5
|
-
"""
|
6
|
-
|
7
|
-
from typing import Dict, Any, List, Optional, Tuple
|
8
|
-
from dataclasses import dataclass
|
9
|
-
from datetime import datetime
|
10
|
-
import json
|
11
|
-
import os
|
12
|
-
|
13
|
-
from .config import ModelConfig
|
14
|
-
from .target import TargetDefinition
|
15
|
-
|
16
|
-
|
17
|
-
@dataclass
|
18
|
-
class ModelEntry:
|
19
|
-
"""模型注册条目"""
|
20
|
-
model_config: ModelConfig
|
21
|
-
target_definition: TargetDefinition
|
22
|
-
registered_at: datetime
|
23
|
-
status: str = "active" # active, inactive, deprecated
|
24
|
-
performance_metrics: Dict[str, float] = None
|
25
|
-
|
26
|
-
def __post_init__(self):
|
27
|
-
if self.performance_metrics is None:
|
28
|
-
self.performance_metrics = {}
|
29
|
-
|
30
|
-
|
31
|
-
class ModelRegistry:
|
32
|
-
"""模型注册表"""
|
33
|
-
|
34
|
-
_models: Dict[str, ModelEntry] = {}
|
35
|
-
_version_history: Dict[str, List[str]] = {} # 模型名称 -> 版本列表
|
36
|
-
|
37
|
-
@classmethod
|
38
|
-
def register(cls, model_config: ModelConfig, target_definition: TargetDefinition) -> str:
|
39
|
-
"""
|
40
|
-
注册一个新模型
|
41
|
-
|
42
|
-
Args:
|
43
|
-
model_config: 模型配置
|
44
|
-
target_definition: 目标变量定义
|
45
|
-
|
46
|
-
Returns:
|
47
|
-
模型的唯一标识符
|
48
|
-
"""
|
49
|
-
model_id = f"{model_config.name}_{model_config.version}"
|
50
|
-
|
51
|
-
# 检查是否已存在
|
52
|
-
if model_id in cls._models:
|
53
|
-
raise ValueError(f"模型 {model_id} 已存在")
|
54
|
-
|
55
|
-
# 创建模型条目
|
56
|
-
entry = ModelEntry(
|
57
|
-
model_config=model_config,
|
58
|
-
target_definition=target_definition,
|
59
|
-
registered_at=datetime.now()
|
60
|
-
)
|
61
|
-
|
62
|
-
# 注册模型
|
63
|
-
cls._models[model_id] = entry
|
64
|
-
|
65
|
-
# 更新版本历史
|
66
|
-
if model_config.name not in cls._version_history:
|
67
|
-
cls._version_history[model_config.name] = []
|
68
|
-
cls._version_history[model_config.name].append(model_config.version)
|
69
|
-
|
70
|
-
print(f"✅ 模型 {model_id} 注册成功")
|
71
|
-
return model_id
|
72
|
-
|
73
|
-
@classmethod
|
74
|
-
def get_model(cls, model_name: str, version: str = None) -> Optional[ModelEntry]:
|
75
|
-
"""
|
76
|
-
获取模型条目
|
77
|
-
|
78
|
-
Args:
|
79
|
-
model_name: 模型名称
|
80
|
-
version: 版本号,如果不指定则返回最新版本
|
81
|
-
|
82
|
-
Returns:
|
83
|
-
模型条目或None
|
84
|
-
"""
|
85
|
-
if version:
|
86
|
-
model_id = f"{model_name}_{version}"
|
87
|
-
return cls._models.get(model_id)
|
88
|
-
else:
|
89
|
-
# 获取最新版本
|
90
|
-
versions = cls._version_history.get(model_name, [])
|
91
|
-
if not versions:
|
92
|
-
return None
|
93
|
-
|
94
|
-
latest_version = max(versions) # 简单的字符串比较,实际应该用版本比较
|
95
|
-
model_id = f"{model_name}_{latest_version}"
|
96
|
-
return cls._models.get(model_id)
|
97
|
-
|
98
|
-
@classmethod
|
99
|
-
def get_model_config(cls, model_name: str, version: str = None) -> Optional[ModelConfig]:
|
100
|
-
"""获取模型配置"""
|
101
|
-
entry = cls.get_model(model_name, version)
|
102
|
-
return entry.model_config if entry else None
|
103
|
-
|
104
|
-
@classmethod
|
105
|
-
def get_target_definition(cls, model_name: str, version: str = None) -> Optional[TargetDefinition]:
|
106
|
-
"""获取目标变量定义"""
|
107
|
-
entry = cls.get_model(model_name, version)
|
108
|
-
return entry.target_definition if entry else None
|
109
|
-
|
110
|
-
@classmethod
|
111
|
-
def list_models(cls) -> List[Dict[str, Any]]:
|
112
|
-
"""列出所有注册的模型"""
|
113
|
-
result = []
|
114
|
-
for model_id, entry in cls._models.items():
|
115
|
-
result.append({
|
116
|
-
'model_id': model_id,
|
117
|
-
'name': entry.model_config.name,
|
118
|
-
'version': entry.model_config.version,
|
119
|
-
'type': entry.model_config.model_type.value,
|
120
|
-
'algorithm': entry.model_config.algorithm.value,
|
121
|
-
'bank_code': entry.model_config.bank_code,
|
122
|
-
'status': entry.status,
|
123
|
-
'registered_at': entry.registered_at.isoformat(),
|
124
|
-
'description': entry.model_config.description
|
125
|
-
})
|
126
|
-
return result
|
127
|
-
|
128
|
-
@classmethod
|
129
|
-
def list_versions(cls, model_name: str) -> List[str]:
|
130
|
-
"""列出模型的所有版本"""
|
131
|
-
return cls._version_history.get(model_name, [])
|
132
|
-
|
133
|
-
@classmethod
|
134
|
-
def update_status(cls, model_name: str, status: str, version: str = None):
|
135
|
-
"""更新模型状态"""
|
136
|
-
entry = cls.get_model(model_name, version)
|
137
|
-
if entry:
|
138
|
-
entry.status = status
|
139
|
-
print(f"✅ 模型 {model_name} 状态更新为: {status}")
|
140
|
-
else:
|
141
|
-
print(f"❌ 模型 {model_name} 不存在")
|
142
|
-
|
143
|
-
@classmethod
|
144
|
-
def update_performance(cls, model_name: str, metrics: Dict[str, float], version: str = None):
|
145
|
-
"""更新模型性能指标"""
|
146
|
-
entry = cls.get_model(model_name, version)
|
147
|
-
if entry:
|
148
|
-
entry.performance_metrics.update(metrics)
|
149
|
-
print(f"✅ 模型 {model_name} 性能指标已更新")
|
150
|
-
else:
|
151
|
-
print(f"❌ 模型 {model_name} 不存在")
|
152
|
-
|
153
|
-
@classmethod
|
154
|
-
def get_model_summary(cls, model_name: str, version: str = None) -> Optional[Dict[str, Any]]:
|
155
|
-
"""获取模型详细信息摘要"""
|
156
|
-
entry = cls.get_model(model_name, version)
|
157
|
-
if not entry:
|
158
|
-
return None
|
159
|
-
|
160
|
-
model_config = entry.model_config
|
161
|
-
target_def = entry.target_definition
|
162
|
-
|
163
|
-
return {
|
164
|
-
'basic_info': {
|
165
|
-
'name': model_config.name,
|
166
|
-
'version': model_config.version,
|
167
|
-
'type': model_config.model_type.value,
|
168
|
-
'algorithm': model_config.algorithm.value,
|
169
|
-
'description': model_config.description,
|
170
|
-
'created_by': model_config.created_by,
|
171
|
-
'bank_code': model_config.bank_code
|
172
|
-
},
|
173
|
-
'feature_config': {
|
174
|
-
'schema_name': model_config.feature_config.schema_name,
|
175
|
-
'table_types': model_config.feature_config.table_types,
|
176
|
-
'feature_selection': model_config.feature_config.feature_selection,
|
177
|
-
'feature_engineering': model_config.feature_config.feature_engineering
|
178
|
-
},
|
179
|
-
'target_config': {
|
180
|
-
'name': target_def.name,
|
181
|
-
'type': target_def.target_type.value,
|
182
|
-
'column': target_def.target_column,
|
183
|
-
'description': target_def.description
|
184
|
-
},
|
185
|
-
'training_config': model_config.training_config,
|
186
|
-
'hyperparameters': model_config.hyperparameters,
|
187
|
-
'evaluation_metrics': model_config.evaluation_metrics,
|
188
|
-
'registry_info': {
|
189
|
-
'status': entry.status,
|
190
|
-
'registered_at': entry.registered_at.isoformat(),
|
191
|
-
'performance_metrics': entry.performance_metrics
|
192
|
-
}
|
193
|
-
}
|
194
|
-
|
195
|
-
@classmethod
|
196
|
-
def save_to_file(cls, filepath: str):
|
197
|
-
"""保存注册表到文件"""
|
198
|
-
data = {
|
199
|
-
'models': {},
|
200
|
-
'version_history': cls._version_history,
|
201
|
-
'saved_at': datetime.now().isoformat()
|
202
|
-
}
|
203
|
-
|
204
|
-
# 序列化模型数据
|
205
|
-
for model_id, entry in cls._models.items():
|
206
|
-
data['models'][model_id] = {
|
207
|
-
'model_config': entry.model_config.to_dict(),
|
208
|
-
'target_definition': entry.target_definition.to_dict(),
|
209
|
-
'registered_at': entry.registered_at.isoformat(),
|
210
|
-
'status': entry.status,
|
211
|
-
'performance_metrics': entry.performance_metrics
|
212
|
-
}
|
213
|
-
|
214
|
-
os.makedirs(os.path.dirname(filepath), exist_ok=True)
|
215
|
-
with open(filepath, 'w', encoding='utf-8') as f:
|
216
|
-
json.dump(data, f, indent=2, ensure_ascii=False)
|
217
|
-
|
218
|
-
print(f"✅ 模型注册表已保存到: {filepath}")
|
219
|
-
|
220
|
-
@classmethod
|
221
|
-
def load_from_file(cls, filepath: str):
|
222
|
-
"""从文件加载注册表"""
|
223
|
-
if not os.path.exists(filepath):
|
224
|
-
print(f"❌ 文件不存在: {filepath}")
|
225
|
-
return
|
226
|
-
|
227
|
-
with open(filepath, 'r', encoding='utf-8') as f:
|
228
|
-
data = json.load(f)
|
229
|
-
|
230
|
-
cls._version_history = data.get('version_history', {})
|
231
|
-
cls._models = {}
|
232
|
-
|
233
|
-
# 反序列化模型数据
|
234
|
-
for model_id, entry_data in data.get('models', {}).items():
|
235
|
-
model_config = ModelConfig.from_dict(entry_data['model_config'])
|
236
|
-
target_definition = TargetDefinition(
|
237
|
-
**entry_data['target_definition']
|
238
|
-
)
|
239
|
-
|
240
|
-
entry = ModelEntry(
|
241
|
-
model_config=model_config,
|
242
|
-
target_definition=target_definition,
|
243
|
-
registered_at=datetime.fromisoformat(entry_data['registered_at']),
|
244
|
-
status=entry_data.get('status', 'active'),
|
245
|
-
performance_metrics=entry_data.get('performance_metrics', {})
|
246
|
-
)
|
247
|
-
|
248
|
-
cls._models[model_id] = entry
|
249
|
-
|
250
|
-
print(f"✅ 从 {filepath} 加载了 {len(cls._models)} 个模型")
|
251
|
-
|
252
|
-
|
253
|
-
# 便捷函数
|
254
|
-
def register_model(model_config: ModelConfig, target_definition: TargetDefinition) -> str:
|
255
|
-
"""注册模型的便捷函数"""
|
256
|
-
return ModelRegistry.register(model_config, target_definition)
|
257
|
-
|
258
|
-
|
259
|
-
def get_model_config(model_name: str, version: str = None) -> Optional[ModelConfig]:
|
260
|
-
"""获取模型配置的便捷函数"""
|
261
|
-
return ModelRegistry.get_model_config(model_name, version)
|
262
|
-
|
263
|
-
|
264
|
-
def get_target_definition(model_name: str, version: str = None) -> Optional[TargetDefinition]:
|
265
|
-
"""获取目标变量定义的便捷函数"""
|
266
|
-
return ModelRegistry.get_target_definition(model_name, version)
|
267
|
-
|
268
|
-
|
269
|
-
def list_available_models() -> List[Dict[str, Any]]:
|
270
|
-
"""列出可用模型的便捷函数"""
|
271
|
-
return ModelRegistry.list_models()
|
272
|
-
|
273
|
-
|
274
|
-
def save_model_registry(filepath: str = "./models/model_registry.json"):
|
275
|
-
"""保存模型注册表的便捷函数"""
|
276
|
-
ModelRegistry.save_to_file(filepath)
|
277
|
-
|
278
|
-
|
279
|
-
def load_model_registry(filepath: str = "./models/model_registry.json"):
|
280
|
-
"""加载模型注册表的便捷函数"""
|
281
|
-
ModelRegistry.load_from_file(filepath)
|