staran 0.4.2__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,321 @@
1
+ """
2
+ 目标变量定义模块
3
+
4
+ 提供基于SQL的目标变量定义和生成功能
5
+ """
6
+
7
+ from enum import Enum
8
+ from typing import Dict, Any, List, Optional, Union
9
+ from dataclasses import dataclass, field
10
+ from datetime import datetime
11
+ import re
12
+
13
+
14
+ class TargetType(Enum):
15
+ """目标变量类型"""
16
+ BINARY_CLASSIFICATION = "binary_classification" # 二分类
17
+ MULTI_CLASSIFICATION = "multi_classification" # 多分类
18
+ REGRESSION = "regression" # 回归
19
+ RANKING = "ranking" # 排序
20
+ CLUSTERING = "clustering" # 聚类
21
+ SQL_BASED = "sql_based" # 基于SQL的自定义目标
22
+
23
+
24
+ class TargetEncoding(Enum):
25
+ """目标变量编码方式"""
26
+ NONE = "none" # 不编码
27
+ LABEL_ENCODING = "label" # 标签编码
28
+ ONE_HOT = "one_hot" # 独热编码
29
+ ORDINAL = "ordinal" # 序数编码
30
+ BINARY = "binary" # 二进制编码
31
+
32
+
33
+ @dataclass
34
+ class TargetDefinition:
35
+ """目标变量定义类"""
36
+ # 基本信息
37
+ name: str # 目标变量名称
38
+ target_type: TargetType # 目标类型
39
+ description: str = "" # 描述
40
+
41
+ # SQL定义 (核心功能)
42
+ sql_query: str = "" # 生成目标变量的SQL查询
43
+ target_column: str = "target" # 目标列名
44
+
45
+ # 数据信息
46
+ data_type: str = "float" # 数据类型
47
+ encoding: TargetEncoding = TargetEncoding.NONE # 编码方式
48
+
49
+ # 分类相关
50
+ class_labels: List[str] = field(default_factory=list) # 类别标签
51
+ class_weights: Dict[str, float] = field(default_factory=dict) # 类别权重
52
+
53
+ # 回归相关
54
+ min_value: Optional[float] = None # 最小值
55
+ max_value: Optional[float] = None # 最大值
56
+ normalization: bool = False # 是否标准化
57
+
58
+ # 时间相关
59
+ time_window: str = "" # 时间窗口 (如 "30_days", "3_months")
60
+ prediction_horizon: str = "" # 预测时间范围
61
+
62
+ # 银行特定
63
+ bank_code: str = "generic" # 银行代码
64
+ business_rules: Dict[str, Any] = field(default_factory=dict) # 业务规则
65
+
66
+ # 元数据
67
+ created_at: datetime = field(default_factory=datetime.now)
68
+ created_by: str = "system"
69
+ version: str = "1.0.0"
70
+ tags: List[str] = field(default_factory=list)
71
+
72
+ # 验证配置
73
+ validation_rules: Dict[str, Any] = field(default_factory=dict)
74
+
75
+ def __post_init__(self):
76
+ """初始化后处理"""
77
+ if not self.sql_query and self.target_type != TargetType.SQL_BASED:
78
+ self.sql_query = self._generate_default_sql()
79
+
80
+ # 验证SQL语法
81
+ if self.sql_query:
82
+ self._validate_sql()
83
+
84
+ def _generate_default_sql(self) -> str:
85
+ """根据目标类型生成默认SQL"""
86
+ if self.target_type == TargetType.BINARY_CLASSIFICATION:
87
+ return f"""
88
+ SELECT party_id,
89
+ CASE WHEN condition THEN 1 ELSE 0 END as {self.target_column}
90
+ FROM source_table
91
+ WHERE data_dt = '{{data_dt}}'
92
+ """
93
+ elif self.target_type == TargetType.REGRESSION:
94
+ return f"""
95
+ SELECT party_id,
96
+ target_value as {self.target_column}
97
+ FROM source_table
98
+ WHERE data_dt = '{{data_dt}}'
99
+ """
100
+ else:
101
+ return ""
102
+
103
+ def _validate_sql(self):
104
+ """验证SQL语法基本正确性"""
105
+ sql = self.sql_query.strip().upper()
106
+
107
+ # 基本SQL结构检查 (支持WITH语句)
108
+ if not (sql.startswith('SELECT') or sql.startswith('WITH')):
109
+ raise ValueError("SQL查询必须以SELECT或WITH开始")
110
+
111
+ # 检查是否包含目标列
112
+ if self.target_column.upper() not in sql:
113
+ print(f"警告: SQL中未找到目标列 '{self.target_column}'")
114
+
115
+ # 检查参数占位符
116
+ placeholders = re.findall(r'\{(\w+)\}', self.sql_query)
117
+ if placeholders:
118
+ print(f"发现参数占位符: {placeholders}")
119
+
120
+ def generate_sql(self, **params) -> str:
121
+ """
122
+ 生成最终的SQL查询,替换参数占位符
123
+
124
+ Args:
125
+ **params: SQL参数字典
126
+
127
+ Returns:
128
+ 最终的SQL查询字符串
129
+ """
130
+ sql = self.sql_query
131
+
132
+ # 替换参数占位符
133
+ for key, value in params.items():
134
+ placeholder = f"{{{key}}}"
135
+ sql = sql.replace(placeholder, str(value))
136
+
137
+ return sql
138
+
139
+ def get_sample_sql(self, data_dt: str = "20250728") -> str:
140
+ """获取示例SQL"""
141
+ return self.generate_sql(data_dt=data_dt)
142
+
143
+ def validate_target_values(self, values: List[Any]) -> bool:
144
+ """验证目标值是否符合定义"""
145
+ if self.target_type == TargetType.BINARY_CLASSIFICATION:
146
+ unique_values = set(values)
147
+ return unique_values.issubset({0, 1, 0.0, 1.0})
148
+
149
+ elif self.target_type == TargetType.MULTI_CLASSIFICATION:
150
+ if self.class_labels:
151
+ unique_values = set(values)
152
+ return unique_values.issubset(set(self.class_labels))
153
+
154
+ elif self.target_type == TargetType.REGRESSION:
155
+ if self.min_value is not None and self.max_value is not None:
156
+ return all(self.min_value <= v <= self.max_value for v in values)
157
+
158
+ return True
159
+
160
+ def to_dict(self) -> Dict[str, Any]:
161
+ """转换为字典格式"""
162
+ return {
163
+ 'name': self.name,
164
+ 'target_type': self.target_type.value,
165
+ 'description': self.description,
166
+ 'sql_query': self.sql_query,
167
+ 'target_column': self.target_column,
168
+ 'data_type': self.data_type,
169
+ 'encoding': self.encoding.value,
170
+ 'class_labels': self.class_labels,
171
+ 'class_weights': self.class_weights,
172
+ 'min_value': self.min_value,
173
+ 'max_value': self.max_value,
174
+ 'normalization': self.normalization,
175
+ 'time_window': self.time_window,
176
+ 'prediction_horizon': self.prediction_horizon,
177
+ 'bank_code': self.bank_code,
178
+ 'business_rules': self.business_rules,
179
+ 'created_at': self.created_at.isoformat(),
180
+ 'created_by': self.created_by,
181
+ 'version': self.version,
182
+ 'tags': self.tags,
183
+ 'validation_rules': self.validation_rules
184
+ }
185
+
186
+
187
+ def create_target_definition(
188
+ name: str,
189
+ target_type: str,
190
+ sql_query: str,
191
+ target_column: str = "target",
192
+ bank_code: str = "generic",
193
+ **kwargs
194
+ ) -> TargetDefinition:
195
+ """
196
+ 创建目标变量定义的便捷函数
197
+
198
+ Args:
199
+ name: 目标变量名称
200
+ target_type: 目标类型
201
+ sql_query: SQL查询
202
+ target_column: 目标列名
203
+ bank_code: 银行代码
204
+ **kwargs: 其他配置参数
205
+
206
+ Returns:
207
+ TargetDefinition实例
208
+ """
209
+ return TargetDefinition(
210
+ name=name,
211
+ target_type=TargetType(target_type),
212
+ sql_query=sql_query,
213
+ target_column=target_column,
214
+ bank_code=bank_code,
215
+ **kwargs
216
+ )
217
+
218
+
219
+ # 预定义的目标变量模板
220
+ TARGET_TEMPLATES = {
221
+ "aum_longtail_purchase": {
222
+ "target_type": "binary_classification",
223
+ "description": "AUM长尾客户未来购买预测",
224
+ "sql_query": """
225
+ SELECT
226
+ a.party_id,
227
+ CASE
228
+ WHEN b.purchase_amount > 0 THEN 1
229
+ ELSE 0
230
+ END as target
231
+ FROM
232
+ bi_hlwj_dfcw_f1_f4_wy a
233
+ LEFT JOIN (
234
+ SELECT party_id, SUM(productamount_sum) as purchase_amount
235
+ FROM bi_hlwj_dfcw_f1_f4_wy
236
+ WHERE data_dt BETWEEN '{start_dt}' AND '{end_dt}'
237
+ GROUP BY party_id
238
+ ) b ON a.party_id = b.party_id
239
+ WHERE a.data_dt = '{feature_dt}'
240
+ """,
241
+ "target_column": "target",
242
+ "time_window": "30_days",
243
+ "class_labels": ["no_purchase", "purchase"]
244
+ },
245
+
246
+ "customer_value_prediction": {
247
+ "target_type": "regression",
248
+ "description": "客户价值预测",
249
+ "sql_query": """
250
+ SELECT
251
+ party_id,
252
+ asset_total_bal as target
253
+ FROM
254
+ bi_hlwj_zi_chan_avg_wy
255
+ WHERE
256
+ data_dt = '{target_dt}'
257
+ """,
258
+ "target_column": "target",
259
+ "data_type": "float",
260
+ "normalization": True
261
+ },
262
+
263
+ "risk_level_classification": {
264
+ "target_type": "multi_classification",
265
+ "description": "风险等级分类",
266
+ "sql_query": """
267
+ SELECT
268
+ party_id,
269
+ CASE
270
+ WHEN asset_total_bal < 10000 THEN 'low_risk'
271
+ WHEN asset_total_bal < 100000 THEN 'medium_risk'
272
+ ELSE 'high_risk'
273
+ END as target
274
+ FROM
275
+ bi_hlwj_zi_chan_avg_wy
276
+ WHERE
277
+ data_dt = '{data_dt}'
278
+ """,
279
+ "target_column": "target",
280
+ "class_labels": ["low_risk", "medium_risk", "high_risk"]
281
+ }
282
+ }
283
+
284
+
285
+ def create_preset_target(preset_name: str, **overrides) -> TargetDefinition:
286
+ """
287
+ 基于预设模板创建目标变量定义
288
+
289
+ Args:
290
+ preset_name: 预设模板名称
291
+ **overrides: 覆盖的配置参数
292
+
293
+ Returns:
294
+ TargetDefinition实例
295
+ """
296
+ if preset_name not in TARGET_TEMPLATES:
297
+ raise ValueError(f"未知的目标变量模板: {preset_name}")
298
+
299
+ template = TARGET_TEMPLATES[preset_name].copy()
300
+ template.update(overrides)
301
+
302
+ return create_target_definition(
303
+ name=preset_name,
304
+ **template
305
+ )
306
+
307
+
308
+ def create_icbc_target(name: str, sql_query: str, target_type: str = "binary_classification", **kwargs) -> TargetDefinition:
309
+ """创建工商银行专用目标变量定义"""
310
+ return create_target_definition(
311
+ name=name,
312
+ target_type=target_type,
313
+ sql_query=sql_query,
314
+ bank_code="icbc",
315
+ business_rules={
316
+ "data_retention_days": 90,
317
+ "privacy_compliance": True,
318
+ "audit_required": True
319
+ },
320
+ **kwargs
321
+ )
@@ -1,28 +1,27 @@
1
1
  """
2
- Staran Schemas模块 - 数据表结构定义与文档生成
2
+ Staran Schemas模块 - 新疆工行代发长尾客户表结构定义
3
3
 
4
- 提供标准化的表结构定义、字段管理和文档生成功能。
5
- 支持根据表结构生成Markdown和PDF文档供业务方使用。
4
+ 提供新疆工行代发长尾客户的标准化表结构定义和字段管理功能。
6
5
 
7
6
  主要功能:
8
- - 表结构标准化定义
7
+ - 代发长尾客户表结构定义
9
8
  - 业务字段含义管理
10
- - 文档自动生成 (MD/PDF)
11
- - 多业务领域支持
9
+ - 新疆工行专用配置
10
+ - 表结构文档生成
12
11
  """
13
12
 
14
- from .document_generator import SchemaDocumentGenerator
13
+ from ..tools.document_generator import SchemaDocumentGenerator
15
14
  from .aum import *
16
15
 
17
16
  __all__ = [
18
17
  'SchemaDocumentGenerator',
19
- # AUM业务表
20
- 'AUMBehaviorSchema',
21
- 'AUMAssetAvgSchema',
22
- 'AUMAssetConfigSchema',
23
- 'AUMMonthlyStatSchema',
24
- 'get_aum_schemas',
25
- 'export_aum_docs'
18
+ # 新疆工行代发长尾客户表
19
+ 'XinjiangICBCDaifaLongtailBehaviorSchema',
20
+ 'XinjiangICBCDaifaLongtailAssetAvgSchema',
21
+ 'XinjiangICBCDaifaLongtailAssetConfigSchema',
22
+ 'XinjiangICBCDaifaLongtailMonthlyStatSchema',
23
+ 'get_xinjiang_icbc_daifa_longtail_schemas',
24
+ 'export_xinjiang_icbc_daifa_longtail_docs'
26
25
  ]
27
26
 
28
- __version__ = "0.3.0"
27
+ __version__ = "0.6.0"