awslabs.postgres-mcp-server 1.0.8__py3-none-any.whl → 1.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,128 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Database connection map for postgres MCP Server."""
16
+
17
+ import json
18
+ import threading
19
+ from awslabs.postgres_mcp_server.connection.abstract_db_connection import AbstractDBConnection
20
+ from enum import Enum
21
+ from loguru import logger
22
+ from typing import List
23
+
24
+
25
+ class DatabaseType(str, Enum):
26
+ """Database type enumeration."""
27
+
28
+ APG = ('APG',)
29
+ RPG = 'RPG'
30
+
31
+
32
+ class ConnectionMethod(str, Enum):
33
+ """Connection method enumeration."""
34
+
35
+ RDS_API = 'rdsapi'
36
+ PG_WIRE_PROTOCOL = 'pgwire'
37
+ PG_WIRE_IAM_PROTOCOL = 'pgwire_iam'
38
+
39
+
40
+ class DBConnectionMap:
41
+ """Manages Postgres DB connection map."""
42
+
43
+ def __init__(self):
44
+ """Initialize the connection map."""
45
+ self.map = {}
46
+ self._lock = threading.Lock()
47
+
48
+ def get(
49
+ self,
50
+ method: ConnectionMethod,
51
+ cluster_identifier: str,
52
+ db_endpoint: str,
53
+ database: str,
54
+ port: int = 5432,
55
+ ) -> AbstractDBConnection | None:
56
+ """Get a database connection from the map."""
57
+ if not method:
58
+ raise ValueError('method cannot be None')
59
+
60
+ if not database:
61
+ raise ValueError('database cannot be None or empty')
62
+
63
+ with self._lock:
64
+ return self.map.get((method, cluster_identifier, db_endpoint, database, port))
65
+
66
+ def set(
67
+ self,
68
+ method: ConnectionMethod,
69
+ cluster_identifier: str,
70
+ db_endpoint: str,
71
+ database: str,
72
+ conn: AbstractDBConnection,
73
+ port: int = 5432,
74
+ ) -> None:
75
+ """Set a database connection in the map."""
76
+ if not database:
77
+ raise ValueError('database cannot be None or empty')
78
+
79
+ if not conn:
80
+ raise ValueError('conn cannot be None')
81
+
82
+ with self._lock:
83
+ self.map[(method, cluster_identifier, db_endpoint, database, port)] = conn
84
+
85
+ def remove(
86
+ self,
87
+ method: ConnectionMethod,
88
+ cluster_identifier: str,
89
+ db_endpoint: str,
90
+ database: str,
91
+ port: int = 5432,
92
+ ) -> None:
93
+ """Remove a database connection from the map."""
94
+ if not database:
95
+ raise ValueError('database cannot be None or empty')
96
+
97
+ with self._lock:
98
+ try:
99
+ self.map.pop((method, cluster_identifier, db_endpoint, database, port))
100
+ except KeyError:
101
+ logger.info(
102
+ f'Try to remove a non-existing connection. {method} {cluster_identifier} {db_endpoint} {database} {port}'
103
+ )
104
+
105
+ def get_keys_json(self) -> str:
106
+ """Get all connection keys as JSON string."""
107
+ entries: List[dict] = []
108
+ with self._lock:
109
+ for key in self.map.keys():
110
+ entry = {
111
+ 'connection_method': key[0],
112
+ 'cluster_identifier': key[1],
113
+ 'db_endpoint': key[2],
114
+ 'database': key[3],
115
+ 'port': key[4],
116
+ }
117
+ entries.append(entry)
118
+ return json.dumps(entries, indent=2)
119
+
120
+ def close_all(self) -> None:
121
+ """Close all connections and clear the map."""
122
+ with self._lock:
123
+ for key, conn in self.map.items():
124
+ try:
125
+ conn.close()
126
+ except Exception as e:
127
+ logger.warning(f'Failed to close connection {key}: {e}')
128
+ self.map.clear()
@@ -21,7 +21,11 @@ parameters (host, port, database, user, password) or via AWS Secrets Manager.
21
21
 
22
22
  import boto3
23
23
  import json
24
+ from aiorwlock import RWLock
25
+ from awslabs.postgres_mcp_server import __user_agent__
24
26
  from awslabs.postgres_mcp_server.connection.abstract_db_connection import AbstractDBConnection
27
+ from botocore.config import Config
28
+ from datetime import datetime, timedelta
25
29
  from loguru import logger
26
30
  from psycopg_pool import AsyncConnectionPool
27
31
  from typing import Any, Dict, List, Optional, Tuple
@@ -45,7 +49,10 @@ class PsycopgPoolConnection(AbstractDBConnection):
45
49
  database: str,
46
50
  readonly: bool,
47
51
  secret_arn: str,
52
+ db_user: str,
48
53
  region: str,
54
+ is_iam_auth: bool = False,
55
+ pool_expiry_min: int = 30,
49
56
  min_size: int = 1,
50
57
  max_size: int = 10,
51
58
  is_test: bool = False,
@@ -58,7 +65,10 @@ class PsycopgPoolConnection(AbstractDBConnection):
58
65
  database: Database name
59
66
  readonly: Whether connections should be read-only
60
67
  secret_arn: ARN of the secret containing credentials
68
+ db_user: Database username
61
69
  region: AWS region for Secrets Manager
70
+ is_iam_auth: Whether to use IAM authentication
71
+ pool_expiry_min: Pool expiry time in minutes
62
72
  min_size: Minimum number of connections in the pool
63
73
  max_size: Maximum number of connections in the pool
64
74
  is_test: Whether this is a test connection
@@ -69,58 +79,83 @@ class PsycopgPoolConnection(AbstractDBConnection):
69
79
  self.database = database
70
80
  self.min_size = min_size
71
81
  self.max_size = max_size
82
+ self.region = region
83
+ self.is_iam_auth = is_iam_auth
84
+ self.user = db_user
85
+ self.pool_expiry_min = pool_expiry_min
86
+ self.secret_arn = secret_arn
87
+ self.is_test = is_test
72
88
  self.pool: Optional['AsyncConnectionPool[Any]'] = None
89
+ self.rw_lock = RWLock()
90
+ self.created_time = datetime.now()
73
91
 
74
- # Get credentials from Secrets Manager
75
- logger.info(f'Retrieving credentials from Secrets Manager: {secret_arn}')
76
- self.user, self.password = self._get_credentials_from_secret(secret_arn, region, is_test)
77
- logger.info(f'Successfully retrieved credentials for user: {self.user}')
92
+ if is_iam_auth:
93
+ # if db_user is set, then it is IAM auth scenario and iam_auth_token must be set
94
+ if not db_user:
95
+ raise ValueError('db_user must be set when is_iam_auth is True')
78
96
 
79
- # Store connection info
80
- if not is_test:
81
- self.conninfo = f'host={host} port={port} dbname={database} user={self.user} password={self.password}'
82
- logger.info('Connection parameters stored')
97
+ # set pool expiry before IAM auth token expiry of 15 minutes
98
+ self.pool_expiry_min = 14
99
+ logger.info(f'Use IAM auth for user: {db_user}')
83
100
 
84
101
  async def initialize_pool(self):
85
102
  """Initialize the connection pool."""
86
- if self.pool is None:
103
+ async with self.rw_lock.reader_lock:
104
+ if self.pool is not None:
105
+ return
106
+
107
+ async with self.rw_lock.writer_lock:
108
+ if self.pool is not None:
109
+ return
110
+
87
111
  logger.info(
88
- f'Initializing connection pool with min_size={self.min_size}, max_size={self.max_size}'
112
+ f'initialize_pool:\n'
113
+ f'endpoint:{self.host}\n'
114
+ f'port:{self.port}\n'
115
+ f'region:{self.region}\n'
116
+ f'db:{self.database}\n'
117
+ f'user:{self.user}\n'
118
+ f'is_iam_auth:{self.is_iam_auth}\n'
89
119
  )
120
+
121
+ if self.is_iam_auth:
122
+ logger.info(f'Retrieving IAM auth token for {self.user}')
123
+ password = self.get_iam_auth_token()
124
+ else:
125
+ logger.info(f'Retrieving credentials from Secrets Manager: {self.secret_arn}')
126
+ self.user, password = self._get_credentials_from_secret(
127
+ self.secret_arn, self.region, self.is_test
128
+ )
129
+
130
+ self.created_time = datetime.now()
131
+ self.conninfo = f'host={self.host} port={self.port} dbname={self.database} user={self.user} password={password}'
90
132
  self.pool = AsyncConnectionPool(
91
- self.conninfo, min_size=self.min_size, max_size=self.max_size, open=True
133
+ self.conninfo, min_size=self.min_size, max_size=self.max_size, open=False
92
134
  )
93
- logger.info('Connection pool initialized successfully')
94
135
 
95
- # Set read-only mode if needed
96
- if self.readonly_query:
97
- await self._set_all_connections_readonly()
136
+ # wait up to 30 seconds to fill the pool with connections
137
+ await self.pool.open(True, 30)
138
+ logger.info('Connection pool initialized successfully')
98
139
 
99
140
  async def _get_connection(self):
100
141
  """Get a database connection from the pool."""
101
- if self.pool is None:
102
- await self.initialize_pool()
103
-
104
- if self.pool is None:
105
- raise ValueError('Failed to initialize connection pool')
142
+ await self.check_expiry()
106
143
 
107
- return self.pool.connection(timeout=15.0)
144
+ async with self.rw_lock.reader_lock:
145
+ if self.pool is None:
146
+ raise ValueError('Failed to initialize connection pool')
147
+ return self.pool.connection(timeout=15.0)
108
148
 
109
- async def _set_all_connections_readonly(self):
110
- """Set all connections in the pool to read-only mode."""
111
- if self.pool is None:
112
- logger.warning('Connection pool is not initialized, cannot set read-only mode')
113
- return
149
+ async def check_expiry(self):
150
+ """Check and handle pool expiry."""
151
+ async with self.rw_lock.reader_lock:
152
+ if self.pool and datetime.now() - self.created_time < timedelta(
153
+ minutes=self.pool_expiry_min
154
+ ):
155
+ return
114
156
 
115
- try:
116
- async with self.pool.connection(timeout=15.0) as conn:
117
- await conn.execute(
118
- 'ALTER ROLE CURRENT_USER SET default_transaction_read_only = on'
119
- ) # type: ignore
120
- logger.info('Successfully set connection to read-only mode')
121
- except Exception as e:
122
- logger.warning(f'Failed to set connections to read-only mode: {str(e)}')
123
- logger.warning('Continuing without setting read-only mode')
157
+ await self.close()
158
+ await self.initialize_pool()
124
159
 
125
160
  async def execute_query(
126
161
  self, sql: str, parameters: Optional[List[Dict[str, Any]]] = None
@@ -130,7 +165,8 @@ class PsycopgPoolConnection(AbstractDBConnection):
130
165
  async with await self._get_connection() as conn:
131
166
  async with conn.transaction():
132
167
  if self.readonly_query:
133
- await conn.execute('SET TRANSACTION READ ONLY') # type: ignore
168
+ logger.info('SET TRANSACTION READ ONLY')
169
+ await conn.execute('SET TRANSACTION READ ONLY')
134
170
 
135
171
  # Create a cursor for better control
136
172
  async with conn.cursor() as cursor:
@@ -161,12 +197,12 @@ class PsycopgPoolConnection(AbstractDBConnection):
161
197
  record.append({'isNull': True})
162
198
  elif isinstance(value, str):
163
199
  record.append({'stringValue': value})
200
+ elif isinstance(value, bool):
201
+ record.append({'booleanValue': value})
164
202
  elif isinstance(value, int):
165
203
  record.append({'longValue': value})
166
204
  elif isinstance(value, float):
167
205
  record.append({'doubleValue': value})
168
- elif isinstance(value, bool):
169
- record.append({'booleanValue': value})
170
206
  elif isinstance(value, bytes):
171
207
  record.append({'blobValue': value})
172
208
  else:
@@ -258,11 +294,12 @@ class PsycopgPoolConnection(AbstractDBConnection):
258
294
 
259
295
  async def close(self) -> None:
260
296
  """Close all connections in the pool."""
261
- if self.pool is not None:
262
- logger.info('Closing connection pool')
263
- await self.pool.close()
264
- self.pool = None
265
- logger.info('Connection pool closed successfully')
297
+ async with self.rw_lock.writer_lock:
298
+ if self.pool is not None:
299
+ logger.info('Closing connection pool')
300
+ await self.pool.close()
301
+ self.pool = None
302
+ logger.info('Connection pool closed successfully')
266
303
 
267
304
  async def check_connection_health(self) -> bool:
268
305
  """Check if the connection is healthy."""
@@ -273,15 +310,25 @@ class PsycopgPoolConnection(AbstractDBConnection):
273
310
  logger.error(f'Connection health check failed: {str(e)}')
274
311
  return False
275
312
 
276
- def get_pool_stats(self) -> Dict[str, int]:
313
+ async def get_pool_stats(self) -> Dict[str, int]:
277
314
  """Get current connection pool statistics."""
278
- if not hasattr(self, 'pool') or self.pool is None:
279
- return {'size': 0, 'min_size': self.min_size, 'max_size': self.max_size, 'idle': 0}
280
-
281
- # Access pool attributes safely
282
- size = getattr(self.pool, 'size', 0)
283
- min_size = getattr(self.pool, 'min_size', self.min_size)
284
- max_size = getattr(self.pool, 'max_size', self.max_size)
285
- idle = getattr(self.pool, 'idle', 0)
286
-
287
- return {'size': size, 'min_size': min_size, 'max_size': max_size, 'idle': idle}
315
+ async with self.rw_lock.reader_lock:
316
+ if not hasattr(self, 'pool') or self.pool is None:
317
+ return {'size': 0, 'min_size': self.min_size, 'max_size': self.max_size, 'idle': 0}
318
+
319
+ # Access pool attributes safely
320
+ size = getattr(self.pool, 'size', 0)
321
+ min_size = getattr(self.pool, 'min_size', self.min_size)
322
+ max_size = getattr(self.pool, 'max_size', self.max_size)
323
+ idle = getattr(self.pool, 'idle', 0)
324
+
325
+ return {'size': size, 'min_size': min_size, 'max_size': max_size, 'idle': idle}
326
+
327
+ def get_iam_auth_token(self) -> str:
328
+ """Generate an IAM authentication token for RDS database access."""
329
+ rds_client = boto3.client(
330
+ 'rds', region_name=self.region, config=Config(user_agent_extra=__user_agent__)
331
+ )
332
+ return rds_client.generate_db_auth_token(
333
+ DBHostname=self.host, Port=self.port, DBUsername=self.user, Region=self.region
334
+ )
@@ -16,7 +16,9 @@
16
16
 
17
17
  import asyncio
18
18
  import boto3
19
+ from awslabs.postgres_mcp_server import __user_agent__
19
20
  from awslabs.postgres_mcp_server.connection.abstract_db_connection import AbstractDBConnection
21
+ from botocore.config import Config
20
22
  from loguru import logger
21
23
  from typing import Any, Dict, List, Optional
22
24
 
@@ -48,7 +50,9 @@ class RDSDataAPIConnection(AbstractDBConnection):
48
50
  self.secret_arn = secret_arn
49
51
  self.database = database
50
52
  if not is_test:
51
- self.data_client = boto3.client('rds-data', region_name=region)
53
+ self.data_client = boto3.client(
54
+ 'rds-data', region_name=region, config=Config(user_agent_extra=__user_agent__)
55
+ )
52
56
 
53
57
  async def execute_query(
54
58
  self, sql: str, parameters: Optional[List[Dict[str, Any]]] = None