staran 0.4.2__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- staran/models/__init__.py +81 -0
- staran/models/bank_configs.py +269 -0
- staran/models/config.py +271 -0
- staran/models/registry.py +281 -0
- staran/models/target.py +321 -0
- {staran-0.4.2.dist-info → staran-0.5.0.dist-info}/METADATA +1 -1
- {staran-0.4.2.dist-info → staran-0.5.0.dist-info}/RECORD +10 -5
- {staran-0.4.2.dist-info → staran-0.5.0.dist-info}/WHEEL +0 -0
- {staran-0.4.2.dist-info → staran-0.5.0.dist-info}/licenses/LICENSE +0 -0
- {staran-0.4.2.dist-info → staran-0.5.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,81 @@
|
|
1
|
+
"""
|
2
|
+
Staran Models Module - v0.5.0
|
3
|
+
|
4
|
+
专业的机器学习模型配置和管理模块,提供:
|
5
|
+
- 模型配置管理 (ModelConfig)
|
6
|
+
- 目标变量定义 (TargetDefinition)
|
7
|
+
- 银行特定配置支持
|
8
|
+
- SQL驱动的target生成
|
9
|
+
- 模型部署和版本管理
|
10
|
+
|
11
|
+
支持的模型类型:
|
12
|
+
- 分类模型 (Classification)
|
13
|
+
- 回归模型 (Regression)
|
14
|
+
- 聚类模型 (Clustering)
|
15
|
+
- 时间序列模型 (TimeSeries)
|
16
|
+
|
17
|
+
支持的银行:
|
18
|
+
- 工商银行 (ICBC)
|
19
|
+
- 通用配置 (Generic)
|
20
|
+
"""
|
21
|
+
|
22
|
+
from .config import ModelConfig, ModelType, create_model_config
|
23
|
+
from .target import TargetDefinition, TargetType, create_target_definition
|
24
|
+
from .registry import ModelRegistry, register_model, get_model_config, save_model_registry
|
25
|
+
from .bank_configs import BankConfig, get_bank_config, register_bank_config
|
26
|
+
|
27
|
+
# 版本信息
|
28
|
+
__version__ = "0.5.0"
|
29
|
+
|
30
|
+
# 主要导出
|
31
|
+
__all__ = [
|
32
|
+
# 模型配置
|
33
|
+
'ModelConfig',
|
34
|
+
'ModelType',
|
35
|
+
'create_model_config',
|
36
|
+
|
37
|
+
# 目标定义
|
38
|
+
'TargetDefinition',
|
39
|
+
'TargetType',
|
40
|
+
'create_target_definition',
|
41
|
+
|
42
|
+
# 模型注册
|
43
|
+
'ModelRegistry',
|
44
|
+
'register_model',
|
45
|
+
'get_model_config',
|
46
|
+
'save_model_registry',
|
47
|
+
|
48
|
+
# 银行配置
|
49
|
+
'BankConfig',
|
50
|
+
'get_bank_config',
|
51
|
+
'register_bank_config',
|
52
|
+
]
|
53
|
+
|
54
|
+
# 便捷函数
|
55
|
+
def create_icbc_model(model_name: str, model_type: str, target_sql: str, algorithm: str = "random_forest", **kwargs):
|
56
|
+
"""创建工商银行专用模型配置的便捷函数"""
|
57
|
+
bank_config = get_bank_config('icbc')
|
58
|
+
model_config = create_model_config(
|
59
|
+
name=model_name,
|
60
|
+
model_type=model_type,
|
61
|
+
algorithm=algorithm,
|
62
|
+
bank_code="icbc",
|
63
|
+
**kwargs
|
64
|
+
)
|
65
|
+
|
66
|
+
target_config = create_target_definition(
|
67
|
+
name=f"{model_name}_target",
|
68
|
+
target_type="sql_based",
|
69
|
+
sql_query=target_sql,
|
70
|
+
bank_code="icbc"
|
71
|
+
)
|
72
|
+
|
73
|
+
return register_model(model_config, target_config)
|
74
|
+
|
75
|
+
def list_available_models():
|
76
|
+
"""列出所有可用的模型配置"""
|
77
|
+
return ModelRegistry.list_models()
|
78
|
+
|
79
|
+
def get_model_summary(model_name: str):
|
80
|
+
"""获取模型配置摘要"""
|
81
|
+
return ModelRegistry.get_model_summary(model_name)
|
@@ -0,0 +1,269 @@
|
|
1
|
+
"""
|
2
|
+
银行特定配置模块
|
3
|
+
|
4
|
+
为不同银行提供定制化的配置和业务规则
|
5
|
+
"""
|
6
|
+
|
7
|
+
from enum import Enum
|
8
|
+
from typing import Dict, Any, List, Optional
|
9
|
+
from dataclasses import dataclass, field
|
10
|
+
|
11
|
+
|
12
|
+
class BankCode(Enum):
|
13
|
+
"""银行代码枚举"""
|
14
|
+
ICBC = "icbc" # 工商银行
|
15
|
+
CCB = "ccb" # 建设银行
|
16
|
+
BOC = "boc" # 中国银行
|
17
|
+
ABC = "abc" # 农业银行
|
18
|
+
CMB = "cmb" # 招商银行
|
19
|
+
GENERIC = "generic" # 通用配置
|
20
|
+
|
21
|
+
|
22
|
+
@dataclass
|
23
|
+
class BankConfig:
|
24
|
+
"""银行配置类"""
|
25
|
+
# 基本信息
|
26
|
+
bank_code: str # 银行代码
|
27
|
+
bank_name: str # 银行名称
|
28
|
+
region: str = "cn" # 地区代码
|
29
|
+
|
30
|
+
# 数据库配置
|
31
|
+
database_config: Dict[str, Any] = field(default_factory=dict)
|
32
|
+
|
33
|
+
# 表名映射 (不同银行的表名可能不同)
|
34
|
+
table_mappings: Dict[str, str] = field(default_factory=dict)
|
35
|
+
|
36
|
+
# 字段映射 (不同银行的字段名可能不同)
|
37
|
+
field_mappings: Dict[str, Dict[str, str]] = field(default_factory=dict)
|
38
|
+
|
39
|
+
# 业务规则
|
40
|
+
business_rules: Dict[str, Any] = field(default_factory=dict)
|
41
|
+
|
42
|
+
# 合规要求
|
43
|
+
compliance_rules: Dict[str, Any] = field(default_factory=dict)
|
44
|
+
|
45
|
+
# 数据处理规则
|
46
|
+
data_processing_rules: Dict[str, Any] = field(default_factory=dict)
|
47
|
+
|
48
|
+
# 模型部署配置
|
49
|
+
deployment_config: Dict[str, Any] = field(default_factory=dict)
|
50
|
+
|
51
|
+
# 特征工程配置
|
52
|
+
feature_engineering_config: Dict[str, Any] = field(default_factory=dict)
|
53
|
+
|
54
|
+
def get_table_name(self, standard_table: str) -> str:
|
55
|
+
"""获取银行特定的表名"""
|
56
|
+
return self.table_mappings.get(standard_table, standard_table)
|
57
|
+
|
58
|
+
def get_field_name(self, table: str, standard_field: str) -> str:
|
59
|
+
"""获取银行特定的字段名"""
|
60
|
+
table_fields = self.field_mappings.get(table, {})
|
61
|
+
return table_fields.get(standard_field, standard_field)
|
62
|
+
|
63
|
+
def get_business_rule(self, rule_name: str, default=None):
|
64
|
+
"""获取业务规则"""
|
65
|
+
return self.business_rules.get(rule_name, default)
|
66
|
+
|
67
|
+
def validate_compliance(self, operation: str) -> bool:
|
68
|
+
"""验证操作是否符合合规要求"""
|
69
|
+
compliance_checks = self.compliance_rules.get(operation, {})
|
70
|
+
# 这里可以实现具体的合规检查逻辑
|
71
|
+
return compliance_checks.get('enabled', True)
|
72
|
+
|
73
|
+
|
74
|
+
# 银行配置注册表
|
75
|
+
_BANK_CONFIGS: Dict[str, BankConfig] = {}
|
76
|
+
|
77
|
+
|
78
|
+
def register_bank_config(config: BankConfig):
|
79
|
+
"""注册银行配置"""
|
80
|
+
_BANK_CONFIGS[config.bank_code] = config
|
81
|
+
print(f"✅ 银行配置 {config.bank_code} ({config.bank_name}) 注册成功")
|
82
|
+
|
83
|
+
|
84
|
+
def get_bank_config(bank_code: str) -> Optional[BankConfig]:
|
85
|
+
"""获取银行配置"""
|
86
|
+
return _BANK_CONFIGS.get(bank_code)
|
87
|
+
|
88
|
+
|
89
|
+
def list_bank_configs() -> List[Dict[str, str]]:
|
90
|
+
"""列出所有银行配置"""
|
91
|
+
return [
|
92
|
+
{
|
93
|
+
'bank_code': config.bank_code,
|
94
|
+
'bank_name': config.bank_name,
|
95
|
+
'region': config.region
|
96
|
+
}
|
97
|
+
for config in _BANK_CONFIGS.values()
|
98
|
+
]
|
99
|
+
|
100
|
+
|
101
|
+
# 预定义银行配置
|
102
|
+
def create_icbc_config() -> BankConfig:
|
103
|
+
"""创建工商银行配置"""
|
104
|
+
return BankConfig(
|
105
|
+
bank_code="icbc",
|
106
|
+
bank_name="中国工商银行",
|
107
|
+
region="cn",
|
108
|
+
|
109
|
+
database_config={
|
110
|
+
"default_database": "dwegdata03000",
|
111
|
+
"connection_pool_size": 10,
|
112
|
+
"query_timeout": 300
|
113
|
+
},
|
114
|
+
|
115
|
+
table_mappings={
|
116
|
+
"behavior_table": "bi_hlwj_dfcw_f1_f4_wy",
|
117
|
+
"asset_avg_table": "bi_hlwj_zi_chan_avg_wy",
|
118
|
+
"asset_config_table": "bi_hlwj_zi_chang_month_total_zb",
|
119
|
+
"monthly_stat_table": "bi_hlwj_realy_month_stat_wy"
|
120
|
+
},
|
121
|
+
|
122
|
+
field_mappings={
|
123
|
+
"behavior_table": {
|
124
|
+
"customer_id": "party_id",
|
125
|
+
"date_field": "data_dt"
|
126
|
+
}
|
127
|
+
},
|
128
|
+
|
129
|
+
business_rules={
|
130
|
+
"data_retention_days": 90,
|
131
|
+
"min_sample_size": 1000,
|
132
|
+
"max_features": 500,
|
133
|
+
"risk_threshold": 0.8,
|
134
|
+
"aum_threshold": 100000,
|
135
|
+
"longtail_definition": {
|
136
|
+
"asset_threshold": 50000,
|
137
|
+
"activity_threshold": 0.3
|
138
|
+
}
|
139
|
+
},
|
140
|
+
|
141
|
+
compliance_rules={
|
142
|
+
"data_export": {
|
143
|
+
"enabled": True,
|
144
|
+
"approval_required": True,
|
145
|
+
"encryption_required": True
|
146
|
+
},
|
147
|
+
"model_deployment": {
|
148
|
+
"enabled": True,
|
149
|
+
"testing_required": True,
|
150
|
+
"documentation_required": True
|
151
|
+
},
|
152
|
+
"feature_selection": {
|
153
|
+
"enabled": True,
|
154
|
+
"sensitive_data_allowed": False,
|
155
|
+
"audit_trail_required": True
|
156
|
+
}
|
157
|
+
},
|
158
|
+
|
159
|
+
data_processing_rules={
|
160
|
+
"missing_value_strategy": "median",
|
161
|
+
"outlier_detection": True,
|
162
|
+
"outlier_threshold": 3.0,
|
163
|
+
"feature_scaling": "standard",
|
164
|
+
"categorical_encoding": "one_hot"
|
165
|
+
},
|
166
|
+
|
167
|
+
deployment_config={
|
168
|
+
"platform": "turing",
|
169
|
+
"environment": "production",
|
170
|
+
"monitoring_enabled": True,
|
171
|
+
"auto_scaling": True,
|
172
|
+
"backup_required": True
|
173
|
+
},
|
174
|
+
|
175
|
+
feature_engineering_config={
|
176
|
+
"time_windows": ["1_month", "3_months", "6_months", "1_year"],
|
177
|
+
"aggregation_functions": ["sum", "avg", "max", "min", "std"],
|
178
|
+
"interaction_features": True,
|
179
|
+
"polynomial_features": False,
|
180
|
+
"target_encoding": True
|
181
|
+
}
|
182
|
+
)
|
183
|
+
|
184
|
+
|
185
|
+
def create_generic_config() -> BankConfig:
|
186
|
+
"""创建通用银行配置"""
|
187
|
+
return BankConfig(
|
188
|
+
bank_code="generic",
|
189
|
+
bank_name="通用银行配置",
|
190
|
+
region="generic",
|
191
|
+
|
192
|
+
database_config={
|
193
|
+
"default_database": "default_db",
|
194
|
+
"connection_pool_size": 5,
|
195
|
+
"query_timeout": 180
|
196
|
+
},
|
197
|
+
|
198
|
+
table_mappings={
|
199
|
+
"behavior_table": "customer_behavior",
|
200
|
+
"asset_avg_table": "customer_assets",
|
201
|
+
"asset_config_table": "asset_config",
|
202
|
+
"monthly_stat_table": "monthly_stats"
|
203
|
+
},
|
204
|
+
|
205
|
+
business_rules={
|
206
|
+
"data_retention_days": 30,
|
207
|
+
"min_sample_size": 100,
|
208
|
+
"max_features": 100
|
209
|
+
},
|
210
|
+
|
211
|
+
compliance_rules={
|
212
|
+
"data_export": {"enabled": True},
|
213
|
+
"model_deployment": {"enabled": True}
|
214
|
+
},
|
215
|
+
|
216
|
+
data_processing_rules={
|
217
|
+
"missing_value_strategy": "mean",
|
218
|
+
"outlier_detection": False,
|
219
|
+
"feature_scaling": "none"
|
220
|
+
}
|
221
|
+
)
|
222
|
+
|
223
|
+
|
224
|
+
# 初始化默认银行配置
|
225
|
+
def initialize_default_configs():
|
226
|
+
"""初始化默认银行配置"""
|
227
|
+
# 注册工商银行配置
|
228
|
+
register_bank_config(create_icbc_config())
|
229
|
+
|
230
|
+
# 注册通用配置
|
231
|
+
register_bank_config(create_generic_config())
|
232
|
+
|
233
|
+
|
234
|
+
# 自动初始化
|
235
|
+
initialize_default_configs()
|
236
|
+
|
237
|
+
|
238
|
+
# 新疆工行特定配置
|
239
|
+
def create_xinjiang_icbc_config() -> BankConfig:
|
240
|
+
"""创建新疆工商银行配置"""
|
241
|
+
base_config = create_icbc_config()
|
242
|
+
|
243
|
+
# 基于基础工行配置进行定制
|
244
|
+
base_config.bank_code = "xinjiang_icbc"
|
245
|
+
base_config.bank_name = "新疆工商银行"
|
246
|
+
base_config.region = "xinjiang"
|
247
|
+
|
248
|
+
# 新疆特定的业务规则
|
249
|
+
base_config.business_rules.update({
|
250
|
+
"regional_compliance": True,
|
251
|
+
"minority_customer_support": True,
|
252
|
+
"language_support": ["zh", "ug"], # 中文和维吾尔语
|
253
|
+
"timezone": "Asia/Urumqi",
|
254
|
+
"currency_support": ["CNY"],
|
255
|
+
"cross_border_transaction": True
|
256
|
+
})
|
257
|
+
|
258
|
+
# 新疆特定的数据处理规则
|
259
|
+
base_config.data_processing_rules.update({
|
260
|
+
"character_encoding": "utf-8",
|
261
|
+
"regional_holidays": True,
|
262
|
+
"time_zone_conversion": True
|
263
|
+
})
|
264
|
+
|
265
|
+
return base_config
|
266
|
+
|
267
|
+
|
268
|
+
# 注册新疆工行配置
|
269
|
+
register_bank_config(create_xinjiang_icbc_config())
|
staran/models/config.py
ADDED
@@ -0,0 +1,271 @@
|
|
1
|
+
"""
|
2
|
+
模型配置管理模块
|
3
|
+
|
4
|
+
定义模型的核心配置信息,包括模型类型、参数、特征配置等
|
5
|
+
"""
|
6
|
+
|
7
|
+
from enum import Enum
|
8
|
+
from typing import Dict, Any, List, Optional
|
9
|
+
from dataclasses import dataclass, field
|
10
|
+
from datetime import datetime
|
11
|
+
|
12
|
+
|
13
|
+
class ModelType(Enum):
|
14
|
+
"""模型类型枚举"""
|
15
|
+
CLASSIFICATION = "classification"
|
16
|
+
REGRESSION = "regression"
|
17
|
+
CLUSTERING = "clustering"
|
18
|
+
TIME_SERIES = "time_series"
|
19
|
+
ANOMALY_DETECTION = "anomaly_detection"
|
20
|
+
RECOMMENDATION = "recommendation"
|
21
|
+
|
22
|
+
|
23
|
+
class ModelAlgorithm(Enum):
|
24
|
+
"""模型算法枚举"""
|
25
|
+
# 分类算法
|
26
|
+
LOGISTIC_REGRESSION = "logistic_regression"
|
27
|
+
RANDOM_FOREST = "random_forest"
|
28
|
+
GRADIENT_BOOSTING = "gradient_boosting"
|
29
|
+
SVM = "svm"
|
30
|
+
NEURAL_NETWORK = "neural_network"
|
31
|
+
|
32
|
+
# 回归算法
|
33
|
+
LINEAR_REGRESSION = "linear_regression"
|
34
|
+
RIDGE_REGRESSION = "ridge_regression"
|
35
|
+
LASSO_REGRESSION = "lasso_regression"
|
36
|
+
|
37
|
+
# 聚类算法
|
38
|
+
KMEANS = "kmeans"
|
39
|
+
DBSCAN = "dbscan"
|
40
|
+
HIERARCHICAL = "hierarchical"
|
41
|
+
|
42
|
+
# 时间序列
|
43
|
+
ARIMA = "arima"
|
44
|
+
LSTM = "lstm"
|
45
|
+
PROPHET = "prophet"
|
46
|
+
|
47
|
+
|
48
|
+
@dataclass
|
49
|
+
class FeatureConfig:
|
50
|
+
"""特征配置"""
|
51
|
+
schema_name: str # 使用的schema名称 (如 'aum')
|
52
|
+
table_types: List[str] # 使用的表类型列表 (如 ['behavior', 'asset_avg'])
|
53
|
+
feature_selection: bool = True # 是否启用特征选择
|
54
|
+
feature_engineering: bool = True # 是否启用特征工程
|
55
|
+
scaling: bool = True # 是否启用特征缩放
|
56
|
+
encoding: Dict[str, str] = field(default_factory=dict) # 编码配置
|
57
|
+
|
58
|
+
|
59
|
+
@dataclass
|
60
|
+
class ModelConfig:
|
61
|
+
"""模型配置类"""
|
62
|
+
# 基本信息
|
63
|
+
name: str # 模型名称
|
64
|
+
model_type: ModelType # 模型类型
|
65
|
+
algorithm: ModelAlgorithm # 使用的算法
|
66
|
+
version: str = "1.0.0" # 模型版本
|
67
|
+
|
68
|
+
# 特征配置
|
69
|
+
feature_config: FeatureConfig = None
|
70
|
+
|
71
|
+
# 模型参数
|
72
|
+
hyperparameters: Dict[str, Any] = field(default_factory=dict)
|
73
|
+
|
74
|
+
# 训练配置
|
75
|
+
training_config: Dict[str, Any] = field(default_factory=lambda: {
|
76
|
+
'test_size': 0.2,
|
77
|
+
'random_state': 42,
|
78
|
+
'cross_validation': True,
|
79
|
+
'cv_folds': 5
|
80
|
+
})
|
81
|
+
|
82
|
+
# 评估配置
|
83
|
+
evaluation_metrics: List[str] = field(default_factory=list)
|
84
|
+
|
85
|
+
# 银行特定配置
|
86
|
+
bank_code: str = "generic" # 银行代码
|
87
|
+
business_domain: str = "generic" # 业务领域
|
88
|
+
|
89
|
+
# 元数据
|
90
|
+
description: str = "" # 模型描述
|
91
|
+
created_at: datetime = field(default_factory=datetime.now)
|
92
|
+
created_by: str = "system" # 创建者
|
93
|
+
tags: List[str] = field(default_factory=list)
|
94
|
+
|
95
|
+
# 部署配置
|
96
|
+
deployment_config: Dict[str, Any] = field(default_factory=dict)
|
97
|
+
|
98
|
+
def __post_init__(self):
|
99
|
+
"""初始化后处理"""
|
100
|
+
if self.feature_config is None:
|
101
|
+
self.feature_config = FeatureConfig(
|
102
|
+
schema_name="generic",
|
103
|
+
table_types=["base"]
|
104
|
+
)
|
105
|
+
|
106
|
+
# 根据模型类型设置默认评估指标
|
107
|
+
if not self.evaluation_metrics:
|
108
|
+
self.evaluation_metrics = self._get_default_metrics()
|
109
|
+
|
110
|
+
def _get_default_metrics(self) -> List[str]:
|
111
|
+
"""根据模型类型获取默认评估指标"""
|
112
|
+
if self.model_type == ModelType.CLASSIFICATION:
|
113
|
+
return ['accuracy', 'precision', 'recall', 'f1_score', 'auc']
|
114
|
+
elif self.model_type == ModelType.REGRESSION:
|
115
|
+
return ['mae', 'mse', 'rmse', 'r2_score']
|
116
|
+
elif self.model_type == ModelType.CLUSTERING:
|
117
|
+
return ['silhouette_score', 'calinski_harabasz_score']
|
118
|
+
else:
|
119
|
+
return ['custom_metric']
|
120
|
+
|
121
|
+
def to_dict(self) -> Dict[str, Any]:
|
122
|
+
"""转换为字典格式"""
|
123
|
+
return {
|
124
|
+
'name': self.name,
|
125
|
+
'model_type': self.model_type.value,
|
126
|
+
'algorithm': self.algorithm.value,
|
127
|
+
'version': self.version,
|
128
|
+
'feature_config': {
|
129
|
+
'schema_name': self.feature_config.schema_name,
|
130
|
+
'table_types': self.feature_config.table_types,
|
131
|
+
'feature_selection': self.feature_config.feature_selection,
|
132
|
+
'feature_engineering': self.feature_config.feature_engineering,
|
133
|
+
'scaling': self.feature_config.scaling,
|
134
|
+
'encoding': self.feature_config.encoding
|
135
|
+
},
|
136
|
+
'hyperparameters': self.hyperparameters,
|
137
|
+
'training_config': self.training_config,
|
138
|
+
'evaluation_metrics': self.evaluation_metrics,
|
139
|
+
'bank_code': self.bank_code,
|
140
|
+
'business_domain': self.business_domain,
|
141
|
+
'description': self.description,
|
142
|
+
'created_at': self.created_at.isoformat(),
|
143
|
+
'created_by': self.created_by,
|
144
|
+
'tags': self.tags,
|
145
|
+
'deployment_config': self.deployment_config
|
146
|
+
}
|
147
|
+
|
148
|
+
@classmethod
|
149
|
+
def from_dict(cls, data: Dict[str, Any]) -> 'ModelConfig':
|
150
|
+
"""从字典创建ModelConfig实例"""
|
151
|
+
feature_config_data = data.get('feature_config', {})
|
152
|
+
feature_config = FeatureConfig(
|
153
|
+
schema_name=feature_config_data.get('schema_name', 'generic'),
|
154
|
+
table_types=feature_config_data.get('table_types', ['base']),
|
155
|
+
feature_selection=feature_config_data.get('feature_selection', True),
|
156
|
+
feature_engineering=feature_config_data.get('feature_engineering', True),
|
157
|
+
scaling=feature_config_data.get('scaling', True),
|
158
|
+
encoding=feature_config_data.get('encoding', {})
|
159
|
+
)
|
160
|
+
|
161
|
+
return cls(
|
162
|
+
name=data['name'],
|
163
|
+
model_type=ModelType(data['model_type']),
|
164
|
+
algorithm=ModelAlgorithm(data['algorithm']),
|
165
|
+
version=data.get('version', '1.0.0'),
|
166
|
+
feature_config=feature_config,
|
167
|
+
hyperparameters=data.get('hyperparameters', {}),
|
168
|
+
training_config=data.get('training_config', {}),
|
169
|
+
evaluation_metrics=data.get('evaluation_metrics', []),
|
170
|
+
bank_code=data.get('bank_code', 'generic'),
|
171
|
+
business_domain=data.get('business_domain', 'generic'),
|
172
|
+
description=data.get('description', ''),
|
173
|
+
created_by=data.get('created_by', 'system'),
|
174
|
+
tags=data.get('tags', []),
|
175
|
+
deployment_config=data.get('deployment_config', {})
|
176
|
+
)
|
177
|
+
|
178
|
+
|
179
|
+
def create_model_config(
|
180
|
+
name: str,
|
181
|
+
model_type: str,
|
182
|
+
algorithm: str,
|
183
|
+
schema_name: str = "generic",
|
184
|
+
table_types: List[str] = None,
|
185
|
+
bank_code: str = "generic",
|
186
|
+
**kwargs
|
187
|
+
) -> ModelConfig:
|
188
|
+
"""
|
189
|
+
创建模型配置的便捷函数
|
190
|
+
|
191
|
+
Args:
|
192
|
+
name: 模型名称
|
193
|
+
model_type: 模型类型
|
194
|
+
algorithm: 算法名称
|
195
|
+
schema_name: 使用的schema名称
|
196
|
+
table_types: 使用的表类型列表
|
197
|
+
bank_code: 银行代码
|
198
|
+
**kwargs: 其他配置参数
|
199
|
+
|
200
|
+
Returns:
|
201
|
+
ModelConfig实例
|
202
|
+
"""
|
203
|
+
if table_types is None:
|
204
|
+
table_types = ["base"]
|
205
|
+
|
206
|
+
feature_config = FeatureConfig(
|
207
|
+
schema_name=schema_name,
|
208
|
+
table_types=table_types
|
209
|
+
)
|
210
|
+
|
211
|
+
return ModelConfig(
|
212
|
+
name=name,
|
213
|
+
model_type=ModelType(model_type),
|
214
|
+
algorithm=ModelAlgorithm(algorithm),
|
215
|
+
feature_config=feature_config,
|
216
|
+
bank_code=bank_code,
|
217
|
+
**kwargs
|
218
|
+
)
|
219
|
+
|
220
|
+
|
221
|
+
# 预定义的模型配置模板
|
222
|
+
PRESET_CONFIGS = {
|
223
|
+
"aum_longtail_classification": {
|
224
|
+
"model_type": "classification",
|
225
|
+
"algorithm": "random_forest",
|
226
|
+
"schema_name": "aum",
|
227
|
+
"table_types": ["behavior", "asset_avg", "asset_config", "monthly_stat"],
|
228
|
+
"hyperparameters": {
|
229
|
+
"n_estimators": 100,
|
230
|
+
"max_depth": 10,
|
231
|
+
"random_state": 42
|
232
|
+
},
|
233
|
+
"description": "AUM长尾客户分类模型"
|
234
|
+
},
|
235
|
+
|
236
|
+
"customer_value_regression": {
|
237
|
+
"model_type": "regression",
|
238
|
+
"algorithm": "gradient_boosting",
|
239
|
+
"schema_name": "aum",
|
240
|
+
"table_types": ["behavior", "asset_avg"],
|
241
|
+
"hyperparameters": {
|
242
|
+
"n_estimators": 150,
|
243
|
+
"learning_rate": 0.1,
|
244
|
+
"max_depth": 8
|
245
|
+
},
|
246
|
+
"description": "客户价值预测回归模型"
|
247
|
+
}
|
248
|
+
}
|
249
|
+
|
250
|
+
|
251
|
+
def create_preset_config(preset_name: str, **overrides) -> ModelConfig:
|
252
|
+
"""
|
253
|
+
基于预设模板创建模型配置
|
254
|
+
|
255
|
+
Args:
|
256
|
+
preset_name: 预设模板名称
|
257
|
+
**overrides: 覆盖的配置参数
|
258
|
+
|
259
|
+
Returns:
|
260
|
+
ModelConfig实例
|
261
|
+
"""
|
262
|
+
if preset_name not in PRESET_CONFIGS:
|
263
|
+
raise ValueError(f"未知的预设配置: {preset_name}")
|
264
|
+
|
265
|
+
config = PRESET_CONFIGS[preset_name].copy()
|
266
|
+
config.update(overrides)
|
267
|
+
|
268
|
+
return create_model_config(
|
269
|
+
name=preset_name,
|
270
|
+
**config
|
271
|
+
)
|
@@ -0,0 +1,281 @@
|
|
1
|
+
"""
|
2
|
+
模型注册和管理模块
|
3
|
+
|
4
|
+
提供模型配置的注册、查询、版本管理等功能
|
5
|
+
"""
|
6
|
+
|
7
|
+
from typing import Dict, Any, List, Optional, Tuple
|
8
|
+
from dataclasses import dataclass
|
9
|
+
from datetime import datetime
|
10
|
+
import json
|
11
|
+
import os
|
12
|
+
|
13
|
+
from .config import ModelConfig
|
14
|
+
from .target import TargetDefinition
|
15
|
+
|
16
|
+
|
17
|
+
@dataclass
|
18
|
+
class ModelEntry:
|
19
|
+
"""模型注册条目"""
|
20
|
+
model_config: ModelConfig
|
21
|
+
target_definition: TargetDefinition
|
22
|
+
registered_at: datetime
|
23
|
+
status: str = "active" # active, inactive, deprecated
|
24
|
+
performance_metrics: Dict[str, float] = None
|
25
|
+
|
26
|
+
def __post_init__(self):
|
27
|
+
if self.performance_metrics is None:
|
28
|
+
self.performance_metrics = {}
|
29
|
+
|
30
|
+
|
31
|
+
class ModelRegistry:
|
32
|
+
"""模型注册表"""
|
33
|
+
|
34
|
+
_models: Dict[str, ModelEntry] = {}
|
35
|
+
_version_history: Dict[str, List[str]] = {} # 模型名称 -> 版本列表
|
36
|
+
|
37
|
+
@classmethod
|
38
|
+
def register(cls, model_config: ModelConfig, target_definition: TargetDefinition) -> str:
|
39
|
+
"""
|
40
|
+
注册一个新模型
|
41
|
+
|
42
|
+
Args:
|
43
|
+
model_config: 模型配置
|
44
|
+
target_definition: 目标变量定义
|
45
|
+
|
46
|
+
Returns:
|
47
|
+
模型的唯一标识符
|
48
|
+
"""
|
49
|
+
model_id = f"{model_config.name}_{model_config.version}"
|
50
|
+
|
51
|
+
# 检查是否已存在
|
52
|
+
if model_id in cls._models:
|
53
|
+
raise ValueError(f"模型 {model_id} 已存在")
|
54
|
+
|
55
|
+
# 创建模型条目
|
56
|
+
entry = ModelEntry(
|
57
|
+
model_config=model_config,
|
58
|
+
target_definition=target_definition,
|
59
|
+
registered_at=datetime.now()
|
60
|
+
)
|
61
|
+
|
62
|
+
# 注册模型
|
63
|
+
cls._models[model_id] = entry
|
64
|
+
|
65
|
+
# 更新版本历史
|
66
|
+
if model_config.name not in cls._version_history:
|
67
|
+
cls._version_history[model_config.name] = []
|
68
|
+
cls._version_history[model_config.name].append(model_config.version)
|
69
|
+
|
70
|
+
print(f"✅ 模型 {model_id} 注册成功")
|
71
|
+
return model_id
|
72
|
+
|
73
|
+
@classmethod
|
74
|
+
def get_model(cls, model_name: str, version: str = None) -> Optional[ModelEntry]:
|
75
|
+
"""
|
76
|
+
获取模型条目
|
77
|
+
|
78
|
+
Args:
|
79
|
+
model_name: 模型名称
|
80
|
+
version: 版本号,如果不指定则返回最新版本
|
81
|
+
|
82
|
+
Returns:
|
83
|
+
模型条目或None
|
84
|
+
"""
|
85
|
+
if version:
|
86
|
+
model_id = f"{model_name}_{version}"
|
87
|
+
return cls._models.get(model_id)
|
88
|
+
else:
|
89
|
+
# 获取最新版本
|
90
|
+
versions = cls._version_history.get(model_name, [])
|
91
|
+
if not versions:
|
92
|
+
return None
|
93
|
+
|
94
|
+
latest_version = max(versions) # 简单的字符串比较,实际应该用版本比较
|
95
|
+
model_id = f"{model_name}_{latest_version}"
|
96
|
+
return cls._models.get(model_id)
|
97
|
+
|
98
|
+
@classmethod
|
99
|
+
def get_model_config(cls, model_name: str, version: str = None) -> Optional[ModelConfig]:
|
100
|
+
"""获取模型配置"""
|
101
|
+
entry = cls.get_model(model_name, version)
|
102
|
+
return entry.model_config if entry else None
|
103
|
+
|
104
|
+
@classmethod
|
105
|
+
def get_target_definition(cls, model_name: str, version: str = None) -> Optional[TargetDefinition]:
|
106
|
+
"""获取目标变量定义"""
|
107
|
+
entry = cls.get_model(model_name, version)
|
108
|
+
return entry.target_definition if entry else None
|
109
|
+
|
110
|
+
@classmethod
|
111
|
+
def list_models(cls) -> List[Dict[str, Any]]:
|
112
|
+
"""列出所有注册的模型"""
|
113
|
+
result = []
|
114
|
+
for model_id, entry in cls._models.items():
|
115
|
+
result.append({
|
116
|
+
'model_id': model_id,
|
117
|
+
'name': entry.model_config.name,
|
118
|
+
'version': entry.model_config.version,
|
119
|
+
'type': entry.model_config.model_type.value,
|
120
|
+
'algorithm': entry.model_config.algorithm.value,
|
121
|
+
'bank_code': entry.model_config.bank_code,
|
122
|
+
'status': entry.status,
|
123
|
+
'registered_at': entry.registered_at.isoformat(),
|
124
|
+
'description': entry.model_config.description
|
125
|
+
})
|
126
|
+
return result
|
127
|
+
|
128
|
+
@classmethod
|
129
|
+
def list_versions(cls, model_name: str) -> List[str]:
|
130
|
+
"""列出模型的所有版本"""
|
131
|
+
return cls._version_history.get(model_name, [])
|
132
|
+
|
133
|
+
@classmethod
|
134
|
+
def update_status(cls, model_name: str, status: str, version: str = None):
|
135
|
+
"""更新模型状态"""
|
136
|
+
entry = cls.get_model(model_name, version)
|
137
|
+
if entry:
|
138
|
+
entry.status = status
|
139
|
+
print(f"✅ 模型 {model_name} 状态更新为: {status}")
|
140
|
+
else:
|
141
|
+
print(f"❌ 模型 {model_name} 不存在")
|
142
|
+
|
143
|
+
@classmethod
|
144
|
+
def update_performance(cls, model_name: str, metrics: Dict[str, float], version: str = None):
|
145
|
+
"""更新模型性能指标"""
|
146
|
+
entry = cls.get_model(model_name, version)
|
147
|
+
if entry:
|
148
|
+
entry.performance_metrics.update(metrics)
|
149
|
+
print(f"✅ 模型 {model_name} 性能指标已更新")
|
150
|
+
else:
|
151
|
+
print(f"❌ 模型 {model_name} 不存在")
|
152
|
+
|
153
|
+
@classmethod
|
154
|
+
def get_model_summary(cls, model_name: str, version: str = None) -> Optional[Dict[str, Any]]:
|
155
|
+
"""获取模型详细信息摘要"""
|
156
|
+
entry = cls.get_model(model_name, version)
|
157
|
+
if not entry:
|
158
|
+
return None
|
159
|
+
|
160
|
+
model_config = entry.model_config
|
161
|
+
target_def = entry.target_definition
|
162
|
+
|
163
|
+
return {
|
164
|
+
'basic_info': {
|
165
|
+
'name': model_config.name,
|
166
|
+
'version': model_config.version,
|
167
|
+
'type': model_config.model_type.value,
|
168
|
+
'algorithm': model_config.algorithm.value,
|
169
|
+
'description': model_config.description,
|
170
|
+
'created_by': model_config.created_by,
|
171
|
+
'bank_code': model_config.bank_code
|
172
|
+
},
|
173
|
+
'feature_config': {
|
174
|
+
'schema_name': model_config.feature_config.schema_name,
|
175
|
+
'table_types': model_config.feature_config.table_types,
|
176
|
+
'feature_selection': model_config.feature_config.feature_selection,
|
177
|
+
'feature_engineering': model_config.feature_config.feature_engineering
|
178
|
+
},
|
179
|
+
'target_config': {
|
180
|
+
'name': target_def.name,
|
181
|
+
'type': target_def.target_type.value,
|
182
|
+
'column': target_def.target_column,
|
183
|
+
'description': target_def.description
|
184
|
+
},
|
185
|
+
'training_config': model_config.training_config,
|
186
|
+
'hyperparameters': model_config.hyperparameters,
|
187
|
+
'evaluation_metrics': model_config.evaluation_metrics,
|
188
|
+
'registry_info': {
|
189
|
+
'status': entry.status,
|
190
|
+
'registered_at': entry.registered_at.isoformat(),
|
191
|
+
'performance_metrics': entry.performance_metrics
|
192
|
+
}
|
193
|
+
}
|
194
|
+
|
195
|
+
@classmethod
|
196
|
+
def save_to_file(cls, filepath: str):
|
197
|
+
"""保存注册表到文件"""
|
198
|
+
data = {
|
199
|
+
'models': {},
|
200
|
+
'version_history': cls._version_history,
|
201
|
+
'saved_at': datetime.now().isoformat()
|
202
|
+
}
|
203
|
+
|
204
|
+
# 序列化模型数据
|
205
|
+
for model_id, entry in cls._models.items():
|
206
|
+
data['models'][model_id] = {
|
207
|
+
'model_config': entry.model_config.to_dict(),
|
208
|
+
'target_definition': entry.target_definition.to_dict(),
|
209
|
+
'registered_at': entry.registered_at.isoformat(),
|
210
|
+
'status': entry.status,
|
211
|
+
'performance_metrics': entry.performance_metrics
|
212
|
+
}
|
213
|
+
|
214
|
+
os.makedirs(os.path.dirname(filepath), exist_ok=True)
|
215
|
+
with open(filepath, 'w', encoding='utf-8') as f:
|
216
|
+
json.dump(data, f, indent=2, ensure_ascii=False)
|
217
|
+
|
218
|
+
print(f"✅ 模型注册表已保存到: {filepath}")
|
219
|
+
|
220
|
+
@classmethod
|
221
|
+
def load_from_file(cls, filepath: str):
|
222
|
+
"""从文件加载注册表"""
|
223
|
+
if not os.path.exists(filepath):
|
224
|
+
print(f"❌ 文件不存在: {filepath}")
|
225
|
+
return
|
226
|
+
|
227
|
+
with open(filepath, 'r', encoding='utf-8') as f:
|
228
|
+
data = json.load(f)
|
229
|
+
|
230
|
+
cls._version_history = data.get('version_history', {})
|
231
|
+
cls._models = {}
|
232
|
+
|
233
|
+
# 反序列化模型数据
|
234
|
+
for model_id, entry_data in data.get('models', {}).items():
|
235
|
+
model_config = ModelConfig.from_dict(entry_data['model_config'])
|
236
|
+
target_definition = TargetDefinition(
|
237
|
+
**entry_data['target_definition']
|
238
|
+
)
|
239
|
+
|
240
|
+
entry = ModelEntry(
|
241
|
+
model_config=model_config,
|
242
|
+
target_definition=target_definition,
|
243
|
+
registered_at=datetime.fromisoformat(entry_data['registered_at']),
|
244
|
+
status=entry_data.get('status', 'active'),
|
245
|
+
performance_metrics=entry_data.get('performance_metrics', {})
|
246
|
+
)
|
247
|
+
|
248
|
+
cls._models[model_id] = entry
|
249
|
+
|
250
|
+
print(f"✅ 从 {filepath} 加载了 {len(cls._models)} 个模型")
|
251
|
+
|
252
|
+
|
253
|
+
# 便捷函数
|
254
|
+
def register_model(model_config: ModelConfig, target_definition: TargetDefinition) -> str:
|
255
|
+
"""注册模型的便捷函数"""
|
256
|
+
return ModelRegistry.register(model_config, target_definition)
|
257
|
+
|
258
|
+
|
259
|
+
def get_model_config(model_name: str, version: str = None) -> Optional[ModelConfig]:
|
260
|
+
"""获取模型配置的便捷函数"""
|
261
|
+
return ModelRegistry.get_model_config(model_name, version)
|
262
|
+
|
263
|
+
|
264
|
+
def get_target_definition(model_name: str, version: str = None) -> Optional[TargetDefinition]:
|
265
|
+
"""获取目标变量定义的便捷函数"""
|
266
|
+
return ModelRegistry.get_target_definition(model_name, version)
|
267
|
+
|
268
|
+
|
269
|
+
def list_available_models() -> List[Dict[str, Any]]:
|
270
|
+
"""列出可用模型的便捷函数"""
|
271
|
+
return ModelRegistry.list_models()
|
272
|
+
|
273
|
+
|
274
|
+
def save_model_registry(filepath: str = "./models/model_registry.json"):
|
275
|
+
"""保存模型注册表的便捷函数"""
|
276
|
+
ModelRegistry.save_to_file(filepath)
|
277
|
+
|
278
|
+
|
279
|
+
def load_model_registry(filepath: str = "./models/model_registry.json"):
|
280
|
+
"""加载模型注册表的便捷函数"""
|
281
|
+
ModelRegistry.load_from_file(filepath)
|
staran/models/target.py
ADDED
@@ -0,0 +1,321 @@
|
|
1
|
+
"""
|
2
|
+
目标变量定义模块
|
3
|
+
|
4
|
+
提供基于SQL的目标变量定义和生成功能
|
5
|
+
"""
|
6
|
+
|
7
|
+
from enum import Enum
|
8
|
+
from typing import Dict, Any, List, Optional, Union
|
9
|
+
from dataclasses import dataclass, field
|
10
|
+
from datetime import datetime
|
11
|
+
import re
|
12
|
+
|
13
|
+
|
14
|
+
class TargetType(Enum):
|
15
|
+
"""目标变量类型"""
|
16
|
+
BINARY_CLASSIFICATION = "binary_classification" # 二分类
|
17
|
+
MULTI_CLASSIFICATION = "multi_classification" # 多分类
|
18
|
+
REGRESSION = "regression" # 回归
|
19
|
+
RANKING = "ranking" # 排序
|
20
|
+
CLUSTERING = "clustering" # 聚类
|
21
|
+
SQL_BASED = "sql_based" # 基于SQL的自定义目标
|
22
|
+
|
23
|
+
|
24
|
+
class TargetEncoding(Enum):
|
25
|
+
"""目标变量编码方式"""
|
26
|
+
NONE = "none" # 不编码
|
27
|
+
LABEL_ENCODING = "label" # 标签编码
|
28
|
+
ONE_HOT = "one_hot" # 独热编码
|
29
|
+
ORDINAL = "ordinal" # 序数编码
|
30
|
+
BINARY = "binary" # 二进制编码
|
31
|
+
|
32
|
+
|
33
|
+
@dataclass
|
34
|
+
class TargetDefinition:
|
35
|
+
"""目标变量定义类"""
|
36
|
+
# 基本信息
|
37
|
+
name: str # 目标变量名称
|
38
|
+
target_type: TargetType # 目标类型
|
39
|
+
description: str = "" # 描述
|
40
|
+
|
41
|
+
# SQL定义 (核心功能)
|
42
|
+
sql_query: str = "" # 生成目标变量的SQL查询
|
43
|
+
target_column: str = "target" # 目标列名
|
44
|
+
|
45
|
+
# 数据信息
|
46
|
+
data_type: str = "float" # 数据类型
|
47
|
+
encoding: TargetEncoding = TargetEncoding.NONE # 编码方式
|
48
|
+
|
49
|
+
# 分类相关
|
50
|
+
class_labels: List[str] = field(default_factory=list) # 类别标签
|
51
|
+
class_weights: Dict[str, float] = field(default_factory=dict) # 类别权重
|
52
|
+
|
53
|
+
# 回归相关
|
54
|
+
min_value: Optional[float] = None # 最小值
|
55
|
+
max_value: Optional[float] = None # 最大值
|
56
|
+
normalization: bool = False # 是否标准化
|
57
|
+
|
58
|
+
# 时间相关
|
59
|
+
time_window: str = "" # 时间窗口 (如 "30_days", "3_months")
|
60
|
+
prediction_horizon: str = "" # 预测时间范围
|
61
|
+
|
62
|
+
# 银行特定
|
63
|
+
bank_code: str = "generic" # 银行代码
|
64
|
+
business_rules: Dict[str, Any] = field(default_factory=dict) # 业务规则
|
65
|
+
|
66
|
+
# 元数据
|
67
|
+
created_at: datetime = field(default_factory=datetime.now)
|
68
|
+
created_by: str = "system"
|
69
|
+
version: str = "1.0.0"
|
70
|
+
tags: List[str] = field(default_factory=list)
|
71
|
+
|
72
|
+
# 验证配置
|
73
|
+
validation_rules: Dict[str, Any] = field(default_factory=dict)
|
74
|
+
|
75
|
+
def __post_init__(self):
|
76
|
+
"""初始化后处理"""
|
77
|
+
if not self.sql_query and self.target_type != TargetType.SQL_BASED:
|
78
|
+
self.sql_query = self._generate_default_sql()
|
79
|
+
|
80
|
+
# 验证SQL语法
|
81
|
+
if self.sql_query:
|
82
|
+
self._validate_sql()
|
83
|
+
|
84
|
+
def _generate_default_sql(self) -> str:
|
85
|
+
"""根据目标类型生成默认SQL"""
|
86
|
+
if self.target_type == TargetType.BINARY_CLASSIFICATION:
|
87
|
+
return f"""
|
88
|
+
SELECT party_id,
|
89
|
+
CASE WHEN condition THEN 1 ELSE 0 END as {self.target_column}
|
90
|
+
FROM source_table
|
91
|
+
WHERE data_dt = '{{data_dt}}'
|
92
|
+
"""
|
93
|
+
elif self.target_type == TargetType.REGRESSION:
|
94
|
+
return f"""
|
95
|
+
SELECT party_id,
|
96
|
+
target_value as {self.target_column}
|
97
|
+
FROM source_table
|
98
|
+
WHERE data_dt = '{{data_dt}}'
|
99
|
+
"""
|
100
|
+
else:
|
101
|
+
return ""
|
102
|
+
|
103
|
+
def _validate_sql(self):
|
104
|
+
"""验证SQL语法基本正确性"""
|
105
|
+
sql = self.sql_query.strip().upper()
|
106
|
+
|
107
|
+
# 基本SQL结构检查 (支持WITH语句)
|
108
|
+
if not (sql.startswith('SELECT') or sql.startswith('WITH')):
|
109
|
+
raise ValueError("SQL查询必须以SELECT或WITH开始")
|
110
|
+
|
111
|
+
# 检查是否包含目标列
|
112
|
+
if self.target_column.upper() not in sql:
|
113
|
+
print(f"警告: SQL中未找到目标列 '{self.target_column}'")
|
114
|
+
|
115
|
+
# 检查参数占位符
|
116
|
+
placeholders = re.findall(r'\{(\w+)\}', self.sql_query)
|
117
|
+
if placeholders:
|
118
|
+
print(f"发现参数占位符: {placeholders}")
|
119
|
+
|
120
|
+
def generate_sql(self, **params) -> str:
|
121
|
+
"""
|
122
|
+
生成最终的SQL查询,替换参数占位符
|
123
|
+
|
124
|
+
Args:
|
125
|
+
**params: SQL参数字典
|
126
|
+
|
127
|
+
Returns:
|
128
|
+
最终的SQL查询字符串
|
129
|
+
"""
|
130
|
+
sql = self.sql_query
|
131
|
+
|
132
|
+
# 替换参数占位符
|
133
|
+
for key, value in params.items():
|
134
|
+
placeholder = f"{{{key}}}"
|
135
|
+
sql = sql.replace(placeholder, str(value))
|
136
|
+
|
137
|
+
return sql
|
138
|
+
|
139
|
+
def get_sample_sql(self, data_dt: str = "20250728") -> str:
|
140
|
+
"""获取示例SQL"""
|
141
|
+
return self.generate_sql(data_dt=data_dt)
|
142
|
+
|
143
|
+
def validate_target_values(self, values: List[Any]) -> bool:
|
144
|
+
"""验证目标值是否符合定义"""
|
145
|
+
if self.target_type == TargetType.BINARY_CLASSIFICATION:
|
146
|
+
unique_values = set(values)
|
147
|
+
return unique_values.issubset({0, 1, 0.0, 1.0})
|
148
|
+
|
149
|
+
elif self.target_type == TargetType.MULTI_CLASSIFICATION:
|
150
|
+
if self.class_labels:
|
151
|
+
unique_values = set(values)
|
152
|
+
return unique_values.issubset(set(self.class_labels))
|
153
|
+
|
154
|
+
elif self.target_type == TargetType.REGRESSION:
|
155
|
+
if self.min_value is not None and self.max_value is not None:
|
156
|
+
return all(self.min_value <= v <= self.max_value for v in values)
|
157
|
+
|
158
|
+
return True
|
159
|
+
|
160
|
+
def to_dict(self) -> Dict[str, Any]:
|
161
|
+
"""转换为字典格式"""
|
162
|
+
return {
|
163
|
+
'name': self.name,
|
164
|
+
'target_type': self.target_type.value,
|
165
|
+
'description': self.description,
|
166
|
+
'sql_query': self.sql_query,
|
167
|
+
'target_column': self.target_column,
|
168
|
+
'data_type': self.data_type,
|
169
|
+
'encoding': self.encoding.value,
|
170
|
+
'class_labels': self.class_labels,
|
171
|
+
'class_weights': self.class_weights,
|
172
|
+
'min_value': self.min_value,
|
173
|
+
'max_value': self.max_value,
|
174
|
+
'normalization': self.normalization,
|
175
|
+
'time_window': self.time_window,
|
176
|
+
'prediction_horizon': self.prediction_horizon,
|
177
|
+
'bank_code': self.bank_code,
|
178
|
+
'business_rules': self.business_rules,
|
179
|
+
'created_at': self.created_at.isoformat(),
|
180
|
+
'created_by': self.created_by,
|
181
|
+
'version': self.version,
|
182
|
+
'tags': self.tags,
|
183
|
+
'validation_rules': self.validation_rules
|
184
|
+
}
|
185
|
+
|
186
|
+
|
187
|
+
def create_target_definition(
|
188
|
+
name: str,
|
189
|
+
target_type: str,
|
190
|
+
sql_query: str,
|
191
|
+
target_column: str = "target",
|
192
|
+
bank_code: str = "generic",
|
193
|
+
**kwargs
|
194
|
+
) -> TargetDefinition:
|
195
|
+
"""
|
196
|
+
创建目标变量定义的便捷函数
|
197
|
+
|
198
|
+
Args:
|
199
|
+
name: 目标变量名称
|
200
|
+
target_type: 目标类型
|
201
|
+
sql_query: SQL查询
|
202
|
+
target_column: 目标列名
|
203
|
+
bank_code: 银行代码
|
204
|
+
**kwargs: 其他配置参数
|
205
|
+
|
206
|
+
Returns:
|
207
|
+
TargetDefinition实例
|
208
|
+
"""
|
209
|
+
return TargetDefinition(
|
210
|
+
name=name,
|
211
|
+
target_type=TargetType(target_type),
|
212
|
+
sql_query=sql_query,
|
213
|
+
target_column=target_column,
|
214
|
+
bank_code=bank_code,
|
215
|
+
**kwargs
|
216
|
+
)
|
217
|
+
|
218
|
+
|
219
|
+
# 预定义的目标变量模板
|
220
|
+
TARGET_TEMPLATES = {
|
221
|
+
"aum_longtail_purchase": {
|
222
|
+
"target_type": "binary_classification",
|
223
|
+
"description": "AUM长尾客户未来购买预测",
|
224
|
+
"sql_query": """
|
225
|
+
SELECT
|
226
|
+
a.party_id,
|
227
|
+
CASE
|
228
|
+
WHEN b.purchase_amount > 0 THEN 1
|
229
|
+
ELSE 0
|
230
|
+
END as target
|
231
|
+
FROM
|
232
|
+
bi_hlwj_dfcw_f1_f4_wy a
|
233
|
+
LEFT JOIN (
|
234
|
+
SELECT party_id, SUM(productamount_sum) as purchase_amount
|
235
|
+
FROM bi_hlwj_dfcw_f1_f4_wy
|
236
|
+
WHERE data_dt BETWEEN '{start_dt}' AND '{end_dt}'
|
237
|
+
GROUP BY party_id
|
238
|
+
) b ON a.party_id = b.party_id
|
239
|
+
WHERE a.data_dt = '{feature_dt}'
|
240
|
+
""",
|
241
|
+
"target_column": "target",
|
242
|
+
"time_window": "30_days",
|
243
|
+
"class_labels": ["no_purchase", "purchase"]
|
244
|
+
},
|
245
|
+
|
246
|
+
"customer_value_prediction": {
|
247
|
+
"target_type": "regression",
|
248
|
+
"description": "客户价值预测",
|
249
|
+
"sql_query": """
|
250
|
+
SELECT
|
251
|
+
party_id,
|
252
|
+
asset_total_bal as target
|
253
|
+
FROM
|
254
|
+
bi_hlwj_zi_chan_avg_wy
|
255
|
+
WHERE
|
256
|
+
data_dt = '{target_dt}'
|
257
|
+
""",
|
258
|
+
"target_column": "target",
|
259
|
+
"data_type": "float",
|
260
|
+
"normalization": True
|
261
|
+
},
|
262
|
+
|
263
|
+
"risk_level_classification": {
|
264
|
+
"target_type": "multi_classification",
|
265
|
+
"description": "风险等级分类",
|
266
|
+
"sql_query": """
|
267
|
+
SELECT
|
268
|
+
party_id,
|
269
|
+
CASE
|
270
|
+
WHEN asset_total_bal < 10000 THEN 'low_risk'
|
271
|
+
WHEN asset_total_bal < 100000 THEN 'medium_risk'
|
272
|
+
ELSE 'high_risk'
|
273
|
+
END as target
|
274
|
+
FROM
|
275
|
+
bi_hlwj_zi_chan_avg_wy
|
276
|
+
WHERE
|
277
|
+
data_dt = '{data_dt}'
|
278
|
+
""",
|
279
|
+
"target_column": "target",
|
280
|
+
"class_labels": ["low_risk", "medium_risk", "high_risk"]
|
281
|
+
}
|
282
|
+
}
|
283
|
+
|
284
|
+
|
285
|
+
def create_preset_target(preset_name: str, **overrides) -> TargetDefinition:
|
286
|
+
"""
|
287
|
+
基于预设模板创建目标变量定义
|
288
|
+
|
289
|
+
Args:
|
290
|
+
preset_name: 预设模板名称
|
291
|
+
**overrides: 覆盖的配置参数
|
292
|
+
|
293
|
+
Returns:
|
294
|
+
TargetDefinition实例
|
295
|
+
"""
|
296
|
+
if preset_name not in TARGET_TEMPLATES:
|
297
|
+
raise ValueError(f"未知的目标变量模板: {preset_name}")
|
298
|
+
|
299
|
+
template = TARGET_TEMPLATES[preset_name].copy()
|
300
|
+
template.update(overrides)
|
301
|
+
|
302
|
+
return create_target_definition(
|
303
|
+
name=preset_name,
|
304
|
+
**template
|
305
|
+
)
|
306
|
+
|
307
|
+
|
308
|
+
def create_icbc_target(name: str, sql_query: str, target_type: str = "binary_classification", **kwargs) -> TargetDefinition:
|
309
|
+
"""创建工商银行专用目标变量定义"""
|
310
|
+
return create_target_definition(
|
311
|
+
name=name,
|
312
|
+
target_type=target_type,
|
313
|
+
sql_query=sql_query,
|
314
|
+
bank_code="icbc",
|
315
|
+
business_rules={
|
316
|
+
"data_retention_days": 90,
|
317
|
+
"privacy_compliance": True,
|
318
|
+
"audit_required": True
|
319
|
+
},
|
320
|
+
**kwargs
|
321
|
+
)
|
@@ -12,13 +12,18 @@ staran/features/engines.py,sha256=kqdS2xjmCVi0Xz1Oc3WaTMIavgAriX8F7VvUgVcpfqo,10
|
|
12
12
|
staran/features/generator.py,sha256=CI1F_PshOvokQJelsqSaVp-SNQpMc-WVmjMQKzgdeLw,23114
|
13
13
|
staran/features/manager.py,sha256=2-3Hc3qthtyzwiuQy5QTz6RfhKK3szoylconzI3moc4,5201
|
14
14
|
staran/features/schema.py,sha256=FwOfpTcxq4K8zkO3MFNqKPQBp_e8qY-N6gazqm9_lAQ,6067
|
15
|
+
staran/models/__init__.py,sha256=NH4r6GTAz9MeUfq1jAyVkx-nC4bM78XvbWA9TuwMLik,2141
|
16
|
+
staran/models/bank_configs.py,sha256=wN3GA_8cb5wevDC-sWRcJ3lMuaHahZVjC85K_t2aQt0,8177
|
17
|
+
staran/models/config.py,sha256=fTbZtJq4-ZuCSSd1eW7TkIbEdDyZv2agHJCYnwOCJ_s,8886
|
18
|
+
staran/models/registry.py,sha256=Zeey4TtbHtJ40odyZQzOLijyZCmlMBRuniPk_znS2Q8,10223
|
19
|
+
staran/models/target.py,sha256=gKTTatxvOJjmE50qD6G6mhlYLuZL3Cvn3FLNbXl1eeU,10531
|
15
20
|
staran/schemas/__init__.py,sha256=2RkcWCaIkrOHd37zzRCla0-jNg4cPnc6BGmmW5Vha0Y,652
|
16
21
|
staran/schemas/document_generator.py,sha256=Mr7TjmKwspqxXnp9DhzZxsRx0l2Bo7MOI8mOxRtgwxU,13600
|
17
22
|
staran/schemas/aum/__init__.py,sha256=jVkmJdhHGHdGE4rJ605zsRU2zIQMEHWnlgW2ZQk8AdU,13082
|
18
23
|
staran/tools/__init__.py,sha256=KtudrYnxKD9HZEL4H-mrWlKrmsI3rYjJrLeC9YDTpG4,1054
|
19
24
|
staran/tools/date.py,sha256=-QyEMWVx6czMuOIwcV7kR3gBMRVOwb5qevo7GEFSJKE,10488
|
20
|
-
staran-0.
|
21
|
-
staran-0.
|
22
|
-
staran-0.
|
23
|
-
staran-0.
|
24
|
-
staran-0.
|
25
|
+
staran-0.5.0.dist-info/licenses/LICENSE,sha256=2EmsBIyDCono4iVXNpv5_px9qt2b7hfPq1WuyGVMNP4,1361
|
26
|
+
staran-0.5.0.dist-info/METADATA,sha256=1c6403YfhOFEsZV7Ng1pe4B_wlRdp8SZypmhIH_AaVo,18809
|
27
|
+
staran-0.5.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
28
|
+
staran-0.5.0.dist-info/top_level.txt,sha256=NOUZtXSh5oSIEjHrC0lQ9WmoKtD010Q00dghWyag-Zs,7
|
29
|
+
staran-0.5.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|