awslabs.postgres-mcp-server 1.0.8__py3-none-any.whl → 1.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -16,25 +16,45 @@
16
16
 
17
17
  import argparse
18
18
  import asyncio
19
+ import json
19
20
  import sys
20
- from awslabs.postgres_mcp_server.connection import DBConnectionSingleton
21
+ import threading
22
+ import traceback
23
+ from awslabs.postgres_mcp_server.connection.abstract_db_connection import AbstractDBConnection
24
+ from awslabs.postgres_mcp_server.connection.cp_api_connection import (
25
+ internal_create_serverless_cluster,
26
+ internal_get_cluster_properties,
27
+ internal_get_instance_properties,
28
+ setup_aurora_iam_policy_for_current_user,
29
+ )
30
+ from awslabs.postgres_mcp_server.connection.db_connection_map import (
31
+ ConnectionMethod,
32
+ DatabaseType,
33
+ DBConnectionMap,
34
+ )
21
35
  from awslabs.postgres_mcp_server.connection.psycopg_pool_connection import PsycopgPoolConnection
36
+ from awslabs.postgres_mcp_server.connection.rds_api_connection import RDSDataAPIConnection
22
37
  from awslabs.postgres_mcp_server.mutable_sql_detector import (
23
38
  check_sql_injection_risk,
24
39
  detect_mutating_keywords,
25
40
  )
26
- from botocore.exceptions import BotoCoreError, ClientError
41
+ from botocore.exceptions import ClientError
42
+ from datetime import datetime
27
43
  from loguru import logger
28
44
  from mcp.server.fastmcp import Context, FastMCP
29
45
  from pydantic import Field
30
- from typing import Annotated, Any, Dict, List, Optional
46
+ from typing import Annotated, Any, Dict, List, Optional, Tuple
31
47
 
32
48
 
49
+ db_connection_map = DBConnectionMap()
50
+ async_job_status: Dict[str, dict] = {}
51
+ async_job_status_lock = threading.Lock()
33
52
  client_error_code_key = 'run_query ClientError code'
34
53
  unexpected_error_key = 'run_query unexpected error'
35
54
  write_query_prohibited_key = 'Your MCP tool only allows readonly query. If you want to write, change the MCP configuration per README.md'
36
55
  query_comment_prohibited_key = 'The comment in query is prohibited because of injection risk'
37
56
  query_injection_risk_key = 'Your query contains risky injection patterns'
57
+ readonly_query = True
38
58
 
39
59
 
40
60
  class DummyCtx:
