staran 0.6.1__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- staran/__init__.py +10 -0
- staran/tools/__init__.py +6 -6
- staran/tools/date.py +327 -222
- staran/tools/tests/__init__.py +119 -0
- staran/tools/tests/run_tests.py +241 -0
- staran/tools/tests/test_api_compatibility.py +319 -0
- staran/tools/tests/test_date.py +565 -0
- staran/tools/tests/test_logging.py +402 -0
- staran-1.0.1.dist-info/METADATA +37 -0
- staran-1.0.1.dist-info/RECORD +13 -0
- staran/banks/__init__.py +0 -30
- staran/banks/xinjiang_icbc/__init__.py +0 -90
- staran/engines/__init__.py +0 -65
- staran/engines/base.py +0 -255
- staran/engines/hive.py +0 -163
- staran/engines/spark.py +0 -252
- staran/engines/turing.py +0 -439
- staran/features/__init__.py +0 -59
- staran/features/engines.py +0 -284
- staran/features/generator.py +0 -603
- staran/features/manager.py +0 -155
- staran/features/schema.py +0 -193
- staran/models/__init__.py +0 -72
- staran/models/config.py +0 -271
- staran/models/daifa_models.py +0 -361
- staran/models/registry.py +0 -281
- staran/models/target.py +0 -321
- staran/schemas/__init__.py +0 -27
- staran/schemas/aum/__init__.py +0 -210
- staran/tools/document_generator.py +0 -350
- staran-0.6.1.dist-info/METADATA +0 -586
- staran-0.6.1.dist-info/RECORD +0 -28
- {staran-0.6.1.dist-info → staran-1.0.1.dist-info}/WHEEL +0 -0
- {staran-0.6.1.dist-info → staran-1.0.1.dist-info}/licenses/LICENSE +0 -0
- {staran-0.6.1.dist-info → staran-1.0.1.dist-info}/top_level.txt +0 -0
staran/engines/base.py
DELETED
@@ -1,255 +0,0 @@
|
|
1
|
-
#!/usr/bin/env python3
|
2
|
-
# -*- coding: utf-8 -*-
|
3
|
-
|
4
|
-
"""
|
5
|
-
数据库引擎基类
|
6
|
-
定义统一的SQL生成、执行和数据下载接口
|
7
|
-
"""
|
8
|
-
|
9
|
-
from abc import ABC, abstractmethod
|
10
|
-
from typing import Dict, Any, Optional, List, Callable
|
11
|
-
from enum import Enum
|
12
|
-
from datetime import datetime
|
13
|
-
|
14
|
-
|
15
|
-
class DatabaseType(Enum):
|
16
|
-
"""数据库类型枚举"""
|
17
|
-
SPARK = "spark"
|
18
|
-
HIVE = "hive"
|
19
|
-
MYSQL = "mysql"
|
20
|
-
POSTGRESQL = "postgresql"
|
21
|
-
|
22
|
-
|
23
|
-
class BaseEngine(ABC):
|
24
|
-
"""
|
25
|
-
数据库引擎基类
|
26
|
-
整合SQL生成、执行和数据下载功能
|
27
|
-
"""
|
28
|
-
|
29
|
-
def __init__(self, database_name: str, sql_executor: Optional[Callable] = None):
|
30
|
-
"""
|
31
|
-
初始化引擎
|
32
|
-
|
33
|
-
Args:
|
34
|
-
database_name: 数据库名称
|
35
|
-
sql_executor: SQL执行器函数 (可选)
|
36
|
-
"""
|
37
|
-
self.database_name = database_name
|
38
|
-
self.sql_executor = sql_executor
|
39
|
-
self.execution_history = []
|
40
|
-
|
41
|
-
@abstractmethod
|
42
|
-
def get_engine_type(self) -> DatabaseType:
|
43
|
-
"""获取引擎类型"""
|
44
|
-
pass
|
45
|
-
|
46
|
-
@abstractmethod
|
47
|
-
def get_engine_name(self) -> str:
|
48
|
-
"""获取引擎名称"""
|
49
|
-
pass
|
50
|
-
|
51
|
-
# ==================== SQL生成方法 ====================
|
52
|
-
|
53
|
-
@abstractmethod
|
54
|
-
def generate_create_table_sql(self, table_name: str, select_sql: str,
|
55
|
-
if_not_exists: bool = True) -> str:
|
56
|
-
"""生成创建表的SQL"""
|
57
|
-
pass
|
58
|
-
|
59
|
-
@abstractmethod
|
60
|
-
def generate_insert_sql(self, table_name: str, select_sql: str) -> str:
|
61
|
-
"""生成插入数据的SQL"""
|
62
|
-
pass
|
63
|
-
|
64
|
-
@abstractmethod
|
65
|
-
def generate_drop_table_sql(self, table_name: str, if_exists: bool = True) -> str:
|
66
|
-
"""生成删除表的SQL"""
|
67
|
-
pass
|
68
|
-
|
69
|
-
def generate_aggregation_sql(self, schema, year: int, month: int,
|
70
|
-
aggregation_types: List[str]) -> str:
|
71
|
-
"""生成聚合特征SQL (可被子类重写)"""
|
72
|
-
base_table = self.get_full_table_name(schema.table_name)
|
73
|
-
pk_field = schema.primary_key
|
74
|
-
date_field = schema.date_field
|
75
|
-
|
76
|
-
# 获取可聚合字段
|
77
|
-
agg_fields = [field for field in schema.fields.values() if field.aggregatable]
|
78
|
-
|
79
|
-
# 构建聚合选择语句
|
80
|
-
select_parts = [pk_field, f"'{year}-{month:02d}-01' as feature_month"]
|
81
|
-
|
82
|
-
for field in agg_fields:
|
83
|
-
for agg_type in aggregation_types:
|
84
|
-
alias = f"{field.name}_{agg_type}"
|
85
|
-
select_parts.append(f"{agg_type.upper()}({field.name}) as {alias}")
|
86
|
-
|
87
|
-
sql = f"""
|
88
|
-
SELECT {', '.join(select_parts)}
|
89
|
-
FROM {base_table}
|
90
|
-
WHERE YEAR({date_field}) = {year}
|
91
|
-
AND MONTH({date_field}) = {month}
|
92
|
-
GROUP BY {pk_field}
|
93
|
-
"""
|
94
|
-
|
95
|
-
return sql.strip()
|
96
|
-
|
97
|
-
# ==================== SQL执行方法 ====================
|
98
|
-
|
99
|
-
def execute_sql(self, sql: str, description: str = "") -> Any:
|
100
|
-
"""
|
101
|
-
执行SQL语句
|
102
|
-
|
103
|
-
Args:
|
104
|
-
sql: SQL语句
|
105
|
-
description: 执行描述
|
106
|
-
|
107
|
-
Returns:
|
108
|
-
执行结果
|
109
|
-
"""
|
110
|
-
if self.sql_executor:
|
111
|
-
result = self.sql_executor(sql)
|
112
|
-
self.execution_history.append({
|
113
|
-
'sql': sql,
|
114
|
-
'description': description,
|
115
|
-
'timestamp': datetime.now(),
|
116
|
-
'result': result
|
117
|
-
})
|
118
|
-
return result
|
119
|
-
else:
|
120
|
-
print(f"SQL (未执行): {description or 'SQL语句'}")
|
121
|
-
print(f" {sql[:100]}...")
|
122
|
-
return None
|
123
|
-
|
124
|
-
def create_table(self, table_name: str, select_sql: str,
|
125
|
-
execute: bool = False) -> Dict[str, Any]:
|
126
|
-
"""
|
127
|
-
创建表
|
128
|
-
|
129
|
-
Args:
|
130
|
-
table_name: 表名
|
131
|
-
select_sql: 选择SQL
|
132
|
-
execute: 是否立即执行
|
133
|
-
|
134
|
-
Returns:
|
135
|
-
操作结果
|
136
|
-
"""
|
137
|
-
full_table_name = self.get_full_table_name(table_name)
|
138
|
-
create_sql = self.generate_create_table_sql(full_table_name, select_sql)
|
139
|
-
|
140
|
-
result = {
|
141
|
-
'table_name': table_name,
|
142
|
-
'full_table_name': full_table_name,
|
143
|
-
'sql': create_sql,
|
144
|
-
'executed': execute
|
145
|
-
}
|
146
|
-
|
147
|
-
if execute:
|
148
|
-
exec_result = self.execute_sql(create_sql, f"创建表 {table_name}")
|
149
|
-
result['execution_result'] = exec_result
|
150
|
-
result['status'] = 'success' if exec_result is not None else 'simulated'
|
151
|
-
else:
|
152
|
-
result['status'] = 'prepared'
|
153
|
-
|
154
|
-
return result
|
155
|
-
|
156
|
-
# ==================== 数据下载方法 ====================
|
157
|
-
|
158
|
-
@abstractmethod
|
159
|
-
def download_table_data(self, table_name: str, output_path: str,
|
160
|
-
**kwargs) -> Dict[str, Any]:
|
161
|
-
"""
|
162
|
-
下载表数据 (子类必须实现)
|
163
|
-
|
164
|
-
Args:
|
165
|
-
table_name: 表名
|
166
|
-
output_path: 输出路径
|
167
|
-
**kwargs: 其他参数
|
168
|
-
|
169
|
-
Returns:
|
170
|
-
下载结果
|
171
|
-
"""
|
172
|
-
pass
|
173
|
-
|
174
|
-
def download_query_result(self, sql: str, output_path: str,
|
175
|
-
**kwargs) -> Dict[str, Any]:
|
176
|
-
"""
|
177
|
-
下载查询结果 (默认实现,子类可重写)
|
178
|
-
|
179
|
-
Args:
|
180
|
-
sql: 查询SQL
|
181
|
-
output_path: 输出路径
|
182
|
-
**kwargs: 其他参数
|
183
|
-
|
184
|
-
Returns:
|
185
|
-
下载结果
|
186
|
-
"""
|
187
|
-
# 创建临时表然后下载
|
188
|
-
temp_table = f"temp_query_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
189
|
-
|
190
|
-
try:
|
191
|
-
# 创建临时表
|
192
|
-
self.create_table(temp_table, sql, execute=True)
|
193
|
-
|
194
|
-
# 下载数据
|
195
|
-
result = self.download_table_data(temp_table, output_path, **kwargs)
|
196
|
-
|
197
|
-
# 清理临时表
|
198
|
-
self.drop_table(temp_table, execute=True)
|
199
|
-
|
200
|
-
return result
|
201
|
-
|
202
|
-
except Exception as e:
|
203
|
-
return {
|
204
|
-
'status': 'error',
|
205
|
-
'message': f"下载查询结果失败: {str(e)}",
|
206
|
-
'error': str(e)
|
207
|
-
}
|
208
|
-
|
209
|
-
# ==================== 工具方法 ====================
|
210
|
-
|
211
|
-
def get_full_table_name(self, table_name: str) -> str:
|
212
|
-
"""获取完整的表名(包含数据库名)"""
|
213
|
-
if '.' in table_name:
|
214
|
-
return table_name # 已经包含数据库名
|
215
|
-
return f"{self.database_name}.{table_name}"
|
216
|
-
|
217
|
-
def generate_table_name(self, base_name: str, year: int, month: int,
|
218
|
-
suffix: str = "raw") -> str:
|
219
|
-
"""
|
220
|
-
生成标准化的表名
|
221
|
-
格式: {base_name}_{yyyy}_{MM}_{suffix}
|
222
|
-
"""
|
223
|
-
return f"{base_name}_{year}_{month:02d}_{suffix}"
|
224
|
-
|
225
|
-
def drop_table(self, table_name: str, execute: bool = False) -> Dict[str, Any]:
|
226
|
-
"""删除表"""
|
227
|
-
full_table_name = self.get_full_table_name(table_name)
|
228
|
-
drop_sql = self.generate_drop_table_sql(full_table_name)
|
229
|
-
|
230
|
-
result = {
|
231
|
-
'table_name': table_name,
|
232
|
-
'full_table_name': full_table_name,
|
233
|
-
'sql': drop_sql,
|
234
|
-
'executed': execute
|
235
|
-
}
|
236
|
-
|
237
|
-
if execute:
|
238
|
-
exec_result = self.execute_sql(drop_sql, f"删除表 {table_name}")
|
239
|
-
result['execution_result'] = exec_result
|
240
|
-
result['status'] = 'success' if exec_result is not None else 'simulated'
|
241
|
-
else:
|
242
|
-
result['status'] = 'prepared'
|
243
|
-
|
244
|
-
return result
|
245
|
-
|
246
|
-
def get_execution_history(self) -> List[Dict]:
|
247
|
-
"""获取SQL执行历史"""
|
248
|
-
return self.execution_history.copy()
|
249
|
-
|
250
|
-
def clear_history(self):
|
251
|
-
"""清空执行历史"""
|
252
|
-
self.execution_history.clear()
|
253
|
-
|
254
|
-
def __str__(self):
|
255
|
-
return f"{self.__class__.__name__}(db={self.database_name}, type={self.get_engine_type().value})"
|
staran/engines/hive.py
DELETED
@@ -1,163 +0,0 @@
|
|
1
|
-
#!/usr/bin/env python3
|
2
|
-
# -*- coding: utf-8 -*-
|
3
|
-
|
4
|
-
"""
|
5
|
-
Hive数据库引擎
|
6
|
-
实现Hive SQL的生成、执行和数据下载
|
7
|
-
"""
|
8
|
-
|
9
|
-
from typing import Dict, Any, Optional, List, Callable
|
10
|
-
from .base import BaseEngine, DatabaseType
|
11
|
-
|
12
|
-
|
13
|
-
class HiveEngine(BaseEngine):
|
14
|
-
"""Hive数据库引擎"""
|
15
|
-
|
16
|
-
def __init__(self, database_name: str, sql_executor: Optional[Callable] = None):
|
17
|
-
super().__init__(database_name, sql_executor)
|
18
|
-
|
19
|
-
def get_engine_type(self) -> DatabaseType:
|
20
|
-
return DatabaseType.HIVE
|
21
|
-
|
22
|
-
def get_engine_name(self) -> str:
|
23
|
-
return "Apache Hive"
|
24
|
-
|
25
|
-
# ==================== SQL生成方法 ====================
|
26
|
-
|
27
|
-
def generate_create_table_sql(self, table_name: str, select_sql: str,
|
28
|
-
if_not_exists: bool = True) -> str:
|
29
|
-
"""生成Hive创建表的SQL"""
|
30
|
-
if_not_exists_clause = "IF NOT EXISTS " if if_not_exists else ""
|
31
|
-
|
32
|
-
return f"""
|
33
|
-
CREATE TABLE {if_not_exists_clause}{table_name}
|
34
|
-
STORED AS PARQUET
|
35
|
-
AS (
|
36
|
-
{select_sql}
|
37
|
-
)
|
38
|
-
""".strip()
|
39
|
-
|
40
|
-
def generate_insert_sql(self, table_name: str, select_sql: str) -> str:
|
41
|
-
"""生成Hive插入数据的SQL"""
|
42
|
-
return f"""
|
43
|
-
INSERT INTO TABLE {table_name} (
|
44
|
-
{select_sql}
|
45
|
-
)
|
46
|
-
""".strip()
|
47
|
-
|
48
|
-
def generate_drop_table_sql(self, table_name: str, if_exists: bool = True) -> str:
|
49
|
-
"""生成Hive删除表的SQL"""
|
50
|
-
if_exists_clause = "IF EXISTS " if if_exists else ""
|
51
|
-
return f"DROP TABLE {if_exists_clause}{table_name}"
|
52
|
-
|
53
|
-
def generate_aggregation_sql(self, schema, year: int, month: int,
|
54
|
-
aggregation_types: List[str]) -> str:
|
55
|
-
"""生成Hive聚合特征SQL"""
|
56
|
-
base_table = self.get_full_table_name(schema.table_name)
|
57
|
-
pk_field = schema.primary_key
|
58
|
-
date_field = schema.date_field
|
59
|
-
|
60
|
-
# 获取可聚合字段
|
61
|
-
agg_fields = [field for field in schema.fields.values() if field.aggregatable]
|
62
|
-
|
63
|
-
# 构建聚合选择语句
|
64
|
-
select_parts = [
|
65
|
-
pk_field,
|
66
|
-
f"'{year}-{month:02d}-01' as feature_month",
|
67
|
-
f"COUNT(*) as record_count"
|
68
|
-
]
|
69
|
-
|
70
|
-
for field in agg_fields:
|
71
|
-
for agg_type in aggregation_types:
|
72
|
-
alias = f"{field.name}_{agg_type}"
|
73
|
-
if agg_type.lower() == 'sum':
|
74
|
-
select_parts.append(f"SUM(CAST({field.name} AS DOUBLE)) as {alias}")
|
75
|
-
elif agg_type.lower() == 'avg':
|
76
|
-
select_parts.append(f"AVG(CAST({field.name} AS DOUBLE)) as {alias}")
|
77
|
-
elif agg_type.lower() == 'count':
|
78
|
-
select_parts.append(f"COUNT({field.name}) as {alias}")
|
79
|
-
elif agg_type.lower() == 'max':
|
80
|
-
select_parts.append(f"MAX(CAST({field.name} AS DOUBLE)) as {alias}")
|
81
|
-
elif agg_type.lower() == 'min':
|
82
|
-
select_parts.append(f"MIN(CAST({field.name} AS DOUBLE)) as {alias}")
|
83
|
-
else:
|
84
|
-
select_parts.append(f"{agg_type.upper()}({field.name}) as {alias}")
|
85
|
-
|
86
|
-
sql = f"""
|
87
|
-
SELECT {', '.join(select_parts)}
|
88
|
-
FROM {base_table}
|
89
|
-
WHERE year({date_field}) = {year}
|
90
|
-
AND month({date_field}) = {month}
|
91
|
-
GROUP BY {pk_field}
|
92
|
-
""".strip()
|
93
|
-
|
94
|
-
return sql
|
95
|
-
|
96
|
-
# ==================== 数据下载方法 ====================
|
97
|
-
|
98
|
-
def download_table_data(self, table_name: str, output_path: str,
|
99
|
-
format: str = "textfile", delimiter: str = "\t",
|
100
|
-
**kwargs) -> Dict[str, Any]:
|
101
|
-
"""
|
102
|
-
下载Hive表数据
|
103
|
-
|
104
|
-
Args:
|
105
|
-
table_name: 表名
|
106
|
-
output_path: 输出路径
|
107
|
-
format: 输出格式 (textfile, parquet等)
|
108
|
-
delimiter: 分隔符 (仅对textfile有效)
|
109
|
-
**kwargs: 其他参数
|
110
|
-
|
111
|
-
Returns:
|
112
|
-
下载结果
|
113
|
-
"""
|
114
|
-
full_table_name = self.get_full_table_name(table_name)
|
115
|
-
|
116
|
-
# 构建Hive导出SQL
|
117
|
-
if format.lower() == "textfile":
|
118
|
-
export_sql = f"""
|
119
|
-
INSERT OVERWRITE DIRECTORY '{output_path}'
|
120
|
-
ROW FORMAT DELIMITED
|
121
|
-
FIELDS TERMINATED BY '{delimiter}'
|
122
|
-
SELECT * FROM {full_table_name}
|
123
|
-
"""
|
124
|
-
else:
|
125
|
-
# 对于其他格式,使用CREATE TABLE AS的方式
|
126
|
-
temp_table = f"temp_export_{table_name.replace('.', '_')}"
|
127
|
-
export_sql = f"""
|
128
|
-
CREATE TABLE {temp_table}
|
129
|
-
STORED AS {format.upper()}
|
130
|
-
LOCATION '{output_path}'
|
131
|
-
AS SELECT * FROM {full_table_name}
|
132
|
-
"""
|
133
|
-
|
134
|
-
try:
|
135
|
-
if self.sql_executor:
|
136
|
-
result = self.sql_executor(export_sql)
|
137
|
-
return {
|
138
|
-
'status': 'success',
|
139
|
-
'message': f'数据已导出到: {output_path}',
|
140
|
-
'table_name': table_name,
|
141
|
-
'output_path': output_path,
|
142
|
-
'format': format,
|
143
|
-
'export_sql': export_sql,
|
144
|
-
'execution_result': result
|
145
|
-
}
|
146
|
-
else:
|
147
|
-
return {
|
148
|
-
'status': 'simulated',
|
149
|
-
'message': f'模拟导出到: {output_path}',
|
150
|
-
'table_name': table_name,
|
151
|
-
'output_path': output_path,
|
152
|
-
'format': format,
|
153
|
-
'export_sql': export_sql
|
154
|
-
}
|
155
|
-
|
156
|
-
except Exception as e:
|
157
|
-
return {
|
158
|
-
'status': 'error',
|
159
|
-
'message': f"导出失败: {str(e)}",
|
160
|
-
'table_name': table_name,
|
161
|
-
'error': str(e),
|
162
|
-
'export_sql': export_sql
|
163
|
-
}
|
staran/engines/spark.py
DELETED
@@ -1,252 +0,0 @@
|
|
1
|
-
#!/usr/bin/env python3
|
2
|
-
# -*- coding: utf-8 -*-
|
3
|
-
|
4
|
-
"""
|
5
|
-
Spark数据库引擎
|
6
|
-
实现Spark SQL的生成、执行和数据下载
|
7
|
-
"""
|
8
|
-
|
9
|
-
from typing import Dict, Any, Optional, List, Callable
|
10
|
-
from .base import BaseEngine, DatabaseType
|
11
|
-
|
12
|
-
|
13
|
-
class SparkEngine(BaseEngine):
|
14
|
-
"""Spark数据库引擎"""
|
15
|
-
|
16
|
-
def __init__(self, database_name: str, sql_executor: Optional[Callable] = None):
|
17
|
-
super().__init__(database_name, sql_executor)
|
18
|
-
|
19
|
-
def get_engine_type(self) -> DatabaseType:
|
20
|
-
return DatabaseType.SPARK
|
21
|
-
|
22
|
-
def get_engine_name(self) -> str:
|
23
|
-
return "Apache Spark"
|
24
|
-
|
25
|
-
# ==================== SQL生成方法 ====================
|
26
|
-
|
27
|
-
def generate_create_table_sql(self, table_name: str, select_sql: str,
|
28
|
-
if_not_exists: bool = True) -> str:
|
29
|
-
"""生成Spark创建表的SQL"""
|
30
|
-
if_not_exists_clause = "IF NOT EXISTS " if if_not_exists else ""
|
31
|
-
|
32
|
-
return f"""
|
33
|
-
CREATE TABLE {if_not_exists_clause}{table_name}
|
34
|
-
USING DELTA
|
35
|
-
AS (
|
36
|
-
{select_sql}
|
37
|
-
)
|
38
|
-
""".strip()
|
39
|
-
|
40
|
-
def generate_insert_sql(self, table_name: str, select_sql: str) -> str:
|
41
|
-
"""生成Spark插入数据的SQL"""
|
42
|
-
return f"""
|
43
|
-
INSERT INTO {table_name} (
|
44
|
-
{select_sql}
|
45
|
-
)
|
46
|
-
""".strip()
|
47
|
-
|
48
|
-
def generate_drop_table_sql(self, table_name: str, if_exists: bool = True) -> str:
|
49
|
-
"""生成Spark删除表的SQL"""
|
50
|
-
if_exists_clause = "IF EXISTS " if if_exists else ""
|
51
|
-
return f"DROP TABLE {if_exists_clause}{table_name}"
|
52
|
-
|
53
|
-
def generate_aggregation_sql(self, schema, year: int, month: int,
|
54
|
-
aggregation_types: List[str]) -> str:
|
55
|
-
"""生成Spark聚合特征SQL"""
|
56
|
-
base_table = self.get_full_table_name(schema.table_name)
|
57
|
-
pk_field = schema.primary_key
|
58
|
-
date_field = schema.date_field
|
59
|
-
|
60
|
-
# 获取可聚合字段
|
61
|
-
agg_fields = [field for field in schema.fields.values() if field.aggregatable]
|
62
|
-
|
63
|
-
# 构建聚合选择语句
|
64
|
-
select_parts = [
|
65
|
-
pk_field,
|
66
|
-
f"'{year}-{month:02d}-01' as feature_month",
|
67
|
-
f"COUNT(*) as record_count"
|
68
|
-
]
|
69
|
-
|
70
|
-
for field in agg_fields:
|
71
|
-
for agg_type in aggregation_types:
|
72
|
-
alias = f"{field.name}_{agg_type}"
|
73
|
-
if agg_type.lower() == 'sum':
|
74
|
-
select_parts.append(f"SUM(CAST({field.name} AS DOUBLE)) as {alias}")
|
75
|
-
elif agg_type.lower() == 'avg':
|
76
|
-
select_parts.append(f"AVG(CAST({field.name} AS DOUBLE)) as {alias}")
|
77
|
-
elif agg_type.lower() == 'count':
|
78
|
-
select_parts.append(f"COUNT({field.name}) as {alias}")
|
79
|
-
elif agg_type.lower() == 'max':
|
80
|
-
select_parts.append(f"MAX(CAST({field.name} AS DOUBLE)) as {alias}")
|
81
|
-
elif agg_type.lower() == 'min':
|
82
|
-
select_parts.append(f"MIN(CAST({field.name} AS DOUBLE)) as {alias}")
|
83
|
-
else:
|
84
|
-
select_parts.append(f"{agg_type.upper()}({field.name}) as {alias}")
|
85
|
-
|
86
|
-
sql = f"""
|
87
|
-
SELECT {', '.join(select_parts)}
|
88
|
-
FROM {base_table}
|
89
|
-
WHERE year({date_field}) = {year}
|
90
|
-
AND month({date_field}) = {month}
|
91
|
-
GROUP BY {pk_field}
|
92
|
-
""".strip()
|
93
|
-
|
94
|
-
return sql
|
95
|
-
|
96
|
-
def generate_mom_sql(self, schema, year: int, month: int,
|
97
|
-
periods: List[int] = [1]) -> str:
|
98
|
-
"""生成环比特征SQL"""
|
99
|
-
base_table = self.get_full_table_name(schema.table_name)
|
100
|
-
pk_field = schema.primary_key
|
101
|
-
date_field = schema.date_field
|
102
|
-
|
103
|
-
# 获取可聚合字段
|
104
|
-
agg_fields = [f for f in schema.fields if f.aggregatable]
|
105
|
-
|
106
|
-
# 构建环比查询
|
107
|
-
select_parts = [
|
108
|
-
f"curr.{pk_field}",
|
109
|
-
f"curr.feature_month"
|
110
|
-
]
|
111
|
-
|
112
|
-
for field in agg_fields:
|
113
|
-
for period in periods:
|
114
|
-
for agg_type in ['sum', 'avg']:
|
115
|
-
curr_field = f"curr.{field.name}_{agg_type}"
|
116
|
-
prev_field = f"prev{period}.{field.name}_{agg_type}"
|
117
|
-
|
118
|
-
# 环比增长率
|
119
|
-
alias = f"{field.name}_{agg_type}_mom_{period}m"
|
120
|
-
select_parts.append(f"""
|
121
|
-
CASE
|
122
|
-
WHEN {prev_field} IS NULL OR {prev_field} = 0 THEN NULL
|
123
|
-
ELSE ({curr_field} - {prev_field}) / {prev_field}
|
124
|
-
END as {alias}
|
125
|
-
""".strip())
|
126
|
-
|
127
|
-
# 环比差值
|
128
|
-
diff_alias = f"{field.name}_{agg_type}_diff_{period}m"
|
129
|
-
select_parts.append(f"({curr_field} - {prev_field}) as {diff_alias}")
|
130
|
-
|
131
|
-
# 构建FROM子句和JOIN
|
132
|
-
from_clause = f"""
|
133
|
-
FROM (
|
134
|
-
SELECT {pk_field}, feature_month, {', '.join([f'{f.name}_sum, {f.name}_avg' for f in agg_fields])}
|
135
|
-
FROM {base_table}_aggregation_{year}_{month:02d}_1
|
136
|
-
) curr
|
137
|
-
"""
|
138
|
-
|
139
|
-
for period in periods:
|
140
|
-
prev_year = year
|
141
|
-
prev_month = month - period
|
142
|
-
if prev_month <= 0:
|
143
|
-
prev_month += 12
|
144
|
-
prev_year -= 1
|
145
|
-
|
146
|
-
from_clause += f"""
|
147
|
-
LEFT JOIN (
|
148
|
-
SELECT {pk_field}, {', '.join([f'{f.name}_sum, {f.name}_avg' for f in agg_fields])}
|
149
|
-
FROM {base_table}_aggregation_{prev_year}_{prev_month:02d}_1
|
150
|
-
) prev{period} ON curr.{pk_field} = prev{period}.{pk_field}
|
151
|
-
"""
|
152
|
-
|
153
|
-
sql = f"SELECT {', '.join(select_parts)} {from_clause}"
|
154
|
-
return sql.strip()
|
155
|
-
|
156
|
-
# ==================== 数据下载方法 ====================
|
157
|
-
|
158
|
-
def download_table_data(self, table_name: str, output_path: str,
|
159
|
-
format: str = "parquet", mode: str = "overwrite",
|
160
|
-
**kwargs) -> Dict[str, Any]:
|
161
|
-
"""
|
162
|
-
下载Spark表数据
|
163
|
-
|
164
|
-
Args:
|
165
|
-
table_name: 表名
|
166
|
-
output_path: 输出路径
|
167
|
-
format: 输出格式 (parquet, csv, json等)
|
168
|
-
mode: 写入模式 (overwrite, append)
|
169
|
-
**kwargs: 其他参数
|
170
|
-
|
171
|
-
Returns:
|
172
|
-
下载结果
|
173
|
-
"""
|
174
|
-
full_table_name = self.get_full_table_name(table_name)
|
175
|
-
|
176
|
-
# 构建Spark下载SQL/代码
|
177
|
-
spark_code = f"""
|
178
|
-
df = spark.sql("SELECT * FROM {full_table_name}")
|
179
|
-
df.write.mode("{mode}").format("{format}").save("{output_path}")
|
180
|
-
"""
|
181
|
-
|
182
|
-
try:
|
183
|
-
if self.sql_executor:
|
184
|
-
# 如果有执行器,尝试执行
|
185
|
-
result = self.sql_executor(spark_code)
|
186
|
-
return {
|
187
|
-
'status': 'success',
|
188
|
-
'message': f'数据已下载到: {output_path}',
|
189
|
-
'table_name': table_name,
|
190
|
-
'output_path': output_path,
|
191
|
-
'format': format,
|
192
|
-
'spark_code': spark_code,
|
193
|
-
'execution_result': result
|
194
|
-
}
|
195
|
-
else:
|
196
|
-
# 模拟模式
|
197
|
-
return {
|
198
|
-
'status': 'simulated',
|
199
|
-
'message': f'模拟下载到: {output_path}',
|
200
|
-
'table_name': table_name,
|
201
|
-
'output_path': output_path,
|
202
|
-
'format': format,
|
203
|
-
'spark_code': spark_code
|
204
|
-
}
|
205
|
-
|
206
|
-
except Exception as e:
|
207
|
-
return {
|
208
|
-
'status': 'error',
|
209
|
-
'message': f"下载失败: {str(e)}",
|
210
|
-
'table_name': table_name,
|
211
|
-
'error': str(e),
|
212
|
-
'spark_code': spark_code
|
213
|
-
}
|
214
|
-
|
215
|
-
def download_query_result(self, sql: str, output_path: str,
|
216
|
-
format: str = "parquet", mode: str = "overwrite",
|
217
|
-
**kwargs) -> Dict[str, Any]:
|
218
|
-
"""直接下载查询结果,不创建临时表"""
|
219
|
-
spark_code = f"""
|
220
|
-
df = spark.sql(\"\"\"
|
221
|
-
{sql}
|
222
|
-
\"\"\")
|
223
|
-
df.write.mode("{mode}").format("{format}").save("{output_path}")
|
224
|
-
"""
|
225
|
-
|
226
|
-
try:
|
227
|
-
if self.sql_executor:
|
228
|
-
result = self.sql_executor(spark_code)
|
229
|
-
return {
|
230
|
-
'status': 'success',
|
231
|
-
'message': f'查询结果已下载到: {output_path}',
|
232
|
-
'output_path': output_path,
|
233
|
-
'format': format,
|
234
|
-
'spark_code': spark_code,
|
235
|
-
'execution_result': result
|
236
|
-
}
|
237
|
-
else:
|
238
|
-
return {
|
239
|
-
'status': 'simulated',
|
240
|
-
'message': f'模拟下载查询结果到: {output_path}',
|
241
|
-
'output_path': output_path,
|
242
|
-
'format': format,
|
243
|
-
'spark_code': spark_code
|
244
|
-
}
|
245
|
-
|
246
|
-
except Exception as e:
|
247
|
-
return {
|
248
|
-
'status': 'error',
|
249
|
-
'message': f"下载查询结果失败: {str(e)}",
|
250
|
-
'error': str(e),
|
251
|
-
'spark_code': spark_code
|
252
|
-
}
|