mcp-dbutils 0.23.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mcp_dbutils/base.py CHANGED
@@ -3,7 +3,7 @@
3
3
  import json
4
4
  from abc import ABC, abstractmethod
5
5
  from contextlib import asynccontextmanager
6
- from datetime import datetime, timedelta
6
+ from datetime import datetime
7
7
  from importlib.metadata import metadata
8
8
  from typing import Any, AsyncContextManager, Dict
9
9
 
@@ -12,6 +12,7 @@ import mcp.types as types
12
12
  import yaml
13
13
  from mcp.server import Server
14
14
 
15
+ from .audit import format_logs, get_logs, log_write_operation
15
16
  from .log import create_logger
16
17
  from .stats import ResourceStats
17
18
 
@@ -42,6 +43,10 @@ EMPTY_TABLE_NAME_ERROR = "Table name cannot be empty"
42
43
  CONNECTION_NAME_REQUIRED_ERROR = "Connection name must be specified"
43
44
  SELECT_ONLY_ERROR = "Only SELECT queries are supported for security reasons"
44
45
  INVALID_URI_FORMAT_ERROR = "Invalid resource URI format"
46
+ CONNECTION_NOT_WRITABLE_ERROR = "This connection is not configured for write operations. Add 'writable: true' to the connection configuration."
47
+ WRITE_OPERATION_NOT_ALLOWED_ERROR = "No permission to perform {operation} operation on table {table}."
48
+ WRITE_CONFIRMATION_REQUIRED_ERROR = "Operation not confirmed. To execute write operations, you must set confirmation='CONFIRM_WRITE'."
49
+ UNSUPPORTED_WRITE_OPERATION_ERROR = "Unsupported SQL operation: {operation}. Only INSERT, UPDATE, DELETE are supported."
45
50
 
46
51
  # 获取包信息用于日志命名
47
52
  pkg_meta = metadata("mcp-dbutils")
@@ -116,6 +121,11 @@ class ConnectionHandler(ABC):
116
121
  """Internal query execution method to be implemented by subclasses"""
117
122
  pass
118
123
 
124
+ @abstractmethod
125
+ async def _execute_write_query(self, sql: str) -> str:
126
+ """Internal write query execution method to be implemented by subclasses"""
127
+ pass
128
+
119
129
  async def execute_query(self, sql: str) -> str:
120
130
  """Execute SQL query with performance tracking"""
121
131
  start_time = datetime.now()
@@ -139,6 +149,170 @@ class ConnectionHandler(ABC):
139
149
  )
140
150
  raise
141
151
 