@@ -91,7 +111,10 @@ mcp = FastMCP(
91
111
  async def run_query(
92
112
  sql: Annotated[str, Field(description='The SQL query to run')],
93
113
  ctx: Context,
94
- db_connection=None,
114
+ connection_method: Annotated[ConnectionMethod, Field(description='connection method')],
115
+ cluster_identifier: Annotated[str, Field(description='Cluster identifier')],
116
+ db_endpoint: Annotated[str, Field(description='database endpoint')],
117
+ database: Annotated[str, Field(description='database name')],
95
118
  query_parameters: Annotated[
96
119
  Optional[List[Dict[str, Any]]], Field(description='Parameters for the SQL query')
97
120
  ] = None,
@@ -101,7 +124,10 @@ async def run_query(
101
124
  Args:
102
125
  sql: The sql statement to run
103
126
  ctx: MCP context for logging and state management
104
- db_connection: DB connection object passed by unit test. It should be None if called by MCP server.
127
+ connection_method: connection method
128
+ cluster_identifier: Cluster identifier
129
+ db_endpoint: database endpoint
130
+ database: database name
105
131
  query_parameters: Parameters for the SQL query
106
132
 
107
133
  Returns:
@@ -110,25 +136,38 @@ async def run_query(
110
136
  global client_error_code_key
111
137
  global unexpected_error_key
112
138
  global write_query_prohibited_key
139
+ global db_connection_map
113
140
 
114
- if db_connection is None:
115
- try:
116
- # Try to get the connection from the singleton
117
- db_connection = DBConnectionSingleton.get().db_connection
118
- except RuntimeError:
119
- # If the singleton is not initialized, this might be a direct connection
120
- logger.error('No database connection available')
121
- await ctx.error('No database connection available')
122
- return [{'error': 'No database connection available'}]
141
+ logger.info(
142
+ f'Entered run_query with '
143
+ f'method:{connection_method}, cluster_identifier:{cluster_identifier}, '
144
+ f'db_endpoint:{db_endpoint}, database:{database}, '
145
+ f'sql:{sql}'
146
+ )
123
147
 
124
- if db_connection is None:
125
- raise AssertionError('db_connection should never be None')
148
+ db_connection = db_connection_map.get(
149
+ method=connection_method,
150
+ cluster_identifier=cluster_identifier,
151
+ db_endpoint=db_endpoint,
152
+ database=database,
153
+ )
154
+ if not db_connection:
155
+ err = (
156
+ f'No database connection available for method:{connection_method}, '
157
+ f'cluster_identifier:{cluster_identifier}, db_endpoint:{db_endpoint}, database:{database}'
158
+ )
159
+ logger.error(err)
160
+ await ctx.error(err)
161
+ return [{'error': err}]
126
162
 
127
163
  if db_connection.readonly_query:
128
164
  matches = detect_mutating_keywords(sql)
129
165
  if (bool)(matches):
130
166
  logger.info(
131
- f'query is rejected because current setting only allows readonly query. detected keywords: {matches}, SQL query: {sql}'
167
+ (
168
+ f'query is rejected because current setting only allows readonly query.'
169
+ f'detected keywords: {matches}, SQL query: {sql}'
170
+ )
132
171
  )
133
172
  await ctx.error(write_query_prohibited_key)
134
173
  return [{'error': write_query_prohibited_key}]
@@ -144,9 +183,15 @@ async def run_query(
144
183
  return [{'error': query_injection_risk_key}]
145
184
 
146
185
  try:
147
- logger.info(f'run_query: readonly:{db_connection.readonly_query}, SQL:{sql}')
186
+ logger.info(
187
+ (
188
+ f'run_query: sql:{sql} method:{connection_method}, '
189
+ f'cluster_identifier:{cluster_identifier} database:{database} '
190
+ f'db_endpoint:{db_endpoint} '
191
+ f'readonly:{db_connection.readonly_query} query_parameters:{query_parameters}'
192
+ )
193
+ )
148
194
 
149
- # Execute the query using the abstract connection interface
150
195
  response = await db_connection.execute_query(sql, query_parameters)
151
196
 
152
197
  logger.success(f'run_query successfully executed query:{sql}')
@@ -164,23 +209,34 @@ async def run_query(
164
209
  return [{'error': unexpected_error_key}]
165
210
 
166
211
 
167
- @mcp.tool(
168
- name='get_table_schema',
169
- description='Fetch table columns and comments from Postgres',
170
- )
212
+ @mcp.tool(name='get_table_schema', description='Fetch table columns and comments from Postgres')
171
213
  async def get_table_schema(
172
- table_name: Annotated[str, Field(description='name of the table')], ctx: Context
214
+ connection_method: Annotated[ConnectionMethod, Field(description='connection method')],
215
+ cluster_identifier: Annotated[str, Field(description='Cluster identifier')],
216
+ db_endpoint: Annotated[str, Field(description='database endpoint')],
217
+ database: Annotated[str, Field(description='database name')],
218
+ table_name: Annotated[str, Field(description='name of the table')],
219
+ ctx: Context,
173
220
  ) -> list[dict]:
174
221
  """Get a table's schema information given the table name.
175
222
 
176
223
  Args:
224
+ connection_method: connection method
225
+ cluster_identifier: Cluster identifier
226
+ db_endpoint: database endpoint
227
+ database: database name
177
228
  table_name: name of the table
178
229
  ctx: MCP context for logging and state management
179
230
 
180
231
  Returns:
181
232
  List of dictionary that contains query response rows
182
233
  """
183
- logger.info(f'get_table_schema: {table_name}')
234
+ logger.info(
235
+ (
236
+ f'Entered get_table_schema: table_name:{table_name} connection_method:{connection_method}, '
237
+ f'cluster_identifier:{cluster_identifier}, db_endpoint:{db_endpoint}, database:{database}'
238
+ )
239
+ )
184
240
 
185
241
  sql = """
186
242
  SELECT
@@ -190,7 +246,7 @@ async def get_table_schema(
190
246
  FROM
191
247
  pg_attribute a
192
248
  WHERE
193
- a.attrelid = to_regclass(:table_name)
249
+ a.attrelid = to_regclass(%(table_name)s)
194
250
  AND a.attnum > 0
195
251
  AND NOT a.attisdropped
196
252
  ORDER BY a.attnum
@@ -198,125 +254,511 @@ async def get_table_schema(
198
254
 
199
255
  params = [{'name': 'table_name', 'value': {'stringValue': table_name}}]
200
256
 
201
- return await run_query(sql=sql, ctx=ctx, query_parameters=params)
257
+ return await run_query(
258
+ sql=sql,
259
+ ctx=ctx,
260
+ connection_method=connection_method,
261
+ cluster_identifier=cluster_identifier,
262
+ db_endpoint=db_endpoint,
263
+ database=database,
264
+ query_parameters=params,
265
+ )
266
+
202
267
 
268
+ @mcp.tool(
269
+ name='connect_to_database',
270
+ description='Connect to a specific database and save the connection internally',
271
+ )
272
+ def connect_to_database(
273
+ region: Annotated[str, Field(description='region')],
274
+ database_type: Annotated[DatabaseType, Field(description='database type')],
275
+ connection_method: Annotated[ConnectionMethod, Field(description='connection method')],
276
+ cluster_identifier: Annotated[str, Field(description='cluster identifier')],
277
+ db_endpoint: Annotated[str, Field(description='database endpoint')],
278
+ port: Annotated[int, Field(description='Postgres port')],
279
+ database: Annotated[str, Field(description='database name')],
280
+ ) -> str:
281
+ """Connect to a specific database save the connection internally.
203
282
 
204
- def main():
205
- """Main entry point for the MCP server application."""
206
- global client_error_code_key
283
+ Args:
284
+ region: region of the database. Required parametere.
285
+ database_type: Either APG for Aurora Postgres or RPG for RDS Postgres cluster. Required parameter
286
+ connection_method: Either RDS_API, PG_WIRE_PROTOCOL, or PG_WIRE_IAM_PROTOCOL. Required parameter
287
+ cluster_identifier: Either Aurora Postgres cluster identifier or RDS Postgres cluster identifier
288
+ db_endpoint: database endpoint
289
+ port: database port
290
+ database: database name. Required parameter
291
+
292
+ Supported scenario:
293
+ 1. Aurora Postgres database with RDS_API + Credential Manager:
294
+ cluster_identifier must be set
295
+ db_endpoint and port will be ignored
296
+ 2. Aurora Postgres database with direct connection + IAM:
297
+ cluster_identifier must be set
298
+ db_endpoint must be set
299
+ 3. Aurora Postgres database with direct connection + PG_AUTH (Credential Manager):
300
+ cluster_identifier must be set
301
+ db_endpoint must be set
302
+ 4. RDS Postgres database with direct connection + PG_AUTH (Credential Manager):
303
+ credential manager setting is either on instance or cluster
304
+ db_endpoint must be set
305
+ """
306
+ try:
307
+ db_connection, llm_response = internal_connect_to_database(
308
+ region=region,
309
+ database_type=database_type,
310
+ connection_method=connection_method,
311
+ cluster_identifier=cluster_identifier,
312
+ db_endpoint=db_endpoint,
313
+ port=port,
314
+ database=database,
315
+ )
207
316
 
208
- """Run the MCP server with CLI argument support."""
209
- parser = argparse.ArgumentParser(
210
- description='An AWS Labs Model Context Protocol (MCP) server for postgres'
317
+ return str(llm_response)
318
+
319
+ except Exception as e:
320
+ logger.error(f'connect_to_database failed with error: {str(e)}')
321
+ trace_msg = traceback.format_exc()
322
+ logger.error(f'Trace:{trace_msg}')
323
+ llm_response = {'status': 'Failed', 'error': str(e)}
324
+ return json.dumps(llm_response, indent=2)
325
+
326
+
327
+ @mcp.tool(name='is_database_connected', description='Check if a connection has been established')
328
+ def is_database_connected(
329
+ cluster_identifier: Annotated[str, Field(description='cluster identifier')],
330
+ db_endpoint: Annotated[str, Field(description='database endpoint')] = '',
331
+ database: Annotated[str, Field(description='database name')] = 'postgres',
332
+ ) -> bool:
333
+ """Check if a connection has been established.
334
+
335
+ Args:
336
+ cluster_identifier: cluster identifier
337
+ db_endpoint: database endpoint
338
+ database: database name
339
+
340
+ Returns:
341
+ result in boolean
342
+ """
343
+ global db_connection_map
344
+ if db_connection_map.get(ConnectionMethod.RDS_API, cluster_identifier, db_endpoint, database):
345
+ return True
346
+
347
+ if db_connection_map.get(
348
+ ConnectionMethod.PG_WIRE_PROTOCOL, cluster_identifier, db_endpoint, database
349
+ ):
350
+ return True
351
+
352
+ if db_connection_map.get(
353
+ ConnectionMethod.PG_WIRE_IAM_PROTOCOL, cluster_identifier, db_endpoint, database
354
+ ):
355
+ return True
356
+
357
+ return False
358
+
359
+
360
+ @mcp.tool(
361
+ name='get_database_connection_info',
362
+ description='Get all cached database connection information',
363
+ )
364
+ def get_database_connection_info() -> str:
365
+ """Get all cached database connection information.
366
+
367
+ Return:
368
+ A list of cached connection information.
369
+ """
370
+ global db_connection_map
371
+ return db_connection_map.get_keys_json()
372
+
373
+
374
+ @mcp.tool(name='create_cluster', description='Create an Aurora Postgres cluster')
375
+ def create_cluster(
376
+ region: Annotated[str, Field(description='region')],
377
+ cluster_identifier: Annotated[str, Field(description='cluster identifier')],
378
+ database: Annotated[str, Field(description='default database name')] = 'postgres',
379
+ engine_version: Annotated[str, Field(description='engine version')] = '17.5',
380
+ ) -> str:
381
+ """Create an RDS/Aurora cluster.
382
+
383
+ Args:
384
+ region: region
385
+ cluster_identifier: cluster identifier
386
+ database: database name
387
+ engine_version: engine version
388
+
389
+ Returns:
390
+ result
391
+ """
392
+ logger.info(
393
+ f'Entered create_cluster with region:{region}, '
394
+ f'cluster_identifier:{cluster_identifier} '
395
+ f'database:{database} '
396
+ f'engine_version:{engine_version}'
211
397
  )
212
398
 
213
- # Connection method 1: RDS Data API
214
- parser.add_argument('--resource_arn', help='ARN of the RDS cluster (for RDS Data API)')
399
+ database_type = DatabaseType.APG
400
+ connection_method = ConnectionMethod.RDS_API
215
401
 
216
- # Connection method 2: Psycopg Direct Connection
217
- parser.add_argument('--hostname', help='Database hostname (for direct PostgreSQL connection)')
218
- parser.add_argument('--port', type=int, default=5432, help='Database port (default: 5432)')
402
+ job_id = (
403
+ f'create-cluster-{cluster_identifier}-{datetime.now().isoformat(timespec="milliseconds")}'
404
+ )
219
405
 
220
- # Common parameters
221
- parser.add_argument(
222
- '--secret_arn',
223
- required=True,
224
- help='ARN of the Secrets Manager secret for database credentials',
406
+ try:
407
+ async_job_status_lock.acquire()
408
+ async_job_status[job_id] = {'state': 'pending', 'result': None}
409
+ finally:
410
+ async_job_status_lock.release()
411
+
412
+ t = threading.Thread(
413
+ target=create_cluster_worker,
414
+ args=(
415
+ job_id,
416
+ region,
417
+ database_type,
418
+ connection_method,
419
+ cluster_identifier,
420
+ engine_version,
421
+ database,
422
+ ),
423
+ daemon=False,
225
424
  )
226
- parser.add_argument('--database', required=True, help='Database name')
227
- parser.add_argument('--region', required=True, help='AWS region')
228
- parser.add_argument('--readonly', required=True, help='Enforce readonly SQL statements')
425
+ t.start()
426
+
427
+ logger.info(
428
+ f'start_create_cluster_job return with job_id:{job_id}'
429
+ f'region:{region} cluster_identifier:{cluster_identifier} database:{database} '
430
+ f'engine_version:{engine_version}'
431
+ )
432
+
433
+ result = {
434
+ 'status': 'Pending',
435
+ 'message': 'cluster creation started',
436
+ 'job_id': job_id,
437
+ 'cluster_identifier': cluster_identifier,
438
+ 'check_status_tool': 'get_job_status',
439
+ 'next_action': f"Use get_job_status(job_id='{job_id}') to get results",
440
+ }
441
+
442
+ return json.dumps(result, indent=2)
229
443
 
230
- args = parser.parse_args()
231
444
 
232
- # Validate connection parameters
233
- if not args.resource_arn and not args.hostname:
234
- parser.error(
235
- 'Either --resource_arn (for RDS Data API) or '
236
- '--hostname (for direct PostgreSQL) must be provided'
445
+ @mcp.tool(name='get_job_status', description='get background job status')
446
+ def get_job_status(job_id: str) -> dict:
447
+ """Get background job status.
448
+
449
+ Args:
450
+ job_id: job id
451
+ Returns:
452
+ job status
453
+ """
454
+ global async_job_status
455
+ global async_job_status_lock
456
+
457
+ try:
458
+ async_job_status_lock.acquire()
459
+ return async_job_status.get(job_id, {'state': 'not_found'})
460
+ finally:
461
+ async_job_status_lock.release()
462
+
463
+
464
+ def create_cluster_worker(
465
+ job_id: str,
466
+ region: str,
467
+ database_type: DatabaseType,
468
+ connection_method: ConnectionMethod,
469
+ cluster_identifier: str,
470
+ engine_version: str,
471
+ database: str,
472
+ ):
473
+ """Background worker to create a cluster asynchronously."""
474
+ global db_connection_map
475
+ global async_job_status
476
+ global async_job_status_lock
477
+ global readonly_query
478
+
479
+ try:
480
+ cluster_result = internal_create_serverless_cluster(
481
+ region=region,
482
+ cluster_identifier=cluster_identifier,
483
+ engine_version=engine_version,
484
+ database_name=database,
237
485
  )
238
486
 
239
- if args.resource_arn and args.hostname:
240
- parser.error(
241
- 'Cannot specify both --resource_arn and --hostname. Choose one connection method.'
487
+ setup_aurora_iam_policy_for_current_user(
488
+ db_user=cluster_result['MasterUsername'],
489
+ cluster_resource_id=cluster_result['DbClusterResourceId'],
490
+ cluster_region=region,
242
491
  )
243
492
 
244
- # Convert args to dict for easier handling
245
- connection_params = vars(args)
493
+ internal_connect_to_database(
494
+ region=region,
495
+ database_type=database_type,
496
+ connection_method=connection_method,
497
+ cluster_identifier=cluster_identifier,
498
+ db_endpoint=cluster_result['Endpoint'],
499
+ port=5432,
500
+ database=database,
501
+ )
246
502
 
247
- # Convert readonly string to boolean
248
- connection_params['readonly'] = args.readonly.lower() == 'true'
503
+ try:
504
+ async_job_status_lock.acquire()
505
+ async_job_status[job_id]['state'] = 'succeeded'
506
+ finally:
507
+ async_job_status_lock.release()
508
+ except Exception as e:
509
+ logger.error(f'create_cluster_worker failed with {e}')
510
+ try:
511
+ async_job_status_lock.acquire()
512
+ async_job_status[job_id]['state'] = 'failed'
513
+ async_job_status[job_id]['result'] = str(e)
514
+ finally:
515
+ async_job_status_lock.release()
516
+
517
+
518
+ def internal_connect_to_database(
519
+ region: Annotated[str, Field(description='region')],
520
+ database_type: Annotated[DatabaseType, Field(description='database type')],
521
+ connection_method: Annotated[ConnectionMethod, Field(description='connection method')],
522
+ cluster_identifier: Annotated[str, Field(description='cluster identifier')],
523
+ db_endpoint: Annotated[str, Field(description='database endpoint')],
524
+ port: Annotated[int, Field(description='Postgres port')],
525
+ database: Annotated[str, Field(description='database name')] = 'postgres',
526
+ ) -> Tuple:
527
+ """Connect to a specific database save the connection internally.
249
528
 
250
- # Log connection information
251
- connection_target = args.resource_arn if args.resource_arn else f'{args.hostname}:{args.port}'
529
+ Args:
530
+ region: region
531
+ database_type: database type (APG or RPG)
532
+ connection_method: connection method (RDS_API, PG_WIRE_PROTOCOL, or PG_WIRE_IAM_PROTOCOL)
533
+ cluster_identifier: cluster identifier
534
+ db_endpoint: database endpoint
535
+ port: database port
536
+ database: database name
537
+ """
538
+ global db_connection_map
539
+ global readonly_query
540
+
541
+ logger.info(
542
+ f'Enter internal_connect_to_database\n'
543
+ f'region:{region}\n'
544
+ f'database_type:{database_type}\n'
545
+ f'connection_method:{connection_method}\n'
546
+ f'cluster_identifier:{cluster_identifier}\n'
547
+ f'db_endpoint:{db_endpoint}\n'
548
+ f'database:{database}\n'
549
+ f'readonly_query:{readonly_query}'
550
+ )
252
551
 
253
- if args.resource_arn:
254
- logger.info(
255
- f'Postgres MCP init with RDS Data API: CONNECTION_TARGET:{connection_target}, SECRET_ARN:{args.secret_arn}, REGION:{args.region}, DATABASE:{args.database}, READONLY:{args.readonly}'
552
+ if not region:
553
+ raise ValueError("region can't be none or empty")
554
+
555
+ if not connection_method:
556
+ raise ValueError("connection_method can't be none or empty")
557
+
558
+ if not database_type:
559
+ raise ValueError("database_type can't be none or empty")
560
+
561
+ if database_type == DatabaseType.APG and not cluster_identifier:
562
+ raise ValueError("cluster_identifier can't be none or empty for Aurora Postgres Database")
563
+
564
+ existing_conn = db_connection_map.get(
565
+ connection_method, cluster_identifier, db_endpoint, database, port
566
+ )
567
+ if existing_conn:
568
+ llm_response = json.dumps(
569
+ {
570
+ 'connection_method': connection_method,
571
+ 'cluster_identifier': cluster_identifier,
572
+ 'db_endpoint': db_endpoint,
573
+ 'database': database,
574
+ 'port': port,
575
+ },
576
+ indent=2,
577
+ default=str,
256
578
  )
257
- else:
258
- logger.info(
259
- f'Postgres MCP init with psycopg: CONNECTION_TARGET:{connection_target}, PORT:{args.port}, DATABASE:{args.database}, READONLY:{args.readonly}'
579
+ return (existing_conn, llm_response)
580
+
581
+ enable_data_api: bool = False
582
+ masteruser: str = ''
583
+ cluster_arn: str = ''
584
+ secret_arn: str = ''
585
+
586
+ if cluster_identifier:
587
+ # Can be either APG (APG always requires cluster) or RPG multi-AZ cluster deployment case
588
+ cluster_properties = internal_get_cluster_properties(
589
+ cluster_identifier=cluster_identifier, region=region
260
590
  )
261
591
 
262
- # Create the appropriate database connection based on the provided parameters
592
+ enable_data_api = cluster_properties.get('HttpEndpointEnabled', False)
593
+ masteruser = cluster_properties.get('MasterUsername', '')
594
+ cluster_arn = cluster_properties.get('DBClusterArn', '')
595
+ secret_arn = cluster_properties.get('MasterUserSecret', {}).get('SecretArn')
596
+
597
+ if not db_endpoint:
598
+ # if db_endpoint not set, we will use cluster's endpoint
599
+ db_endpoint = cluster_properties.get('Endpoint', '')
600
+ port = int(cluster_properties.get('Port', ''))
601
+ else:
602
+ # Must be RPG instance only deployment case (i.e. without cluster)
603
+ instance_properties = internal_get_instance_properties(db_endpoint, region)
604
+ masteruser = instance_properties.get('MasterUsername', '')
605
+ secret_arn = instance_properties.get('MasterUserSecret', {}).get('SecretArn')
606
+ port = int(instance_properties.get('Endpoint', {}).get('Port'))
607
+
608
+ logger.info(
609
+ f'About to create internal DB connections with:'
610
+ f'enable_data_api:{enable_data_api}\n'
611
+ f'masteruser:{masteruser}\n'
612
+ f'cluster_arn:{cluster_arn}\n'
613
+ f'secret_arn:{secret_arn}\n'
614
+ f'db_endpoint:{db_endpoint}\n'
615
+ f'port:{port}\n'
616
+ f'region:{region}\n'
617
+ f'readonly:{readonly_query}'
618
+ )
619
+
263
620
  db_connection = None
621
+ if connection_method == ConnectionMethod.PG_WIRE_IAM_PROTOCOL:
622
+ db_connection = PsycopgPoolConnection(
623
+ host=db_endpoint,
624
+ port=port,
625
+ database=database,
626
+ readonly=readonly_query,
627
+ secret_arn='',
628
+ db_user=masteruser,
629
+ region=region,
630
+ is_iam_auth=True,
631
+ )
264
632
 
265
- try:
266
- if args.resource_arn:
267
- # Use RDS Data API with singleton pattern
268
- try:
269
- # Initialize the RDS Data API connection singleton
270
- DBConnectionSingleton.initialize(
271
- resource_arn=args.resource_arn,
272
- secret_arn=args.secret_arn,
273
- database=args.database,
274
- region=args.region,
275
- readonly=connection_params['readonly'],
276
- )
633
+ elif connection_method == ConnectionMethod.RDS_API:
634
+ db_connection = RDSDataAPIConnection(
635
+ cluster_arn=cluster_arn,
636
+ secret_arn=str(secret_arn),
637
+ database=database,
638
+ region=region,
639
+ readonly=readonly_query,
640
+ )
641
+ else:
642
+ # must be connection_method == ConnectionMethod.PG_WIRE_PROTOCOL
643
+ db_connection = PsycopgPoolConnection(
644
+ host=db_endpoint,
645
+ port=port,
646
+ database=database,
647
+ readonly=readonly_query,
648
+ secret_arn=secret_arn,
649
+ db_user='',
650
+ region=region,
651
+ is_iam_auth=False,
652
+ )
277
653
 
278
- # Get the connection from the singleton
279
- db_connection = DBConnectionSingleton.get().db_connection
280
- except Exception as e:
281
- logger.exception(f'Failed to create RDS Data API connection: {str(e)}')
282
- sys.exit(1)
283
-
284
- else:
285
- # Use Direct PostgreSQL connection using psycopg connection pool
286
- try:
287
- # Create a direct PostgreSQL connection pool
288
- db_connection = PsycopgPoolConnection(
289
- host=args.hostname,
290
- port=args.port,
291
- database=args.database,
292
- readonly=connection_params['readonly'],
293
- secret_arn=args.secret_arn,
294
- region=args.region,
295
- )
296
- except Exception as e:
297
- logger.exception(f'Failed to create PostgreSQL connection: {str(e)}')
298
- sys.exit(1)
299
-
300
- except BotoCoreError as e:
301
- logger.exception(f'Failed to create database connection: {str(e)}')
302
- sys.exit(1)
303
-
304
- # Test database connection
305
- ctx = DummyCtx()
306
- response = asyncio.run(run_query('SELECT 1', ctx, db_connection))
307
- if (
308
- isinstance(response, list)
309
- and len(response) == 1
310
- and isinstance(response[0], dict)
311
- and 'error' in response[0]
312
- ):
313
- logger.error('Failed to validate database connection to Postgres. Exit the MCP server')
314
- sys.exit(1)
654
+ if db_connection:
655
+ db_connection_map.set(
656
+ connection_method, cluster_identifier, db_endpoint, database, db_connection
657
+ )
658
+ llm_response = json.dumps(
659
+ {
660
+ 'connection_method': connection_method,
661
+ 'cluster_identifier': cluster_identifier,
662
+ 'db_endpoint': db_endpoint,
663
+ 'database': database,
664
+ 'port': port,
665
+ },
666
+ indent=2,
667
+ default=str,
668
+ )
669
+ return (db_connection, llm_response)
315
670
 
316
- logger.success('Successfully validated database connection to Postgres')
671
+ raise ValueError("Can't create connection because invalid input parameter combination")
317
672
 
318
- logger.info('Starting Postgres MCP server')
319
- mcp.run()
673
+
674
+ def main():
675
+ """Main entry point for the MCP server application.
676
+
677
+ Runs the MCP server with CLI argument support for PostgreSQL connections.
678
+ """
679
+ global db_connection_map
680
+ global readonly_query
681
+
682
+ parser = argparse.ArgumentParser(
683
+ description='An AWS Labs Model Context Protocol (MCP) server for postgres'
684
+ )
685
+
686
+ parser.add_argument(
687
+ '--connection_method',
688
+ help='Connection method to the database. It can be RDS_API, PG_WIRE_PROTOCOL OR PG_WIRE_IAM_PROTOCOL)',
689
+ )
690
+ parser.add_argument('--db_cluster_arn', help='ARN of the RDS or Aurora Postgres cluster')
691
+ parser.add_argument('--db_type', help='APG for Aurora Postgres or RPG for RDS Postgres')
692
+ parser.add_argument('--db_endpoint', help='Instance endpoint address')
693
+ parser.add_argument('--region', help='AWS region')
694
+ parser.add_argument(
695
+ '--allow_write_query', action='store_true', help='Enforce readonly SQL statements'
696
+ )
697
+ parser.add_argument('--database', help='Database name')
698
+ parser.add_argument('--port', type=int, default=5432, help='Database port (default: 5432)')
699
+ args = parser.parse_args()
700
+
701
+ logger.info(
702
+ f'MCP configuration:\n'
703
+ f'db_type:{args.db_type}\n'
704
+ f'db_cluster_arn:{args.db_cluster_arn}\n'
705
+ f'connection_method:{args.connection_method}\n'
706
+ f'db_endpoint:{args.db_endpoint}\n'
707
+ f'region:{args.region}\n'
708
+ f'allow_write_query:{args.allow_write_query}\n'
709
+ f'database:{args.database}\n'
710
+ f'port:{args.port}\n'
711
+ )
712
+
713
+ readonly_query = not args.allow_write_query
714
+
715
+ try:
716
+ if args.db_type:
717
+ # Create the appropriate database connection based on the provided parameters
718
+ db_connection: Optional[AbstractDBConnection] = None
719
+
720
+ cluster_identifier = args.db_cluster_arn.split(':')[-1]
721
+ db_connection, llm_response = internal_connect_to_database(
722
+ region=args.region,
723
+ database_type=DatabaseType[args.db_type],
724
+ connection_method=ConnectionMethod[args.connection_method],
725
+ cluster_identifier=cluster_identifier,
726
+ db_endpoint=args.hostname,
727
+ port=args.port,
728
+ database=args.database,
729
+ )
730
+
731
+ # Test database connection
732
+ if db_connection:
733
+ ctx = DummyCtx()
734
+ response = asyncio.run(
735
+ run_query(
736
+ 'SELECT 1',
737
+ ctx,
738
+ ConnectionMethod[args.connection_method],
739
+ cluster_identifier,
740
+ args.db_endpoint,
741
+ args.database,
742
+ )
743
+ )
744
+ if (
745
+ isinstance(response, list)
746
+ and len(response) == 1
747
+ and isinstance(response[0], dict)
748
+ and 'error' in response[0]
749
+ ):
750
+ logger.error(
751
+ 'Failed to validate database connection to Postgres. Exit the MCP server'
752
+ )
753
+ sys.exit(1)
754
+ else:
755
+ logger.success('Successfully validated database connection to Postgres')
756
+
757
+ logger.info('Postgres MCP server started')
758
+ mcp.run()
759
+ logger.info('Postgres MCP server stopped')
760
+ finally:
761
+ db_connection_map.close_all()
320
762
 
321
763
 
322
764
  if __name__ == '__main__':