mcp-dbutils 0.23.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mcp_dbutils/audit.py +269 -0
- mcp_dbutils/base.py +505 -3
- mcp_dbutils/config.py +103 -1
- mcp_dbutils/mysql/config.py +57 -40
- mcp_dbutils/mysql/handler.py +60 -0
- mcp_dbutils/postgres/config.py +40 -22
- mcp_dbutils/postgres/handler.py +60 -0
- mcp_dbutils/sqlite/config.py +8 -1
- mcp_dbutils/sqlite/handler.py +53 -0
- {mcp_dbutils-0.23.0.dist-info → mcp_dbutils-1.0.0.dist-info}/METADATA +1 -1
- mcp_dbutils-1.0.0.dist-info/RECORD +23 -0
- mcp_dbutils-0.23.0.dist-info/RECORD +0 -22
- {mcp_dbutils-0.23.0.dist-info → mcp_dbutils-1.0.0.dist-info}/WHEEL +0 -0
- {mcp_dbutils-0.23.0.dist-info → mcp_dbutils-1.0.0.dist-info}/entry_points.txt +0 -0
- {mcp_dbutils-0.23.0.dist-info → mcp_dbutils-1.0.0.dist-info}/licenses/LICENSE +0 -0
mcp_dbutils/config.py
CHANGED
@@ -2,18 +2,111 @@
|
|
2
2
|
|
3
3
|
import os
|
4
4
|
from abc import ABC, abstractmethod
|
5
|
-
from typing import Any, Dict, Literal
|
5
|
+
from typing import Any, Dict, List, Literal, Optional, Set, Union
|
6
6
|
|
7
7
|
import yaml
|
8
8
|
|
9
9
|
# Supported connection types
|
10
10
|
ConnectionType = Literal['sqlite', 'postgres', 'mysql']
|
11
11
|
|
12
|
+
# Supported write operations
|
13
|
+
WriteOperationType = Literal['INSERT', 'UPDATE', 'DELETE']
|
14
|
+
|
15
|
+
# Default policy for tables not explicitly listed in write_permissions
|
16
|
+
DefaultPolicyType = Literal['read_only', 'allow_all']
|
17
|
+
|
18
|
+
class WritePermissions:
|
19
|
+
"""Write permissions configuration"""
|
20
|
+
|
21
|
+
def __init__(self, config: Optional[Dict[str, Any]] = None):
|
22
|
+
"""Initialize write permissions
|
23
|
+
|
24
|
+
Args:
|
25
|
+
config: Write permissions configuration dictionary
|
26
|
+
"""
|
27
|
+
self.tables: Dict[str, Set[WriteOperationType]] = {}
|
28
|
+
self.default_policy: DefaultPolicyType = 'read_only'
|
29
|
+
|
30
|
+
if config:
|
31
|
+
# Parse table permissions
|
32
|
+
if 'tables' in config and isinstance(config['tables'], dict):
|
33
|
+
for table_name, table_config in config['tables'].items():
|
34
|
+
operations: Set[WriteOperationType] = set()
|
35
|
+
|
36
|
+
if isinstance(table_config, dict) and 'operations' in table_config:
|
37
|
+
ops = table_config['operations']
|
38
|
+
if isinstance(ops, list):
|
39
|
+
for op in ops:
|
40
|
+
if op in ('INSERT', 'UPDATE', 'DELETE'):
|
41
|
+
operations.add(op) # type: ignore
|
42
|
+
|
43
|
+
# If no operations specified, allow all
|
44
|
+
if not operations:
|
45
|
+
operations = {'INSERT', 'UPDATE', 'DELETE'} # type: ignore
|
46
|
+
|
47
|
+
self.tables[table_name] = operations
|
48
|
+
|
49
|
+
# Parse default policy
|
50
|
+
if 'default_policy' in config:
|
51
|
+
policy = config['default_policy']
|
52
|
+
if policy in ('read_only', 'allow_all'):
|
53
|
+
self.default_policy = policy # type: ignore
|
54
|
+
|
55
|
+
def can_write_to_table(self, table_name: str) -> bool:
|
56
|
+
"""Check if writing to the table is allowed
|
57
|
+
|
58
|
+
Args:
|
59
|
+
table_name: Name of the table
|
60
|
+
|
61
|
+
Returns:
|
62
|
+
True if writing to the table is allowed, False otherwise
|
63
|
+
"""
|
64
|
+
# If table is explicitly listed, it's writable
|
65
|
+
if table_name in self.tables:
|
66
|
+
return True
|
67
|
+
|
68
|
+
# Otherwise, check default policy
|
69
|
+
return self.default_policy == 'allow_all'
|
70
|
+
|
71
|
+
def allowed_operations(self, table_name: str) -> Set[WriteOperationType]:
|
72
|
+
"""Get allowed operations for a table
|
73
|
+
|
74
|
+
Args:
|
75
|
+
table_name: Name of the table
|
76
|
+
|
77
|
+
Returns:
|
78
|
+
Set of allowed operations
|
79
|
+
"""
|
80
|
+
# If table is explicitly listed, return its allowed operations
|
81
|
+
if table_name in self.tables:
|
82
|
+
return self.tables[table_name]
|
83
|
+
|
84
|
+
# Otherwise, check default policy
|
85
|
+
if self.default_policy == 'allow_all':
|
86
|
+
return {'INSERT', 'UPDATE', 'DELETE'} # type: ignore
|
87
|
+
|
88
|
+
# Default to empty set (no operations allowed)
|
89
|
+
return set() # type: ignore
|
90
|
+
|
91
|
+
def is_operation_allowed(self, table_name: str, operation: WriteOperationType) -> bool:
|
92
|
+
"""Check if an operation is allowed on a table
|
93
|
+
|
94
|
+
Args:
|
95
|
+
table_name: Name of the table
|
96
|
+
operation: Operation type
|
97
|
+
|
98
|
+
Returns:
|
99
|
+
True if the operation is allowed, False otherwise
|
100
|
+
"""
|
101
|
+
return operation in self.allowed_operations(table_name)
|
102
|
+
|
12
103
|
class ConnectionConfig(ABC):
|
13
104
|
"""Base class for connection configuration"""
|
14
105
|
|
15
106
|
debug: bool = False
|
16
107
|
type: ConnectionType # Connection type
|
108
|
+
writable: bool = False # Whether write operations are allowed
|
109
|
+
write_permissions: Optional[WritePermissions] = None # Write permissions configuration
|
17
110
|
|
18
111
|
@abstractmethod
|
19
112
|
def get_connection_params(self) -> Dict[str, Any]:
|
@@ -50,6 +143,15 @@ class ConnectionConfig(ABC):
|
|
50
143
|
if db_type not in ('sqlite', 'postgres', 'mysql'):
|
51
144
|
raise ValueError(f"Invalid type value in database configuration {conn_name}: {db_type}")
|
52
145
|
|
146
|
+
# Validate write permissions if writable is true
|
147
|
+
if db_config.get('writable', False):
|
148
|
+
if not isinstance(db_config.get('writable'), bool):
|
149
|
+
raise ValueError(f"Invalid writable value in database configuration {conn_name}: {db_config['writable']}")
|
150
|
+
|
151
|
+
# Validate write_permissions if present
|
152
|
+
if 'write_permissions' in db_config and not isinstance(db_config['write_permissions'], dict):
|
153
|
+
raise ValueError(f"Invalid write_permissions in database configuration {conn_name}: {db_config['write_permissions']}")
|
154
|
+
|
53
155
|
return connections
|
54
156
|
|
55
157
|
@classmethod
|
mcp_dbutils/mysql/config.py
CHANGED
@@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|
3
3
|
from typing import Any, Dict, Literal, Optional
|
4
4
|
from urllib.parse import parse_qs, urlparse
|
5
5
|
|
6
|
-
from ..config import ConnectionConfig
|
6
|
+
from ..config import ConnectionConfig, WritePermissions
|
7
7
|
|
8
8
|
|
9
9
|
@dataclass
|
@@ -16,30 +16,30 @@ class SSLConfig:
|
|
16
16
|
|
17
17
|
def parse_url(url: str) -> Dict[str, Any]:
|
18
18
|
"""Parse MySQL URL into connection parameters
|
19
|
-
|
19
|
+
|
20
20
|
Args:
|
21
21
|
url: URL (e.g. mysql://host:port/dbname?ssl-mode=verify_identity)
|
22
|
-
|
22
|
+
|
23
23
|
Returns:
|
24
24
|
Dictionary of connection parameters including SSL settings
|
25
25
|
"""
|
26
26
|
if not url.startswith('mysql://'):
|
27
27
|
raise ValueError("Invalid MySQL URL format")
|
28
|
-
|
28
|
+
|
29
29
|
if '@' in url:
|
30
30
|
raise ValueError("URL should not contain credentials. Please provide username and password separately.")
|
31
|
-
|
31
|
+
|
32
32
|
# Parse URL and query parameters
|
33
33
|
parsed = urlparse(url)
|
34
34
|
query_params = parse_qs(parsed.query)
|
35
|
-
|
35
|
+
|
36
36
|
params = {
|
37
37
|
'host': parsed.hostname or 'localhost',
|
38
38
|
'port': str(parsed.port or 3306),
|
39
39
|
'database': parsed.path.lstrip('/') if parsed.path else '',
|
40
40
|
'charset': query_params.get('charset', ['utf8mb4'])[0]
|
41
41
|
}
|
42
|
-
|
42
|
+
|
43
43
|
if not params['database']:
|
44
44
|
raise ValueError("MySQL database name must be specified in URL")
|
45
45
|
|
@@ -50,17 +50,17 @@ def parse_url(url: str) -> Dict[str, Any]:
|
|
50
50
|
if mode not in ['disabled', 'required', 'verify_ca', 'verify_identity']:
|
51
51
|
raise ValueError(f"Invalid ssl-mode: {mode}")
|
52
52
|
ssl_params['mode'] = mode
|
53
|
-
|
53
|
+
|
54
54
|
if 'ssl-ca' in query_params:
|
55
55
|
ssl_params['ca'] = query_params['ssl-ca'][0]
|
56
56
|
if 'ssl-cert' in query_params:
|
57
57
|
ssl_params['cert'] = query_params['ssl-cert'][0]
|
58
58
|
if 'ssl-key' in query_params:
|
59
59
|
ssl_params['key'] = query_params['ssl-key'][0]
|
60
|
-
|
60
|
+
|
61
61
|
if ssl_params:
|
62
62
|
params['ssl'] = SSLConfig(**ssl_params)
|
63
|
-
|
63
|
+
|
64
64
|
return params
|
65
65
|
|
66
66
|
@dataclass
|
@@ -75,18 +75,20 @@ class MySQLConfig(ConnectionConfig):
|
|
75
75
|
type: Literal['mysql'] = 'mysql'
|
76
76
|
url: Optional[str] = None
|
77
77
|
ssl: Optional[SSLConfig] = None
|
78
|
+
writable: bool = False # Whether write operations are allowed
|
79
|
+
write_permissions: Optional[WritePermissions] = None # Write permissions configuration
|
78
80
|
|
79
81
|
@classmethod
|
80
82
|
def _validate_connection_config(cls, configs: dict, db_name: str) -> dict:
|
81
83
|
"""验证连接配置是否有效
|
82
|
-
|
84
|
+
|
83
85
|
Args:
|
84
86
|
configs: 配置字典
|
85
87
|
db_name: 连接名称
|
86
|
-
|
88
|
+
|
87
89
|
Returns:
|
88
90
|
dict: 数据库配置
|
89
|
-
|
91
|
+
|
90
92
|
Raises:
|
91
93
|
ValueError: 如果配置无效
|
92
94
|
"""
|
@@ -107,17 +109,17 @@ class MySQLConfig(ConnectionConfig):
|
|
107
109
|
raise ValueError("User must be specified in connection configuration")
|
108
110
|
if not db_config.get('password'):
|
109
111
|
raise ValueError("Password must be specified in connection configuration")
|
110
|
-
|
112
|
+
|
111
113
|
return db_config
|
112
|
-
|
114
|
+
|
113
115
|
@classmethod
|
114
116
|
def _create_config_from_url(cls, db_config: dict, local_host: Optional[str] = None) -> 'MySQLConfig':
|
115
117
|
"""从URL创建配置
|
116
|
-
|
118
|
+
|
117
119
|
Args:
|
118
120
|
db_config: 数据库配置
|
119
121
|
local_host: 可选的本地主机地址
|
120
|
-
|
122
|
+
|
121
123
|
Returns:
|
122
124
|
MySQLConfig: 配置对象
|
123
125
|
"""
|
@@ -135,18 +137,18 @@ class MySQLConfig(ConnectionConfig):
|
|
135
137
|
ssl=params.get('ssl')
|
136
138
|
)
|
137
139
|
return config
|
138
|
-
|
140
|
+
|
139
141
|
@classmethod
|
140
142
|
def _create_config_from_params(cls, db_config: dict, local_host: Optional[str] = None) -> 'MySQLConfig':
|
141
143
|
"""从参数创建配置
|
142
|
-
|
144
|
+
|
143
145
|
Args:
|
144
146
|
db_config: 数据库配置
|
145
147
|
local_host: 可选的本地主机地址
|
146
|
-
|
148
|
+
|
147
149
|
Returns:
|
148
150
|
MySQLConfig: 配置对象
|
149
|
-
|
151
|
+
|
150
152
|
Raises:
|
151
153
|
ValueError: 如果缺少必需参数或SSL配置无效
|
152
154
|
"""
|
@@ -156,10 +158,10 @@ class MySQLConfig(ConnectionConfig):
|
|
156
158
|
raise ValueError("Host must be specified in connection configuration")
|
157
159
|
if not db_config.get('port'):
|
158
160
|
raise ValueError("Port must be specified in connection configuration")
|
159
|
-
|
161
|
+
|
160
162
|
# Parse SSL configuration if present
|
161
163
|
ssl_config = cls._parse_ssl_config(db_config)
|
162
|
-
|
164
|
+
|
163
165
|
config = cls(
|
164
166
|
database=db_config['database'],
|
165
167
|
user=db_config['user'],
|
@@ -171,30 +173,30 @@ class MySQLConfig(ConnectionConfig):
|
|
171
173
|
ssl=ssl_config
|
172
174
|
)
|
173
175
|
return config
|
174
|
-
|
176
|
+
|
175
177
|
@classmethod
|
176
178
|
def _parse_ssl_config(cls, db_config: dict) -> Optional[SSLConfig]:
|
177
179
|
"""解析SSL配置
|
178
|
-
|
180
|
+
|
179
181
|
Args:
|
180
182
|
db_config: 数据库配置
|
181
|
-
|
183
|
+
|
182
184
|
Returns:
|
183
185
|
Optional[SSLConfig]: SSL配置或None
|
184
|
-
|
186
|
+
|
185
187
|
Raises:
|
186
188
|
ValueError: 如果SSL配置无效
|
187
189
|
"""
|
188
190
|
if 'ssl' not in db_config:
|
189
191
|
return None
|
190
|
-
|
192
|
+
|
191
193
|
ssl_params = db_config['ssl']
|
192
194
|
if not isinstance(ssl_params, dict):
|
193
195
|
raise ValueError("SSL configuration must be a dictionary")
|
194
|
-
|
196
|
+
|
195
197
|
if ssl_params.get('mode') not in [None, 'disabled', 'required', 'verify_ca', 'verify_identity']:
|
196
198
|
raise ValueError(f"Invalid ssl-mode: {ssl_params.get('mode')}")
|
197
|
-
|
199
|
+
|
198
200
|
return SSLConfig(
|
199
201
|
mode=ssl_params.get('mode', 'disabled'),
|
200
202
|
ca=ssl_params.get('ca'),
|
@@ -212,7 +214,7 @@ class MySQLConfig(ConnectionConfig):
|
|
212
214
|
local_host: Optional local host address
|
213
215
|
"""
|
214
216
|
configs = cls.load_yaml_config(yaml_path)
|
215
|
-
|
217
|
+
|
216
218
|
# Validate connection config
|
217
219
|
db_config = cls._validate_connection_config(configs, db_name)
|
218
220
|
|
@@ -221,26 +223,35 @@ class MySQLConfig(ConnectionConfig):
|
|
221
223
|
config = cls._create_config_from_url(db_config, local_host)
|
222
224
|
else:
|
223
225
|
config = cls._create_config_from_params(db_config, local_host)
|
224
|
-
|
226
|
+
|
227
|
+
# Parse write permissions
|
228
|
+
config.writable = db_config.get('writable', False)
|
229
|
+
if config.writable and 'write_permissions' in db_config:
|
230
|
+
config.write_permissions = WritePermissions(db_config['write_permissions'])
|
231
|
+
|
225
232
|
config.debug = cls.get_debug_mode()
|
226
233
|
return config
|
227
234
|
|
228
235
|
@classmethod
|
229
|
-
def from_url(cls, url: str, user: str, password: str,
|
230
|
-
local_host: Optional[str] = None
|
236
|
+
def from_url(cls, url: str, user: str, password: str,
|
237
|
+
local_host: Optional[str] = None,
|
238
|
+
writable: bool = False,
|
239
|
+
write_permissions: Optional[Dict[str, Any]] = None) -> 'MySQLConfig':
|
231
240
|
"""Create configuration from URL and credentials
|
232
|
-
|
241
|
+
|
233
242
|
Args:
|
234
243
|
url: URL (mysql://host:port/dbname)
|
235
244
|
user: Username for connection
|
236
245
|
password: Password for connection
|
237
246
|
local_host: Optional local host address
|
238
|
-
|
247
|
+
writable: Whether write operations are allowed
|
248
|
+
write_permissions: Write permissions configuration
|
249
|
+
|
239
250
|
Raises:
|
240
251
|
ValueError: If URL format is invalid or required parameters are missing
|
241
252
|
"""
|
242
253
|
params = parse_url(url)
|
243
|
-
|
254
|
+
|
244
255
|
config = cls(
|
245
256
|
database=params['database'],
|
246
257
|
user=user,
|
@@ -250,8 +261,14 @@ class MySQLConfig(ConnectionConfig):
|
|
250
261
|
charset=params['charset'],
|
251
262
|
local_host=local_host,
|
252
263
|
url=url,
|
253
|
-
ssl=params.get('ssl')
|
264
|
+
ssl=params.get('ssl'),
|
265
|
+
writable=writable
|
254
266
|
)
|
267
|
+
|
268
|
+
# Parse write permissions
|
269
|
+
if writable and write_permissions:
|
270
|
+
config.write_permissions = WritePermissions(write_permissions)
|
271
|
+
|
255
272
|
config.debug = cls.get_debug_mode()
|
256
273
|
return config
|
257
274
|
|
@@ -266,7 +283,7 @@ class MySQLConfig(ConnectionConfig):
|
|
266
283
|
'charset': self.charset,
|
267
284
|
'use_unicode': True
|
268
285
|
}
|
269
|
-
|
286
|
+
|
270
287
|
# Add SSL parameters if configured
|
271
288
|
if self.ssl:
|
272
289
|
params['ssl_mode'] = self.ssl.mode
|
@@ -276,7 +293,7 @@ class MySQLConfig(ConnectionConfig):
|
|
276
293
|
params['ssl_cert'] = self.ssl.cert
|
277
294
|
if self.ssl.key:
|
278
295
|
params['ssl_key'] = self.ssl.key
|
279
|
-
|
296
|
+
|
280
297
|
return {k: v for k, v in params.items() if v is not None}
|
281
298
|
|
282
299
|
def get_masked_connection_info(self) -> Dict[str, Any]:
|
mcp_dbutils/mysql/handler.py
CHANGED
@@ -181,6 +181,66 @@ class MySQLHandler(ConnectionHandler):
|
|
181
181
|
if conn:
|
182
182
|
conn.close()
|
183
183
|
|
184
|
+
async def _execute_write_query(self, sql: str) -> str:
|
185
|
+
"""Execute SQL write query
|
186
|
+
|
187
|
+
Args:
|
188
|
+
sql: SQL write query (INSERT, UPDATE, DELETE)
|
189
|
+
|
190
|
+
Returns:
|
191
|
+
str: Execution result
|
192
|
+
|
193
|
+
Raises:
|
194
|
+
ConnectionHandlerError: If query execution fails
|
195
|
+
"""
|
196
|
+
conn = None
|
197
|
+
try:
|
198
|
+
# Check if the query is a write operation
|
199
|
+
sql_upper = sql.strip().upper()
|
200
|
+
is_insert = sql_upper.startswith("INSERT")
|
201
|
+
is_update = sql_upper.startswith("UPDATE")
|
202
|
+
is_delete = sql_upper.startswith("DELETE")
|
203
|
+
is_transaction = sql_upper.startswith(("BEGIN", "COMMIT", "ROLLBACK", "START TRANSACTION"))
|
204
|
+
|
205
|
+
if not (is_insert or is_update or is_delete or is_transaction):
|
206
|
+
raise ConnectionHandlerError("Only INSERT, UPDATE, DELETE, and transaction statements are allowed for write operations")
|
207
|
+
|
208
|
+
conn_params = self.config.get_connection_params()
|
209
|
+
conn = mysql.connector.connect(**conn_params)
|
210
|
+
self.log("debug", f"Executing write operation: {sql}")
|
211
|
+
|
212
|
+
with conn.cursor() as cur:
|
213
|
+
try:
|
214
|
+
# Execute the write operation
|
215
|
+
cur.execute(sql)
|
216
|
+
|
217
|
+
# Get number of affected rows
|
218
|
+
affected_rows = cur.rowcount
|
219
|
+
|
220
|
+
# Commit the transaction if not in a transaction block
|
221
|
+
if not is_transaction:
|
222
|
+
conn.commit()
|
223
|
+
|
224
|
+
self.log("debug", f"Write operation executed successfully, affected {affected_rows} rows")
|
225
|
+
|
226
|
+
# Return result
|
227
|
+
if is_transaction:
|
228
|
+
return f"Transaction operation executed successfully"
|
229
|
+
else:
|
230
|
+
return f"Write operation executed successfully. {affected_rows} row{'s' if affected_rows != 1 else ''} affected."
|
231
|
+
except mysql.connector.Error as e:
|
232
|
+
# Rollback on error
|
233
|
+
if not is_transaction:
|
234
|
+
conn.rollback()
|
235
|
+
self.log("error", f"Write operation error: {str(e)}")
|
236
|
+
raise ConnectionHandlerError(str(e))
|
237
|
+
except mysql.connector.Error as e:
|
238
|
+
error_msg = f"[{self.db_type}] Write operation failed: {str(e)}"
|
239
|
+
raise ConnectionHandlerError(error_msg)
|
240
|
+
finally:
|
241
|
+
if conn:
|
242
|
+
conn.close()
|
243
|
+
|
184
244
|
async def get_table_description(self, table_name: str) -> str:
|
185
245
|
"""Get detailed table description"""
|
186
246
|
conn = None
|
mcp_dbutils/postgres/config.py
CHANGED
@@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|
3
3
|
from typing import Any, Dict, Literal, Optional
|
4
4
|
from urllib.parse import parse_qs, urlparse
|
5
5
|
|
6
|
-
from ..config import ConnectionConfig
|
6
|
+
from ..config import ConnectionConfig, WritePermissions
|
7
7
|
|
8
8
|
|
9
9
|
@dataclass
|
@@ -16,29 +16,29 @@ class SSLConfig:
|
|
16
16
|
|
17
17
|
def parse_url(url: str) -> Dict[str, Any]:
|
18
18
|
"""Parse PostgreSQL URL into connection parameters
|
19
|
-
|
19
|
+
|
20
20
|
Args:
|
21
21
|
url: URL (e.g. postgresql://host:port/dbname?sslmode=verify-full)
|
22
|
-
|
22
|
+
|
23
23
|
Returns:
|
24
24
|
Dictionary of connection parameters including SSL settings
|
25
25
|
"""
|
26
26
|
if not url.startswith('postgresql://'):
|
27
27
|
raise ValueError("Invalid PostgreSQL URL format")
|
28
|
-
|
28
|
+
|
29
29
|
if '@' in url:
|
30
30
|
raise ValueError("URL should not contain credentials. Please provide username and password separately.")
|
31
|
-
|
31
|
+
|
32
32
|
# Parse URL and query parameters
|
33
33
|
parsed = urlparse(url)
|
34
34
|
query_params = parse_qs(parsed.query)
|
35
|
-
|
35
|
+
|
36
36
|
params = {
|
37
37
|
'host': parsed.hostname or 'localhost',
|
38
38
|
'port': str(parsed.port or 5432),
|
39
39
|
'dbname': parsed.path.lstrip('/') if parsed.path else '',
|
40
40
|
}
|
41
|
-
|
41
|
+
|
42
42
|
if not params['dbname']:
|
43
43
|
raise ValueError("PostgreSQL database name must be specified in URL")
|
44
44
|
|
@@ -49,17 +49,17 @@ def parse_url(url: str) -> Dict[str, Any]:
|
|
49
49
|
if mode not in ['disable', 'require', 'verify-ca', 'verify-full']:
|
50
50
|
raise ValueError(f"Invalid sslmode: {mode}")
|
51
51
|
ssl_params['mode'] = mode
|
52
|
-
|
52
|
+
|
53
53
|
if 'sslcert' in query_params:
|
54
54
|
ssl_params['cert'] = query_params['sslcert'][0]
|
55
55
|
if 'sslkey' in query_params:
|
56
56
|
ssl_params['key'] = query_params['sslkey'][0]
|
57
57
|
if 'sslrootcert' in query_params:
|
58
58
|
ssl_params['root'] = query_params['sslrootcert'][0]
|
59
|
-
|
59
|
+
|
60
60
|
if ssl_params:
|
61
61
|
params['ssl'] = SSLConfig(**ssl_params)
|
62
|
-
|
62
|
+
|
63
63
|
return params
|
64
64
|
|
65
65
|
@dataclass
|
@@ -73,6 +73,8 @@ class PostgreSQLConfig(ConnectionConfig):
|
|
73
73
|
type: Literal['postgres'] = 'postgres'
|
74
74
|
url: Optional[str] = None
|
75
75
|
ssl: Optional[SSLConfig] = None
|
76
|
+
writable: bool = False # Whether write operations are allowed
|
77
|
+
write_permissions: Optional[WritePermissions] = None # Write permissions configuration
|
76
78
|
|
77
79
|
@classmethod
|
78
80
|
def from_yaml(cls, yaml_path: str, db_name: str, local_host: Optional[str] = None) -> 'PostgreSQLConfig':
|
@@ -123,24 +125,24 @@ class PostgreSQLConfig(ConnectionConfig):
|
|
123
125
|
raise ValueError("Host must be specified in connection configuration")
|
124
126
|
if not db_config.get('port'):
|
125
127
|
raise ValueError("Port must be specified in connection configuration")
|
126
|
-
|
128
|
+
|
127
129
|
# Parse SSL configuration if present
|
128
130
|
ssl_config = None
|
129
131
|
if 'ssl' in db_config:
|
130
132
|
ssl_params = db_config['ssl']
|
131
133
|
if not isinstance(ssl_params, dict):
|
132
134
|
raise ValueError("SSL configuration must be a dictionary")
|
133
|
-
|
135
|
+
|
134
136
|
if ssl_params.get('mode') not in [None, 'disable', 'require', 'verify-ca', 'verify-full']:
|
135
137
|
raise ValueError(f"Invalid sslmode: {ssl_params.get('mode')}")
|
136
|
-
|
138
|
+
|
137
139
|
ssl_config = SSLConfig(
|
138
140
|
mode=ssl_params.get('mode', 'disable'),
|
139
141
|
cert=ssl_params.get('cert'),
|
140
142
|
key=ssl_params.get('key'),
|
141
143
|
root=ssl_params.get('root')
|
142
144
|
)
|
143
|
-
|
145
|
+
|
144
146
|
config = cls(
|
145
147
|
dbname=db_config['dbname'],
|
146
148
|
user=db_config['user'],
|
@@ -150,25 +152,35 @@ class PostgreSQLConfig(ConnectionConfig):
|
|
150
152
|
local_host=local_host,
|
151
153
|
ssl=ssl_config
|
152
154
|
)
|
155
|
+
|
156
|
+
# Parse write permissions
|
157
|
+
config.writable = db_config.get('writable', False)
|
158
|
+
if config.writable and 'write_permissions' in db_config:
|
159
|
+
config.write_permissions = WritePermissions(db_config['write_permissions'])
|
160
|
+
|
153
161
|
config.debug = cls.get_debug_mode()
|
154
162
|
return config
|
155
163
|
|
156
164
|
@classmethod
|
157
|
-
def from_url(cls, url: str, user: str, password: str,
|
158
|
-
local_host: Optional[str] = None
|
165
|
+
def from_url(cls, url: str, user: str, password: str,
|
166
|
+
local_host: Optional[str] = None,
|
167
|
+
writable: bool = False,
|
168
|
+
write_permissions: Optional[Dict[str, Any]] = None) -> 'PostgreSQLConfig':
|
159
169
|
"""Create configuration from URL and credentials
|
160
|
-
|
170
|
+
|
161
171
|
Args:
|
162
172
|
url: URL (postgresql://host:port/dbname)
|
163
173
|
user: Username for connection
|
164
174
|
password: Password for connection
|
165
175
|
local_host: Optional local host address
|
166
|
-
|
176
|
+
writable: Whether write operations are allowed
|
177
|
+
write_permissions: Write permissions configuration
|
178
|
+
|
167
179
|
Raises:
|
168
180
|
ValueError: If URL format is invalid or required parameters are missing
|
169
181
|
"""
|
170
182
|
params = parse_url(url)
|
171
|
-
|
183
|
+
|
172
184
|
config = cls(
|
173
185
|
dbname=params['dbname'],
|
174
186
|
user=user,
|
@@ -177,8 +189,14 @@ class PostgreSQLConfig(ConnectionConfig):
|
|
177
189
|
port=params['port'],
|
178
190
|
local_host=local_host,
|
179
191
|
url=url,
|
180
|
-
ssl=params.get('ssl')
|
192
|
+
ssl=params.get('ssl'),
|
193
|
+
writable=writable
|
181
194
|
)
|
195
|
+
|
196
|
+
# Parse write permissions
|
197
|
+
if writable and write_permissions:
|
198
|
+
config.write_permissions = WritePermissions(write_permissions)
|
199
|
+
|
182
200
|
config.debug = cls.get_debug_mode()
|
183
201
|
return config
|
184
202
|
|
@@ -191,7 +209,7 @@ class PostgreSQLConfig(ConnectionConfig):
|
|
191
209
|
'host': self.local_host or self.host,
|
192
210
|
'port': self.port
|
193
211
|
}
|
194
|
-
|
212
|
+
|
195
213
|
# Add SSL parameters if configured
|
196
214
|
if self.ssl:
|
197
215
|
params['sslmode'] = self.ssl.mode
|
@@ -201,7 +219,7 @@ class PostgreSQLConfig(ConnectionConfig):
|
|
201
219
|
params['sslkey'] = self.ssl.key
|
202
220
|
if self.ssl.root:
|
203
221
|
params['sslrootcert'] = self.ssl.root
|
204
|
-
|
222
|
+
|
205
223
|
return {k: v for k, v in params.items() if v}
|
206
224
|
|
207
225
|
def get_masked_connection_info(self) -> Dict[str, Any]:
|