awslabs.redshift-mcp-server 0.0.7__tar.gz → 0.0.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {awslabs_redshift_mcp_server-0.0.7 → awslabs_redshift_mcp_server-0.0.8}/PKG-INFO +1 -2
- {awslabs_redshift_mcp_server-0.0.7 → awslabs_redshift_mcp_server-0.0.8}/README.md +0 -1
- {awslabs_redshift_mcp_server-0.0.7 → awslabs_redshift_mcp_server-0.0.8}/awslabs/redshift_mcp_server/__init__.py +1 -1
- {awslabs_redshift_mcp_server-0.0.7 → awslabs_redshift_mcp_server-0.0.8}/awslabs/redshift_mcp_server/consts.py +5 -4
- {awslabs_redshift_mcp_server-0.0.7 → awslabs_redshift_mcp_server-0.0.8}/awslabs/redshift_mcp_server/redshift.py +237 -91
- {awslabs_redshift_mcp_server-0.0.7 → awslabs_redshift_mcp_server-0.0.8}/awslabs/redshift_mcp_server/server.py +16 -5
- {awslabs_redshift_mcp_server-0.0.7 → awslabs_redshift_mcp_server-0.0.8}/pyproject.toml +1 -1
- awslabs_redshift_mcp_server-0.0.8/tests/test_redshift.py +1090 -0
- awslabs_redshift_mcp_server-0.0.8/tests/test_server.py +530 -0
- {awslabs_redshift_mcp_server-0.0.7 → awslabs_redshift_mcp_server-0.0.8}/uv.lock +2 -2
- awslabs_redshift_mcp_server-0.0.7/tests/test_redshift.py +0 -252
- awslabs_redshift_mcp_server-0.0.7/tests/test_server.py +0 -969
- {awslabs_redshift_mcp_server-0.0.7 → awslabs_redshift_mcp_server-0.0.8}/.gitignore +0 -0
- {awslabs_redshift_mcp_server-0.0.7 → awslabs_redshift_mcp_server-0.0.8}/.python-version +0 -0
- {awslabs_redshift_mcp_server-0.0.7 → awslabs_redshift_mcp_server-0.0.8}/CHANGELOG.md +0 -0
- {awslabs_redshift_mcp_server-0.0.7 → awslabs_redshift_mcp_server-0.0.8}/Dockerfile +0 -0
- {awslabs_redshift_mcp_server-0.0.7 → awslabs_redshift_mcp_server-0.0.8}/LICENSE +0 -0
- {awslabs_redshift_mcp_server-0.0.7 → awslabs_redshift_mcp_server-0.0.8}/NOTICE +0 -0
- {awslabs_redshift_mcp_server-0.0.7 → awslabs_redshift_mcp_server-0.0.8}/awslabs/__init__.py +0 -0
- {awslabs_redshift_mcp_server-0.0.7 → awslabs_redshift_mcp_server-0.0.8}/awslabs/redshift_mcp_server/models.py +0 -0
- {awslabs_redshift_mcp_server-0.0.7 → awslabs_redshift_mcp_server-0.0.8}/docker-healthcheck.sh +0 -0
- {awslabs_redshift_mcp_server-0.0.7 → awslabs_redshift_mcp_server-0.0.8}/tests/test_init.py +0 -0
- {awslabs_redshift_mcp_server-0.0.7 → awslabs_redshift_mcp_server-0.0.8}/tests/test_main.py +0 -0
- {awslabs_redshift_mcp_server-0.0.7 → awslabs_redshift_mcp_server-0.0.8}/uv-requirements.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: awslabs.redshift-mcp-server
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.8
|
|
4
4
|
Summary: An AWS Labs Model Context Protocol (MCP) server for Redshift
|
|
5
5
|
Project-URL: homepage, https://awslabs.github.io/mcp/
|
|
6
6
|
Project-URL: docs, https://awslabs.github.io/mcp/servers/redshift-mcp-server/
|
|
@@ -449,7 +449,6 @@ Your AWS credentials need the following IAM permissions:
|
|
|
449
449
|
"redshift-serverless:ListWorkgroups",
|
|
450
450
|
"redshift-serverless:GetWorkgroup",
|
|
451
451
|
"redshift-data:ExecuteStatement",
|
|
452
|
-
"redshift-data:BatchExecuteStatement",
|
|
453
452
|
"redshift-data:DescribeStatement",
|
|
454
453
|
"redshift-data:GetStatementResult"
|
|
455
454
|
],
|
|
@@ -418,7 +418,6 @@ Your AWS credentials need the following IAM permissions:
|
|
|
418
418
|
"redshift-serverless:ListWorkgroups",
|
|
419
419
|
"redshift-serverless:GetWorkgroup",
|
|
420
420
|
"redshift-data:ExecuteStatement",
|
|
421
|
-
"redshift-data:BatchExecuteStatement",
|
|
422
421
|
"redshift-data:DescribeStatement",
|
|
423
422
|
"redshift-data:GetStatementResult"
|
|
424
423
|
],
|
|
@@ -21,7 +21,8 @@ CLIENT_RETRIES = {'max_attempts': 5, 'mode': 'adaptive'}
|
|
|
21
21
|
CLIENT_USER_AGENT_NAME = 'awslabs/mcp/redshift-mcp-server'
|
|
22
22
|
DEFAULT_LOG_LEVEL = 'WARNING'
|
|
23
23
|
QUERY_TIMEOUT = 3600
|
|
24
|
-
QUERY_POLL_INTERVAL =
|
|
24
|
+
QUERY_POLL_INTERVAL = 1
|
|
25
|
+
SESSION_KEEPALIVE = 600
|
|
25
26
|
|
|
26
27
|
# Best practices
|
|
27
28
|
|
|
@@ -85,7 +86,7 @@ SELECT
|
|
|
85
86
|
source_database,
|
|
86
87
|
schema_option
|
|
87
88
|
FROM pg_catalog.svv_all_schemas
|
|
88
|
-
WHERE database_name =
|
|
89
|
+
WHERE database_name = :database_name
|
|
89
90
|
ORDER BY schema_name;
|
|
90
91
|
"""
|
|
91
92
|
|
|
@@ -98,7 +99,7 @@ SELECT
|
|
|
98
99
|
table_type,
|
|
99
100
|
remarks
|
|
100
101
|
FROM pg_catalog.svv_all_tables
|
|
101
|
-
WHERE database_name =
|
|
102
|
+
WHERE database_name = :database_name AND schema_name = :schema_name
|
|
102
103
|
ORDER BY table_name;
|
|
103
104
|
"""
|
|
104
105
|
|
|
@@ -117,7 +118,7 @@ SELECT
|
|
|
117
118
|
numeric_scale,
|
|
118
119
|
remarks
|
|
119
120
|
FROM pg_catalog.svv_all_columns
|
|
120
|
-
WHERE database_name =
|
|
121
|
+
WHERE database_name = :database_name AND schema_name = :schema_name AND table_name = :table_name
|
|
121
122
|
ORDER BY ordinal_position;
|
|
122
123
|
"""
|
|
123
124
|
|
|
@@ -18,6 +18,7 @@ import asyncio
|
|
|
18
18
|
import boto3
|
|
19
19
|
import os
|
|
20
20
|
import regex
|
|
21
|
+
import time
|
|
21
22
|
from awslabs.redshift_mcp_server import __version__
|
|
22
23
|
from awslabs.redshift_mcp_server.consts import (
|
|
23
24
|
CLIENT_CONNECT_TIMEOUT,
|
|
@@ -26,6 +27,7 @@ from awslabs.redshift_mcp_server.consts import (
|
|
|
26
27
|
CLIENT_USER_AGENT_NAME,
|
|
27
28
|
QUERY_POLL_INTERVAL,
|
|
28
29
|
QUERY_TIMEOUT,
|
|
30
|
+
SESSION_KEEPALIVE,
|
|
29
31
|
SUSPICIOUS_QUERY_REGEXP,
|
|
30
32
|
SVV_ALL_COLUMNS_QUERY,
|
|
31
33
|
SVV_ALL_SCHEMAS_QUERY,
|
|
@@ -101,61 +103,124 @@ class RedshiftClientManager:
|
|
|
101
103
|
return self._redshift_data_client
|
|
102
104
|
|
|
103
105
|
|
|
104
|
-
|
|
105
|
-
"""
|
|
106
|
+
class RedshiftSessionManager:
|
|
107
|
+
"""Manages Redshift Data API sessions for connection reuse."""
|
|
108
|
+
|
|
109
|
+
def __init__(self, session_keepalive: int, app_name: str):
|
|
110
|
+
"""Initialize the session manager.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
session_keepalive: Session keepalive timeout in seconds.
|
|
114
|
+
app_name: Application name to set in sessions.
|
|
115
|
+
"""
|
|
116
|
+
self._sessions = {} # {cluster:database -> session_info}
|
|
117
|
+
self._session_keepalive = session_keepalive
|
|
118
|
+
self._app_name = app_name
|
|
119
|
+
|
|
120
|
+
async def session(
|
|
121
|
+
self, cluster_identifier: str, database_name: str, cluster_info: dict
|
|
122
|
+
) -> str:
|
|
123
|
+
"""Get or create a session for the given cluster and database.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
cluster_identifier: The cluster identifier to get session for.
|
|
127
|
+
database_name: The database name to get session for.
|
|
128
|
+
cluster_info: Cluster information dictionary from discover_clusters.
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
Session ID for use in ExecuteStatement calls.
|
|
132
|
+
"""
|
|
133
|
+
# Check existing session
|
|
134
|
+
session_key = f'{cluster_identifier}:{database_name}'
|
|
135
|
+
if session_key in self._sessions:
|
|
136
|
+
session_info = self._sessions[session_key]
|
|
137
|
+
if not self._is_session_expired(session_info):
|
|
138
|
+
logger.debug(f'Reusing existing session: {session_info["session_id"]}')
|
|
139
|
+
return session_info['session_id']
|
|
140
|
+
else:
|
|
141
|
+
logger.debug(f'Session expired, removing: {session_info["session_id"]}')
|
|
142
|
+
del self._sessions[session_key]
|
|
143
|
+
|
|
144
|
+
# Create new session with application name
|
|
145
|
+
session_id = await self._create_session_with_app_name(
|
|
146
|
+
cluster_identifier, database_name, cluster_info
|
|
147
|
+
)
|
|
106
148
|
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
"""
|
|
110
|
-
if value is None:
|
|
111
|
-
return 'NULL'
|
|
149
|
+
# Store session
|
|
150
|
+
self._sessions[session_key] = {'session_id': session_id, 'created_at': time.time()}
|
|
112
151
|
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
return "'" + repr('"' + value)[2:]
|
|
152
|
+
logger.info(f'Created new session: {session_id} for {cluster_identifier}:{database_name}')
|
|
153
|
+
return session_id
|
|
116
154
|
|
|
155
|
+
async def _create_session_with_app_name(
|
|
156
|
+
self, cluster_identifier: str, database_name: str, cluster_info: dict
|
|
157
|
+
) -> str:
|
|
158
|
+
"""Create a new session by executing SET application_name.
|
|
117
159
|
|
|
118
|
-
|
|
119
|
-
|
|
160
|
+
Args:
|
|
161
|
+
cluster_identifier: The cluster identifier.
|
|
162
|
+
database_name: The database name.
|
|
163
|
+
cluster_info: Cluster information dictionary.
|
|
120
164
|
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
165
|
+
Returns:
|
|
166
|
+
Session ID from the ExecuteStatement response.
|
|
167
|
+
"""
|
|
168
|
+
# Set application name to create session
|
|
169
|
+
app_name_sql = f"SET application_name TO '{self._app_name}';"
|
|
124
170
|
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
171
|
+
# Execute statement to create session
|
|
172
|
+
statement_id = await _execute_statement(
|
|
173
|
+
cluster_info=cluster_info,
|
|
174
|
+
cluster_identifier=cluster_identifier,
|
|
175
|
+
database_name=database_name,
|
|
176
|
+
sql=app_name_sql,
|
|
177
|
+
session_keepalive=self._session_keepalive,
|
|
178
|
+
)
|
|
129
179
|
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
180
|
+
# Get session ID from the response
|
|
181
|
+
data_client = client_manager.redshift_data_client()
|
|
182
|
+
status_response = data_client.describe_statement(Id=statement_id)
|
|
183
|
+
session_id = status_response['SessionId']
|
|
133
184
|
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
"""
|
|
137
|
-
if allow_read_write:
|
|
138
|
-
return ['BEGIN READ WRITE;', sql, 'END;']
|
|
139
|
-
else:
|
|
140
|
-
# Check if SQL contains suspicious patterns trying to break the transaction context
|
|
141
|
-
if regex.compile(SUSPICIOUS_QUERY_REGEXP).search(sql):
|
|
142
|
-
logger.error(f'SQL contains suspicious pattern, execution rejected: {sql}')
|
|
143
|
-
raise Exception(f'SQL contains suspicious pattern, execution rejected: {sql}')
|
|
185
|
+
logger.debug(f'Created session with application name: {session_id}')
|
|
186
|
+
return session_id
|
|
144
187
|
|
|
145
|
-
|
|
188
|
+
def _is_session_expired(self, session_info: dict) -> bool:
|
|
189
|
+
"""Check if a session has expired based on keepalive timeout.
|
|
146
190
|
|
|
191
|
+
Args:
|
|
192
|
+
session_info: Session information dictionary.
|
|
147
193
|
|
|
148
|
-
|
|
149
|
-
|
|
194
|
+
Returns:
|
|
195
|
+
True if session is expired, False otherwise.
|
|
196
|
+
"""
|
|
197
|
+
return (time.time() - session_info['created_at']) > self._session_keepalive
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
async def _execute_protected_statement(
|
|
201
|
+
cluster_identifier: str,
|
|
202
|
+
database_name: str,
|
|
203
|
+
sql: str,
|
|
204
|
+
parameters: list[dict] | None = None,
|
|
205
|
+
allow_read_write: bool = False,
|
|
150
206
|
) -> tuple[dict, str]:
|
|
151
|
-
"""Execute a SQL statement against a Redshift cluster
|
|
207
|
+
"""Execute a SQL statement against a Redshift cluster in a protected fashion.
|
|
208
|
+
|
|
209
|
+
The SQL is protected by wrapping it in a transaction block with READ ONLY or READ WRITE mode
|
|
210
|
+
based on allow_read_write flag. Transaction breaker protection is implemented
|
|
211
|
+
to prevent unauthorized modifications.
|
|
152
212
|
|
|
153
|
-
|
|
213
|
+
The SQL execution takes the form:
|
|
214
|
+
1. Get or create session (with SET application_name)
|
|
215
|
+
2. BEGIN [READ ONLY|READ WRITE];
|
|
216
|
+
3. <user sql>
|
|
217
|
+
4. END;
|
|
154
218
|
|
|
155
219
|
Args:
|
|
156
220
|
cluster_identifier: The cluster identifier to query.
|
|
157
221
|
database_name: The database to execute the query against.
|
|
158
222
|
sql: The SQL statement to execute.
|
|
223
|
+
parameters: Optional list of parameter dictionaries with 'name' and 'value' keys.
|
|
159
224
|
allow_read_write: Indicates if read-write mode should be activated.
|
|
160
225
|
|
|
161
226
|
Returns:
|
|
@@ -166,9 +231,7 @@ async def execute_statement(
|
|
|
166
231
|
Raises:
|
|
167
232
|
Exception: If cluster not found, query fails, or times out.
|
|
168
233
|
"""
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
# First, check if this is a provisioned cluster or serverless workgroup
|
|
234
|
+
# Get cluster info
|
|
172
235
|
clusters = await discover_clusters()
|
|
173
236
|
cluster_info = None
|
|
174
237
|
for cluster in clusters:
|
|
@@ -181,57 +244,131 @@ async def execute_statement(
|
|
|
181
244
|
f'Cluster {cluster_identifier} not found. Please use list_clusters to get valid cluster identifiers.'
|
|
182
245
|
)
|
|
183
246
|
|
|
184
|
-
#
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
247
|
+
# Get session (creates if needed, sets app name automatically)
|
|
248
|
+
session_id = await session_manager.session(cluster_identifier, database_name, cluster_info)
|
|
249
|
+
|
|
250
|
+
# Check for suspicious patterns in read-only mode
|
|
251
|
+
if not allow_read_write:
|
|
252
|
+
if regex.compile(SUSPICIOUS_QUERY_REGEXP).search(sql):
|
|
253
|
+
logger.error(f'SQL contains suspicious pattern, execution rejected: {sql}')
|
|
254
|
+
raise Exception(f'SQL contains suspicious pattern, execution rejected: {sql}')
|
|
188
255
|
|
|
189
|
-
|
|
256
|
+
# Execute BEGIN statement
|
|
257
|
+
begin_sql = 'BEGIN READ WRITE;' if allow_read_write else 'BEGIN READ ONLY;'
|
|
258
|
+
await _execute_statement(
|
|
259
|
+
cluster_info=cluster_info,
|
|
260
|
+
cluster_identifier=cluster_identifier,
|
|
261
|
+
database_name=database_name,
|
|
262
|
+
sql=begin_sql,
|
|
263
|
+
session_id=session_id,
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
# Execute user SQL with parameters
|
|
267
|
+
user_query_id = await _execute_statement(
|
|
268
|
+
cluster_info=cluster_info,
|
|
269
|
+
cluster_identifier=cluster_identifier,
|
|
270
|
+
database_name=database_name,
|
|
271
|
+
sql=sql,
|
|
272
|
+
parameters=parameters,
|
|
273
|
+
session_id=session_id,
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
# Execute END statement to close transaction
|
|
277
|
+
await _execute_statement(
|
|
278
|
+
cluster_info=cluster_info,
|
|
279
|
+
cluster_identifier=cluster_identifier,
|
|
280
|
+
database_name=database_name,
|
|
281
|
+
sql='END;',
|
|
282
|
+
session_id=session_id,
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
# Get results from user query
|
|
286
|
+
data_client = client_manager.redshift_data_client()
|
|
287
|
+
results_response = data_client.get_statement_result(Id=user_query_id)
|
|
288
|
+
return results_response, user_query_id
|
|
190
289
|
|
|
191
|
-
# Execute the query using Data API
|
|
192
|
-
if cluster_info['type'] == 'provisioned':
|
|
193
|
-
logger.debug(f'Using ClusterIdentifier for provisioned cluster: {cluster_identifier}')
|
|
194
|
-
response = data_client.batch_execute_statement(
|
|
195
|
-
ClusterIdentifier=cluster_identifier, Database=database_name, Sqls=sqls
|
|
196
|
-
)
|
|
197
|
-
elif cluster_info['type'] == 'serverless':
|
|
198
|
-
logger.debug(f'Using WorkgroupName for serverless workgroup: {cluster_identifier}')
|
|
199
|
-
response = data_client.batch_execute_statement(
|
|
200
|
-
WorkgroupName=cluster_identifier, Database=database_name, Sqls=sqls
|
|
201
|
-
)
|
|
202
|
-
else:
|
|
203
|
-
raise Exception(f'Unknown cluster type: {cluster_info["type"]}')
|
|
204
290
|
|
|
205
|
-
|
|
206
|
-
|
|
291
|
+
async def _execute_statement(
|
|
292
|
+
cluster_info: dict,
|
|
293
|
+
cluster_identifier: str,
|
|
294
|
+
database_name: str,
|
|
295
|
+
sql: str,
|
|
296
|
+
parameters: list[dict] | None = None,
|
|
297
|
+
session_id: str | None = None,
|
|
298
|
+
session_keepalive: int | None = None,
|
|
299
|
+
query_poll_interval: float = QUERY_POLL_INTERVAL,
|
|
300
|
+
query_timeout: float = QUERY_TIMEOUT,
|
|
301
|
+
) -> str:
|
|
302
|
+
"""Execute a single statement with optional session support and parameters.
|
|
303
|
+
|
|
304
|
+
Args:
|
|
305
|
+
cluster_info: Cluster information dictionary.
|
|
306
|
+
cluster_identifier: The cluster identifier.
|
|
307
|
+
database_name: The database name.
|
|
308
|
+
sql: The SQL statement to execute.
|
|
309
|
+
parameters: Optional list of parameter dictionaries with 'name' and 'value' keys.
|
|
310
|
+
session_id: Optional session ID to use.
|
|
311
|
+
session_keepalive: Optional session keepalive seconds (only used when session_id is None).
|
|
312
|
+
query_poll_interval: Polling interval in seconds for checking query status.
|
|
313
|
+
query_timeout: Maximum time in seconds to wait for query completion.
|
|
314
|
+
|
|
315
|
+
Returns:
|
|
316
|
+
Statement ID from the ExecuteStatement response.
|
|
317
|
+
"""
|
|
318
|
+
data_client = client_manager.redshift_data_client()
|
|
207
319
|
|
|
208
|
-
#
|
|
320
|
+
# Build request parameters
|
|
321
|
+
request_params: dict[str, str | int | list[dict]] = {'Sql': sql}
|
|
322
|
+
|
|
323
|
+
# Add database and cluster/workgroup identifier only if not using session
|
|
324
|
+
if not session_id:
|
|
325
|
+
request_params['Database'] = database_name
|
|
326
|
+
if cluster_info['type'] == 'provisioned':
|
|
327
|
+
request_params['ClusterIdentifier'] = cluster_identifier
|
|
328
|
+
elif cluster_info['type'] == 'serverless':
|
|
329
|
+
request_params['WorkgroupName'] = cluster_identifier
|
|
330
|
+
else:
|
|
331
|
+
raise Exception(f'Unknown cluster type: {cluster_info["type"]}')
|
|
332
|
+
|
|
333
|
+
# Add parameters if provided
|
|
334
|
+
if parameters:
|
|
335
|
+
request_params['Parameters'] = parameters
|
|
336
|
+
|
|
337
|
+
# Add session ID if provided, otherwise add session keepalive
|
|
338
|
+
if session_id:
|
|
339
|
+
request_params['SessionId'] = session_id
|
|
340
|
+
elif session_keepalive is not None:
|
|
341
|
+
request_params['SessionKeepAliveSeconds'] = session_keepalive
|
|
342
|
+
|
|
343
|
+
response = data_client.execute_statement(**request_params)
|
|
344
|
+
statement_id = response['Id']
|
|
345
|
+
|
|
346
|
+
logger.debug(
|
|
347
|
+
f'Executed statement: {statement_id}' + (f' in session {session_id}' if session_id else '')
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
# Wait for statement completion
|
|
209
351
|
wait_time = 0
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
status_response = data_client.describe_statement(Id=query_id)
|
|
352
|
+
while wait_time < query_timeout:
|
|
353
|
+
status_response = data_client.describe_statement(Id=statement_id)
|
|
213
354
|
status = status_response['Status']
|
|
214
355
|
|
|
215
356
|
if status == 'FINISHED':
|
|
216
|
-
logger.debug(f'
|
|
357
|
+
logger.debug(f'Statement completed: {statement_id}')
|
|
217
358
|
break
|
|
218
359
|
elif status in ['FAILED', 'ABORTED']:
|
|
219
360
|
error_msg = status_response.get('Error', 'Unknown error')
|
|
220
|
-
logger.error(f'
|
|
221
|
-
raise Exception(f'
|
|
361
|
+
logger.error(f'Statement failed: {error_msg}')
|
|
362
|
+
raise Exception(f'Statement failed: {error_msg}')
|
|
222
363
|
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
wait_time += QUERY_POLL_INTERVAL
|
|
364
|
+
await asyncio.sleep(query_poll_interval)
|
|
365
|
+
wait_time += query_poll_interval
|
|
226
366
|
|
|
227
|
-
if wait_time >=
|
|
228
|
-
logger.error(f'
|
|
229
|
-
raise Exception(f'
|
|
367
|
+
if wait_time >= query_timeout:
|
|
368
|
+
logger.error(f'Statement timed out: {statement_id}')
|
|
369
|
+
raise Exception(f'Statement timed out after {wait_time} seconds')
|
|
230
370
|
|
|
231
|
-
|
|
232
|
-
subquery2_id = status_response['SubStatements'][2]['Id']
|
|
233
|
-
results_response = data_client.get_statement_result(Id=subquery2_id)
|
|
234
|
-
return results_response, subquery2_id
|
|
371
|
+
return statement_id
|
|
235
372
|
|
|
236
373
|
|
|
237
374
|
async def discover_clusters() -> list[dict]:
|
|
@@ -334,7 +471,7 @@ async def discover_databases(cluster_identifier: str, database_name: str = 'dev'
|
|
|
334
471
|
logger.info(f'Discovering databases in cluster {cluster_identifier}')
|
|
335
472
|
|
|
336
473
|
# Execute the query using the common function
|
|
337
|
-
results_response, _ = await
|
|
474
|
+
results_response, _ = await _execute_protected_statement(
|
|
338
475
|
cluster_identifier=cluster_identifier,
|
|
339
476
|
database_name=database_name,
|
|
340
477
|
sql=SVV_REDSHIFT_DATABASES_QUERY,
|
|
@@ -379,10 +516,11 @@ async def discover_schemas(cluster_identifier: str, schema_database_name: str) -
|
|
|
379
516
|
)
|
|
380
517
|
|
|
381
518
|
# Execute the query using the common function
|
|
382
|
-
results_response, _ = await
|
|
519
|
+
results_response, _ = await _execute_protected_statement(
|
|
383
520
|
cluster_identifier=cluster_identifier,
|
|
384
521
|
database_name=schema_database_name,
|
|
385
|
-
sql=SVV_ALL_SCHEMAS_QUERY
|
|
522
|
+
sql=SVV_ALL_SCHEMAS_QUERY,
|
|
523
|
+
parameters=[{'name': 'database_name', 'value': schema_database_name}],
|
|
386
524
|
)
|
|
387
525
|
|
|
388
526
|
schemas = []
|
|
@@ -432,12 +570,14 @@ async def discover_tables(
|
|
|
432
570
|
)
|
|
433
571
|
|
|
434
572
|
# Execute the query using the common function
|
|
435
|
-
results_response, _ = await
|
|
573
|
+
results_response, _ = await _execute_protected_statement(
|
|
436
574
|
cluster_identifier=cluster_identifier,
|
|
437
575
|
database_name=table_database_name,
|
|
438
|
-
sql=SVV_ALL_TABLES_QUERY
|
|
439
|
-
|
|
440
|
-
|
|
576
|
+
sql=SVV_ALL_TABLES_QUERY,
|
|
577
|
+
parameters=[
|
|
578
|
+
{'name': 'database_name', 'value': table_database_name},
|
|
579
|
+
{'name': 'schema_name', 'value': table_schema_name},
|
|
580
|
+
],
|
|
441
581
|
)
|
|
442
582
|
|
|
443
583
|
tables = []
|
|
@@ -490,14 +630,15 @@ async def discover_columns(
|
|
|
490
630
|
)
|
|
491
631
|
|
|
492
632
|
# Execute the query using the common function
|
|
493
|
-
results_response, _ = await
|
|
633
|
+
results_response, _ = await _execute_protected_statement(
|
|
494
634
|
cluster_identifier=cluster_identifier,
|
|
495
635
|
database_name=column_database_name,
|
|
496
|
-
sql=SVV_ALL_COLUMNS_QUERY
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
636
|
+
sql=SVV_ALL_COLUMNS_QUERY,
|
|
637
|
+
parameters=[
|
|
638
|
+
{'name': 'database_name', 'value': column_database_name},
|
|
639
|
+
{'name': 'schema_name', 'value': column_schema_name},
|
|
640
|
+
{'name': 'table_name', 'value': column_table_name},
|
|
641
|
+
],
|
|
501
642
|
)
|
|
502
643
|
|
|
503
644
|
columns = []
|
|
@@ -554,7 +695,7 @@ async def execute_query(cluster_identifier: str, database_name: str, sql: str) -
|
|
|
554
695
|
start_time = time.time()
|
|
555
696
|
|
|
556
697
|
# Execute the query using the common function
|
|
557
|
-
results_response, query_id = await
|
|
698
|
+
results_response, query_id = await _execute_protected_statement(
|
|
558
699
|
cluster_identifier=cluster_identifier, database_name=database_name, sql=sql
|
|
559
700
|
)
|
|
560
701
|
|
|
@@ -620,3 +761,8 @@ client_manager = RedshiftClientManager(
|
|
|
620
761
|
aws_region=os.environ.get('AWS_REGION'),
|
|
621
762
|
aws_profile=os.environ.get('AWS_PROFILE'),
|
|
622
763
|
)
|
|
764
|
+
|
|
765
|
+
# Global session manager instance
|
|
766
|
+
session_manager = RedshiftSessionManager(
|
|
767
|
+
session_keepalive=SESSION_KEEPALIVE, app_name=f'{CLIENT_USER_AGENT_NAME}/{__version__}'
|
|
768
|
+
)
|
|
@@ -219,7 +219,9 @@ async def list_databases_tool(
|
|
|
219
219
|
"""
|
|
220
220
|
try:
|
|
221
221
|
logger.info(f'Discovering databases on cluster: {cluster_identifier}')
|
|
222
|
-
databases_data = await discover_databases(
|
|
222
|
+
databases_data = await discover_databases(
|
|
223
|
+
cluster_identifier=cluster_identifier, database_name=database_name
|
|
224
|
+
)
|
|
223
225
|
|
|
224
226
|
# Convert to RedshiftDatabase models
|
|
225
227
|
databases = []
|
|
@@ -302,7 +304,9 @@ async def list_schemas_tool(
|
|
|
302
304
|
logger.info(
|
|
303
305
|
f'Discovering schemas in database {schema_database_name} on cluster {cluster_identifier}'
|
|
304
306
|
)
|
|
305
|
-
schemas_data = await discover_schemas(
|
|
307
|
+
schemas_data = await discover_schemas(
|
|
308
|
+
cluster_identifier=cluster_identifier, schema_database_name=schema_database_name
|
|
309
|
+
)
|
|
306
310
|
|
|
307
311
|
# Convert to RedshiftSchema models
|
|
308
312
|
schemas = []
|
|
@@ -394,7 +398,9 @@ async def list_tables_tool(
|
|
|
394
398
|
f'Discovering tables in schema {table_schema_name} in database {table_database_name} on cluster {cluster_identifier}'
|
|
395
399
|
)
|
|
396
400
|
tables_data = await discover_tables(
|
|
397
|
-
cluster_identifier,
|
|
401
|
+
cluster_identifier=cluster_identifier,
|
|
402
|
+
table_database_name=table_database_name,
|
|
403
|
+
table_schema_name=table_schema_name,
|
|
398
404
|
)
|
|
399
405
|
|
|
400
406
|
# Convert to RedshiftTable models
|
|
@@ -500,7 +506,10 @@ async def list_columns_tool(
|
|
|
500
506
|
f'Discovering columns in table {column_table_name} in schema {column_schema_name} in database {column_database_name} on cluster {cluster_identifier}'
|
|
501
507
|
)
|
|
502
508
|
columns_data = await discover_columns(
|
|
503
|
-
cluster_identifier,
|
|
509
|
+
cluster_identifier=cluster_identifier,
|
|
510
|
+
column_database_name=column_database_name,
|
|
511
|
+
column_schema_name=column_schema_name,
|
|
512
|
+
column_table_name=column_table_name,
|
|
504
513
|
)
|
|
505
514
|
|
|
506
515
|
# Convert to RedshiftColumn models
|
|
@@ -594,7 +603,9 @@ async def execute_query_tool(
|
|
|
594
603
|
"""
|
|
595
604
|
try:
|
|
596
605
|
logger.info(f'Executing query on cluster {cluster_identifier} in database {database_name}')
|
|
597
|
-
query_result_data = await execute_query(
|
|
606
|
+
query_result_data = await execute_query(
|
|
607
|
+
cluster_identifier=cluster_identifier, database_name=database_name, sql=sql
|
|
608
|
+
)
|
|
598
609
|
|
|
599
610
|
# Convert to QueryResult model
|
|
600
611
|
query_result = QueryResult(**query_result_data)
|