staran 0.6.1__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,284 +0,0 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
-
4
- """
5
- SQL引擎模块
6
- 支持不同数据库引擎的SQL生成
7
- """
8
-
9
- from abc import ABC, abstractmethod
10
- from typing import List, Dict, Optional
11
- from enum import Enum
12
- from .schema import TableSchema, Field, FieldType
13
-
14
-
15
- class DatabaseType(Enum):
16
- """数据库类型枚举"""
17
- SPARK = "spark"
18
- HIVE = "hive"
19
- MYSQL = "mysql"
20
- POSTGRESQL = "postgresql"
21
-
22
-
23
- class BaseSQLGenerator(ABC):
24
- """SQL生成器基类"""
25
-
26
- def __init__(self, schema: TableSchema, config):
27
- self.schema = schema
28
- self.config = config
29
-
30
- @abstractmethod
31
- def generate(self) -> str:
32
- """生成SQL"""
33
- pass
34
-
35
- @abstractmethod
36
- def get_engine_name(self) -> str:
37
- """获取引擎名称"""
38
- pass
39
-
40
-
41
- class SparkSQLGenerator(BaseSQLGenerator):
42
- """Spark SQL生成器"""
43
-
44
- def get_engine_name(self) -> str:
45
- return "Spark SQL"
46
-
47
- def generate(self) -> str:
48
- """生成Spark SQL"""
49
- if not self.schema.is_monthly_unique:
50
- return self._generate_basic_aggregation_sql()
51
- else:
52
- return self._generate_monthly_feature_sql()
53
-
54
- def _generate_basic_aggregation_sql(self) -> str:
55
- """生成基础聚合SQL(非每人每月唯一数据)"""
56
- base_table = self.schema.table_name
57
- pk_field = self.schema.primary_key
58
- date_field = self.schema.date_field
59
-
60
- # 基础字段选择
61
- select_parts = [
62
- f" {pk_field}",
63
- f" year({date_field}) as year",
64
- f" month({date_field}) as month"
65
- ]
66
-
67
- # 原始字段拷贝
68
- if self.config.include_raw_copy:
69
- for field in self.schema.get_non_aggregatable_fields():
70
- select_parts.append(f" first({field.name}) as {field.name}")
71
-
72
- # 聚合统计
73
- if self.config.include_aggregation:
74
- for field in self.schema.get_aggregatable_fields():
75
- for agg_type in self.config.aggregation_types:
76
- if self._is_agg_applicable(field, agg_type):
77
- agg_expr = self._get_spark_agg_expression(field.name, agg_type)
78
- select_parts.append(f" {agg_expr} as {field.name}_{agg_type}")
79
-
80
- sql = f"""-- Spark SQL: 基础聚合分析
81
- -- 表: {self.schema.table_name}
82
- -- 生成时间: {{current_timestamp}}
83
-
84
- WITH base_data AS (
85
- SELECT
86
- {',\\n'.join(select_parts)}
87
- FROM {base_table}
88
- GROUP BY {pk_field}, year({date_field}), month({date_field})
89
- )
90
-
91
- SELECT * FROM base_data
92
- ORDER BY {pk_field}, year, month;"""
93
-
94
- return sql
95
-
96
- def _generate_monthly_feature_sql(self) -> str:
97
- """生成月度特征SQL(每人每月唯一数据)"""
98
- base_table = self.schema.table_name
99
- pk_field = self.schema.primary_key
100
- date_field = self.schema.date_field
101
-
102
- # 构建完整的特征SQL
103
- sql_parts = []
104
-
105
- # 1. 基础数据CTE
106
- sql_parts.append(self._build_base_data_cte())
107
-
108
- # 2. 聚合特征CTE
109
- if self.config.include_aggregation:
110
- sql_parts.append(self._build_aggregation_cte())
111
-
112
- # 3. 环比特征CTE
113
- if self.config.include_mom:
114
- sql_parts.append(self._build_mom_cte())
115
-
116
- # 4. 同比特征CTE
117
- if self.config.include_yoy:
118
- sql_parts.append(self._build_yoy_cte())
119
-
120
- # 5. 最终结果
121
- sql_parts.append(self._build_final_select())
122
-
123
- header = f"""-- Spark SQL: 月度特征工程
124
- -- 表: {self.schema.table_name}
125
- -- 每人每月唯一数据特征生成
126
- -- 生成时间: {{current_timestamp}}
127
- """
128
-
129
- return header + ",\n\n".join(sql_parts)
130
-
131
- def _build_base_data_cte(self) -> str:
132
- """构建基础数据CTE"""
133
- pk_field = self.schema.primary_key
134
- date_field = self.schema.date_field
135
-
136
- select_parts = [
137
- f" {pk_field}",
138
- f" {date_field}",
139
- f" year({date_field}) as year",
140
- f" month({date_field}) as month",
141
- f" date_format({date_field}, 'yyyy-MM') as year_month"
142
- ]
143
-
144
- # 添加所有其他字段
145
- for field in self.schema.fields.values():
146
- if not field.is_primary_key and not field.is_date_field:
147
- select_parts.append(f" {field.name}")
148
-
149
- return f"""base_data AS (
150
- SELECT
151
- {',\\n'.join(select_parts)}
152
- FROM {self.schema.table_name}
153
- )"""
154
-
155
- def _build_aggregation_cte(self) -> str:
156
- """构建聚合特征CTE"""
157
- pk_field = self.schema.primary_key
158
-
159
- select_parts = [
160
- f" {pk_field}",
161
- " year_month"
162
- ]
163
-
164
- # 原始字段拷贝
165
- if self.config.include_raw_copy:
166
- for field in self.schema.get_non_aggregatable_fields():
167
- select_parts.append(f" first({field.name}) as {field.name}")
168
-
169
- # 聚合统计特征
170
- for field in self.schema.get_aggregatable_fields():
171
- for agg_type in self.config.aggregation_types:
172
- if self._is_agg_applicable(field, agg_type):
173
- agg_expr = self._get_spark_agg_expression(field.name, agg_type)
174
- select_parts.append(f" {agg_expr} as {field.name}_{agg_type}")
175
-
176
- return f"""agg_features AS (
177
- SELECT
178
- {',\\n'.join(select_parts)}
179
- FROM base_data
180
- GROUP BY {pk_field}, year_month
181
- )"""
182
-
183
- def _build_mom_cte(self) -> str:
184
- """构建环比特征CTE"""
185
- pk_field = self.schema.primary_key
186
-
187
- select_parts = [f" a.{pk_field}", " a.year_month"]
188
-
189
- for field in self.schema.get_aggregatable_fields():
190
- for agg_type in self.config.aggregation_types:
191
- if self._is_agg_applicable(field, agg_type):
192
- field_name = f"{field.name}_{agg_type}"
193
- for months in self.config.mom_months:
194
- mom_expr = f"a.{field_name} - lag(a.{field_name}, {months}) OVER (PARTITION BY a.{pk_field} ORDER BY a.year_month)"
195
- select_parts.append(f" {mom_expr} as {field_name}_mom_{months}m")
196
-
197
- return f"""mom_features AS (
198
- SELECT
199
- {',\\n'.join(select_parts)}
200
- FROM agg_features a
201
- )"""
202
-
203
- def _build_yoy_cte(self) -> str:
204
- """构建同比特征CTE"""
205
- pk_field = self.schema.primary_key
206
-
207
- select_parts = [f" a.{pk_field}", " a.year_month"]
208
-
209
- for field in self.schema.get_aggregatable_fields():
210
- for agg_type in self.config.aggregation_types:
211
- if self._is_agg_applicable(field, agg_type):
212
- field_name = f"{field.name}_{agg_type}"
213
- for months in self.config.yoy_months:
214
- yoy_expr = f"a.{field_name} - lag(a.{field_name}, {months}) OVER (PARTITION BY a.{pk_field} ORDER BY a.year_month)"
215
- select_parts.append(f" {yoy_expr} as {field_name}_yoy_{months}m")
216
-
217
- return f"""yoy_features AS (
218
- SELECT
219
- {',\\n'.join(select_parts)}
220
- FROM agg_features a
221
- )"""
222
-
223
- def _build_final_select(self) -> str:
224
- """构建最终SELECT"""
225
- pk_field = self.schema.primary_key
226
-
227
- # 构建JOIN逻辑
228
- joins = []
229
- if self.config.include_mom:
230
- joins.append(f"LEFT JOIN mom_features m ON a.{pk_field} = m.{pk_field} AND a.year_month = m.year_month")
231
- if self.config.include_yoy:
232
- joins.append(f"LEFT JOIN yoy_features y ON a.{pk_field} = y.{pk_field} AND a.year_month = y.year_month")
233
-
234
- select_fields = ["a.*"]
235
- if self.config.include_mom:
236
- for field in self.schema.get_aggregatable_fields():
237
- for agg_type in self.config.aggregation_types:
238
- if self._is_agg_applicable(field, agg_type):
239
- for months in self.config.mom_months:
240
- select_fields.append(f"m.{field.name}_{agg_type}_mom_{months}m")
241
-
242
- if self.config.include_yoy:
243
- for field in self.schema.get_aggregatable_fields():
244
- for agg_type in self.config.aggregation_types:
245
- if self._is_agg_applicable(field, agg_type):
246
- for months in self.config.yoy_months:
247
- select_fields.append(f"y.{field.name}_{agg_type}_yoy_{months}m")
248
-
249
- join_clause = "\n".join(joins) if joins else ""
250
-
251
- return f"""SELECT
252
- {',\\n '.join(select_fields)}
253
- FROM agg_features a
254
- {join_clause}
255
- ORDER BY a.{pk_field}, a.year_month"""
256
-
257
- def _get_spark_agg_expression(self, field_name: str, agg_type: str) -> str:
258
- """获取Spark聚合表达式"""
259
- agg_map = {
260
- 'sum': f'sum({field_name})',
261
- 'avg': f'avg({field_name})',
262
- 'min': f'min({field_name})',
263
- 'max': f'max({field_name})',
264
- 'count': f'count({field_name})',
265
- 'variance': f'variance({field_name})',
266
- 'stddev': f'stddev({field_name})'
267
- }
268
- return agg_map.get(agg_type, f'{agg_type}({field_name})')
269
-
270
- def _is_agg_applicable(self, field: Field, agg_type: str) -> bool:
271
- """检查聚合类型是否适用于字段"""
272
- # COUNT适用于所有字段
273
- if agg_type == 'count':
274
- return True
275
-
276
- # 数值聚合仅适用于数值字段
277
- numeric_aggs = ['sum', 'avg', 'min', 'max', 'variance', 'stddev']
278
- if agg_type in numeric_aggs:
279
- return field.field_type in [
280
- FieldType.INTEGER, FieldType.BIGINT, FieldType.DECIMAL,
281
- FieldType.DOUBLE, FieldType.FLOAT
282
- ]
283
-
284
- return True