awslabs.postgres-mcp-server 1.0.9__py3-none-any.whl → 1.0.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- awslabs/postgres_mcp_server/__init__.py +8 -1
- awslabs/postgres_mcp_server/connection/__init__.py +0 -1
- awslabs/postgres_mcp_server/connection/cp_api_connection.py +592 -0
- awslabs/postgres_mcp_server/connection/db_connection_map.py +128 -0
- awslabs/postgres_mcp_server/connection/psycopg_pool_connection.py +101 -54
- awslabs/postgres_mcp_server/connection/rds_api_connection.py +5 -1
- awslabs/postgres_mcp_server/server.py +562 -120
- {awslabs_postgres_mcp_server-1.0.9.dist-info → awslabs_postgres_mcp_server-1.0.11.dist-info}/METADATA +48 -79
- awslabs_postgres_mcp_server-1.0.11.dist-info/RECORD +16 -0
- {awslabs_postgres_mcp_server-1.0.9.dist-info → awslabs_postgres_mcp_server-1.0.11.dist-info}/WHEEL +1 -1
- awslabs/postgres_mcp_server/connection/db_connection_singleton.py +0 -117
- awslabs_postgres_mcp_server-1.0.9.dist-info/RECORD +0 -15
- {awslabs_postgres_mcp_server-1.0.9.dist-info → awslabs_postgres_mcp_server-1.0.11.dist-info}/entry_points.txt +0 -0
- {awslabs_postgres_mcp_server-1.0.9.dist-info → awslabs_postgres_mcp_server-1.0.11.dist-info}/licenses/LICENSE +0 -0
- {awslabs_postgres_mcp_server-1.0.9.dist-info → awslabs_postgres_mcp_server-1.0.11.dist-info}/licenses/NOTICE +0 -0
|
@@ -16,25 +16,45 @@
|
|
|
16
16
|
|
|
17
17
|
import argparse
|
|
18
18
|
import asyncio
|
|
19
|
+
import json
|
|
19
20
|
import sys
|
|
20
|
-
|
|
21
|
+
import threading
|
|
22
|
+
import traceback
|
|
23
|
+
from awslabs.postgres_mcp_server.connection.abstract_db_connection import AbstractDBConnection
|
|
24
|
+
from awslabs.postgres_mcp_server.connection.cp_api_connection import (
|
|
25
|
+
internal_create_serverless_cluster,
|
|
26
|
+
internal_get_cluster_properties,
|
|
27
|
+
internal_get_instance_properties,
|
|
28
|
+
setup_aurora_iam_policy_for_current_user,
|
|
29
|
+
)
|
|
30
|
+
from awslabs.postgres_mcp_server.connection.db_connection_map import (
|
|
31
|
+
ConnectionMethod,
|
|
32
|
+
DatabaseType,
|
|
33
|
+
DBConnectionMap,
|
|
34
|
+
)
|
|
21
35
|
from awslabs.postgres_mcp_server.connection.psycopg_pool_connection import PsycopgPoolConnection
|
|
36
|
+
from awslabs.postgres_mcp_server.connection.rds_api_connection import RDSDataAPIConnection
|
|
22
37
|
from awslabs.postgres_mcp_server.mutable_sql_detector import (
|
|
23
38
|
check_sql_injection_risk,
|
|
24
39
|
detect_mutating_keywords,
|
|
25
40
|
)
|
|
26
|
-
from botocore.exceptions import
|
|
41
|
+
from botocore.exceptions import ClientError
|
|
42
|
+
from datetime import datetime
|
|
27
43
|
from loguru import logger
|
|
28
44
|
from mcp.server.fastmcp import Context, FastMCP
|
|
29
45
|
from pydantic import Field
|
|
30
|
-
from typing import Annotated, Any, Dict, List, Optional
|
|
46
|
+
from typing import Annotated, Any, Dict, List, Optional, Tuple
|
|
31
47
|
|
|
32
48
|
|
|
49
|
+
db_connection_map = DBConnectionMap()
|
|
50
|
+
async_job_status: Dict[str, dict] = {}
|
|
51
|
+
async_job_status_lock = threading.Lock()
|
|
33
52
|
client_error_code_key = 'run_query ClientError code'
|
|
34
53
|
unexpected_error_key = 'run_query unexpected error'
|
|
35
54
|
write_query_prohibited_key = 'Your MCP tool only allows readonly query. If you want to write, change the MCP configuration per README.md'
|
|
36
55
|
query_comment_prohibited_key = 'The comment in query is prohibited because of injection risk'
|
|
37
56
|
query_injection_risk_key = 'Your query contains risky injection patterns'
|
|
57
|
+
readonly_query = True
|
|
38
58
|
|
|
39
59
|
|
|
40
60
|
class DummyCtx:
|
|
@@ -91,7 +111,10 @@ mcp = FastMCP(
|
|
|
91
111
|
async def run_query(
|
|
92
112
|
sql: Annotated[str, Field(description='The SQL query to run')],
|
|
93
113
|
ctx: Context,
|
|
94
|
-
|
|
114
|
+
connection_method: Annotated[ConnectionMethod, Field(description='connection method')],
|
|
115
|
+
cluster_identifier: Annotated[str, Field(description='Cluster identifier')],
|
|
116
|
+
db_endpoint: Annotated[str, Field(description='database endpoint')],
|
|
117
|
+
database: Annotated[str, Field(description='database name')],
|
|
95
118
|
query_parameters: Annotated[
|
|
96
119
|
Optional[List[Dict[str, Any]]], Field(description='Parameters for the SQL query')
|
|
97
120
|
] = None,
|
|
@@ -101,7 +124,10 @@ async def run_query(
|
|
|
101
124
|
Args:
|
|
102
125
|
sql: The sql statement to run
|
|
103
126
|
ctx: MCP context for logging and state management
|
|
104
|
-
|
|
127
|
+
connection_method: connection method
|
|
128
|
+
cluster_identifier: Cluster identifier
|
|
129
|
+
db_endpoint: database endpoint
|
|
130
|
+
database: database name
|
|
105
131
|
query_parameters: Parameters for the SQL query
|
|
106
132
|
|
|
107
133
|
Returns:
|
|
@@ -110,25 +136,38 @@ async def run_query(
|
|
|
110
136
|
global client_error_code_key
|
|
111
137
|
global unexpected_error_key
|
|
112
138
|
global write_query_prohibited_key
|
|
139
|
+
global db_connection_map
|
|
113
140
|
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
logger.error('No database connection available')
|
|
121
|
-
await ctx.error('No database connection available')
|
|
122
|
-
return [{'error': 'No database connection available'}]
|
|
141
|
+
logger.info(
|
|
142
|
+
f'Entered run_query with '
|
|
143
|
+
f'method:{connection_method}, cluster_identifier:{cluster_identifier}, '
|
|
144
|
+
f'db_endpoint:{db_endpoint}, database:{database}, '
|
|
145
|
+
f'sql:{sql}'
|
|
146
|
+
)
|
|
123
147
|
|
|
124
|
-
|
|
125
|
-
|
|
148
|
+
db_connection = db_connection_map.get(
|
|
149
|
+
method=connection_method,
|
|
150
|
+
cluster_identifier=cluster_identifier,
|
|
151
|
+
db_endpoint=db_endpoint,
|
|
152
|
+
database=database,
|
|
153
|
+
)
|
|
154
|
+
if not db_connection:
|
|
155
|
+
err = (
|
|
156
|
+
f'No database connection available for method:{connection_method}, '
|
|
157
|
+
f'cluster_identifier:{cluster_identifier}, db_endpoint:{db_endpoint}, database:{database}'
|
|
158
|
+
)
|
|
159
|
+
logger.error(err)
|
|
160
|
+
await ctx.error(err)
|
|
161
|
+
return [{'error': err}]
|
|
126
162
|
|
|
127
163
|
if db_connection.readonly_query:
|
|
128
164
|
matches = detect_mutating_keywords(sql)
|
|
129
165
|
if (bool)(matches):
|
|
130
166
|
logger.info(
|
|
131
|
-
|
|
167
|
+
(
|
|
168
|
+
f'query is rejected because current setting only allows readonly query.'
|
|
169
|
+
f'detected keywords: {matches}, SQL query: {sql}'
|
|
170
|
+
)
|
|
132
171
|
)
|
|
133
172
|
await ctx.error(write_query_prohibited_key)
|
|
134
173
|
return [{'error': write_query_prohibited_key}]
|
|
@@ -144,9 +183,15 @@ async def run_query(
|
|
|
144
183
|
return [{'error': query_injection_risk_key}]
|
|
145
184
|
|
|
146
185
|
try:
|
|
147
|
-
logger.info(
|
|
186
|
+
logger.info(
|
|
187
|
+
(
|
|
188
|
+
f'run_query: sql:{sql} method:{connection_method}, '
|
|
189
|
+
f'cluster_identifier:{cluster_identifier} database:{database} '
|
|
190
|
+
f'db_endpoint:{db_endpoint} '
|
|
191
|
+
f'readonly:{db_connection.readonly_query} query_parameters:{query_parameters}'
|
|
192
|
+
)
|
|
193
|
+
)
|
|
148
194
|
|
|
149
|
-
# Execute the query using the abstract connection interface
|
|
150
195
|
response = await db_connection.execute_query(sql, query_parameters)
|
|
151
196
|
|
|
152
197
|
logger.success(f'run_query successfully executed query:{sql}')
|
|
@@ -164,23 +209,34 @@ async def run_query(
|
|
|
164
209
|
return [{'error': unexpected_error_key}]
|
|
165
210
|
|
|
166
211
|
|
|
167
|
-
@mcp.tool(
|
|
168
|
-
name='get_table_schema',
|
|
169
|
-
description='Fetch table columns and comments from Postgres',
|
|
170
|
-
)
|
|
212
|
+
@mcp.tool(name='get_table_schema', description='Fetch table columns and comments from Postgres')
|
|
171
213
|
async def get_table_schema(
|
|
172
|
-
|
|
214
|
+
connection_method: Annotated[ConnectionMethod, Field(description='connection method')],
|
|
215
|
+
cluster_identifier: Annotated[str, Field(description='Cluster identifier')],
|
|
216
|
+
db_endpoint: Annotated[str, Field(description='database endpoint')],
|
|
217
|
+
database: Annotated[str, Field(description='database name')],
|
|
218
|
+
table_name: Annotated[str, Field(description='name of the table')],
|
|
219
|
+
ctx: Context,
|
|
173
220
|
) -> list[dict]:
|
|
174
221
|
"""Get a table's schema information given the table name.
|
|
175
222
|
|
|
176
223
|
Args:
|
|
224
|
+
connection_method: connection method
|
|
225
|
+
cluster_identifier: Cluster identifier
|
|
226
|
+
db_endpoint: database endpoint
|
|
227
|
+
database: database name
|
|
177
228
|
table_name: name of the table
|
|
178
229
|
ctx: MCP context for logging and state management
|
|
179
230
|
|
|
180
231
|
Returns:
|
|
181
232
|
List of dictionary that contains query response rows
|
|
182
233
|
"""
|
|
183
|
-
logger.info(
|
|
234
|
+
logger.info(
|
|
235
|
+
(
|
|
236
|
+
f'Entered get_table_schema: table_name:{table_name} connection_method:{connection_method}, '
|
|
237
|
+
f'cluster_identifier:{cluster_identifier}, db_endpoint:{db_endpoint}, database:{database}'
|
|
238
|
+
)
|
|
239
|
+
)
|
|
184
240
|
|
|
185
241
|
sql = """
|
|
186
242
|
SELECT
|
|
@@ -190,7 +246,7 @@ async def get_table_schema(
|
|
|
190
246
|
FROM
|
|
191
247
|
pg_attribute a
|
|
192
248
|
WHERE
|
|
193
|
-
a.attrelid = to_regclass(
|
|
249
|
+
a.attrelid = to_regclass(%(table_name)s)
|
|
194
250
|
AND a.attnum > 0
|
|
195
251
|
AND NOT a.attisdropped
|
|
196
252
|
ORDER BY a.attnum
|
|
@@ -198,125 +254,511 @@ async def get_table_schema(
|
|
|
198
254
|
|
|
199
255
|
params = [{'name': 'table_name', 'value': {'stringValue': table_name}}]
|
|
200
256
|
|
|
201
|
-
return await run_query(
|
|
257
|
+
return await run_query(
|
|
258
|
+
sql=sql,
|
|
259
|
+
ctx=ctx,
|
|
260
|
+
connection_method=connection_method,
|
|
261
|
+
cluster_identifier=cluster_identifier,
|
|
262
|
+
db_endpoint=db_endpoint,
|
|
263
|
+
database=database,
|
|
264
|
+
query_parameters=params,
|
|
265
|
+
)
|
|
266
|
+
|
|
202
267
|
|
|
268
|
+
@mcp.tool(
|
|
269
|
+
name='connect_to_database',
|
|
270
|
+
description='Connect to a specific database and save the connection internally',
|
|
271
|
+
)
|
|
272
|
+
def connect_to_database(
|
|
273
|
+
region: Annotated[str, Field(description='region')],
|
|
274
|
+
database_type: Annotated[DatabaseType, Field(description='database type')],
|
|
275
|
+
connection_method: Annotated[ConnectionMethod, Field(description='connection method')],
|
|
276
|
+
cluster_identifier: Annotated[str, Field(description='cluster identifier')],
|
|
277
|
+
db_endpoint: Annotated[str, Field(description='database endpoint')],
|
|
278
|
+
port: Annotated[int, Field(description='Postgres port')],
|
|
279
|
+
database: Annotated[str, Field(description='database name')],
|
|
280
|
+
) -> str:
|
|
281
|
+
"""Connect to a specific database save the connection internally.
|
|
203
282
|
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
283
|
+
Args:
|
|
284
|
+
region: region of the database. Required parametere.
|
|
285
|
+
database_type: Either APG for Aurora Postgres or RPG for RDS Postgres cluster. Required parameter
|
|
286
|
+
connection_method: Either RDS_API, PG_WIRE_PROTOCOL, or PG_WIRE_IAM_PROTOCOL. Required parameter
|
|
287
|
+
cluster_identifier: Either Aurora Postgres cluster identifier or RDS Postgres cluster identifier
|
|
288
|
+
db_endpoint: database endpoint
|
|
289
|
+
port: database port
|
|
290
|
+
database: database name. Required parameter
|
|
291
|
+
|
|
292
|
+
Supported scenario:
|
|
293
|
+
1. Aurora Postgres database with RDS_API + Credential Manager:
|
|
294
|
+
cluster_identifier must be set
|
|
295
|
+
db_endpoint and port will be ignored
|
|
296
|
+
2. Aurora Postgres database with direct connection + IAM:
|
|
297
|
+
cluster_identifier must be set
|
|
298
|
+
db_endpoint must be set
|
|
299
|
+
3. Aurora Postgres database with direct connection + PG_AUTH (Credential Manager):
|
|
300
|
+
cluster_identifier must be set
|
|
301
|
+
db_endpoint must be set
|
|
302
|
+
4. RDS Postgres database with direct connection + PG_AUTH (Credential Manager):
|
|
303
|
+
credential manager setting is either on instance or cluster
|
|
304
|
+
db_endpoint must be set
|
|
305
|
+
"""
|
|
306
|
+
try:
|
|
307
|
+
db_connection, llm_response = internal_connect_to_database(
|
|
308
|
+
region=region,
|
|
309
|
+
database_type=database_type,
|
|
310
|
+
connection_method=connection_method,
|
|
311
|
+
cluster_identifier=cluster_identifier,
|
|
312
|
+
db_endpoint=db_endpoint,
|
|
313
|
+
port=port,
|
|
314
|
+
database=database,
|
|
315
|
+
)
|
|
207
316
|
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
317
|
+
return str(llm_response)
|
|
318
|
+
|
|
319
|
+
except Exception as e:
|
|
320
|
+
logger.error(f'connect_to_database failed with error: {str(e)}')
|
|
321
|
+
trace_msg = traceback.format_exc()
|
|
322
|
+
logger.error(f'Trace:{trace_msg}')
|
|
323
|
+
llm_response = {'status': 'Failed', 'error': str(e)}
|
|
324
|
+
return json.dumps(llm_response, indent=2)
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
@mcp.tool(name='is_database_connected', description='Check if a connection has been established')
|
|
328
|
+
def is_database_connected(
|
|
329
|
+
cluster_identifier: Annotated[str, Field(description='cluster identifier')],
|
|
330
|
+
db_endpoint: Annotated[str, Field(description='database endpoint')] = '',
|
|
331
|
+
database: Annotated[str, Field(description='database name')] = 'postgres',
|
|
332
|
+
) -> bool:
|
|
333
|
+
"""Check if a connection has been established.
|
|
334
|
+
|
|
335
|
+
Args:
|
|
336
|
+
cluster_identifier: cluster identifier
|
|
337
|
+
db_endpoint: database endpoint
|
|
338
|
+
database: database name
|
|
339
|
+
|
|
340
|
+
Returns:
|
|
341
|
+
result in boolean
|
|
342
|
+
"""
|
|
343
|
+
global db_connection_map
|
|
344
|
+
if db_connection_map.get(ConnectionMethod.RDS_API, cluster_identifier, db_endpoint, database):
|
|
345
|
+
return True
|
|
346
|
+
|
|
347
|
+
if db_connection_map.get(
|
|
348
|
+
ConnectionMethod.PG_WIRE_PROTOCOL, cluster_identifier, db_endpoint, database
|
|
349
|
+
):
|
|
350
|
+
return True
|
|
351
|
+
|
|
352
|
+
if db_connection_map.get(
|
|
353
|
+
ConnectionMethod.PG_WIRE_IAM_PROTOCOL, cluster_identifier, db_endpoint, database
|
|
354
|
+
):
|
|
355
|
+
return True
|
|
356
|
+
|
|
357
|
+
return False
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
@mcp.tool(
|
|
361
|
+
name='get_database_connection_info',
|
|
362
|
+
description='Get all cached database connection information',
|
|
363
|
+
)
|
|
364
|
+
def get_database_connection_info() -> str:
|
|
365
|
+
"""Get all cached database connection information.
|
|
366
|
+
|
|
367
|
+
Return:
|
|
368
|
+
A list of cached connection information.
|
|
369
|
+
"""
|
|
370
|
+
global db_connection_map
|
|
371
|
+
return db_connection_map.get_keys_json()
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
@mcp.tool(name='create_cluster', description='Create an Aurora Postgres cluster')
|
|
375
|
+
def create_cluster(
|
|
376
|
+
region: Annotated[str, Field(description='region')],
|
|
377
|
+
cluster_identifier: Annotated[str, Field(description='cluster identifier')],
|
|
378
|
+
database: Annotated[str, Field(description='default database name')] = 'postgres',
|
|
379
|
+
engine_version: Annotated[str, Field(description='engine version')] = '17.5',
|
|
380
|
+
) -> str:
|
|
381
|
+
"""Create an RDS/Aurora cluster.
|
|
382
|
+
|
|
383
|
+
Args:
|
|
384
|
+
region: region
|
|
385
|
+
cluster_identifier: cluster identifier
|
|
386
|
+
database: database name
|
|
387
|
+
engine_version: engine version
|
|
388
|
+
|
|
389
|
+
Returns:
|
|
390
|
+
result
|
|
391
|
+
"""
|
|
392
|
+
logger.info(
|
|
393
|
+
f'Entered create_cluster with region:{region}, '
|
|
394
|
+
f'cluster_identifier:{cluster_identifier} '
|
|
395
|
+
f'database:{database} '
|
|
396
|
+
f'engine_version:{engine_version}'
|
|
211
397
|
)
|
|
212
398
|
|
|
213
|
-
|
|
214
|
-
|
|
399
|
+
database_type = DatabaseType.APG
|
|
400
|
+
connection_method = ConnectionMethod.RDS_API
|
|
215
401
|
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
402
|
+
job_id = (
|
|
403
|
+
f'create-cluster-{cluster_identifier}-{datetime.now().isoformat(timespec="milliseconds")}'
|
|
404
|
+
)
|
|
219
405
|
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
'
|
|
223
|
-
|
|
224
|
-
|
|
406
|
+
try:
|
|
407
|
+
async_job_status_lock.acquire()
|
|
408
|
+
async_job_status[job_id] = {'state': 'pending', 'result': None}
|
|
409
|
+
finally:
|
|
410
|
+
async_job_status_lock.release()
|
|
411
|
+
|
|
412
|
+
t = threading.Thread(
|
|
413
|
+
target=create_cluster_worker,
|
|
414
|
+
args=(
|
|
415
|
+
job_id,
|
|
416
|
+
region,
|
|
417
|
+
database_type,
|
|
418
|
+
connection_method,
|
|
419
|
+
cluster_identifier,
|
|
420
|
+
engine_version,
|
|
421
|
+
database,
|
|
422
|
+
),
|
|
423
|
+
daemon=False,
|
|
225
424
|
)
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
425
|
+
t.start()
|
|
426
|
+
|
|
427
|
+
logger.info(
|
|
428
|
+
f'start_create_cluster_job return with job_id:{job_id}'
|
|
429
|
+
f'region:{region} cluster_identifier:{cluster_identifier} database:{database} '
|
|
430
|
+
f'engine_version:{engine_version}'
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
result = {
|
|
434
|
+
'status': 'Pending',
|
|
435
|
+
'message': 'cluster creation started',
|
|
436
|
+
'job_id': job_id,
|
|
437
|
+
'cluster_identifier': cluster_identifier,
|
|
438
|
+
'check_status_tool': 'get_job_status',
|
|
439
|
+
'next_action': f"Use get_job_status(job_id='{job_id}') to get results",
|
|
440
|
+
}
|
|
441
|
+
|
|
442
|
+
return json.dumps(result, indent=2)
|
|
229
443
|
|
|
230
|
-
args = parser.parse_args()
|
|
231
444
|
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
445
|
+
@mcp.tool(name='get_job_status', description='get background job status')
|
|
446
|
+
def get_job_status(job_id: str) -> dict:
|
|
447
|
+
"""Get background job status.
|
|
448
|
+
|
|
449
|
+
Args:
|
|
450
|
+
job_id: job id
|
|
451
|
+
Returns:
|
|
452
|
+
job status
|
|
453
|
+
"""
|
|
454
|
+
global async_job_status
|
|
455
|
+
global async_job_status_lock
|
|
456
|
+
|
|
457
|
+
try:
|
|
458
|
+
async_job_status_lock.acquire()
|
|
459
|
+
return async_job_status.get(job_id, {'state': 'not_found'})
|
|
460
|
+
finally:
|
|
461
|
+
async_job_status_lock.release()
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
def create_cluster_worker(
|
|
465
|
+
job_id: str,
|
|
466
|
+
region: str,
|
|
467
|
+
database_type: DatabaseType,
|
|
468
|
+
connection_method: ConnectionMethod,
|
|
469
|
+
cluster_identifier: str,
|
|
470
|
+
engine_version: str,
|
|
471
|
+
database: str,
|
|
472
|
+
):
|
|
473
|
+
"""Background worker to create a cluster asynchronously."""
|
|
474
|
+
global db_connection_map
|
|
475
|
+
global async_job_status
|
|
476
|
+
global async_job_status_lock
|
|
477
|
+
global readonly_query
|
|
478
|
+
|
|
479
|
+
try:
|
|
480
|
+
cluster_result = internal_create_serverless_cluster(
|
|
481
|
+
region=region,
|
|
482
|
+
cluster_identifier=cluster_identifier,
|
|
483
|
+
engine_version=engine_version,
|
|
484
|
+
database_name=database,
|
|
237
485
|
)
|
|
238
486
|
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
'
|
|
487
|
+
setup_aurora_iam_policy_for_current_user(
|
|
488
|
+
db_user=cluster_result['MasterUsername'],
|
|
489
|
+
cluster_resource_id=cluster_result['DbClusterResourceId'],
|
|
490
|
+
cluster_region=region,
|
|
242
491
|
)
|
|
243
492
|
|
|
244
|
-
|
|
245
|
-
|
|
493
|
+
internal_connect_to_database(
|
|
494
|
+
region=region,
|
|
495
|
+
database_type=database_type,
|
|
496
|
+
connection_method=connection_method,
|
|
497
|
+
cluster_identifier=cluster_identifier,
|
|
498
|
+
db_endpoint=cluster_result['Endpoint'],
|
|
499
|
+
port=5432,
|
|
500
|
+
database=database,
|
|
501
|
+
)
|
|
246
502
|
|
|
247
|
-
|
|
248
|
-
|
|
503
|
+
try:
|
|
504
|
+
async_job_status_lock.acquire()
|
|
505
|
+
async_job_status[job_id]['state'] = 'succeeded'
|
|
506
|
+
finally:
|
|
507
|
+
async_job_status_lock.release()
|
|
508
|
+
except Exception as e:
|
|
509
|
+
logger.error(f'create_cluster_worker failed with {e}')
|
|
510
|
+
try:
|
|
511
|
+
async_job_status_lock.acquire()
|
|
512
|
+
async_job_status[job_id]['state'] = 'failed'
|
|
513
|
+
async_job_status[job_id]['result'] = str(e)
|
|
514
|
+
finally:
|
|
515
|
+
async_job_status_lock.release()
|
|
516
|
+
|
|
517
|
+
|
|
518
|
+
def internal_connect_to_database(
|
|
519
|
+
region: Annotated[str, Field(description='region')],
|
|
520
|
+
database_type: Annotated[DatabaseType, Field(description='database type')],
|
|
521
|
+
connection_method: Annotated[ConnectionMethod, Field(description='connection method')],
|
|
522
|
+
cluster_identifier: Annotated[str, Field(description='cluster identifier')],
|
|
523
|
+
db_endpoint: Annotated[str, Field(description='database endpoint')],
|
|
524
|
+
port: Annotated[int, Field(description='Postgres port')],
|
|
525
|
+
database: Annotated[str, Field(description='database name')] = 'postgres',
|
|
526
|
+
) -> Tuple:
|
|
527
|
+
"""Connect to a specific database save the connection internally.
|
|
249
528
|
|
|
250
|
-
|
|
251
|
-
|
|
529
|
+
Args:
|
|
530
|
+
region: region
|
|
531
|
+
database_type: database type (APG or RPG)
|
|
532
|
+
connection_method: connection method (RDS_API, PG_WIRE_PROTOCOL, or PG_WIRE_IAM_PROTOCOL)
|
|
533
|
+
cluster_identifier: cluster identifier
|
|
534
|
+
db_endpoint: database endpoint
|
|
535
|
+
port: database port
|
|
536
|
+
database: database name
|
|
537
|
+
"""
|
|
538
|
+
global db_connection_map
|
|
539
|
+
global readonly_query
|
|
540
|
+
|
|
541
|
+
logger.info(
|
|
542
|
+
f'Enter internal_connect_to_database\n'
|
|
543
|
+
f'region:{region}\n'
|
|
544
|
+
f'database_type:{database_type}\n'
|
|
545
|
+
f'connection_method:{connection_method}\n'
|
|
546
|
+
f'cluster_identifier:{cluster_identifier}\n'
|
|
547
|
+
f'db_endpoint:{db_endpoint}\n'
|
|
548
|
+
f'database:{database}\n'
|
|
549
|
+
f'readonly_query:{readonly_query}'
|
|
550
|
+
)
|
|
252
551
|
|
|
253
|
-
if
|
|
254
|
-
|
|
255
|
-
|
|
552
|
+
if not region:
|
|
553
|
+
raise ValueError("region can't be none or empty")
|
|
554
|
+
|
|
555
|
+
if not connection_method:
|
|
556
|
+
raise ValueError("connection_method can't be none or empty")
|
|
557
|
+
|
|
558
|
+
if not database_type:
|
|
559
|
+
raise ValueError("database_type can't be none or empty")
|
|
560
|
+
|
|
561
|
+
if database_type == DatabaseType.APG and not cluster_identifier:
|
|
562
|
+
raise ValueError("cluster_identifier can't be none or empty for Aurora Postgres Database")
|
|
563
|
+
|
|
564
|
+
existing_conn = db_connection_map.get(
|
|
565
|
+
connection_method, cluster_identifier, db_endpoint, database, port
|
|
566
|
+
)
|
|
567
|
+
if existing_conn:
|
|
568
|
+
llm_response = json.dumps(
|
|
569
|
+
{
|
|
570
|
+
'connection_method': connection_method,
|
|
571
|
+
'cluster_identifier': cluster_identifier,
|
|
572
|
+
'db_endpoint': db_endpoint,
|
|
573
|
+
'database': database,
|
|
574
|
+
'port': port,
|
|
575
|
+
},
|
|
576
|
+
indent=2,
|
|
577
|
+
default=str,
|
|
256
578
|
)
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
579
|
+
return (existing_conn, llm_response)
|
|
580
|
+
|
|
581
|
+
enable_data_api: bool = False
|
|
582
|
+
masteruser: str = ''
|
|
583
|
+
cluster_arn: str = ''
|
|
584
|
+
secret_arn: str = ''
|
|
585
|
+
|
|
586
|
+
if cluster_identifier:
|
|
587
|
+
# Can be either APG (APG always requires cluster) or RPG multi-AZ cluster deployment case
|
|
588
|
+
cluster_properties = internal_get_cluster_properties(
|
|
589
|
+
cluster_identifier=cluster_identifier, region=region
|
|
260
590
|
)
|
|
261
591
|
|
|
262
|
-
|
|
592
|
+
enable_data_api = cluster_properties.get('HttpEndpointEnabled', False)
|
|
593
|
+
masteruser = cluster_properties.get('MasterUsername', '')
|
|
594
|
+
cluster_arn = cluster_properties.get('DBClusterArn', '')
|
|
595
|
+
secret_arn = cluster_properties.get('MasterUserSecret', {}).get('SecretArn')
|
|
596
|
+
|
|
597
|
+
if not db_endpoint:
|
|
598
|
+
# if db_endpoint not set, we will use cluster's endpoint
|
|
599
|
+
db_endpoint = cluster_properties.get('Endpoint', '')
|
|
600
|
+
port = int(cluster_properties.get('Port', ''))
|
|
601
|
+
else:
|
|
602
|
+
# Must be RPG instance only deployment case (i.e. without cluster)
|
|
603
|
+
instance_properties = internal_get_instance_properties(db_endpoint, region)
|
|
604
|
+
masteruser = instance_properties.get('MasterUsername', '')
|
|
605
|
+
secret_arn = instance_properties.get('MasterUserSecret', {}).get('SecretArn')
|
|
606
|
+
port = int(instance_properties.get('Endpoint', {}).get('Port'))
|
|
607
|
+
|
|
608
|
+
logger.info(
|
|
609
|
+
f'About to create internal DB connections with:'
|
|
610
|
+
f'enable_data_api:{enable_data_api}\n'
|
|
611
|
+
f'masteruser:{masteruser}\n'
|
|
612
|
+
f'cluster_arn:{cluster_arn}\n'
|
|
613
|
+
f'secret_arn:{secret_arn}\n'
|
|
614
|
+
f'db_endpoint:{db_endpoint}\n'
|
|
615
|
+
f'port:{port}\n'
|
|
616
|
+
f'region:{region}\n'
|
|
617
|
+
f'readonly:{readonly_query}'
|
|
618
|
+
)
|
|
619
|
+
|
|
263
620
|
db_connection = None
|
|
621
|
+
if connection_method == ConnectionMethod.PG_WIRE_IAM_PROTOCOL:
|
|
622
|
+
db_connection = PsycopgPoolConnection(
|
|
623
|
+
host=db_endpoint,
|
|
624
|
+
port=port,
|
|
625
|
+
database=database,
|
|
626
|
+
readonly=readonly_query,
|
|
627
|
+
secret_arn='',
|
|
628
|
+
db_user=masteruser,
|
|
629
|
+
region=region,
|
|
630
|
+
is_iam_auth=True,
|
|
631
|
+
)
|
|
264
632
|
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
633
|
+
elif connection_method == ConnectionMethod.RDS_API:
|
|
634
|
+
db_connection = RDSDataAPIConnection(
|
|
635
|
+
cluster_arn=cluster_arn,
|
|
636
|
+
secret_arn=str(secret_arn),
|
|
637
|
+
database=database,
|
|
638
|
+
region=region,
|
|
639
|
+
readonly=readonly_query,
|
|
640
|
+
)
|
|
641
|
+
else:
|
|
642
|
+
# must be connection_method == ConnectionMethod.PG_WIRE_PROTOCOL
|
|
643
|
+
db_connection = PsycopgPoolConnection(
|
|
644
|
+
host=db_endpoint,
|
|
645
|
+
port=port,
|
|
646
|
+
database=database,
|
|
647
|
+
readonly=readonly_query,
|
|
648
|
+
secret_arn=secret_arn,
|
|
649
|
+
db_user='',
|
|
650
|
+
region=region,
|
|
651
|
+
is_iam_auth=False,
|
|
652
|
+
)
|
|
277
653
|
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
region=args.region,
|
|
295
|
-
)
|
|
296
|
-
except Exception as e:
|
|
297
|
-
logger.exception(f'Failed to create PostgreSQL connection: {str(e)}')
|
|
298
|
-
sys.exit(1)
|
|
299
|
-
|
|
300
|
-
except BotoCoreError as e:
|
|
301
|
-
logger.exception(f'Failed to create database connection: {str(e)}')
|
|
302
|
-
sys.exit(1)
|
|
303
|
-
|
|
304
|
-
# Test database connection
|
|
305
|
-
ctx = DummyCtx()
|
|
306
|
-
response = asyncio.run(run_query('SELECT 1', ctx, db_connection))
|
|
307
|
-
if (
|
|
308
|
-
isinstance(response, list)
|
|
309
|
-
and len(response) == 1
|
|
310
|
-
and isinstance(response[0], dict)
|
|
311
|
-
and 'error' in response[0]
|
|
312
|
-
):
|
|
313
|
-
logger.error('Failed to validate database connection to Postgres. Exit the MCP server')
|
|
314
|
-
sys.exit(1)
|
|
654
|
+
if db_connection:
|
|
655
|
+
db_connection_map.set(
|
|
656
|
+
connection_method, cluster_identifier, db_endpoint, database, db_connection
|
|
657
|
+
)
|
|
658
|
+
llm_response = json.dumps(
|
|
659
|
+
{
|
|
660
|
+
'connection_method': connection_method,
|
|
661
|
+
'cluster_identifier': cluster_identifier,
|
|
662
|
+
'db_endpoint': db_endpoint,
|
|
663
|
+
'database': database,
|
|
664
|
+
'port': port,
|
|
665
|
+
},
|
|
666
|
+
indent=2,
|
|
667
|
+
default=str,
|
|
668
|
+
)
|
|
669
|
+
return (db_connection, llm_response)
|
|
315
670
|
|
|
316
|
-
|
|
671
|
+
raise ValueError("Can't create connection because invalid input parameter combination")
|
|
317
672
|
|
|
318
|
-
|
|
319
|
-
|
|
673
|
+
|
|
674
|
+
def main():
|
|
675
|
+
"""Main entry point for the MCP server application.
|
|
676
|
+
|
|
677
|
+
Runs the MCP server with CLI argument support for PostgreSQL connections.
|
|
678
|
+
"""
|
|
679
|
+
global db_connection_map
|
|
680
|
+
global readonly_query
|
|
681
|
+
|
|
682
|
+
parser = argparse.ArgumentParser(
|
|
683
|
+
description='An AWS Labs Model Context Protocol (MCP) server for postgres'
|
|
684
|
+
)
|
|
685
|
+
|
|
686
|
+
parser.add_argument(
|
|
687
|
+
'--connection_method',
|
|
688
|
+
help='Connection method to the database. It can be RDS_API, PG_WIRE_PROTOCOL OR PG_WIRE_IAM_PROTOCOL)',
|
|
689
|
+
)
|
|
690
|
+
parser.add_argument('--db_cluster_arn', help='ARN of the RDS or Aurora Postgres cluster')
|
|
691
|
+
parser.add_argument('--db_type', help='APG for Aurora Postgres or RPG for RDS Postgres')
|
|
692
|
+
parser.add_argument('--db_endpoint', help='Instance endpoint address')
|
|
693
|
+
parser.add_argument('--region', help='AWS region')
|
|
694
|
+
parser.add_argument(
|
|
695
|
+
'--allow_write_query', action='store_true', help='Enforce readonly SQL statements'
|
|
696
|
+
)
|
|
697
|
+
parser.add_argument('--database', help='Database name')
|
|
698
|
+
parser.add_argument('--port', type=int, default=5432, help='Database port (default: 5432)')
|
|
699
|
+
args = parser.parse_args()
|
|
700
|
+
|
|
701
|
+
logger.info(
|
|
702
|
+
f'MCP configuration:\n'
|
|
703
|
+
f'db_type:{args.db_type}\n'
|
|
704
|
+
f'db_cluster_arn:{args.db_cluster_arn}\n'
|
|
705
|
+
f'connection_method:{args.connection_method}\n'
|
|
706
|
+
f'db_endpoint:{args.db_endpoint}\n'
|
|
707
|
+
f'region:{args.region}\n'
|
|
708
|
+
f'allow_write_query:{args.allow_write_query}\n'
|
|
709
|
+
f'database:{args.database}\n'
|
|
710
|
+
f'port:{args.port}\n'
|
|
711
|
+
)
|
|
712
|
+
|
|
713
|
+
readonly_query = not args.allow_write_query
|
|
714
|
+
|
|
715
|
+
try:
|
|
716
|
+
if args.db_type:
|
|
717
|
+
# Create the appropriate database connection based on the provided parameters
|
|
718
|
+
db_connection: Optional[AbstractDBConnection] = None
|
|
719
|
+
|
|
720
|
+
cluster_identifier = args.db_cluster_arn.split(':')[-1]
|
|
721
|
+
db_connection, llm_response = internal_connect_to_database(
|
|
722
|
+
region=args.region,
|
|
723
|
+
database_type=DatabaseType[args.db_type],
|
|
724
|
+
connection_method=ConnectionMethod[args.connection_method],
|
|
725
|
+
cluster_identifier=cluster_identifier,
|
|
726
|
+
db_endpoint=args.hostname,
|
|
727
|
+
port=args.port,
|
|
728
|
+
database=args.database,
|
|
729
|
+
)
|
|
730
|
+
|
|
731
|
+
# Test database connection
|
|
732
|
+
if db_connection:
|
|
733
|
+
ctx = DummyCtx()
|
|
734
|
+
response = asyncio.run(
|
|
735
|
+
run_query(
|
|
736
|
+
'SELECT 1',
|
|
737
|
+
ctx,
|
|
738
|
+
ConnectionMethod[args.connection_method],
|
|
739
|
+
cluster_identifier,
|
|
740
|
+
args.db_endpoint,
|
|
741
|
+
args.database,
|
|
742
|
+
)
|
|
743
|
+
)
|
|
744
|
+
if (
|
|
745
|
+
isinstance(response, list)
|
|
746
|
+
and len(response) == 1
|
|
747
|
+
and isinstance(response[0], dict)
|
|
748
|
+
and 'error' in response[0]
|
|
749
|
+
):
|
|
750
|
+
logger.error(
|
|
751
|
+
'Failed to validate database connection to Postgres. Exit the MCP server'
|
|
752
|
+
)
|
|
753
|
+
sys.exit(1)
|
|
754
|
+
else:
|
|
755
|
+
logger.success('Successfully validated database connection to Postgres')
|
|
756
|
+
|
|
757
|
+
logger.info('Postgres MCP server started')
|
|
758
|
+
mcp.run()
|
|
759
|
+
logger.info('Postgres MCP server stopped')
|
|
760
|
+
finally:
|
|
761
|
+
db_connection_map.close_all()
|
|
320
762
|
|
|
321
763
|
|
|
322
764
|
if __name__ == '__main__':
|