staran 0.6.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- staran/__init__.py +10 -0
- staran/tools/__init__.py +5 -5
- staran-1.0.0.dist-info/METADATA +301 -0
- staran-1.0.0.dist-info/RECORD +8 -0
- staran/banks/__init__.py +0 -30
- staran/banks/xinjiang_icbc/__init__.py +0 -90
- staran/engines/__init__.py +0 -65
- staran/engines/base.py +0 -255
- staran/engines/hive.py +0 -163
- staran/engines/spark.py +0 -252
- staran/engines/turing.py +0 -439
- staran/features/__init__.py +0 -59
- staran/features/engines.py +0 -284
- staran/features/generator.py +0 -603
- staran/features/manager.py +0 -155
- staran/features/schema.py +0 -193
- staran/models/__init__.py +0 -72
- staran/models/config.py +0 -271
- staran/models/daifa_models.py +0 -361
- staran/models/registry.py +0 -281
- staran/models/target.py +0 -321
- staran/schemas/__init__.py +0 -27
- staran/schemas/aum/__init__.py +0 -210
- staran/tools/document_generator.py +0 -350
- staran-0.6.1.dist-info/METADATA +0 -586
- staran-0.6.1.dist-info/RECORD +0 -28
- {staran-0.6.1.dist-info → staran-1.0.0.dist-info}/WHEEL +0 -0
- {staran-0.6.1.dist-info → staran-1.0.0.dist-info}/licenses/LICENSE +0 -0
- {staran-0.6.1.dist-info → staran-1.0.0.dist-info}/top_level.txt +0 -0
staran/features/manager.py
DELETED
@@ -1,155 +0,0 @@
|
|
1
|
-
#!/usr/bin/env python3
|
2
|
-
# -*- coding: utf-8 -*-
|
3
|
-
|
4
|
-
"""
|
5
|
-
特征管理器
|
6
|
-
负责特征工程的核心管理功能,基于新的引擎架构
|
7
|
-
"""
|
8
|
-
|
9
|
-
from typing import Optional, Dict, Any, List, Callable
|
10
|
-
from datetime import datetime
|
11
|
-
from ..engines import BaseEngine, create_engine, DatabaseType
|
12
|
-
|
13
|
-
|
14
|
-
class FeatureManager:
|
15
|
-
"""
|
16
|
-
特征管理器 - 使用引擎架构的核心特征管理
|
17
|
-
"""
|
18
|
-
|
19
|
-
def __init__(self, database_name: str, engine_type: str = "spark",
|
20
|
-
sql_executor: Optional[Callable] = None):
|
21
|
-
"""
|
22
|
-
初始化特征管理器
|
23
|
-
|
24
|
-
Args:
|
25
|
-
database_name: 数据库名称
|
26
|
-
engine_type: 引擎类型 ('spark', 'hive', 'turing')
|
27
|
-
sql_executor: SQL执行器函数 (可选,仅用于非turing引擎)
|
28
|
-
"""
|
29
|
-
self.database_name = database_name
|
30
|
-
self.engine_type = engine_type
|
31
|
-
|
32
|
-
# 创建数据库引擎
|
33
|
-
self.engine = create_engine(
|
34
|
-
engine_type=engine_type,
|
35
|
-
database_name=database_name,
|
36
|
-
sql_executor=sql_executor
|
37
|
-
)
|
38
|
-
|
39
|
-
# 委托给引擎的方法
|
40
|
-
def execute_sql(self, sql: str, description: str = "") -> Any:
|
41
|
-
"""执行SQL语句"""
|
42
|
-
return self.engine.execute_sql(sql, description)
|
43
|
-
|
44
|
-
def get_full_table_name(self, table_name: str) -> str:
|
45
|
-
"""获取完整的表名(包含数据库名)"""
|
46
|
-
return self.engine.get_full_table_name(table_name)
|
47
|
-
|
48
|
-
def generate_table_name(self, base_name: str, year: int, month: int,
|
49
|
-
suffix: str = "raw") -> str:
|
50
|
-
"""
|
51
|
-
生成标准化的表名
|
52
|
-
格式: {base_name}_{yyyy}_{MM}_{suffix}
|
53
|
-
"""
|
54
|
-
return self.engine.generate_table_name(base_name, year, month, suffix)
|
55
|
-
|
56
|
-
def create_table(self, table_name: str, select_sql: str,
|
57
|
-
execute: bool = False, **kwargs) -> Dict[str, Any]:
|
58
|
-
"""创建表"""
|
59
|
-
return self.engine.create_table(table_name, select_sql, execute, **kwargs)
|
60
|
-
|
61
|
-
def drop_table(self, table_name: str, execute: bool = False) -> Dict[str, Any]:
|
62
|
-
"""删除表"""
|
63
|
-
return self.engine.drop_table(table_name, execute)
|
64
|
-
|
65
|
-
def download_table_data(self, table_name: str, output_path: str,
|
66
|
-
**kwargs) -> Dict[str, Any]:
|
67
|
-
"""下载表数据"""
|
68
|
-
return self.engine.download_table_data(table_name, output_path, **kwargs)
|
69
|
-
|
70
|
-
def download_query_result(self, sql: str, output_path: str,
|
71
|
-
**kwargs) -> Dict[str, Any]:
|
72
|
-
"""下载查询结果"""
|
73
|
-
return self.engine.download_query_result(sql, output_path, **kwargs)
|
74
|
-
|
75
|
-
def get_execution_history(self) -> List[Dict]:
|
76
|
-
"""获取SQL执行历史"""
|
77
|
-
return self.engine.get_execution_history()
|
78
|
-
|
79
|
-
def clear_history(self):
|
80
|
-
"""清空执行历史"""
|
81
|
-
self.engine.clear_history()
|
82
|
-
|
83
|
-
def __str__(self):
|
84
|
-
return f"FeatureManager(engine={self.engine})"
|
85
|
-
|
86
|
-
|
87
|
-
class FeatureTableManager:
|
88
|
-
"""
|
89
|
-
特征表管理器
|
90
|
-
负责特征表的创建、删除、管理等操作
|
91
|
-
"""
|
92
|
-
|
93
|
-
def __init__(self, feature_manager: FeatureManager):
|
94
|
-
"""
|
95
|
-
初始化表管理器
|
96
|
-
|
97
|
-
Args:
|
98
|
-
feature_manager: 特征管理器实例
|
99
|
-
"""
|
100
|
-
self.feature_manager = feature_manager
|
101
|
-
self.created_tables = []
|
102
|
-
|
103
|
-
def create_feature_table(self, base_name: str, year: int, month: int,
|
104
|
-
version: int, sql: str, execute: bool = False,
|
105
|
-
**kwargs) -> str:
|
106
|
-
"""
|
107
|
-
创建特征表
|
108
|
-
|
109
|
-
Args:
|
110
|
-
base_name: 基础表名
|
111
|
-
year: 年份
|
112
|
-
month: 月份
|
113
|
-
version: 版本号
|
114
|
-
sql: 创建表的SQL
|
115
|
-
execute: 是否立即执行
|
116
|
-
**kwargs: 传递给引擎的其他参数
|
117
|
-
|
118
|
-
Returns:
|
119
|
-
创建的表名
|
120
|
-
"""
|
121
|
-
table_name = self.feature_manager.generate_table_name(base_name, year, month)
|
122
|
-
|
123
|
-
result = self.feature_manager.create_table(table_name, sql, execute, **kwargs)
|
124
|
-
|
125
|
-
if execute and result.get('status') == 'success':
|
126
|
-
self.created_tables.append(table_name)
|
127
|
-
|
128
|
-
return table_name
|
129
|
-
|
130
|
-
def drop_feature_table(self, table_name: str, execute: bool = False) -> str:
|
131
|
-
"""
|
132
|
-
删除特征表
|
133
|
-
|
134
|
-
Args:
|
135
|
-
table_name: 表名
|
136
|
-
execute: 是否立即执行
|
137
|
-
|
138
|
-
Returns:
|
139
|
-
删除表的SQL
|
140
|
-
"""
|
141
|
-
result = self.feature_manager.drop_table(table_name, execute)
|
142
|
-
|
143
|
-
if execute and result.get('status') == 'success':
|
144
|
-
if table_name in self.created_tables:
|
145
|
-
self.created_tables.remove(table_name)
|
146
|
-
|
147
|
-
return result.get('sql', '')
|
148
|
-
|
149
|
-
def get_created_tables(self) -> List[str]:
|
150
|
-
"""获取已创建的表列表"""
|
151
|
-
return self.created_tables.copy()
|
152
|
-
|
153
|
-
def table_exists(self, table_name: str) -> bool:
|
154
|
-
"""检查表是否存在(简单检查,实际需要查询数据库)"""
|
155
|
-
return table_name in self.created_tables
|
staran/features/schema.py
DELETED
@@ -1,193 +0,0 @@
|
|
1
|
-
#!/usr/bin/env python3
|
2
|
-
# -*- coding: utf-8 -*-
|
3
|
-
|
4
|
-
"""
|
5
|
-
表结构定义模块
|
6
|
-
定义数据库表的字段、类型和分析特性
|
7
|
-
"""
|
8
|
-
|
9
|
-
from enum import Enum
|
10
|
-
from typing import Dict, List, Optional, Union
|
11
|
-
from dataclasses import dataclass
|
12
|
-
|
13
|
-
|
14
|
-
class FieldType(Enum):
|
15
|
-
"""字段类型枚举"""
|
16
|
-
STRING = "string"
|
17
|
-
INTEGER = "int"
|
18
|
-
BIGINT = "bigint"
|
19
|
-
DECIMAL = "decimal"
|
20
|
-
DOUBLE = "double"
|
21
|
-
FLOAT = "float"
|
22
|
-
DATE = "date"
|
23
|
-
TIMESTAMP = "timestamp"
|
24
|
-
BOOLEAN = "boolean"
|
25
|
-
|
26
|
-
|
27
|
-
@dataclass
|
28
|
-
class Field:
|
29
|
-
"""字段定义"""
|
30
|
-
name: str
|
31
|
-
field_type: FieldType
|
32
|
-
is_primary_key: bool = False
|
33
|
-
is_date_field: bool = False
|
34
|
-
aggregatable: bool = False
|
35
|
-
nullable: bool = True
|
36
|
-
comment: str = ""
|
37
|
-
|
38
|
-
def __post_init__(self):
|
39
|
-
"""初始化后处理"""
|
40
|
-
# 数值类型默认可聚合
|
41
|
-
if self.field_type in [FieldType.INTEGER, FieldType.BIGINT,
|
42
|
-
FieldType.DECIMAL, FieldType.DOUBLE, FieldType.FLOAT]:
|
43
|
-
if not hasattr(self, '_aggregatable_set'):
|
44
|
-
self.aggregatable = True
|
45
|
-
|
46
|
-
def set_aggregatable(self, aggregatable: bool):
|
47
|
-
"""设置是否可聚合"""
|
48
|
-
self.aggregatable = aggregatable
|
49
|
-
self._aggregatable_set = True
|
50
|
-
return self
|
51
|
-
|
52
|
-
|
53
|
-
class TableSchema:
|
54
|
-
"""表结构定义类"""
|
55
|
-
|
56
|
-
def __init__(self, table_name: str, comment: str = ""):
|
57
|
-
"""
|
58
|
-
初始化表结构
|
59
|
-
|
60
|
-
Args:
|
61
|
-
table_name: 表名
|
62
|
-
comment: 表注释
|
63
|
-
"""
|
64
|
-
self.table_name = table_name
|
65
|
-
self.comment = comment
|
66
|
-
self.fields: Dict[str, Field] = {}
|
67
|
-
self.primary_key: Optional[str] = None
|
68
|
-
self.date_field: Optional[str] = None
|
69
|
-
self.is_monthly_unique: bool = False
|
70
|
-
|
71
|
-
def add_field(self, name: str, field_type: Union[str, FieldType],
|
72
|
-
aggregatable: bool = None, nullable: bool = True,
|
73
|
-
comment: str = "") -> 'TableSchema':
|
74
|
-
"""
|
75
|
-
添加字段
|
76
|
-
|
77
|
-
Args:
|
78
|
-
name: 字段名
|
79
|
-
field_type: 字段类型
|
80
|
-
aggregatable: 是否可聚合(None时自动判断)
|
81
|
-
nullable: 是否可空
|
82
|
-
comment: 字段注释
|
83
|
-
|
84
|
-
Returns:
|
85
|
-
self: 支持链式调用
|
86
|
-
"""
|
87
|
-
if isinstance(field_type, str):
|
88
|
-
field_type = FieldType(field_type.lower())
|
89
|
-
|
90
|
-
field = Field(
|
91
|
-
name=name,
|
92
|
-
field_type=field_type,
|
93
|
-
nullable=nullable,
|
94
|
-
comment=comment
|
95
|
-
)
|
96
|
-
|
97
|
-
if aggregatable is not None:
|
98
|
-
field.set_aggregatable(aggregatable)
|
99
|
-
|
100
|
-
self.fields[name] = field
|
101
|
-
return self
|
102
|
-
|
103
|
-
def add_primary_key(self, name: str, field_type: Union[str, FieldType],
|
104
|
-
comment: str = "主键") -> 'TableSchema':
|
105
|
-
"""添加主键字段"""
|
106
|
-
if isinstance(field_type, str):
|
107
|
-
field_type = FieldType(field_type.lower())
|
108
|
-
|
109
|
-
field = Field(
|
110
|
-
name=name,
|
111
|
-
field_type=field_type,
|
112
|
-
is_primary_key=True,
|
113
|
-
nullable=False,
|
114
|
-
comment=comment
|
115
|
-
)
|
116
|
-
field.set_aggregatable(False)
|
117
|
-
|
118
|
-
self.fields[name] = field
|
119
|
-
self.primary_key = name
|
120
|
-
return self
|
121
|
-
|
122
|
-
def add_date_field(self, name: str, field_type: Union[str, FieldType] = FieldType.DATE,
|
123
|
-
comment: str = "日期字段") -> 'TableSchema':
|
124
|
-
"""添加日期字段"""
|
125
|
-
if isinstance(field_type, str):
|
126
|
-
field_type = FieldType(field_type.lower())
|
127
|
-
|
128
|
-
field = Field(
|
129
|
-
name=name,
|
130
|
-
field_type=field_type,
|
131
|
-
is_date_field=True,
|
132
|
-
nullable=False,
|
133
|
-
comment=comment
|
134
|
-
)
|
135
|
-
field.set_aggregatable(False)
|
136
|
-
|
137
|
-
self.fields[name] = field
|
138
|
-
self.date_field = name
|
139
|
-
return self
|
140
|
-
|
141
|
-
def set_monthly_unique(self, is_unique: bool = True) -> 'TableSchema':
|
142
|
-
"""设置是否为每人每月唯一数据"""
|
143
|
-
self.is_monthly_unique = is_unique
|
144
|
-
return self
|
145
|
-
|
146
|
-
def get_aggregatable_fields(self) -> List[Field]:
|
147
|
-
"""获取可聚合字段列表"""
|
148
|
-
return [field for field in self.fields.values() if field.aggregatable]
|
149
|
-
|
150
|
-
def get_non_aggregatable_fields(self) -> List[Field]:
|
151
|
-
"""获取不可聚合字段列表(用于原始拷贝)"""
|
152
|
-
return [field for field in self.fields.values()
|
153
|
-
if not field.aggregatable and not field.is_primary_key and not field.is_date_field]
|
154
|
-
|
155
|
-
def validate(self) -> bool:
|
156
|
-
"""验证表结构"""
|
157
|
-
if not self.primary_key:
|
158
|
-
raise ValueError("表必须定义主键")
|
159
|
-
|
160
|
-
if not self.date_field:
|
161
|
-
raise ValueError("表必须定义日期字段")
|
162
|
-
|
163
|
-
if self.primary_key not in self.fields:
|
164
|
-
raise ValueError(f"主键字段 {self.primary_key} 不存在")
|
165
|
-
|
166
|
-
if self.date_field not in self.fields:
|
167
|
-
raise ValueError(f"日期字段 {self.date_field} 不存在")
|
168
|
-
|
169
|
-
return True
|
170
|
-
|
171
|
-
def __str__(self) -> str:
|
172
|
-
"""字符串表示"""
|
173
|
-
lines = [f"Table: {self.table_name}"]
|
174
|
-
if self.comment:
|
175
|
-
lines.append(f"Comment: {self.comment}")
|
176
|
-
|
177
|
-
lines.append(f"Primary Key: {self.primary_key}")
|
178
|
-
lines.append(f"Date Field: {self.date_field}")
|
179
|
-
lines.append(f"Monthly Unique: {self.is_monthly_unique}")
|
180
|
-
lines.append("Fields:")
|
181
|
-
|
182
|
-
for field in self.fields.values():
|
183
|
-
flag_str = ""
|
184
|
-
if field.is_primary_key:
|
185
|
-
flag_str += "[PK]"
|
186
|
-
if field.is_date_field:
|
187
|
-
flag_str += "[DATE]"
|
188
|
-
if field.aggregatable:
|
189
|
-
flag_str += "[AGG]"
|
190
|
-
|
191
|
-
lines.append(f" {field.name}: {field.field_type.value} {flag_str}")
|
192
|
-
|
193
|
-
return "\n".join(lines)
|
staran/models/__init__.py
DELETED
@@ -1,72 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
staran.models v0.6.0 - 新疆工行代发长尾客户模型管理
|
3
|
-
|
4
|
-
专门针对新疆工行代发长尾客户的两个核心模型:
|
5
|
-
1. 代发长尾客户提升3k预测模型 (daifa_longtail_upgrade_3k)
|
6
|
-
2. 代发长尾客户防流失1.5k预测模型 (daifa_longtail_churn_1_5k)
|
7
|
-
|
8
|
-
主要功能:
|
9
|
-
- 模型配置管理
|
10
|
-
- SQL驱动的目标变量定义
|
11
|
-
- 模型注册和版本控制
|
12
|
-
- 新疆工行特定配置
|
13
|
-
"""
|
14
|
-
|
15
|
-
from .config import ModelConfig, create_model_config
|
16
|
-
from .target import TargetDefinition, create_target_definition
|
17
|
-
from .registry import ModelRegistry, register_model, save_model_registry
|
18
|
-
from .daifa_models import (
|
19
|
-
create_daifa_longtail_upgrade_model,
|
20
|
-
create_daifa_longtail_churn_model,
|
21
|
-
get_available_daifa_models,
|
22
|
-
create_both_daifa_models
|
23
|
-
)
|
24
|
-
|
25
|
-
# 便捷函数
|
26
|
-
def create_xinjiang_icbc_models(output_dir: str = "./xinjiang_models") -> dict:
|
27
|
-
"""为新疆工行创建两个代发长尾客户模型"""
|
28
|
-
return create_both_daifa_models(output_dir)
|
29
|
-
|
30
|
-
def list_available_models() -> list:
|
31
|
-
"""列出所有可用的代发长尾客户模型"""
|
32
|
-
return get_available_daifa_models()
|
33
|
-
|
34
|
-
def get_model_summary() -> dict:
|
35
|
-
"""获取模型概述信息"""
|
36
|
-
return {
|
37
|
-
"version": "0.6.0",
|
38
|
-
"bank": "新疆工行",
|
39
|
-
"business_domain": "代发长尾客户",
|
40
|
-
"models": [
|
41
|
-
{
|
42
|
-
"name": "daifa_longtail_upgrade_3k",
|
43
|
-
"description": "预测下个月代发长尾客户资产提升3k的概率",
|
44
|
-
"target_amount": 3000,
|
45
|
-
"model_type": "binary_classification"
|
46
|
-
},
|
47
|
-
{
|
48
|
-
"name": "daifa_longtail_churn_1_5k",
|
49
|
-
"description": "预测下个月代发长尾客户流失1.5k资产的风险",
|
50
|
-
"target_amount": 1500,
|
51
|
-
"model_type": "binary_classification"
|
52
|
-
}
|
53
|
-
]
|
54
|
-
}
|
55
|
-
|
56
|
-
__all__ = [
|
57
|
-
# 核心组件
|
58
|
-
'ModelConfig', 'TargetDefinition', 'ModelRegistry',
|
59
|
-
|
60
|
-
# 创建函数
|
61
|
-
'create_model_config', 'create_target_definition', 'register_model',
|
62
|
-
|
63
|
-
# 代发长尾模型
|
64
|
-
'create_daifa_longtail_upgrade_model', 'create_daifa_longtail_churn_model',
|
65
|
-
'create_both_daifa_models', 'get_available_daifa_models',
|
66
|
-
|
67
|
-
# 便捷函数
|
68
|
-
'create_xinjiang_icbc_models', 'list_available_models', 'get_model_summary',
|
69
|
-
'save_model_registry'
|
70
|
-
]
|
71
|
-
|
72
|
-
__version__ = "0.6.0"
|
staran/models/config.py
DELETED
@@ -1,271 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
模型配置管理模块
|
3
|
-
|
4
|
-
定义模型的核心配置信息,包括模型类型、参数、特征配置等
|
5
|
-
"""
|
6
|
-
|
7
|
-
from enum import Enum
|
8
|
-
from typing import Dict, Any, List, Optional
|
9
|
-
from dataclasses import dataclass, field
|
10
|
-
from datetime import datetime
|
11
|
-
|
12
|
-
|
13
|
-
class ModelType(Enum):
|
14
|
-
"""模型类型枚举"""
|
15
|
-
CLASSIFICATION = "classification"
|
16
|
-
REGRESSION = "regression"
|
17
|
-
CLUSTERING = "clustering"
|
18
|
-
TIME_SERIES = "time_series"
|
19
|
-
ANOMALY_DETECTION = "anomaly_detection"
|
20
|
-
RECOMMENDATION = "recommendation"
|
21
|
-
|
22
|
-
|
23
|
-
class ModelAlgorithm(Enum):
|
24
|
-
"""模型算法枚举"""
|
25
|
-
# 分类算法
|
26
|
-
LOGISTIC_REGRESSION = "logistic_regression"
|
27
|
-
RANDOM_FOREST = "random_forest"
|
28
|
-
GRADIENT_BOOSTING = "gradient_boosting"
|
29
|
-
SVM = "svm"
|
30
|
-
NEURAL_NETWORK = "neural_network"
|
31
|
-
|
32
|
-
# 回归算法
|
33
|
-
LINEAR_REGRESSION = "linear_regression"
|
34
|
-
RIDGE_REGRESSION = "ridge_regression"
|
35
|
-
LASSO_REGRESSION = "lasso_regression"
|
36
|
-
|
37
|
-
# 聚类算法
|
38
|
-
KMEANS = "kmeans"
|
39
|
-
DBSCAN = "dbscan"
|
40
|
-
HIERARCHICAL = "hierarchical"
|
41
|
-
|
42
|
-
# 时间序列
|
43
|
-
ARIMA = "arima"
|
44
|
-
LSTM = "lstm"
|
45
|
-
PROPHET = "prophet"
|
46
|
-
|
47
|
-
|
48
|
-
@dataclass
|
49
|
-
class FeatureConfig:
|
50
|
-
"""特征配置"""
|
51
|
-
schema_name: str # 使用的schema名称 (如 'aum')
|
52
|
-
table_types: List[str] # 使用的表类型列表 (如 ['behavior', 'asset_avg'])
|
53
|
-
feature_selection: bool = True # 是否启用特征选择
|
54
|
-
feature_engineering: bool = True # 是否启用特征工程
|
55
|
-
scaling: bool = True # 是否启用特征缩放
|
56
|
-
encoding: Dict[str, str] = field(default_factory=dict) # 编码配置
|
57
|
-
|
58
|
-
|
59
|
-
@dataclass
|
60
|
-
class ModelConfig:
|
61
|
-
"""模型配置类"""
|
62
|
-
# 基本信息
|
63
|
-
name: str # 模型名称
|
64
|
-
model_type: ModelType # 模型类型
|
65
|
-
algorithm: ModelAlgorithm # 使用的算法
|
66
|
-
version: str = "1.0.0" # 模型版本
|
67
|
-
|
68
|
-
# 特征配置
|
69
|
-
feature_config: FeatureConfig = None
|
70
|
-
|
71
|
-
# 模型参数
|
72
|
-
hyperparameters: Dict[str, Any] = field(default_factory=dict)
|
73
|
-
|
74
|
-
# 训练配置
|
75
|
-
training_config: Dict[str, Any] = field(default_factory=lambda: {
|
76
|
-
'test_size': 0.2,
|
77
|
-
'random_state': 42,
|
78
|
-
'cross_validation': True,
|
79
|
-
'cv_folds': 5
|
80
|
-
})
|
81
|
-
|
82
|
-
# 评估配置
|
83
|
-
evaluation_metrics: List[str] = field(default_factory=list)
|
84
|
-
|
85
|
-
# 银行特定配置
|
86
|
-
bank_code: str = "generic" # 银行代码
|
87
|
-
business_domain: str = "generic" # 业务领域
|
88
|
-
|
89
|
-
# 元数据
|
90
|
-
description: str = "" # 模型描述
|
91
|
-
created_at: datetime = field(default_factory=datetime.now)
|
92
|
-
created_by: str = "system" # 创建者
|
93
|
-
tags: List[str] = field(default_factory=list)
|
94
|
-
|
95
|
-
# 部署配置
|
96
|
-
deployment_config: Dict[str, Any] = field(default_factory=dict)
|
97
|
-
|
98
|
-
def __post_init__(self):
|
99
|
-
"""初始化后处理"""
|
100
|
-
if self.feature_config is None:
|
101
|
-
self.feature_config = FeatureConfig(
|
102
|
-
schema_name="generic",
|
103
|
-
table_types=["base"]
|
104
|
-
)
|
105
|
-
|
106
|
-
# 根据模型类型设置默认评估指标
|
107
|
-
if not self.evaluation_metrics:
|
108
|
-
self.evaluation_metrics = self._get_default_metrics()
|
109
|
-
|
110
|
-
def _get_default_metrics(self) -> List[str]:
|
111
|
-
"""根据模型类型获取默认评估指标"""
|
112
|
-
if self.model_type == ModelType.CLASSIFICATION:
|
113
|
-
return ['accuracy', 'precision', 'recall', 'f1_score', 'auc']
|
114
|
-
elif self.model_type == ModelType.REGRESSION:
|
115
|
-
return ['mae', 'mse', 'rmse', 'r2_score']
|
116
|
-
elif self.model_type == ModelType.CLUSTERING:
|
117
|
-
return ['silhouette_score', 'calinski_harabasz_score']
|
118
|
-
else:
|
119
|
-
return ['custom_metric']
|
120
|
-
|
121
|
-
def to_dict(self) -> Dict[str, Any]:
|
122
|
-
"""转换为字典格式"""
|
123
|
-
return {
|
124
|
-
'name': self.name,
|
125
|
-
'model_type': self.model_type.value,
|
126
|
-
'algorithm': self.algorithm.value,
|
127
|
-
'version': self.version,
|
128
|
-
'feature_config': {
|
129
|
-
'schema_name': self.feature_config.schema_name,
|
130
|
-
'table_types': self.feature_config.table_types,
|
131
|
-
'feature_selection': self.feature_config.feature_selection,
|
132
|
-
'feature_engineering': self.feature_config.feature_engineering,
|
133
|
-
'scaling': self.feature_config.scaling,
|
134
|
-
'encoding': self.feature_config.encoding
|
135
|
-
},
|
136
|
-
'hyperparameters': self.hyperparameters,
|
137
|
-
'training_config': self.training_config,
|
138
|
-
'evaluation_metrics': self.evaluation_metrics,
|
139
|
-
'bank_code': self.bank_code,
|
140
|
-
'business_domain': self.business_domain,
|
141
|
-
'description': self.description,
|
142
|
-
'created_at': self.created_at.isoformat(),
|
143
|
-
'created_by': self.created_by,
|
144
|
-
'tags': self.tags,
|
145
|
-
'deployment_config': self.deployment_config
|
146
|
-
}
|
147
|
-
|
148
|
-
@classmethod
|
149
|
-
def from_dict(cls, data: Dict[str, Any]) -> 'ModelConfig':
|
150
|
-
"""从字典创建ModelConfig实例"""
|
151
|
-
feature_config_data = data.get('feature_config', {})
|
152
|
-
feature_config = FeatureConfig(
|
153
|
-
schema_name=feature_config_data.get('schema_name', 'generic'),
|
154
|
-
table_types=feature_config_data.get('table_types', ['base']),
|
155
|
-
feature_selection=feature_config_data.get('feature_selection', True),
|
156
|
-
feature_engineering=feature_config_data.get('feature_engineering', True),
|
157
|
-
scaling=feature_config_data.get('scaling', True),
|
158
|
-
encoding=feature_config_data.get('encoding', {})
|
159
|
-
)
|
160
|
-
|
161
|
-
return cls(
|
162
|
-
name=data['name'],
|
163
|
-
model_type=ModelType(data['model_type']),
|
164
|
-
algorithm=ModelAlgorithm(data['algorithm']),
|
165
|
-
version=data.get('version', '1.0.0'),
|
166
|
-
feature_config=feature_config,
|
167
|
-
hyperparameters=data.get('hyperparameters', {}),
|
168
|
-
training_config=data.get('training_config', {}),
|
169
|
-
evaluation_metrics=data.get('evaluation_metrics', []),
|
170
|
-
bank_code=data.get('bank_code', 'generic'),
|
171
|
-
business_domain=data.get('business_domain', 'generic'),
|
172
|
-
description=data.get('description', ''),
|
173
|
-
created_by=data.get('created_by', 'system'),
|
174
|
-
tags=data.get('tags', []),
|
175
|
-
deployment_config=data.get('deployment_config', {})
|
176
|
-
)
|
177
|
-
|
178
|
-
|
179
|
-
def create_model_config(
|
180
|
-
name: str,
|
181
|
-
model_type: str,
|
182
|
-
algorithm: str,
|
183
|
-
schema_name: str = "generic",
|
184
|
-
table_types: List[str] = None,
|
185
|
-
bank_code: str = "generic",
|
186
|
-
**kwargs
|
187
|
-
) -> ModelConfig:
|
188
|
-
"""
|
189
|
-
创建模型配置的便捷函数
|
190
|
-
|
191
|
-
Args:
|
192
|
-
name: 模型名称
|
193
|
-
model_type: 模型类型
|
194
|
-
algorithm: 算法名称
|
195
|
-
schema_name: 使用的schema名称
|
196
|
-
table_types: 使用的表类型列表
|
197
|
-
bank_code: 银行代码
|
198
|
-
**kwargs: 其他配置参数
|
199
|
-
|
200
|
-
Returns:
|
201
|
-
ModelConfig实例
|
202
|
-
"""
|
203
|
-
if table_types is None:
|
204
|
-
table_types = ["base"]
|
205
|
-
|
206
|
-
feature_config = FeatureConfig(
|
207
|
-
schema_name=schema_name,
|
208
|
-
table_types=table_types
|
209
|
-
)
|
210
|
-
|
211
|
-
return ModelConfig(
|
212
|
-
name=name,
|
213
|
-
model_type=ModelType(model_type),
|
214
|
-
algorithm=ModelAlgorithm(algorithm),
|
215
|
-
feature_config=feature_config,
|
216
|
-
bank_code=bank_code,
|
217
|
-
**kwargs
|
218
|
-
)
|
219
|
-
|
220
|
-
|
221
|
-
# 预定义的模型配置模板
|
222
|
-
PRESET_CONFIGS = {
|
223
|
-
"aum_longtail_classification": {
|
224
|
-
"model_type": "classification",
|
225
|
-
"algorithm": "random_forest",
|
226
|
-
"schema_name": "aum",
|
227
|
-
"table_types": ["behavior", "asset_avg", "asset_config", "monthly_stat"],
|
228
|
-
"hyperparameters": {
|
229
|
-
"n_estimators": 100,
|
230
|
-
"max_depth": 10,
|
231
|
-
"random_state": 42
|
232
|
-
},
|
233
|
-
"description": "AUM长尾客户分类模型"
|
234
|
-
},
|
235
|
-
|
236
|
-
"customer_value_regression": {
|
237
|
-
"model_type": "regression",
|
238
|
-
"algorithm": "gradient_boosting",
|
239
|
-
"schema_name": "aum",
|
240
|
-
"table_types": ["behavior", "asset_avg"],
|
241
|
-
"hyperparameters": {
|
242
|
-
"n_estimators": 150,
|
243
|
-
"learning_rate": 0.1,
|
244
|
-
"max_depth": 8
|
245
|
-
},
|
246
|
-
"description": "客户价值预测回归模型"
|
247
|
-
}
|
248
|
-
}
|
249
|
-
|
250
|
-
|
251
|
-
def create_preset_config(preset_name: str, **overrides) -> ModelConfig:
|
252
|
-
"""
|
253
|
-
基于预设模板创建模型配置
|
254
|
-
|
255
|
-
Args:
|
256
|
-
preset_name: 预设模板名称
|
257
|
-
**overrides: 覆盖的配置参数
|
258
|
-
|
259
|
-
Returns:
|
260
|
-
ModelConfig实例
|
261
|
-
"""
|
262
|
-
if preset_name not in PRESET_CONFIGS:
|
263
|
-
raise ValueError(f"未知的预设配置: {preset_name}")
|
264
|
-
|
265
|
-
config = PRESET_CONFIGS[preset_name].copy()
|
266
|
-
config.update(overrides)
|
267
|
-
|
268
|
-
return create_model_config(
|
269
|
-
name=preset_name,
|
270
|
-
**config
|
271
|
-
)
|