awslabs.postgres-mcp-server 1.0.9__py3-none-any.whl → 1.0.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- awslabs/postgres_mcp_server/__init__.py +8 -1
- awslabs/postgres_mcp_server/connection/__init__.py +0 -1
- awslabs/postgres_mcp_server/connection/cp_api_connection.py +592 -0
- awslabs/postgres_mcp_server/connection/db_connection_map.py +128 -0
- awslabs/postgres_mcp_server/connection/psycopg_pool_connection.py +101 -54
- awslabs/postgres_mcp_server/connection/rds_api_connection.py +5 -1
- awslabs/postgres_mcp_server/server.py +562 -120
- {awslabs_postgres_mcp_server-1.0.9.dist-info → awslabs_postgres_mcp_server-1.0.12.dist-info}/METADATA +50 -81
- awslabs_postgres_mcp_server-1.0.12.dist-info/RECORD +16 -0
- {awslabs_postgres_mcp_server-1.0.9.dist-info → awslabs_postgres_mcp_server-1.0.12.dist-info}/WHEEL +1 -1
- awslabs/postgres_mcp_server/connection/db_connection_singleton.py +0 -117
- awslabs_postgres_mcp_server-1.0.9.dist-info/RECORD +0 -15
- {awslabs_postgres_mcp_server-1.0.9.dist-info → awslabs_postgres_mcp_server-1.0.12.dist-info}/entry_points.txt +0 -0
- {awslabs_postgres_mcp_server-1.0.9.dist-info → awslabs_postgres_mcp_server-1.0.12.dist-info}/licenses/LICENSE +0 -0
- {awslabs_postgres_mcp_server-1.0.9.dist-info → awslabs_postgres_mcp_server-1.0.12.dist-info}/licenses/NOTICE +0 -0
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Database connection map for postgres MCP Server."""
|
|
16
|
+
|
|
17
|
+
import json
|
|
18
|
+
import threading
|
|
19
|
+
from awslabs.postgres_mcp_server.connection.abstract_db_connection import AbstractDBConnection
|
|
20
|
+
from enum import Enum
|
|
21
|
+
from loguru import logger
|
|
22
|
+
from typing import List
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class DatabaseType(str, Enum):
|
|
26
|
+
"""Database type enumeration."""
|
|
27
|
+
|
|
28
|
+
APG = ('APG',)
|
|
29
|
+
RPG = 'RPG'
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class ConnectionMethod(str, Enum):
|
|
33
|
+
"""Connection method enumeration."""
|
|
34
|
+
|
|
35
|
+
RDS_API = 'rdsapi'
|
|
36
|
+
PG_WIRE_PROTOCOL = 'pgwire'
|
|
37
|
+
PG_WIRE_IAM_PROTOCOL = 'pgwire_iam'
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class DBConnectionMap:
|
|
41
|
+
"""Manages Postgres DB connection map."""
|
|
42
|
+
|
|
43
|
+
def __init__(self):
|
|
44
|
+
"""Initialize the connection map."""
|
|
45
|
+
self.map = {}
|
|
46
|
+
self._lock = threading.Lock()
|
|
47
|
+
|
|
48
|
+
def get(
|
|
49
|
+
self,
|
|
50
|
+
method: ConnectionMethod,
|
|
51
|
+
cluster_identifier: str,
|
|
52
|
+
db_endpoint: str,
|
|
53
|
+
database: str,
|
|
54
|
+
port: int = 5432,
|
|
55
|
+
) -> AbstractDBConnection | None:
|
|
56
|
+
"""Get a database connection from the map."""
|
|
57
|
+
if not method:
|
|
58
|
+
raise ValueError('method cannot be None')
|
|
59
|
+
|
|
60
|
+
if not database:
|
|
61
|
+
raise ValueError('database cannot be None or empty')
|
|
62
|
+
|
|
63
|
+
with self._lock:
|
|
64
|
+
return self.map.get((method, cluster_identifier, db_endpoint, database, port))
|
|
65
|
+
|
|
66
|
+
def set(
|
|
67
|
+
self,
|
|
68
|
+
method: ConnectionMethod,
|
|
69
|
+
cluster_identifier: str,
|
|
70
|
+
db_endpoint: str,
|
|
71
|
+
database: str,
|
|
72
|
+
conn: AbstractDBConnection,
|
|
73
|
+
port: int = 5432,
|
|
74
|
+
) -> None:
|
|
75
|
+
"""Set a database connection in the map."""
|
|
76
|
+
if not database:
|
|
77
|
+
raise ValueError('database cannot be None or empty')
|
|
78
|
+
|
|
79
|
+
if not conn:
|
|
80
|
+
raise ValueError('conn cannot be None')
|
|
81
|
+
|
|
82
|
+
with self._lock:
|
|
83
|
+
self.map[(method, cluster_identifier, db_endpoint, database, port)] = conn
|
|
84
|
+
|
|
85
|
+
def remove(
|
|
86
|
+
self,
|
|
87
|
+
method: ConnectionMethod,
|
|
88
|
+
cluster_identifier: str,
|
|
89
|
+
db_endpoint: str,
|
|
90
|
+
database: str,
|
|
91
|
+
port: int = 5432,
|
|
92
|
+
) -> None:
|
|
93
|
+
"""Remove a database connection from the map."""
|
|
94
|
+
if not database:
|
|
95
|
+
raise ValueError('database cannot be None or empty')
|
|
96
|
+
|
|
97
|
+
with self._lock:
|
|
98
|
+
try:
|
|
99
|
+
self.map.pop((method, cluster_identifier, db_endpoint, database, port))
|
|
100
|
+
except KeyError:
|
|
101
|
+
logger.info(
|
|
102
|
+
f'Try to remove a non-existing connection. {method} {cluster_identifier} {db_endpoint} {database} {port}'
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
def get_keys_json(self) -> str:
|
|
106
|
+
"""Get all connection keys as JSON string."""
|
|
107
|
+
entries: List[dict] = []
|
|
108
|
+
with self._lock:
|
|
109
|
+
for key in self.map.keys():
|
|
110
|
+
entry = {
|
|
111
|
+
'connection_method': key[0],
|
|
112
|
+
'cluster_identifier': key[1],
|
|
113
|
+
'db_endpoint': key[2],
|
|
114
|
+
'database': key[3],
|
|
115
|
+
'port': key[4],
|
|
116
|
+
}
|
|
117
|
+
entries.append(entry)
|
|
118
|
+
return json.dumps(entries, indent=2)
|
|
119
|
+
|
|
120
|
+
def close_all(self) -> None:
|
|
121
|
+
"""Close all connections and clear the map."""
|
|
122
|
+
with self._lock:
|
|
123
|
+
for key, conn in self.map.items():
|
|
124
|
+
try:
|
|
125
|
+
conn.close()
|
|
126
|
+
except Exception as e:
|
|
127
|
+
logger.warning(f'Failed to close connection {key}: {e}')
|
|
128
|
+
self.map.clear()
|
|
@@ -21,7 +21,11 @@ parameters (host, port, database, user, password) or via AWS Secrets Manager.
|
|
|
21
21
|
|
|
22
22
|
import boto3
|
|
23
23
|
import json
|
|
24
|
+
from aiorwlock import RWLock
|
|
25
|
+
from awslabs.postgres_mcp_server import __user_agent__
|
|
24
26
|
from awslabs.postgres_mcp_server.connection.abstract_db_connection import AbstractDBConnection
|
|
27
|
+
from botocore.config import Config
|
|
28
|
+
from datetime import datetime, timedelta
|
|
25
29
|
from loguru import logger
|
|
26
30
|
from psycopg_pool import AsyncConnectionPool
|
|
27
31
|
from typing import Any, Dict, List, Optional, Tuple
|
|
@@ -45,7 +49,10 @@ class PsycopgPoolConnection(AbstractDBConnection):
|
|
|
45
49
|
database: str,
|
|
46
50
|
readonly: bool,
|
|
47
51
|
secret_arn: str,
|
|
52
|
+
db_user: str,
|
|
48
53
|
region: str,
|
|
54
|
+
is_iam_auth: bool = False,
|
|
55
|
+
pool_expiry_min: int = 30,
|
|
49
56
|
min_size: int = 1,
|
|
50
57
|
max_size: int = 10,
|
|
51
58
|
is_test: bool = False,
|
|
@@ -58,7 +65,10 @@ class PsycopgPoolConnection(AbstractDBConnection):
|
|
|
58
65
|
database: Database name
|
|
59
66
|
readonly: Whether connections should be read-only
|
|
60
67
|
secret_arn: ARN of the secret containing credentials
|
|
68
|
+
db_user: Database username
|
|
61
69
|
region: AWS region for Secrets Manager
|
|
70
|
+
is_iam_auth: Whether to use IAM authentication
|
|
71
|
+
pool_expiry_min: Pool expiry time in minutes
|
|
62
72
|
min_size: Minimum number of connections in the pool
|
|
63
73
|
max_size: Maximum number of connections in the pool
|
|
64
74
|
is_test: Whether this is a test connection
|
|
@@ -69,58 +79,83 @@ class PsycopgPoolConnection(AbstractDBConnection):
|
|
|
69
79
|
self.database = database
|
|
70
80
|
self.min_size = min_size
|
|
71
81
|
self.max_size = max_size
|
|
82
|
+
self.region = region
|
|
83
|
+
self.is_iam_auth = is_iam_auth
|
|
84
|
+
self.user = db_user
|
|
85
|
+
self.pool_expiry_min = pool_expiry_min
|
|
86
|
+
self.secret_arn = secret_arn
|
|
87
|
+
self.is_test = is_test
|
|
72
88
|
self.pool: Optional['AsyncConnectionPool[Any]'] = None
|
|
89
|
+
self.rw_lock = RWLock()
|
|
90
|
+
self.created_time = datetime.now()
|
|
73
91
|
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
92
|
+
if is_iam_auth:
|
|
93
|
+
# if db_user is set, then it is IAM auth scenario and iam_auth_token must be set
|
|
94
|
+
if not db_user:
|
|
95
|
+
raise ValueError('db_user must be set when is_iam_auth is True')
|
|
78
96
|
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
logger.info('Connection parameters stored')
|
|
97
|
+
# set pool expiry before IAM auth token expiry of 15 minutes
|
|
98
|
+
self.pool_expiry_min = 14
|
|
99
|
+
logger.info(f'Use IAM auth for user: {db_user}')
|
|
83
100
|
|
|
84
101
|
async def initialize_pool(self):
|
|
85
102
|
"""Initialize the connection pool."""
|
|
86
|
-
|
|
103
|
+
async with self.rw_lock.reader_lock:
|
|
104
|
+
if self.pool is not None:
|
|
105
|
+
return
|
|
106
|
+
|
|
107
|
+
async with self.rw_lock.writer_lock:
|
|
108
|
+
if self.pool is not None:
|
|
109
|
+
return
|
|
110
|
+
|
|
87
111
|
logger.info(
|
|
88
|
-
f'
|
|
112
|
+
f'initialize_pool:\n'
|
|
113
|
+
f'endpoint:{self.host}\n'
|
|
114
|
+
f'port:{self.port}\n'
|
|
115
|
+
f'region:{self.region}\n'
|
|
116
|
+
f'db:{self.database}\n'
|
|
117
|
+
f'user:{self.user}\n'
|
|
118
|
+
f'is_iam_auth:{self.is_iam_auth}\n'
|
|
89
119
|
)
|
|
120
|
+
|
|
121
|
+
if self.is_iam_auth:
|
|
122
|
+
logger.info(f'Retrieving IAM auth token for {self.user}')
|
|
123
|
+
password = self.get_iam_auth_token()
|
|
124
|
+
else:
|
|
125
|
+
logger.info(f'Retrieving credentials from Secrets Manager: {self.secret_arn}')
|
|
126
|
+
self.user, password = self._get_credentials_from_secret(
|
|
127
|
+
self.secret_arn, self.region, self.is_test
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
self.created_time = datetime.now()
|
|
131
|
+
self.conninfo = f'host={self.host} port={self.port} dbname={self.database} user={self.user} password={password}'
|
|
90
132
|
self.pool = AsyncConnectionPool(
|
|
91
|
-
self.conninfo, min_size=self.min_size, max_size=self.max_size, open=
|
|
133
|
+
self.conninfo, min_size=self.min_size, max_size=self.max_size, open=False
|
|
92
134
|
)
|
|
93
|
-
logger.info('Connection pool initialized successfully')
|
|
94
135
|
|
|
95
|
-
#
|
|
96
|
-
|
|
97
|
-
|
|
136
|
+
# wait up to 30 seconds to fill the pool with connections
|
|
137
|
+
await self.pool.open(True, 30)
|
|
138
|
+
logger.info('Connection pool initialized successfully')
|
|
98
139
|
|
|
99
140
|
async def _get_connection(self):
|
|
100
141
|
"""Get a database connection from the pool."""
|
|
101
|
-
|
|
102
|
-
await self.initialize_pool()
|
|
103
|
-
|
|
104
|
-
if self.pool is None:
|
|
105
|
-
raise ValueError('Failed to initialize connection pool')
|
|
142
|
+
await self.check_expiry()
|
|
106
143
|
|
|
107
|
-
|
|
144
|
+
async with self.rw_lock.reader_lock:
|
|
145
|
+
if self.pool is None:
|
|
146
|
+
raise ValueError('Failed to initialize connection pool')
|
|
147
|
+
return self.pool.connection(timeout=15.0)
|
|
108
148
|
|
|
109
|
-
async def
|
|
110
|
-
"""
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
149
|
+
async def check_expiry(self):
|
|
150
|
+
"""Check and handle pool expiry."""
|
|
151
|
+
async with self.rw_lock.reader_lock:
|
|
152
|
+
if self.pool and datetime.now() - self.created_time < timedelta(
|
|
153
|
+
minutes=self.pool_expiry_min
|
|
154
|
+
):
|
|
155
|
+
return
|
|
114
156
|
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
await conn.execute(
|
|
118
|
-
'ALTER ROLE CURRENT_USER SET default_transaction_read_only = on'
|
|
119
|
-
) # type: ignore
|
|
120
|
-
logger.info('Successfully set connection to read-only mode')
|
|
121
|
-
except Exception as e:
|
|
122
|
-
logger.warning(f'Failed to set connections to read-only mode: {str(e)}')
|
|
123
|
-
logger.warning('Continuing without setting read-only mode')
|
|
157
|
+
await self.close()
|
|
158
|
+
await self.initialize_pool()
|
|
124
159
|
|
|
125
160
|
async def execute_query(
|
|
126
161
|
self, sql: str, parameters: Optional[List[Dict[str, Any]]] = None
|
|
@@ -130,7 +165,8 @@ class PsycopgPoolConnection(AbstractDBConnection):
|
|
|
130
165
|
async with await self._get_connection() as conn:
|
|
131
166
|
async with conn.transaction():
|
|
132
167
|
if self.readonly_query:
|
|
133
|
-
|
|
168
|
+
logger.info('SET TRANSACTION READ ONLY')
|
|
169
|
+
await conn.execute('SET TRANSACTION READ ONLY')
|
|
134
170
|
|
|
135
171
|
# Create a cursor for better control
|
|
136
172
|
async with conn.cursor() as cursor:
|
|
@@ -161,12 +197,12 @@ class PsycopgPoolConnection(AbstractDBConnection):
|
|
|
161
197
|
record.append({'isNull': True})
|
|
162
198
|
elif isinstance(value, str):
|
|
163
199
|
record.append({'stringValue': value})
|
|
200
|
+
elif isinstance(value, bool):
|
|
201
|
+
record.append({'booleanValue': value})
|
|
164
202
|
elif isinstance(value, int):
|
|
165
203
|
record.append({'longValue': value})
|
|
166
204
|
elif isinstance(value, float):
|
|
167
205
|
record.append({'doubleValue': value})
|
|
168
|
-
elif isinstance(value, bool):
|
|
169
|
-
record.append({'booleanValue': value})
|
|
170
206
|
elif isinstance(value, bytes):
|
|
171
207
|
record.append({'blobValue': value})
|
|
172
208
|
else:
|
|
@@ -258,11 +294,12 @@ class PsycopgPoolConnection(AbstractDBConnection):
|
|
|
258
294
|
|
|
259
295
|
async def close(self) -> None:
|
|
260
296
|
"""Close all connections in the pool."""
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
297
|
+
async with self.rw_lock.writer_lock:
|
|
298
|
+
if self.pool is not None:
|
|
299
|
+
logger.info('Closing connection pool')
|
|
300
|
+
await self.pool.close()
|
|
301
|
+
self.pool = None
|
|
302
|
+
logger.info('Connection pool closed successfully')
|
|
266
303
|
|
|
267
304
|
async def check_connection_health(self) -> bool:
|
|
268
305
|
"""Check if the connection is healthy."""
|
|
@@ -273,15 +310,25 @@ class PsycopgPoolConnection(AbstractDBConnection):
|
|
|
273
310
|
logger.error(f'Connection health check failed: {str(e)}')
|
|
274
311
|
return False
|
|
275
312
|
|
|
276
|
-
def get_pool_stats(self) -> Dict[str, int]:
|
|
313
|
+
async def get_pool_stats(self) -> Dict[str, int]:
|
|
277
314
|
"""Get current connection pool statistics."""
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
315
|
+
async with self.rw_lock.reader_lock:
|
|
316
|
+
if not hasattr(self, 'pool') or self.pool is None:
|
|
317
|
+
return {'size': 0, 'min_size': self.min_size, 'max_size': self.max_size, 'idle': 0}
|
|
318
|
+
|
|
319
|
+
# Access pool attributes safely
|
|
320
|
+
size = getattr(self.pool, 'size', 0)
|
|
321
|
+
min_size = getattr(self.pool, 'min_size', self.min_size)
|
|
322
|
+
max_size = getattr(self.pool, 'max_size', self.max_size)
|
|
323
|
+
idle = getattr(self.pool, 'idle', 0)
|
|
324
|
+
|
|
325
|
+
return {'size': size, 'min_size': min_size, 'max_size': max_size, 'idle': idle}
|
|
326
|
+
|
|
327
|
+
def get_iam_auth_token(self) -> str:
|
|
328
|
+
"""Generate an IAM authentication token for RDS database access."""
|
|
329
|
+
rds_client = boto3.client(
|
|
330
|
+
'rds', region_name=self.region, config=Config(user_agent_extra=__user_agent__)
|
|
331
|
+
)
|
|
332
|
+
return rds_client.generate_db_auth_token(
|
|
333
|
+
DBHostname=self.host, Port=self.port, DBUsername=self.user, Region=self.region
|
|
334
|
+
)
|
|
@@ -16,7 +16,9 @@
|
|
|
16
16
|
|
|
17
17
|
import asyncio
|
|
18
18
|
import boto3
|
|
19
|
+
from awslabs.postgres_mcp_server import __user_agent__
|
|
19
20
|
from awslabs.postgres_mcp_server.connection.abstract_db_connection import AbstractDBConnection
|
|
21
|
+
from botocore.config import Config
|
|
20
22
|
from loguru import logger
|
|
21
23
|
from typing import Any, Dict, List, Optional
|
|
22
24
|
|
|
@@ -48,7 +50,9 @@ class RDSDataAPIConnection(AbstractDBConnection):
|
|
|
48
50
|
self.secret_arn = secret_arn
|
|
49
51
|
self.database = database
|
|
50
52
|
if not is_test:
|
|
51
|
-
self.data_client = boto3.client(
|
|
53
|
+
self.data_client = boto3.client(
|
|
54
|
+
'rds-data', region_name=region, config=Config(user_agent_extra=__user_agent__)
|
|
55
|
+
)
|
|
52
56
|
|
|
53
57
|
async def execute_query(
|
|
54
58
|
self, sql: str, parameters: Optional[List[Dict[str, Any]]] = None
|