152
+ async def execute_write_query(self, sql: str) -> str:
153
+ """Execute SQL write query with performance tracking
154
+
155
+ Args:
156
+ sql: SQL write query (INSERT, UPDATE, DELETE)
157
+
158
+ Returns:
159
+ str: Execution result
160
+
161
+ Raises:
162
+ ValueError: If the SQL is not a write operation
163
+ """
164
+ # Validate SQL type
165
+ sql_type = self._get_sql_type(sql)
166
+ if sql_type not in ["INSERT", "UPDATE", "DELETE"]:
167
+ raise ValueError(UNSUPPORTED_WRITE_OPERATION_ERROR.format(operation=sql_type))
168
+
169
+ # Extract table name
170
+ table_name = self._extract_table_name(sql)
171
+
172
+ start_time = datetime.now()
173
+ affected_rows = 0
174
+ status = "SUCCESS"
175
+ error_message = None
176
+
177
+ try:
178
+ self.stats.record_query()
179
+ self.send_log(
180
+ LOG_LEVEL_INFO,
181
+ f"Executing write operation: {sql_type} on table {table_name}",
182
+ )
183
+
184
+ result = await self._execute_write_query(sql)
185
+
186
+ # 尝试从结果中提取受影响的行数
187
+ try:
188
+ if "row" in result and "affected" in result:
189
+ # 从结果字符串中提取受影响的行数
190
+ import re
191
+ # 使用更安全的正则表达式,避免回溯问题
192
+ match = re.search(r"(\d+) rows?", result)
193
+ if match:
194
+ affected_rows = int(match.group(1))
195
+ except Exception:
196
+ # 如果无法提取,使用默认值
197
+ affected_rows = 1
198
+
199
+ duration = (datetime.now() - start_time).total_seconds()
200
+ self.stats.record_query_duration(sql, duration)
201
+ self.stats.update_memory_usage(result)
202
+
203
+ # 记录审计日志
204
+ log_write_operation(
205
+ connection_name=self.connection,
206
+ table_name=table_name,
207
+ operation_type=sql_type,
208
+ sql=sql,
209
+ affected_rows=affected_rows,
210
+ execution_time=duration * 1000, # 转换为毫秒
211
+ status=status,
212
+ error_message=error_message
213
+ )
214
+
215
+ self.send_log(
216
+ LOG_LEVEL_INFO,
217
+ f"Write operation executed in {duration * 1000:.2f}ms. Resource stats: {json.dumps(self.stats.to_dict())}",
218
+ )
219
+ return result
220
+ except Exception as e:
221
+ duration = (datetime.now() - start_time).total_seconds()
222
+ self.stats.record_error(e.__class__.__name__)
223
+ status = "FAILED"
224
+ error_message = str(e)
225
+
226
+ # 记录审计日志(失败)
227
+ log_write_operation(
228
+ connection_name=self.connection,
229
+ table_name=table_name,
230
+ operation_type=sql_type,
231
+ sql=sql,
232
+ affected_rows=0,
233
+ execution_time=duration * 1000, # 转换为毫秒
234
+ status=status,
235
+ error_message=error_message
236
+ )
237
+
238
+ self.send_log(
239
+ LOG_LEVEL_ERROR,
240
+ f"Write operation error after {duration * 1000:.2f}ms - {str(e)}\nResource stats: {json.dumps(self.stats.to_dict())}",
241
+ )
242
+ raise
243
+
244
+ def _get_sql_type(self, sql: str) -> str:
245
+ """Get SQL statement type
246
+
247
+ Args:
248
+ sql: SQL statement
249
+
250
+ Returns:
251
+ str: SQL statement type (SELECT, INSERT, UPDATE, DELETE, etc.)
252
+ """
253
+ sql = sql.strip().upper()
254
+ if sql.startswith("SELECT"):
255
+ return "SELECT"
256
+ elif sql.startswith("INSERT"):
257
+ return "INSERT"
258
+ elif sql.startswith("UPDATE"):
259
+ return "UPDATE"
260
+ elif sql.startswith("DELETE"):
261
+ return "DELETE"
262
+ elif sql.startswith("CREATE"):
263
+ return "CREATE"
264
+ elif sql.startswith("ALTER"):
265
+ return "ALTER"
266
+ elif sql.startswith("DROP"):
267
+ return "DROP"
268
+ elif sql.startswith("TRUNCATE"):
269
+ return "TRUNCATE"
270
+ elif sql.startswith("BEGIN") or sql.startswith("START"):
271
+ return "TRANSACTION_START"
272
+ elif sql.startswith("COMMIT"):
273
+ return "TRANSACTION_COMMIT"
274
+ elif sql.startswith("ROLLBACK"):
275
+ return "TRANSACTION_ROLLBACK"
276
+ else:
277
+ return "UNKNOWN"
278
+
279
+ def _extract_table_name(self, sql: str) -> str:
280
+ """Extract table name from SQL statement
281
+
282
+ This is a simple implementation that works for basic SQL statements.
283
+ Subclasses may override this method to provide more accurate table name extraction.
284
+
285
+ Args:
286
+ sql: SQL statement
287
+
288
+ Returns:
289
+ str: Table name
290
+ """
291
+ sql_type = self._get_sql_type(sql)
292
+ sql = sql.strip()
293
+
294
+ if sql_type == "INSERT":
295
+ # INSERT INTO table_name ...
296
+ match = sql.upper().split("INTO", 1)
297
+ if len(match) > 1:
298
+ table_part = match[1].strip().split(" ", 1)[0]
299
+ return table_part.strip('`"[]')
300
+ elif sql_type == "UPDATE":
301
+ # UPDATE table_name ...
302
+ match = sql.upper().split("UPDATE", 1)
303
+ if len(match) > 1:
304
+ table_part = match[1].strip().split(" ", 1)[0]
305
+ return table_part.strip('`"[]')
306
+ elif sql_type == "DELETE":
307
+ # DELETE FROM table_name ...
308
+ match = sql.upper().split("FROM", 1)
309
+ if len(match) > 1:
310
+ table_part = match[1].strip().split(" ", 1)[0]
311
+ return table_part.strip('`"[]')
312
+
313
+ # Default fallback
314
+ return "unknown_table"
315
+
142
316
  @abstractmethod
