staran 0.3.0__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
staran/__init__.py CHANGED
@@ -158,9 +158,13 @@ from .features import (
158
158
  )
159
159
  from .engines import SparkEngine, HiveEngine, TuringEngine, BaseEngine, create_engine, create_turing_engine
160
160
 
161
- # 示例模块
161
+ # Examples模块 - 业务示例
162
162
  from .examples import create_aum_example, run_aum_example
163
163
 
164
+ # Schemas模块 - 表结构定义与文档生成
165
+ from .schemas import SchemaDocumentGenerator
166
+ from .schemas.aum import get_aum_schemas, export_aum_docs
167
+
164
168
  # 图灵平台引擎 (可选导入,避免依赖问题)
165
169
  try:
166
170
  from .features import quick_create_and_download
@@ -194,6 +198,10 @@ __all__ = [
194
198
  # 示例模块
195
199
  'create_aum_example',
196
200
  'run_aum_example',
201
+ # 表结构模块
202
+ 'SchemaDocumentGenerator',
203
+ 'get_aum_schemas',
204
+ 'export_aum_docs',
197
205
  # 向后兼容
198
206
  'SQLManager',
199
207
  'SparkSQLGenerator'
@@ -206,9 +214,9 @@ if _TURING_AVAILABLE:
206
214
  ])
207
215
 
208
216
  # 包信息
209
- __version__ = '0.2.4'
217
+ __version__ = '0.3.0'
210
218
  __author__ = 'Staran Team'
211
- __description__ = 'Smart feature engineering toolkit with modular engine architecture and examples'
219
+ __description__ = 'Smart feature engineering toolkit with schema management, document generation and business examples'
212
220
  __license__ = 'MIT'
213
221
 
214
222
  # 便捷函数示例
