db_client_toolkit 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- db_client_toolkit-0.0.1.dist-info/METADATA +540 -0
- db_client_toolkit-0.0.1.dist-info/RECORD +22 -0
- db_client_toolkit-0.0.1.dist-info/WHEEL +4 -0
- db_toolkit/__init__.py +151 -0
- db_toolkit/clients/__init__.py +20 -0
- db_toolkit/clients/mongodb.py +251 -0
- db_toolkit/clients/mysql.py +143 -0
- db_toolkit/clients/postgresql.py +152 -0
- db_toolkit/clients/redis.py +321 -0
- db_toolkit/clients/sqlite.py +152 -0
- db_toolkit/clients/supabase.py +230 -0
- db_toolkit/core/__init__.py +9 -0
- db_toolkit/core/base.py +194 -0
- db_toolkit/core/sql_base.py +163 -0
- db_toolkit/exceptions/__init__.py +38 -0
- db_toolkit/mixins/__init__.py +12 -0
- db_toolkit/mixins/batch_ops.py +194 -0
- db_toolkit/mixins/transaction.py +206 -0
- db_toolkit/utils/__init__.py +15 -0
- db_toolkit/utils/config.py +252 -0
- db_toolkit/utils/factory.py +172 -0
- db_toolkit/utils/query_builder.py +316 -0
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
"""
|
|
2
|
+
事务混入类
|
|
3
|
+
为支持事务的数据库客户端提供事务管理功能
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
from contextlib import contextmanager
|
|
8
|
+
|
|
9
|
+
from ..exceptions import TransactionError
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class TransactionMixin:
|
|
16
|
+
"""
|
|
17
|
+
事务混入类
|
|
18
|
+
|
|
19
|
+
为数据库客户端添加事务管理功能
|
|
20
|
+
注意:仅适用于支持事务的数据库(如MySQL, PostgreSQL, SQLite)
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def begin(self) -> None:
|
|
24
|
+
"""
|
|
25
|
+
开始事务
|
|
26
|
+
|
|
27
|
+
Raises:
|
|
28
|
+
TransactionError: 开始事务失败时
|
|
29
|
+
"""
|
|
30
|
+
if not hasattr(self, 'connection') or not self.connection:
|
|
31
|
+
raise TransactionError("数据库未连接")
|
|
32
|
+
|
|
33
|
+
try:
|
|
34
|
+
# 不同数据库的事务开始方式
|
|
35
|
+
if hasattr(self.connection, 'begin'):
|
|
36
|
+
self.connection.begin()
|
|
37
|
+
elif hasattr(self.connection, 'autocommit'):
|
|
38
|
+
self.connection.autocommit = False
|
|
39
|
+
else:
|
|
40
|
+
# 对于不显式支持begin的数据库,通常自动开始事务
|
|
41
|
+
pass
|
|
42
|
+
|
|
43
|
+
logger.debug("事务已开始")
|
|
44
|
+
|
|
45
|
+
except Exception as e:
|
|
46
|
+
logger.error(f"开始事务失败: {e}")
|
|
47
|
+
raise TransactionError(f"开始事务失败: {e}")
|
|
48
|
+
|
|
49
|
+
def commit(self) -> None:
|
|
50
|
+
"""
|
|
51
|
+
提交事务
|
|
52
|
+
|
|
53
|
+
Raises:
|
|
54
|
+
TransactionError: 提交事务失败时
|
|
55
|
+
"""
|
|
56
|
+
if not hasattr(self, 'connection') or not self.connection:
|
|
57
|
+
raise TransactionError("数据库未连接")
|
|
58
|
+
|
|
59
|
+
try:
|
|
60
|
+
if hasattr(self.connection, 'commit'):
|
|
61
|
+
self.connection.commit()
|
|
62
|
+
logger.debug("事务已提交")
|
|
63
|
+
else:
|
|
64
|
+
logger.warning("数据库不支持事务提交")
|
|
65
|
+
|
|
66
|
+
except Exception as e:
|
|
67
|
+
logger.error(f"提交事务失败: {e}")
|
|
68
|
+
raise TransactionError(f"提交事务失败: {e}")
|
|
69
|
+
|
|
70
|
+
def rollback(self) -> None:
|
|
71
|
+
"""
|
|
72
|
+
回滚事务
|
|
73
|
+
|
|
74
|
+
Raises:
|
|
75
|
+
TransactionError: 回滚事务失败时
|
|
76
|
+
"""
|
|
77
|
+
if not hasattr(self, 'connection') or not self.connection:
|
|
78
|
+
raise TransactionError("数据库未连接")
|
|
79
|
+
|
|
80
|
+
try:
|
|
81
|
+
if hasattr(self.connection, 'rollback'):
|
|
82
|
+
self.connection.rollback()
|
|
83
|
+
logger.debug("事务已回滚")
|
|
84
|
+
else:
|
|
85
|
+
logger.warning("数据库不支持事务回滚")
|
|
86
|
+
|
|
87
|
+
except Exception as e:
|
|
88
|
+
logger.error(f"回滚事务失败: {e}")
|
|
89
|
+
raise TransactionError(f"回滚事务失败: {e}")
|
|
90
|
+
|
|
91
|
+
@contextmanager
|
|
92
|
+
def transaction(self):
|
|
93
|
+
"""
|
|
94
|
+
事务上下文管理器
|
|
95
|
+
|
|
96
|
+
Yields:
|
|
97
|
+
self: 数据库客户端实例
|
|
98
|
+
|
|
99
|
+
Examples:
|
|
100
|
+
>>> with client.transaction():
|
|
101
|
+
... client.insert('users', {'name': 'Alice'})
|
|
102
|
+
... client.insert('users', {'name': 'Bob'})
|
|
103
|
+
# 如果没有异常,自动提交;如果有异常,自动回滚
|
|
104
|
+
"""
|
|
105
|
+
try:
|
|
106
|
+
self.begin()
|
|
107
|
+
yield self
|
|
108
|
+
self.commit()
|
|
109
|
+
logger.info("事务成功提交")
|
|
110
|
+
|
|
111
|
+
except Exception as e:
|
|
112
|
+
self.rollback()
|
|
113
|
+
logger.error(f"事务回滚: {e}")
|
|
114
|
+
raise
|
|
115
|
+
|
|
116
|
+
def execute_in_transaction(self, *operations):
|
|
117
|
+
"""
|
|
118
|
+
在事务中执行多个操作
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
*operations: 操作函数列表,每个函数接收client作为参数
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
执行结果列表
|
|
125
|
+
|
|
126
|
+
Examples:
|
|
127
|
+
>>> def op1(client):
|
|
128
|
+
... return client.insert('users', {'name': 'Alice'})
|
|
129
|
+
>>> def op2(client):
|
|
130
|
+
... return client.insert('users', {'name': 'Bob'})
|
|
131
|
+
>>> results = client.execute_in_transaction(op1, op2)
|
|
132
|
+
"""
|
|
133
|
+
results = []
|
|
134
|
+
|
|
135
|
+
try:
|
|
136
|
+
self.begin()
|
|
137
|
+
|
|
138
|
+
for operation in operations:
|
|
139
|
+
if callable(operation):
|
|
140
|
+
result = operation(self)
|
|
141
|
+
results.append(result)
|
|
142
|
+
else:
|
|
143
|
+
logger.warning(f"跳过非可调用对象: {operation}")
|
|
144
|
+
|
|
145
|
+
self.commit()
|
|
146
|
+
logger.info(f"事务成功完成 {len(results)} 个操作")
|
|
147
|
+
return results
|
|
148
|
+
|
|
149
|
+
except Exception as e:
|
|
150
|
+
self.rollback()
|
|
151
|
+
logger.error(f"事务执行失败: {e}")
|
|
152
|
+
raise TransactionError(f"事务执行失败: {e}")
|
|
153
|
+
|
|
154
|
+
def savepoint(self, name: str) -> None:
|
|
155
|
+
"""
|
|
156
|
+
创建保存点
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
name: 保存点名称
|
|
160
|
+
|
|
161
|
+
Note:
|
|
162
|
+
仅PostgreSQL和部分数据库支持保存点
|
|
163
|
+
"""
|
|
164
|
+
try:
|
|
165
|
+
if hasattr(self, 'execute'):
|
|
166
|
+
self.execute(f"SAVEPOINT {name}")
|
|
167
|
+
logger.debug(f"创建保存点: {name}")
|
|
168
|
+
else:
|
|
169
|
+
logger.warning("数据库不支持保存点")
|
|
170
|
+
except Exception as e:
|
|
171
|
+
logger.error(f"创建保存点失败: {e}")
|
|
172
|
+
raise TransactionError(f"创建保存点失败: {e}")
|
|
173
|
+
|
|
174
|
+
def rollback_to_savepoint(self, name: str) -> None:
|
|
175
|
+
"""
|
|
176
|
+
回滚到保存点
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
name: 保存点名称
|
|
180
|
+
"""
|
|
181
|
+
try:
|
|
182
|
+
if hasattr(self, 'execute'):
|
|
183
|
+
self.execute(f"ROLLBACK TO SAVEPOINT {name}")
|
|
184
|
+
logger.debug(f"回滚到保存点: {name}")
|
|
185
|
+
else:
|
|
186
|
+
logger.warning("数据库不支持保存点")
|
|
187
|
+
except Exception as e:
|
|
188
|
+
logger.error(f"回滚到保存点失败: {e}")
|
|
189
|
+
raise TransactionError(f"回滚到保存点失败: {e}")
|
|
190
|
+
|
|
191
|
+
def release_savepoint(self, name: str) -> None:
|
|
192
|
+
"""
|
|
193
|
+
释放保存点
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
name: 保存点名称
|
|
197
|
+
"""
|
|
198
|
+
try:
|
|
199
|
+
if hasattr(self, 'execute'):
|
|
200
|
+
self.execute(f"RELEASE SAVEPOINT {name}")
|
|
201
|
+
logger.debug(f"释放保存点: {name}")
|
|
202
|
+
else:
|
|
203
|
+
logger.warning("数据库不支持保存点")
|
|
204
|
+
except Exception as e:
|
|
205
|
+
logger.error(f"释放保存点失败: {e}")
|
|
206
|
+
raise TransactionError(f"释放保存点失败: {e}")
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""
|
|
2
|
+
工具模块
|
|
3
|
+
包含工厂、配置管理器和查询构建器
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from .factory import ClientFactory, create_client
|
|
7
|
+
from .config import ConfigManager
|
|
8
|
+
from .query_builder import QueryBuilder
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
'ClientFactory',
|
|
12
|
+
'create_client',
|
|
13
|
+
'ConfigManager',
|
|
14
|
+
'QueryBuilder',
|
|
15
|
+
]
|
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
"""
|
|
2
|
+
配置管理器
|
|
3
|
+
用于管理数据库配置文件
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import json
|
|
7
|
+
import os
|
|
8
|
+
from typing import Dict, Any, Optional
|
|
9
|
+
import logging
|
|
10
|
+
|
|
11
|
+
from ..core.base import BaseClient
|
|
12
|
+
from .factory import ClientFactory
|
|
13
|
+
from ..exceptions import ConfigurationError
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ConfigManager:
|
|
19
|
+
"""
|
|
20
|
+
数据库配置管理器
|
|
21
|
+
|
|
22
|
+
管理JSON格式的数据库配置文件,支持多个数据库配置
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(self, config_file: str = 'db_config.json'):
|
|
26
|
+
"""
|
|
27
|
+
初始化配置管理器
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
config_file: 配置文件路径
|
|
31
|
+
"""
|
|
32
|
+
self.config_file = config_file
|
|
33
|
+
self.configs: Dict[str, Any] = self._load_config()
|
|
34
|
+
|
|
35
|
+
def _load_config(self) -> Dict[str, Any]:
|
|
36
|
+
"""
|
|
37
|
+
加载配置文件
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
Dict: 配置字典
|
|
41
|
+
"""
|
|
42
|
+
if not os.path.exists(self.config_file):
|
|
43
|
+
logger.warning(f"配置文件不存在: {self.config_file},使用空配置")
|
|
44
|
+
return {'databases': {}, 'default': None}
|
|
45
|
+
|
|
46
|
+
try:
|
|
47
|
+
with open(self.config_file, 'r', encoding='utf-8') as f:
|
|
48
|
+
config = json.load(f)
|
|
49
|
+
logger.info(f"加载配置文件成功: {self.config_file}")
|
|
50
|
+
return config
|
|
51
|
+
except json.JSONDecodeError as e:
|
|
52
|
+
logger.error(f"配置文件JSON格式错误: {e}")
|
|
53
|
+
raise ConfigurationError(f"配置文件格式错误: {e}")
|
|
54
|
+
except Exception as e:
|
|
55
|
+
logger.error(f"加载配置文件失败: {e}")
|
|
56
|
+
raise ConfigurationError(f"加载配置文件失败: {e}")
|
|
57
|
+
|
|
58
|
+
def _save_config(self) -> None:
|
|
59
|
+
"""保存配置文件"""
|
|
60
|
+
try:
|
|
61
|
+
with open(self.config_file, 'w', encoding='utf-8') as f:
|
|
62
|
+
json.dump(self.configs, f, indent=2, ensure_ascii=False)
|
|
63
|
+
logger.info(f"保存配置文件成功: {self.config_file}")
|
|
64
|
+
except Exception as e:
|
|
65
|
+
logger.error(f"保存配置文件失败: {e}")
|
|
66
|
+
raise ConfigurationError(f"保存配置文件失败: {e}")
|
|
67
|
+
|
|
68
|
+
def get_client(self, name: Optional[str] = None, auto_connect: bool = True) -> BaseClient:
|
|
69
|
+
"""
|
|
70
|
+
获取数据库客户端
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
name: 数据库配置名称,None则使用默认配置
|
|
74
|
+
auto_connect: 是否自动连接
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
BaseClient: 数据库客户端实例
|
|
78
|
+
|
|
79
|
+
Raises:
|
|
80
|
+
ConfigurationError: 配置不存在或无效时
|
|
81
|
+
|
|
82
|
+
Examples:
|
|
83
|
+
>>> manager = ConfigManager()
|
|
84
|
+
>>> with manager.get_client('mysql_prod') as client:
|
|
85
|
+
... users = client.select('users')
|
|
86
|
+
"""
|
|
87
|
+
if name is None:
|
|
88
|
+
name = self.configs.get('default')
|
|
89
|
+
if not name:
|
|
90
|
+
raise ConfigurationError("未指定数据库名称且没有默认配置")
|
|
91
|
+
|
|
92
|
+
db_config = self.configs.get('databases', {}).get(name)
|
|
93
|
+
if not db_config:
|
|
94
|
+
available = ', '.join(self.configs.get('databases', {}).keys())
|
|
95
|
+
raise ConfigurationError(
|
|
96
|
+
f"数据库配置不存在: '{name}'\n"
|
|
97
|
+
f"可用配置: {available or '(无)'}"
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
db_type = db_config.get('type')
|
|
101
|
+
config = db_config.get('config', {})
|
|
102
|
+
|
|
103
|
+
if not db_type:
|
|
104
|
+
raise ConfigurationError(f"数据库配置'{name}'缺少type字段")
|
|
105
|
+
|
|
106
|
+
client = ClientFactory.create(db_type, config)
|
|
107
|
+
|
|
108
|
+
if auto_connect:
|
|
109
|
+
client.connect()
|
|
110
|
+
|
|
111
|
+
return client
|
|
112
|
+
|
|
113
|
+
def add(self, name: str, db_type: str, config: Dict[str, Any],
|
|
114
|
+
set_as_default: bool = False) -> None:
|
|
115
|
+
"""
|
|
116
|
+
添加数据库配置
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
name: 配置名称
|
|
120
|
+
db_type: 数据库类型
|
|
121
|
+
config: 数据库配置
|
|
122
|
+
set_as_default: 是否设置为默认配置
|
|
123
|
+
|
|
124
|
+
Examples:
|
|
125
|
+
>>> manager = ConfigManager()
|
|
126
|
+
>>> manager.add('prod_db', 'mysql', {
|
|
127
|
+
... 'host': 'localhost',
|
|
128
|
+
... 'user': 'root',
|
|
129
|
+
... 'password': 'pass',
|
|
130
|
+
... 'database': 'prod'
|
|
131
|
+
... })
|
|
132
|
+
"""
|
|
133
|
+
if 'databases' not in self.configs:
|
|
134
|
+
self.configs['databases'] = {}
|
|
135
|
+
|
|
136
|
+
self.configs['databases'][name] = {
|
|
137
|
+
'type': db_type,
|
|
138
|
+
'config': config
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
if set_as_default or not self.configs.get('default'):
|
|
142
|
+
self.configs['default'] = name
|
|
143
|
+
|
|
144
|
+
self._save_config()
|
|
145
|
+
logger.info(f"添加数据库配置: {name} ({db_type})")
|
|
146
|
+
|
|
147
|
+
def remove(self, name: str) -> bool:
|
|
148
|
+
"""
|
|
149
|
+
删除数据库配置
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
name: 配置名称
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
bool: 是否删除成功
|
|
156
|
+
"""
|
|
157
|
+
if name not in self.configs.get('databases', {}):
|
|
158
|
+
logger.warning(f"配置不存在: {name}")
|
|
159
|
+
return False
|
|
160
|
+
|
|
161
|
+
del self.configs['databases'][name]
|
|
162
|
+
|
|
163
|
+
# 如果删除的是默认配置,清除默认设置
|
|
164
|
+
if self.configs.get('default') == name:
|
|
165
|
+
self.configs['default'] = None
|
|
166
|
+
|
|
167
|
+
self._save_config()
|
|
168
|
+
logger.info(f"删除数据库配置: {name}")
|
|
169
|
+
return True
|
|
170
|
+
|
|
171
|
+
def set_default(self, name: str) -> None:
|
|
172
|
+
"""
|
|
173
|
+
设置默认数据库
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
name: 配置名称
|
|
177
|
+
|
|
178
|
+
Raises:
|
|
179
|
+
ConfigurationError: 配置不存在时
|
|
180
|
+
"""
|
|
181
|
+
if name not in self.configs.get('databases', {}):
|
|
182
|
+
raise ConfigurationError(f"配置不存在: {name}")
|
|
183
|
+
|
|
184
|
+
self.configs['default'] = name
|
|
185
|
+
self._save_config()
|
|
186
|
+
logger.info(f"设置默认数据库: {name}")
|
|
187
|
+
|
|
188
|
+
def list(self) -> Dict[str, str]:
|
|
189
|
+
"""
|
|
190
|
+
列出所有配置
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
Dict[str, str]: 配置名称到数据库类型的映射
|
|
194
|
+
"""
|
|
195
|
+
databases = self.configs.get('databases', {})
|
|
196
|
+
return {name: conf['type'] for name, conf in databases.items()}
|
|
197
|
+
|
|
198
|
+
def get_default(self) -> Optional[str]:
|
|
199
|
+
"""
|
|
200
|
+
获取默认配置名称
|
|
201
|
+
|
|
202
|
+
Returns:
|
|
203
|
+
Optional[str]: 默认配置名称
|
|
204
|
+
"""
|
|
205
|
+
return self.configs.get('default')
|
|
206
|
+
|
|
207
|
+
def export_config(self, output_file: str) -> None:
|
|
208
|
+
"""
|
|
209
|
+
导出配置到文件
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
output_file: 输出文件路径
|
|
213
|
+
"""
|
|
214
|
+
try:
|
|
215
|
+
with open(output_file, 'w', encoding='utf-8') as f:
|
|
216
|
+
json.dump(self.configs, f, indent=2, ensure_ascii=False)
|
|
217
|
+
logger.info(f"导出配置到: {output_file}")
|
|
218
|
+
except Exception as e:
|
|
219
|
+
logger.error(f"导出配置失败: {e}")
|
|
220
|
+
raise ConfigurationError(f"导出配置失败: {e}")
|
|
221
|
+
|
|
222
|
+
def import_config(self, input_file: str, merge: bool = True) -> None:
|
|
223
|
+
"""
|
|
224
|
+
从文件导入配置
|
|
225
|
+
|
|
226
|
+
Args:
|
|
227
|
+
input_file: 输入文件路径
|
|
228
|
+
merge: 是否合并到现有配置(False则覆盖)
|
|
229
|
+
"""
|
|
230
|
+
try:
|
|
231
|
+
with open(input_file, 'r', encoding='utf-8') as f:
|
|
232
|
+
imported = json.load(f)
|
|
233
|
+
|
|
234
|
+
if merge:
|
|
235
|
+
# 合并数据库配置
|
|
236
|
+
if 'databases' in imported:
|
|
237
|
+
if 'databases' not in self.configs:
|
|
238
|
+
self.configs['databases'] = {}
|
|
239
|
+
self.configs['databases'].update(imported['databases'])
|
|
240
|
+
|
|
241
|
+
# 如果没有默认配置,使用导入的默认配置
|
|
242
|
+
if 'default' in imported and not self.configs.get('default'):
|
|
243
|
+
self.configs['default'] = imported['default']
|
|
244
|
+
else:
|
|
245
|
+
self.configs = imported
|
|
246
|
+
|
|
247
|
+
self._save_config()
|
|
248
|
+
logger.info(f"导入配置从: {input_file}")
|
|
249
|
+
|
|
250
|
+
except Exception as e:
|
|
251
|
+
logger.error(f"导入配置失败: {e}")
|
|
252
|
+
raise ConfigurationError(f"导入配置失败: {e}")
|
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
"""
|
|
2
|
+
数据库客户端工厂
|
|
3
|
+
用于创建和管理数据库客户端实例
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from typing import Dict, Any, Type
|
|
7
|
+
import logging
|
|
8
|
+
|
|
9
|
+
from ..core.base import BaseClient
|
|
10
|
+
from ..clients import (
|
|
11
|
+
MySQLClient,
|
|
12
|
+
PostgreSQLClient,
|
|
13
|
+
SQLiteClient,
|
|
14
|
+
MongoDBClient,
|
|
15
|
+
RedisClient,
|
|
16
|
+
SupabaseClient,
|
|
17
|
+
)
|
|
18
|
+
from ..exceptions import ConfigurationError
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class ClientFactory:
|
|
25
|
+
"""
|
|
26
|
+
数据库客户端工厂类
|
|
27
|
+
|
|
28
|
+
负责创建和注册数据库客户端
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
# 内置客户端注册表
|
|
32
|
+
_clients: Dict[str, Type[BaseClient]] = {
|
|
33
|
+
'mysql': MySQLClient,
|
|
34
|
+
'postgresql': PostgreSQLClient,
|
|
35
|
+
'postgres': PostgreSQLClient, # 别名
|
|
36
|
+
'sqlite': SQLiteClient,
|
|
37
|
+
'sqlite3': SQLiteClient, # 别名
|
|
38
|
+
'mongodb': MongoDBClient,
|
|
39
|
+
'mongo': MongoDBClient, # 别名
|
|
40
|
+
'redis': RedisClient,
|
|
41
|
+
'supabase': SupabaseClient,
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
@classmethod
|
|
45
|
+
def create(cls, db_type: str, config: Dict[str, Any]) -> BaseClient:
|
|
46
|
+
"""
|
|
47
|
+
创建数据库客户端
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
db_type: 数据库类型 (mysql, postgresql, sqlite, mongodb, redis, supabase)
|
|
51
|
+
config: 数据库配置字典
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
BaseClient: 数据库客户端实例
|
|
55
|
+
|
|
56
|
+
Raises:
|
|
57
|
+
ConfigurationError: 当数据库类型不支持或配置无效时
|
|
58
|
+
|
|
59
|
+
Examples:
|
|
60
|
+
>>> config = {'host': 'localhost', 'user': 'root', 'password': 'pass', 'database': 'test'}
|
|
61
|
+
>>> client = ClientFactory.create('mysql', config)
|
|
62
|
+
"""
|
|
63
|
+
db_type = db_type.lower().strip()
|
|
64
|
+
|
|
65
|
+
if db_type not in cls._clients:
|
|
66
|
+
available = ', '.join(sorted(set(cls._clients.keys())))
|
|
67
|
+
raise ConfigurationError(
|
|
68
|
+
f"不支持的数据库类型: '{db_type}'\n"
|
|
69
|
+
f"支持的类型: {available}"
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
client_class = cls._clients[db_type]
|
|
73
|
+
|
|
74
|
+
try:
|
|
75
|
+
logger.info(f"创建数据库客户端: {db_type}")
|
|
76
|
+
return client_class(config)
|
|
77
|
+
except Exception as e:
|
|
78
|
+
logger.error(f"创建数据库客户端失败: {e}")
|
|
79
|
+
raise ConfigurationError(f"创建数据库客户端失败: {e}")
|
|
80
|
+
|
|
81
|
+
@classmethod
|
|
82
|
+
def register(cls, db_type: str, client_class: Type[BaseClient]) -> None:
|
|
83
|
+
"""
|
|
84
|
+
注册自定义数据库客户端
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
db_type: 数据库类型标识符
|
|
88
|
+
client_class: 客户端类(必须继承BaseClient)
|
|
89
|
+
|
|
90
|
+
Raises:
|
|
91
|
+
ValueError: 当client_class不是BaseClient的子类时
|
|
92
|
+
|
|
93
|
+
Examples:
|
|
94
|
+
>>> class CustomClient(BaseClient):
|
|
95
|
+
... pass
|
|
96
|
+
>>> ClientFactory.register('custom', CustomClient)
|
|
97
|
+
"""
|
|
98
|
+
if not issubclass(client_class, BaseClient):
|
|
99
|
+
raise ValueError(f"{client_class.__name__} 必须继承 BaseClient")
|
|
100
|
+
|
|
101
|
+
db_type = db_type.lower().strip()
|
|
102
|
+
cls._clients[db_type] = client_class
|
|
103
|
+
logger.info(f"注册自定义数据库客户端: {db_type} -> {client_class.__name__}")
|
|
104
|
+
|
|
105
|
+
@classmethod
|
|
106
|
+
def unregister(cls, db_type: str) -> bool:
|
|
107
|
+
"""
|
|
108
|
+
注销数据库客户端
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
db_type: 数据库类型标识符
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
bool: 是否成功注销
|
|
115
|
+
"""
|
|
116
|
+
db_type = db_type.lower().strip()
|
|
117
|
+
if db_type in cls._clients:
|
|
118
|
+
del cls._clients[db_type]
|
|
119
|
+
logger.info(f"注销数据库客户端: {db_type}")
|
|
120
|
+
return True
|
|
121
|
+
return False
|
|
122
|
+
|
|
123
|
+
@classmethod
|
|
124
|
+
def list_available(cls) -> list:
|
|
125
|
+
"""
|
|
126
|
+
列出所有可用的数据库类型
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
list: 数据库类型列表
|
|
130
|
+
"""
|
|
131
|
+
return sorted(set(cls._clients.keys()))
|
|
132
|
+
|
|
133
|
+
@classmethod
|
|
134
|
+
def get_client_class(cls, db_type: str) -> Type[BaseClient]:
|
|
135
|
+
"""
|
|
136
|
+
获取数据库客户端类
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
db_type: 数据库类型
|
|
140
|
+
|
|
141
|
+
Returns:
|
|
142
|
+
Type[BaseClient]: 客户端类
|
|
143
|
+
|
|
144
|
+
Raises:
|
|
145
|
+
ConfigurationError: 当数据库类型不存在时
|
|
146
|
+
"""
|
|
147
|
+
db_type = db_type.lower().strip()
|
|
148
|
+
if db_type not in cls._clients:
|
|
149
|
+
raise ConfigurationError(f"数据库类型不存在: {db_type}")
|
|
150
|
+
return cls._clients[db_type]
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
# 便捷函数
|
|
154
|
+
def create_client(db_type: str, config: Dict[str, Any]) -> BaseClient:
|
|
155
|
+
"""
|
|
156
|
+
创建数据库客户端的便捷函数
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
db_type: 数据库类型
|
|
160
|
+
config: 配置字典
|
|
161
|
+
|
|
162
|
+
Returns:
|
|
163
|
+
BaseClient: 数据库客户端实例
|
|
164
|
+
|
|
165
|
+
Examples:
|
|
166
|
+
>>> from db_toolkit import create_client
|
|
167
|
+
>>> config = {'database': 'test.db'}
|
|
168
|
+
>>> client = create_client('sqlite', config)
|
|
169
|
+
>>> with client:
|
|
170
|
+
... results = client.select('users')
|
|
171
|
+
"""
|
|
172
|
+
return ClientFactory.create(db_type, config)
|