143
317
  async def get_table_description(self, table_name: str) -> str:
144
318
  """Get detailed table description including columns, types, and comments
@@ -356,6 +530,140 @@ class ConnectionServer:
356
530
 
357
531
  return db_config
358
532
 
533
+ def _get_sql_type(self, sql: str) -> str:
534
+ """Get SQL statement type
535
+
536
+ Args:
537
+ sql: SQL statement
538
+
539
+ Returns:
540
+ str: SQL statement type (SELECT, INSERT, UPDATE, DELETE, etc.)
541
+ """
542
+ sql = sql.strip().upper()
543
+ if sql.startswith("SELECT"):
544
+ return "SELECT"
545
+ elif sql.startswith("INSERT"):
546
+ return "INSERT"
547
+ elif sql.startswith("UPDATE"):
548
+ return "UPDATE"
549
+ elif sql.startswith("DELETE"):
550
+ return "DELETE"
551
+ elif sql.startswith("CREATE"):
552
+ return "CREATE"
553
+ elif sql.startswith("ALTER"):
554
+ return "ALTER"
555
+ elif sql.startswith("DROP"):
556
+ return "DROP"
557
+ elif sql.startswith("TRUNCATE"):
558
+ return "TRUNCATE"
559
+ elif sql.startswith("BEGIN") or sql.startswith("START"):
560
+ return "TRANSACTION_START"
561
+ elif sql.startswith("COMMIT"):
562
+ return "TRANSACTION_COMMIT"
563
+ elif sql.startswith("ROLLBACK"):
564
+ return "TRANSACTION_ROLLBACK"
565
+ else:
566
+ return "UNKNOWN"
567
+
568
+ def _extract_table_name(self, sql: str) -> str:
569
+ """Extract table name from SQL statement
570
+
571
+ This is a simple implementation that works for basic SQL statements.
572
+
573
+ Args:
574
+ sql: SQL statement
575
+
576
+ Returns:
577
+ str: Table name
578
+ """
579
+ sql_type = self._get_sql_type(sql)
580
+ sql = sql.strip()
581
+
582
+ if sql_type == "INSERT":
583
+ # INSERT INTO table_name ...
584
+ match = sql.upper().split("INTO", 1)
585
+ if len(match) > 1:
586
+ table_part = match[1].strip().split(" ", 1)[0]
587
+ return table_part.strip('`"[]')
588
+ elif sql_type == "UPDATE":
589
+ # UPDATE table_name ...
590
+ match = sql.upper().split("UPDATE", 1)
591
+ if len(match) > 1:
592
+ table_part = match[1].strip().split(" ", 1)[0]
593
+ return table_part.strip('`"[]')
594
+ elif sql_type == "DELETE":
595
+ # DELETE FROM table_name ...
596
+ match = sql.upper().split("FROM", 1)
597
+ if len(match) > 1:
598
+ table_part = match[1].strip().split(" ", 1)[0]
599
+ return table_part.strip('`"[]')
600
+
601
+ # Default fallback
602
+ return "unknown_table"
603
+
604
+ async def _check_write_permission(self, connection: str, table_name: str, operation_type: str) -> bool:
605
+ """检查写操作权限
606
+
607
+ Args:
608
+ connection: 数据库连接名称
609
+ table_name: 表名
610
+ operation_type: 操作类型 (INSERT, UPDATE, DELETE)
611
+
612
+ Returns:
613
+ bool: 是否有权限执行写操作
614
+
615
+ Raises:
616
+ ConfigurationError: 如果连接不可写或没有表级权限
617
+ """
618
+ # 获取连接配置
619
+ db_config = self._get_config_or_raise(connection)
620
+
621
+ # 检查连接是否可写
622
+ if not db_config.get("writable", False):
623
+ raise ConfigurationError(CONNECTION_NOT_WRITABLE_ERROR)
624
+
625
+ # 检查是否有写权限配置
626
+ write_permissions = db_config.get("write_permissions", {})
627
+ if not write_permissions:
628
+ # 没有细粒度权限控制,默认允许所有写操作
629
+ return True
630
+
631
+ # 检查表级权限
632
+ tables = write_permissions.get("tables", {})
633
+ if not tables:
634
+ # 没有表级权限配置,检查默认策略
635
+ default_policy = write_permissions.get("default_policy", "read_only")
636
+ if default_policy == "allow_all":
637
+ return True
638
+ else:
639
+ # 默认只读
640
+ raise ConfigurationError(WRITE_OPERATION_NOT_ALLOWED_ERROR.format(
641
+ operation=operation_type, table=table_name
642
+ ))
643
+
644
+ # 检查特定表的权限
645
+ if table_name in tables:
646
+ table_config = tables[table_name]
647
+ operations = table_config.get("operations", ["INSERT", "UPDATE", "DELETE"])
648
+ if operation_type in operations:
649
+ return True
650
+ else:
651
+ raise ConfigurationError(WRITE_OPERATION_NOT_ALLOWED_ERROR.format(
652
+ operation=operation_type, table=table_name
653
+ ))
654
+ else:
655
+ # 表未明确配置,检查默认策略
656
+ default_policy = write_permissions.get("default_policy", "read_only")
657
+ if default_policy == "allow_all":
658
+ return True
659
+ else:
660
+ # 默认只读
661
+ raise ConfigurationError(WRITE_OPERATION_NOT_ALLOWED_ERROR.format(
662
+ operation=operation_type, table=table_name
663
+ ))
664
+
665
+ return False
666
+
359
667
  def _create_handler_for_type(
360
668
  self, db_type: str, connection: str
361
669
  ) -> ConnectionHandler:
@@ -457,6 +765,54 @@ class ConnectionServer:
457
765
  "required": [],
458
766
  },
459
767
  ),
768
+ types.Tool(
769
+ name="dbutils-execute-write",
770
+ description="CAUTION: This tool executes data modification operations (INSERT, UPDATE, DELETE) on the specified database. It requires explicit configuration and confirmation. Only available for connections with 'writable: true' in configuration. All operations are logged for audit purposes.",
771
+ inputSchema={
772
+ "type": "object",
773
+ "properties": {
774
+ "connection": {
775
+ "type": "string",
776
+ "description": DATABASE_CONNECTION_NAME,
777
+ },
778
+ "sql": {
779
+ "type": "string",
780
+ "description": "SQL statement (INSERT, UPDATE, DELETE)",
781
+ },
782
+ "confirmation": {
783
+ "type": "string",
784
+ "description": "Type 'CONFIRM_WRITE' to confirm you understand the risks",
785
+ },
786
+ },
787
+ "required": ["connection", "sql", "confirmation"],
788
+ },
789
+ annotations={
790
+ "examples": [
791
+ {
792
+ "input": {
793
+ "connection": "example_db",
794
+ "sql": "INSERT INTO logs (event, timestamp) VALUES ('event1', CURRENT_TIMESTAMP)",
795
+ "confirmation": "CONFIRM_WRITE"
796
+ },
797
+ "output": "Write operation executed successfully. 1 row affected."
798
+ },
799
+ {
800
+ "input": {
801
+ "connection": "example_db",
802
+ "sql": "UPDATE users SET status = 'active' WHERE id = 123",
803
+ "confirmation": "CONFIRM_WRITE"
804
+ },
805
+ "output": "Write operation executed successfully. 1 row affected."
806
+ }
807
+ ],
808
+ "usage_tips": [
809
+ "Always confirm with 'CONFIRM_WRITE' to execute write operations",
810
+ "Connection must have 'writable: true' in configuration",
811
+ "Consider using transactions for multiple related operations",
812
+ "Check audit logs after write operations to verify changes"
813
+ ]
814
+ }
815
+ ),
460
816
  types.Tool(
461
817
  name="dbutils-run-query",
462
818
  description="Executes read-only SQL queries on the specified database connection. For security, only SELECT statements are supported. Returns structured results with column names and data rows. Supports complex queries including JOINs, GROUP BY, ORDER BY, and aggregate functions. Use this tool when you need to analyze data, validate hypotheses, or extract specific information. Query execution is protected by resource limits and timeouts to prevent system resource overuse.",
@@ -665,6 +1021,39 @@ class ConnectionServer:
665
1021
  "required": ["connection", "sql"],
666
1022
  },
667
1023
  ),
1024
+ types.Tool(
1025
+ name="dbutils-get-audit-logs",
1026
+ description="Retrieves audit logs for database write operations. Shows who performed what operations, when, and with what results. Useful for security monitoring, compliance, and troubleshooting.",
1027
+ inputSchema={
1028
+ "type": "object",
1029
+ "properties": {
1030
+ "connection": {
1031
+ "type": "string",
1032
+ "description": "Filter logs by connection name",
1033
+ },
1034
+ "table": {
1035
+ "type": "string",
1036
+ "description": "Filter logs by table name",
1037
+ },
1038
+ "operation_type": {
1039
+ "type": "string",
1040
+ "description": "Filter logs by operation type (INSERT, UPDATE, DELETE)",
1041
+ "enum": ["INSERT", "UPDATE", "DELETE"]
1042
+ },
1043
+ "status": {
1044
+ "type": "string",
1045
+ "description": "Filter logs by operation status (SUCCESS, FAILED)",
1046
+ "enum": ["SUCCESS", "FAILED"]
1047
+ },
1048
+ "limit": {
1049
+ "type": "integer",
1050
+ "description": "Maximum number of logs to return",
1051
+ "default": 100
1052
+ }
1053
+ },
1054
+ "required": [],
1055
+ },
1056
+ ),
668
1057
  ]
669
1058
 
670
1059
  async def _handle_list_connections(
@@ -932,6 +1321,110 @@ class ConnectionServer:
932
1321
 
933
1322
  return [types.TextContent(type="text", text="\n".join(analysis))]
934
1323
 
1324
+ async def _handle_execute_write(
1325
+ self, connection: str, sql: str, confirmation: str
1326
+ ) -> list[types.TextContent]:
1327
+ """处理执行写操作工具调用
1328
+
1329
+ Args:
1330
+ connection: 数据库连接名称
1331
+ sql: SQL写操作语句
1332
+ confirmation: 确认字符串
1333
+
1334
+ Returns:
1335
+ list[types.TextContent]: 执行结果
1336
+
1337
+ Raises:
1338
+ ConfigurationError: 如果SQL为空、确认字符串不正确、连接不可写或没有表级权限
1339
+ """
1340
+ if not sql:
1341
+ raise ConfigurationError(EMPTY_QUERY_ERROR)
1342
+
1343
+ # 验证确认字符串
1344
+ if confirmation != "CONFIRM_WRITE":
1345
+ raise ConfigurationError(WRITE_CONFIRMATION_REQUIRED_ERROR)
1346
+
1347
+ # 获取SQL类型和表名
1348
+ sql_type = self._get_sql_type(sql.strip())
1349
+ if sql_type not in ["INSERT", "UPDATE", "DELETE"]:
1350
+ raise ConfigurationError(UNSUPPORTED_WRITE_OPERATION_ERROR.format(operation=sql_type))
1351
+
1352
+ table_name = self._extract_table_name(sql)
1353
+
1354
+ # 获取连接配置并验证写权限
1355
+ db_config = self._get_config_or_raise(connection)
1356
+ await self._check_write_permission(connection, table_name, sql_type)
1357
+
1358
+ # 执行写操作
1359
+ async with self.get_handler(connection) as handler:
1360
+ self.send_log(
1361
+ LOG_LEVEL_NOTICE,
1362
+ f"Executing write operation: {sql_type} on table {table_name} in connection {connection}",
1363
+ )
1364
+
1365
+ try:
1366
+ result = await handler.execute_write_query(sql)
1367
+ self.send_log(
1368
+ LOG_LEVEL_INFO,
1369
+ f"Write operation executed successfully: {sql_type} on table {table_name}",
1370
+ )
1371
+ return [types.TextContent(type="text", text=result)]
1372
+ except Exception as e:
1373
+ self.send_log(
1374
+ LOG_LEVEL_ERROR,
1375
+ f"Write operation failed: {str(e)}",
1376
+ )
1377
+ raise
1378
+
1379
+ async def _handle_get_audit_logs(
1380
+ self,
1381
+ connection: str = None,
1382
+ table: str = None,
1383
+ operation_type: str = None,
1384
+ status: str = None,
1385
+ limit: int = 100
1386
+ ) -> list[types.TextContent]:
1387
+ """处理获取审计日志工具调用
1388
+
1389
+ Args:
1390
+ connection: 数据库连接名称(可选)
1391
+ table: 表名(可选)
1392
+ operation_type: 操作类型(可选,INSERT/UPDATE/DELETE)
1393
+ status: 操作状态(可选,SUCCESS/FAILED)
1394
+ limit: 返回记录数量限制
1395
+
1396
+ Returns:
1397
+ list[types.TextContent]: 审计日志
1398
+ """
1399
+ # 获取审计日志
1400
+ logs = get_logs(
1401
+ connection_name=connection,
1402
+ table_name=table,
1403
+ operation_type=operation_type,
1404
+ status=status,
1405
+ limit=limit
1406
+ )
1407
+
1408
+ # 格式化日志
1409
+ formatted_logs = format_logs(logs)
1410
+
1411
+ # 添加过滤条件信息
1412
+ filter_info = []
1413
+ if connection:
1414
+ filter_info.append(f"Connection: {connection}")
1415
+ if table:
1416
+ filter_info.append(f"Table: {table}")
1417
+ if operation_type:
1418
+ filter_info.append(f"Operation: {operation_type}")
1419
+ if status:
1420
+ filter_info.append(f"Status: {status}")
1421
+
1422
+ if filter_info:
1423
+ filter_text = "Filters applied: " + ", ".join(filter_info)
1424
+ formatted_logs = f"{filter_text}\n\n{formatted_logs}"
1425
+
1426
+ return [types.TextContent(type="text", text=formatted_logs)]
1427
+
935
1428
  def _get_optimization_suggestions(
936
1429
  self, explain_result: str, duration: float
937
1430
  ) -> list[str]:
@@ -1028,6 +1521,16 @@ class ConnectionServer:
1028
1521
  elif name == "dbutils-analyze-query":
1029
1522
  sql = arguments.get("sql", "").strip()
1030
1523
  return await self._handle_analyze_query(connection, sql)
1524
+ elif name == "dbutils-execute-write":
1525
+ sql = arguments.get("sql", "").strip()
1526
+ confirmation = arguments.get("confirmation", "").strip()
1527
+ return await self._handle_execute_write(connection, sql, confirmation)
1528
+ elif name == "dbutils-get-audit-logs":
1529
+ table = arguments.get("table", "").strip()
1530
+ operation_type = arguments.get("operation_type", "").strip()
1531
+ status = arguments.get("status", "").strip()
1532
+ limit = arguments.get("limit", 100)
1533
+ return await self._handle_get_audit_logs(connection, table, operation_type, status, limit)
1031
1534
  else:
1032
1535
  raise ConfigurationError(f"Unknown tool: {name}")
1033
1536
 
@@ -1037,6 +1540,5 @@ class ConnectionServer:
1037
1540
  await self.server.run(
1038
1541
  streams[0],
1039
1542
  streams[1],
1040
- self.server.create_initialization_options(),
1041
- read_timeout_seconds=timedelta(seconds=30) # 设置30秒超时
1543
+ self.server.create_initialization_options()
1042
1544
  )