@@ -0,0 +1,65 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ 数据库引擎模块
6
+ 提供统一的数据库引擎接口
7
+ """
8
+
9
+ # 基础组件
10
+ from .base import BaseEngine, DatabaseType
11
+
12
+ # 具体引擎实现
13
+ from .spark import SparkEngine
14
+ from .hive import HiveEngine
15
+
16
+ # 图灵平台引擎 (可选导入)
17
+ try:
18
+ from .turing import TuringEngine, create_turing_engine
19
+ _TURING_AVAILABLE = True
20
+ except ImportError:
21
+ TuringEngine = None
22
+ create_turing_engine = None
23
+ _TURING_AVAILABLE = False
24
+
25
+ # 便捷创建函数
26
+ def create_engine(engine_type: str, database_name: str, **kwargs) -> BaseEngine:
27
+ """
28
+ 创建数据库引擎的便捷函数
29
+
30
+ Args:
31
+ engine_type: 引擎类型 ('spark', 'hive', 'turing')
32
+ database_name: 数据库名称
33
+ **kwargs: 其他参数
34
+
35
+ Returns:
36
+ 数据库引擎实例
37
+ """
38
+ engine_type = engine_type.lower()
39
+
40
+ if engine_type == 'spark':
41
+ return SparkEngine(database_name, **kwargs)
42
+ elif engine_type == 'hive':
43
+ return HiveEngine(database_name, **kwargs)
44
+ elif engine_type == 'turing':
45
+ if not _TURING_AVAILABLE:
46
+ raise ImportError("TuringEngine不可用,请确保turingPythonLib已安装")
47
+ return TuringEngine(database_name, **kwargs)
48
+ else:
49
+ raise ValueError(f"不支持的引擎类型: {engine_type}")
50
+
51
+ # 主要导出
52
+ __all__ = [
53
+ 'BaseEngine',
54
+ 'DatabaseType',
55
+ 'SparkEngine',
56
+ 'HiveEngine',
57
+ 'create_engine'
58
+ ]
59
+
60
+ # 如果图灵引擎可用,添加到导出
61
+ if _TURING_AVAILABLE:
62
+ __all__.extend([
63
+ 'TuringEngine',
64
+ 'create_turing_engine'
65
+ ])
staran/engines/base.py ADDED
@@ -0,0 +1,255 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ 数据库引擎基类
6
+ 定义统一的SQL生成、执行和数据下载接口
7
+ """
8
+
9
+ from abc import ABC, abstractmethod
10
+ from typing import Dict, Any, Optional, List, Callable
11
+ from enum import Enum
12
+ from datetime import datetime
13
+
14
+
15
+ class DatabaseType(Enum):
16
+ """数据库类型枚举"""
17
+ SPARK = "spark"
18
+ HIVE = "hive"
19
+ MYSQL = "mysql"
20
+ POSTGRESQL = "postgresql"
21
+
22
+
23
+ class BaseEngine(ABC):
24
+ """
25
+ 数据库引擎基类
26
+ 整合SQL生成、执行和数据下载功能
27
+ """
28
+
29
+ def __init__(self, database_name: str, sql_executor: Optional[Callable] = None):
30
+ """
31
+ 初始化引擎
32
+
33
+ Args:
34
+ database_name: 数据库名称
35
+ sql_executor: SQL执行器函数 (可选)
36
+ """
37
+ self.database_name = database_name
38
+ self.sql_executor = sql_executor
39
+ self.execution_history = []
40
+
41
+ @abstractmethod
42
+ def get_engine_type(self) -> DatabaseType:
43
+ """获取引擎类型"""
44
+ pass
45
+
46
+ @abstractmethod
47
+ def get_engine_name(self) -> str:
48
+ """获取引擎名称"""
49
+ pass
50
+
51
+ # ==================== SQL生成方法 ====================
52
+
53
+ @abstractmethod
54
+ def generate_create_table_sql(self, table_name: str, select_sql: str,
55
+ if_not_exists: bool = True) -> str:
56
+ """生成创建表的SQL"""
57
+ pass
58
+
59
+ @abstractmethod
60
+ def generate_insert_sql(self, table_name: str, select_sql: str) -> str:
61
+ """生成插入数据的SQL"""
62
+ pass
63
+
64
+ @abstractmethod
65
+ def generate_drop_table_sql(self, table_name: str, if_exists: bool = True) -> str:
66
+ """生成删除表的SQL"""
67
+ pass
68
+
69
+ def generate_aggregation_sql(self, schema, year: int, month: int,
70
+ aggregation_types: List[str]) -> str:
71
+ """生成聚合特征SQL (可被子类重写)"""
72
+ base_table = self.get_full_table_name(schema.table_name)
73
+ pk_field = schema.primary_key
74
+ date_field = schema.date_field
75
+
76
+ # 获取可聚合字段
77
+ agg_fields = [field for field in schema.fields.values() if field.aggregatable]
78
+
79
+ # 构建聚合选择语句
80
+ select_parts = [pk_field, f"'{year}-{month:02d}-01' as feature_month"]
81
+
82
+ for field in agg_fields:
83
+ for agg_type in aggregation_types:
84
+ alias = f"{field.name}_{agg_type}"
85
+ select_parts.append(f"{agg_type.upper()}({field.name}) as {alias}")
86
+
87
+ sql = f"""
88
+ SELECT {', '.join(select_parts)}
89
+ FROM {base_table}
90
+ WHERE YEAR({date_field}) = {year}
91
+ AND MONTH({date_field}) = {month}
92
+ GROUP BY {pk_field}
93
+ """
94
+
95
+ return sql.strip()
96
+
97
+ # ==================== SQL执行方法 ====================
98
+
99
+ def execute_sql(self, sql: str, description: str = "") -> Any:
100
+ """
101
+ 执行SQL语句
102
+
103
+ Args:
104
+ sql: SQL语句
105
+ description: 执行描述
106
+
107
+ Returns:
108
+ 执行结果
109
+ """
110
+ if self.sql_executor:
111
+ result = self.sql_executor(sql)
112
+ self.execution_history.append({
113
+ 'sql': sql,
114
+ 'description': description,
115
+ 'timestamp': datetime.now(),
116
+ 'result': result
117
+ })
118
+ return result
119
+ else:
120
+ print(f"SQL (未执行): {description or 'SQL语句'}")
121
+ print(f" {sql[:100]}...")
122
+ return None
123
+
124
+ def create_table(self, table_name: str, select_sql: str,
125
+ execute: bool = False) -> Dict[str, Any]:
126
+ """
127
+ 创建表
128
+
129
+ Args:
130
+ table_name: 表名
131
+ select_sql: 选择SQL
132
+ execute: 是否立即执行
133
+
134
+ Returns:
135
+ 操作结果
136
+ """
137
+ full_table_name = self.get_full_table_name(table_name)
138
+ create_sql = self.generate_create_table_sql(full_table_name, select_sql)
139
+
140
+ result = {
141
+ 'table_name': table_name,
142
+ 'full_table_name': full_table_name,
143
+ 'sql': create_sql,
144
+ 'executed': execute
145
+ }
146
+
147
+ if execute:
148
+ exec_result = self.execute_sql(create_sql, f"创建表 {table_name}")
149
+ result['execution_result'] = exec_result
150
+ result['status'] = 'success' if exec_result is not None else 'simulated'
151
+ else:
152
+ result['status'] = 'prepared'
153
+
154
+ return result
155
+
156
+ # ==================== 数据下载方法 ====================
157
+
158
+ @abstractmethod
159
+ def download_table_data(self, table_name: str, output_path: str,
160
+ **kwargs) -> Dict[str, Any]:
161
+ """
162
+ 下载表数据 (子类必须实现)
163
+
164
+ Args:
165
+ table_name: 表名
166
+ output_path: 输出路径
167
+ **kwargs: 其他参数
168
+
169
+ Returns:
170
+ 下载结果
171
+ """
172
+ pass
173
+
174
+ def download_query_result(self, sql: str, output_path: str,
175
+ **kwargs) -> Dict[str, Any]:
176
+ """
177
+ 下载查询结果 (默认实现,子类可重写)
178
+
179
+ Args:
180
+ sql: 查询SQL
181
+ output_path: 输出路径
182
+ **kwargs: 其他参数
183
+
184
+ Returns:
185
+ 下载结果
186
+ """
187
+ # 创建临时表然后下载
188
+ temp_table = f"temp_query_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
189
+
190
+ try:
191
+ # 创建临时表
192
+ self.create_table(temp_table, sql, execute=True)
193
+
194
+ # 下载数据
195
+ result = self.download_table_data(temp_table, output_path, **kwargs)
196
+
197
+ # 清理临时表
198
+ self.drop_table(temp_table, execute=True)
199
+
200
+ return result
201
+
202
+ except Exception as e:
203
+ return {
204
+ 'status': 'error',
205
+ 'message': f"下载查询结果失败: {str(e)}",
206
+ 'error': str(e)
207
+ }
208
+
209
+ # ==================== 工具方法 ====================
210
+
211
+ def get_full_table_name(self, table_name: str) -> str:
212
+ """获取完整的表名(包含数据库名)"""
213
+ if '.' in table_name:
214
+ return table_name # 已经包含数据库名
215
+ return f"{self.database_name}.{table_name}"
216
+
217
+ def generate_table_name(self, base_name: str, year: int, month: int,
218
+ suffix: str = "raw") -> str:
219
+ """
220
+ 生成标准化的表名
221
+ 格式: {base_name}_{yyyy}_{MM}_{suffix}
222
+ """
223
+ return f"{base_name}_{year}_{month:02d}_{suffix}"
224
+
225
+ def drop_table(self, table_name: str, execute: bool = False) -> Dict[str, Any]:
226
+ """删除表"""
227
+ full_table_name = self.get_full_table_name(table_name)
228
+ drop_sql = self.generate_drop_table_sql(full_table_name)
229
+
230
+ result = {
231
+ 'table_name': table_name,
232
+ 'full_table_name': full_table_name,
233
+ 'sql': drop_sql,
234
+ 'executed': execute
235
+ }
236
+
237
+ if execute:
238
+ exec_result = self.execute_sql(drop_sql, f"删除表 {table_name}")
239
+ result['execution_result'] = exec_result
240
+ result['status'] = 'success' if exec_result is not None else 'simulated'
241
+ else:
242
+ result['status'] = 'prepared'
243
+
244
+ return result
245
+
246
+ def get_execution_history(self) -> List[Dict]:
247
+ """获取SQL执行历史"""
248
+ return self.execution_history.copy()
249
+
250
+ def clear_history(self):
251
+ """清空执行历史"""
252
+ self.execution_history.clear()
253
+
254
+ def __str__(self):
255
+ return f"{self.__class__.__name__}(db={self.database_name}, type={self.get_engine_type().value})"
staran/engines/hive.py ADDED
@@ -0,0 +1,163 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ Hive数据库引擎
6
+ 实现Hive SQL的生成、执行和数据下载
7
+ """
8
+
9
+ from typing import Dict, Any, Optional, List, Callable
10
+ from .base import BaseEngine, DatabaseType
11
+
12
+
13
+ class HiveEngine(BaseEngine):
14
+ """Hive数据库引擎"""
15
+
16
+ def __init__(self, database_name: str, sql_executor: Optional[Callable] = None):
17
+ super().__init__(database_name, sql_executor)
18
+
19
+ def get_engine_type(self) -> DatabaseType:
20
+ return DatabaseType.HIVE
21
+
22
+ def get_engine_name(self) -> str:
23
+ return "Apache Hive"
24
+
25
+ # ==================== SQL生成方法 ====================
26
+
27
+ def generate_create_table_sql(self, table_name: str, select_sql: str,
28
+ if_not_exists: bool = True) -> str:
29
+ """生成Hive创建表的SQL"""
30
+ if_not_exists_clause = "IF NOT EXISTS " if if_not_exists else ""
31
+
32
+ return f"""
33
+ CREATE TABLE {if_not_exists_clause}{table_name}
34
+ STORED AS PARQUET
35
+ AS (
36
+ {select_sql}
37
+ )
38
+ """.strip()
39
+
40
+ def generate_insert_sql(self, table_name: str, select_sql: str) -> str:
41
+ """生成Hive插入数据的SQL"""
42
+ return f"""
43
+ INSERT INTO TABLE {table_name} (
44
+ {select_sql}
45
+ )
46
+ """.strip()
47
+
48
+ def generate_drop_table_sql(self, table_name: str, if_exists: bool = True) -> str:
49
+ """生成Hive删除表的SQL"""
50
+ if_exists_clause = "IF EXISTS " if if_exists else ""
51
+ return f"DROP TABLE {if_exists_clause}{table_name}"
52
+
53
+ def generate_aggregation_sql(self, schema, year: int, month: int,
54
+ aggregation_types: List[str]) -> str:
55
+ """生成Hive聚合特征SQL"""
56
+ base_table = self.get_full_table_name(schema.table_name)
57
+ pk_field = schema.primary_key
58
+ date_field = schema.date_field
59
+
60
+ # 获取可聚合字段
61
+ agg_fields = [field for field in schema.fields.values() if field.aggregatable]
62
+
63
+ # 构建聚合选择语句
64
+ select_parts = [
65
+ pk_field,
66
+ f"'{year}-{month:02d}-01' as feature_month",
67
+ f"COUNT(*) as record_count"
68
+ ]
69
+
70
+ for field in agg_fields:
71
+ for agg_type in aggregation_types:
72
+ alias = f"{field.name}_{agg_type}"
73
+ if agg_type.lower() == 'sum':
74
+ select_parts.append(f"SUM(CAST({field.name} AS DOUBLE)) as {alias}")
75
+ elif agg_type.lower() == 'avg':
76
+ select_parts.append(f"AVG(CAST({field.name} AS DOUBLE)) as {alias}")
77
+ elif agg_type.lower() == 'count':
78
+ select_parts.append(f"COUNT({field.name}) as {alias}")
79
+ elif agg_type.lower() == 'max':
80
+ select_parts.append(f"MAX(CAST({field.name} AS DOUBLE)) as {alias}")
81
+ elif agg_type.lower() == 'min':
82
+ select_parts.append(f"MIN(CAST({field.name} AS DOUBLE)) as {alias}")
83
+ else:
84
+ select_parts.append(f"{agg_type.upper()}({field.name}) as {alias}")
85
+
86
+ sql = f"""
87
+ SELECT {', '.join(select_parts)}
88
+ FROM {base_table}
89
+ WHERE year({date_field}) = {year}
90
+ AND month({date_field}) = {month}
91
+ GROUP BY {pk_field}
92
+ """.strip()
93
+
94
+ return sql
95
+
96
+ # ==================== 数据下载方法 ====================
97
+
98
+ def download_table_data(self, table_name: str, output_path: str,
99
+ format: str = "textfile", delimiter: str = "\t",
100
+ **kwargs) -> Dict[str, Any]:
101
+ """
102
+ 下载Hive表数据
103
+
104
+ Args:
105
+ table_name: 表名
106
+ output_path: 输出路径
107
+ format: 输出格式 (textfile, parquet等)
108
+ delimiter: 分隔符 (仅对textfile有效)
109
+ **kwargs: 其他参数
110
+
111
+ Returns:
112
+ 下载结果
113
+ """
114
+ full_table_name = self.get_full_table_name(table_name)
115
+
116
+ # 构建Hive导出SQL
117
+ if format.lower() == "textfile":
118
+ export_sql = f"""
119
+ INSERT OVERWRITE DIRECTORY '{output_path}'
120
+ ROW FORMAT DELIMITED
121
+ FIELDS TERMINATED BY '{delimiter}'
122
+ SELECT * FROM {full_table_name}
123
+ """
124
+ else:
125
+ # 对于其他格式,使用CREATE TABLE AS的方式
126
+ temp_table = f"temp_export_{table_name.replace('.', '_')}"
127
+ export_sql = f"""
128
+ CREATE TABLE {temp_table}
129
+ STORED AS {format.upper()}
130
+ LOCATION '{output_path}'
131
+ AS SELECT * FROM {full_table_name}
132
+ """
133
+
134
+ try:
135
+ if self.sql_executor:
136
+ result = self.sql_executor(export_sql)
137
+ return {
138
+ 'status': 'success',
139
+ 'message': f'数据已导出到: {output_path}',
140
+ 'table_name': table_name,
141
+ 'output_path': output_path,
142
+ 'format': format,
143
+ 'export_sql': export_sql,
144
+ 'execution_result': result
145
+ }
146
+ else:
147
+ return {
148
+ 'status': 'simulated',
149
+ 'message': f'模拟导出到: {output_path}',
150
+ 'table_name': table_name,
151
+ 'output_path': output_path,
152
+ 'format': format,
153
+ 'export_sql': export_sql
154
+ }
155
+
156
+ except Exception as e:
157
+ return {
158
+ 'status': 'error',
159
+ 'message': f"导出失败: {str(e)}",
160
+ 'table_name': table_name,
161
+ 'error': str(e),
162
+ 'export_sql': export_sql
163
+ }