awslabs.redshift-mcp-server 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- awslabs/__init__.py +16 -0
- awslabs/redshift_mcp_server/__init__.py +17 -0
- awslabs/redshift_mcp_server/consts.py +136 -0
- awslabs/redshift_mcp_server/models.py +141 -0
- awslabs/redshift_mcp_server/redshift.py +630 -0
- awslabs/redshift_mcp_server/server.py +621 -0
- awslabs_redshift_mcp_server-0.0.1.dist-info/METADATA +432 -0
- awslabs_redshift_mcp_server-0.0.1.dist-info/RECORD +12 -0
- awslabs_redshift_mcp_server-0.0.1.dist-info/WHEEL +4 -0
- awslabs_redshift_mcp_server-0.0.1.dist-info/entry_points.txt +2 -0
- awslabs_redshift_mcp_server-0.0.1.dist-info/licenses/LICENSE +175 -0
- awslabs_redshift_mcp_server-0.0.1.dist-info/licenses/NOTICE +2 -0
|
@@ -0,0 +1,630 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""AWS client management for Redshift MCP Server."""
|
|
16
|
+
|
|
17
|
+
import asyncio
|
|
18
|
+
import boto3
|
|
19
|
+
import os
|
|
20
|
+
import regex
|
|
21
|
+
from awslabs.redshift_mcp_server import __version__
|
|
22
|
+
from awslabs.redshift_mcp_server.consts import (
|
|
23
|
+
CLIENT_TIMEOUT,
|
|
24
|
+
DEFAULT_AWS_REGION,
|
|
25
|
+
QUERY_POLL_INTERVAL,
|
|
26
|
+
QUERY_TIMEOUT,
|
|
27
|
+
SUSPICIOUS_QUERY_REGEXP,
|
|
28
|
+
SVV_ALL_COLUMNS_QUERY,
|
|
29
|
+
SVV_ALL_SCHEMAS_QUERY,
|
|
30
|
+
SVV_ALL_TABLES_QUERY,
|
|
31
|
+
SVV_REDSHIFT_DATABASES_QUERY,
|
|
32
|
+
)
|
|
33
|
+
from botocore.config import Config
|
|
34
|
+
from loguru import logger
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class RedshiftClientManager:
|
|
38
|
+
"""Manages AWS clients for Redshift operations."""
|
|
39
|
+
|
|
40
|
+
def __init__(self, config: Config, aws_region: str, aws_profile: str | None = None):
|
|
41
|
+
"""Initialize the client manager."""
|
|
42
|
+
self.aws_region = aws_region
|
|
43
|
+
self.aws_profile = aws_profile
|
|
44
|
+
self._redshift_client = None
|
|
45
|
+
self._redshift_serverless_client = None
|
|
46
|
+
self._redshift_data_client = None
|
|
47
|
+
self._config = config
|
|
48
|
+
|
|
49
|
+
def redshift_client(self):
|
|
50
|
+
"""Get or create the Redshift client for provisioned clusters."""
|
|
51
|
+
if self._redshift_client is None:
|
|
52
|
+
try:
|
|
53
|
+
if self.aws_profile:
|
|
54
|
+
session = boto3.Session(profile_name=self.aws_profile)
|
|
55
|
+
self._redshift_client = session.client('redshift', config=self._config)
|
|
56
|
+
logger.info(f'Created Redshift client with profile: {self.aws_profile}')
|
|
57
|
+
else:
|
|
58
|
+
self._redshift_client = boto3.client(
|
|
59
|
+
'redshift', config=self._config, region_name=self.aws_region
|
|
60
|
+
)
|
|
61
|
+
logger.info('Created Redshift client with default credentials')
|
|
62
|
+
except Exception as e:
|
|
63
|
+
logger.error(f'Error creating Redshift client: {str(e)}')
|
|
64
|
+
raise
|
|
65
|
+
|
|
66
|
+
return self._redshift_client
|
|
67
|
+
|
|
68
|
+
def redshift_serverless_client(self):
|
|
69
|
+
"""Get or create the Redshift Serverless client."""
|
|
70
|
+
if self._redshift_serverless_client is None:
|
|
71
|
+
try:
|
|
72
|
+
if self.aws_profile:
|
|
73
|
+
session = boto3.Session(profile_name=self.aws_profile)
|
|
74
|
+
self._redshift_serverless_client = session.client(
|
|
75
|
+
'redshift-serverless', config=self._config
|
|
76
|
+
)
|
|
77
|
+
logger.info(
|
|
78
|
+
f'Created Redshift Serverless client with profile: {self.aws_profile}'
|
|
79
|
+
)
|
|
80
|
+
else:
|
|
81
|
+
self._redshift_serverless_client = boto3.client(
|
|
82
|
+
'redshift-serverless', config=self._config, region_name=self.aws_region
|
|
83
|
+
)
|
|
84
|
+
logger.info('Created Redshift Serverless client with default credentials')
|
|
85
|
+
except Exception as e:
|
|
86
|
+
logger.error(f'Error creating Redshift Serverless client: {str(e)}')
|
|
87
|
+
raise
|
|
88
|
+
|
|
89
|
+
return self._redshift_serverless_client
|
|
90
|
+
|
|
91
|
+
def redshift_data_client(self):
|
|
92
|
+
"""Get or create the Redshift Data API client."""
|
|
93
|
+
if self._redshift_data_client is None:
|
|
94
|
+
try:
|
|
95
|
+
if self.aws_profile:
|
|
96
|
+
session = boto3.Session(profile_name=self.aws_profile)
|
|
97
|
+
self._redshift_data_client = session.client(
|
|
98
|
+
'redshift-data', config=self._config
|
|
99
|
+
)
|
|
100
|
+
logger.info(
|
|
101
|
+
f'Created Redshift Data API client with profile: {self.aws_profile}'
|
|
102
|
+
)
|
|
103
|
+
else:
|
|
104
|
+
self._redshift_data_client = boto3.client(
|
|
105
|
+
'redshift-data', config=self._config, region_name=self.aws_region
|
|
106
|
+
)
|
|
107
|
+
logger.info('Created Redshift Data API client with default credentials')
|
|
108
|
+
except Exception as e:
|
|
109
|
+
logger.error(f'Error creating Redshift Data API client: {str(e)}')
|
|
110
|
+
raise
|
|
111
|
+
|
|
112
|
+
return self._redshift_data_client
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def quote_literal_string(value: str | None) -> str:
|
|
116
|
+
"""Quote a string value as a SQL literal.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
value: The string value to quote.
|
|
120
|
+
"""
|
|
121
|
+
if value is None:
|
|
122
|
+
return 'NULL'
|
|
123
|
+
|
|
124
|
+
# TODO Reimplement a proper way.
|
|
125
|
+
# A lazy hack for SQL literal quoting.
|
|
126
|
+
return "'" + repr('"' + value)[2:]
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def protect_sql(sql: str, allow_read_write: bool) -> list[str]:
|
|
130
|
+
"""Protect SQL depending on if the read-write mode allowed.
|
|
131
|
+
|
|
132
|
+
The SQL is wrapped in a transaction block with READ ONLY or READ WRITE mode
|
|
133
|
+
based on allow_read_write flag. Transaction breaker protection is implemented
|
|
134
|
+
to prevent unauthorized modifications.
|
|
135
|
+
|
|
136
|
+
The SQL takes the form:
|
|
137
|
+
BEGIN [READ ONLY|READ WRITE];
|
|
138
|
+
<sql>
|
|
139
|
+
END;
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
sql: The SQL statement to protect.
|
|
143
|
+
allow_read_write: Indicates if read-write mode should be activated.
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
List of strings to execute by batch_execute_statement.
|
|
147
|
+
"""
|
|
148
|
+
if allow_read_write:
|
|
149
|
+
return ['BEGIN READ WRITE;', sql, 'END;']
|
|
150
|
+
else:
|
|
151
|
+
# Check if SQL contains suspicious patterns trying to break the transaction context
|
|
152
|
+
if regex.compile(SUSPICIOUS_QUERY_REGEXP).search(sql):
|
|
153
|
+
logger.error(f'SQL contains suspicious pattern, execution rejected: {sql}')
|
|
154
|
+
raise Exception(f'SQL contains suspicious pattern, execution rejected: {sql}')
|
|
155
|
+
|
|
156
|
+
return ['BEGIN READ ONLY;', sql, 'END;']
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
async def execute_statement(
|
|
160
|
+
cluster_identifier: str, database_name: str, sql: str, allow_read_write: bool = False
|
|
161
|
+
) -> tuple[dict, str]:
|
|
162
|
+
"""Execute a SQL statement against a Redshift cluster using the Data API.
|
|
163
|
+
|
|
164
|
+
This is a common function used by other functions in this module.
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
cluster_identifier: The cluster identifier to query.
|
|
168
|
+
database_name: The database to execute the query against.
|
|
169
|
+
sql: The SQL statement to execute.
|
|
170
|
+
allow_read_write: Indicates if read-write mode should be activated.
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
Tuple containing:
|
|
174
|
+
- Dictionary with the raw results_response from get_statement_result.
|
|
175
|
+
- String with the query_id.
|
|
176
|
+
|
|
177
|
+
Raises:
|
|
178
|
+
Exception: If cluster not found, query fails, or times out.
|
|
179
|
+
"""
|
|
180
|
+
data_client = client_manager.redshift_data_client()
|
|
181
|
+
|
|
182
|
+
# First, check if this is a provisioned cluster or serverless workgroup
|
|
183
|
+
clusters = await discover_clusters()
|
|
184
|
+
cluster_info = None
|
|
185
|
+
for cluster in clusters:
|
|
186
|
+
if cluster['identifier'] == cluster_identifier:
|
|
187
|
+
cluster_info = cluster
|
|
188
|
+
break
|
|
189
|
+
|
|
190
|
+
if not cluster_info:
|
|
191
|
+
raise Exception(
|
|
192
|
+
f'Cluster {cluster_identifier} not found. Please use list_clusters to get valid cluster identifiers.'
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
# Guard from executing read-write statements if not allowed
|
|
196
|
+
protected_sqls = protect_sql(sql, allow_read_write)
|
|
197
|
+
logger.debug(f'Protected SQL: {" ".join(protected_sqls)}')
|
|
198
|
+
|
|
199
|
+
# Execute the query using Data API
|
|
200
|
+
if cluster_info['type'] == 'provisioned':
|
|
201
|
+
logger.debug(f'Using ClusterIdentifier for provisioned cluster: {cluster_identifier}')
|
|
202
|
+
response = data_client.batch_execute_statement(
|
|
203
|
+
ClusterIdentifier=cluster_identifier, Database=database_name, Sqls=protected_sqls
|
|
204
|
+
)
|
|
205
|
+
elif cluster_info['type'] == 'serverless':
|
|
206
|
+
logger.debug(f'Using WorkgroupName for serverless workgroup: {cluster_identifier}')
|
|
207
|
+
response = data_client.batch_execute_statement(
|
|
208
|
+
WorkgroupName=cluster_identifier, Database=database_name, Sqls=protected_sqls
|
|
209
|
+
)
|
|
210
|
+
else:
|
|
211
|
+
raise Exception(f'Unknown cluster type: {cluster_info["type"]}')
|
|
212
|
+
|
|
213
|
+
query_id = response['Id']
|
|
214
|
+
logger.debug(f'Started query execution: {query_id}')
|
|
215
|
+
|
|
216
|
+
# Wait for query completion
|
|
217
|
+
wait_time = 0
|
|
218
|
+
status_response = {}
|
|
219
|
+
while wait_time < QUERY_TIMEOUT:
|
|
220
|
+
status_response = data_client.describe_statement(Id=query_id)
|
|
221
|
+
status = status_response['Status']
|
|
222
|
+
|
|
223
|
+
if status == 'FINISHED':
|
|
224
|
+
logger.debug(f'Query execution completed: {query_id}')
|
|
225
|
+
break
|
|
226
|
+
elif status in ['FAILED', 'ABORTED']:
|
|
227
|
+
error_msg = status_response.get('Error', 'Unknown error')
|
|
228
|
+
logger.error(f'Query execution failed: {error_msg}')
|
|
229
|
+
raise Exception(f'Query failed: {error_msg}')
|
|
230
|
+
|
|
231
|
+
# Wait before polling again
|
|
232
|
+
await asyncio.sleep(QUERY_POLL_INTERVAL)
|
|
233
|
+
wait_time += QUERY_POLL_INTERVAL
|
|
234
|
+
|
|
235
|
+
if wait_time >= QUERY_TIMEOUT:
|
|
236
|
+
logger.error(f'Query execution timed out: {query_id}')
|
|
237
|
+
raise Exception(f'Query timed out after {QUERY_TIMEOUT} seconds')
|
|
238
|
+
|
|
239
|
+
# Get user query results
|
|
240
|
+
subquery1_id = status_response['SubStatements'][1]['Id']
|
|
241
|
+
results_response = data_client.get_statement_result(Id=subquery1_id)
|
|
242
|
+
return results_response, subquery1_id
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
async def discover_clusters() -> list[dict]:
|
|
246
|
+
"""Discover all Redshift clusters and serverless workgroups.
|
|
247
|
+
|
|
248
|
+
Returns:
|
|
249
|
+
List of cluster information dictionaries.
|
|
250
|
+
"""
|
|
251
|
+
clusters = []
|
|
252
|
+
|
|
253
|
+
try:
|
|
254
|
+
# Get provisioned clusters
|
|
255
|
+
logger.debug('Discovering provisioned Redshift clusters')
|
|
256
|
+
redshift_client = client_manager.redshift_client()
|
|
257
|
+
|
|
258
|
+
paginator = redshift_client.get_paginator('describe_clusters')
|
|
259
|
+
for page in paginator.paginate():
|
|
260
|
+
for cluster in page.get('Clusters', []):
|
|
261
|
+
cluster_info = {
|
|
262
|
+
'identifier': cluster['ClusterIdentifier'],
|
|
263
|
+
'type': 'provisioned',
|
|
264
|
+
'status': cluster['ClusterStatus'],
|
|
265
|
+
'database_name': cluster['DBName'],
|
|
266
|
+
'endpoint': cluster.get('Endpoint', {}).get('Address'),
|
|
267
|
+
'port': cluster.get('Endpoint', {}).get('Port'),
|
|
268
|
+
'vpc_id': cluster.get('VpcId'),
|
|
269
|
+
'node_type': cluster.get('NodeType'),
|
|
270
|
+
'number_of_nodes': cluster.get('NumberOfNodes'),
|
|
271
|
+
'creation_time': cluster.get('ClusterCreateTime'),
|
|
272
|
+
'master_username': cluster.get('MasterUsername'),
|
|
273
|
+
'publicly_accessible': cluster.get('PubliclyAccessible'),
|
|
274
|
+
'encrypted': cluster.get('Encrypted'),
|
|
275
|
+
'tags': {tag['Key']: tag['Value'] for tag in cluster.get('Tags', [])},
|
|
276
|
+
}
|
|
277
|
+
clusters.append(cluster_info)
|
|
278
|
+
|
|
279
|
+
logger.info(f'Found {len(clusters)} provisioned clusters')
|
|
280
|
+
|
|
281
|
+
except Exception as e:
|
|
282
|
+
logger.error(f'Error discovering provisioned clusters: {str(e)}')
|
|
283
|
+
raise
|
|
284
|
+
|
|
285
|
+
try:
|
|
286
|
+
# Get serverless workgroups
|
|
287
|
+
logger.debug('Discovering Redshift Serverless workgroups')
|
|
288
|
+
serverless_client = client_manager.redshift_serverless_client()
|
|
289
|
+
|
|
290
|
+
paginator = serverless_client.get_paginator('list_workgroups')
|
|
291
|
+
for page in paginator.paginate():
|
|
292
|
+
for workgroup in page.get('workgroups', []):
|
|
293
|
+
# Get detailed workgroup information
|
|
294
|
+
workgroup_detail = serverless_client.get_workgroup(
|
|
295
|
+
workgroupName=workgroup['workgroupName']
|
|
296
|
+
)['workgroup']
|
|
297
|
+
|
|
298
|
+
cluster_info = {
|
|
299
|
+
'identifier': workgroup['workgroupName'],
|
|
300
|
+
'type': 'serverless',
|
|
301
|
+
'status': workgroup['status'],
|
|
302
|
+
'database_name': workgroup_detail.get('configParameters', [{}])[0].get(
|
|
303
|
+
'parameterValue', 'dev'
|
|
304
|
+
),
|
|
305
|
+
'endpoint': workgroup_detail.get('endpoint', {}).get('address'),
|
|
306
|
+
'port': workgroup_detail.get('endpoint', {}).get('port'),
|
|
307
|
+
'vpc_id': workgroup_detail.get('subnetIds', [None])[
|
|
308
|
+
0
|
|
309
|
+
], # Approximate VPC from subnet
|
|
310
|
+
'node_type': None, # Not applicable for serverless
|
|
311
|
+
'number_of_nodes': None, # Not applicable for serverless
|
|
312
|
+
'creation_time': workgroup.get('creationDate'),
|
|
313
|
+
'master_username': None, # Serverless uses IAM
|
|
314
|
+
'publicly_accessible': workgroup_detail.get('publiclyAccessible'),
|
|
315
|
+
'encrypted': True, # Serverless is always encrypted
|
|
316
|
+
'tags': {tag['key']: tag['value'] for tag in workgroup_detail.get('tags', [])},
|
|
317
|
+
}
|
|
318
|
+
clusters.append(cluster_info)
|
|
319
|
+
|
|
320
|
+
serverless_count = len([c for c in clusters if c['type'] == 'serverless'])
|
|
321
|
+
logger.info(f'Found {serverless_count} serverless workgroups')
|
|
322
|
+
|
|
323
|
+
except Exception as e:
|
|
324
|
+
logger.error(f'Error discovering serverless workgroups: {str(e)}')
|
|
325
|
+
raise
|
|
326
|
+
|
|
327
|
+
logger.info(f'Total clusters discovered: {len(clusters)}')
|
|
328
|
+
return clusters
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
async def discover_databases(cluster_identifier: str, database_name: str = 'dev') -> list[dict]:
|
|
332
|
+
"""Discover databases in a Redshift cluster using the Data API.
|
|
333
|
+
|
|
334
|
+
Args:
|
|
335
|
+
cluster_identifier: The cluster identifier to query.
|
|
336
|
+
database_name: The database to connect to for querying system views.
|
|
337
|
+
|
|
338
|
+
Returns:
|
|
339
|
+
List of database information dictionaries.
|
|
340
|
+
"""
|
|
341
|
+
try:
|
|
342
|
+
logger.info(f'Discovering databases in cluster {cluster_identifier}')
|
|
343
|
+
|
|
344
|
+
# Execute the query using the common function
|
|
345
|
+
results_response, _ = await execute_statement(
|
|
346
|
+
cluster_identifier=cluster_identifier,
|
|
347
|
+
database_name=database_name,
|
|
348
|
+
sql=SVV_REDSHIFT_DATABASES_QUERY,
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
databases = []
|
|
352
|
+
records = results_response.get('Records', [])
|
|
353
|
+
|
|
354
|
+
for record in records:
|
|
355
|
+
# Extract values from the record
|
|
356
|
+
database_info = {
|
|
357
|
+
'database_name': record[0].get('stringValue'),
|
|
358
|
+
'database_owner': record[1].get('longValue'),
|
|
359
|
+
'database_type': record[2].get('stringValue'),
|
|
360
|
+
'database_acl': record[3].get('stringValue'),
|
|
361
|
+
'database_options': record[4].get('stringValue'),
|
|
362
|
+
'database_isolation_level': record[5].get('stringValue'),
|
|
363
|
+
}
|
|
364
|
+
databases.append(database_info)
|
|
365
|
+
|
|
366
|
+
logger.info(f'Found {len(databases)} databases in cluster {cluster_identifier}')
|
|
367
|
+
return databases
|
|
368
|
+
|
|
369
|
+
except Exception as e:
|
|
370
|
+
logger.error(f'Error discovering databases in cluster {cluster_identifier}: {str(e)}')
|
|
371
|
+
raise
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
async def discover_schemas(cluster_identifier: str, schema_database_name: str) -> list[dict]:
|
|
375
|
+
"""Discover schemas in a Redshift database using the Data API.
|
|
376
|
+
|
|
377
|
+
Args:
|
|
378
|
+
cluster_identifier: The cluster identifier to query.
|
|
379
|
+
schema_database_name: The database name to filter schemas for. Also used to connect to.
|
|
380
|
+
|
|
381
|
+
Returns:
|
|
382
|
+
List of schema information dictionaries.
|
|
383
|
+
"""
|
|
384
|
+
try:
|
|
385
|
+
logger.info(
|
|
386
|
+
f'Discovering schemas in database {schema_database_name} in cluster {cluster_identifier}'
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
# Execute the query using the common function
|
|
390
|
+
results_response, _ = await execute_statement(
|
|
391
|
+
cluster_identifier=cluster_identifier,
|
|
392
|
+
database_name=schema_database_name,
|
|
393
|
+
sql=SVV_ALL_SCHEMAS_QUERY.format(quote_literal_string(schema_database_name)),
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
schemas = []
|
|
397
|
+
records = results_response.get('Records', [])
|
|
398
|
+
|
|
399
|
+
for record in records:
|
|
400
|
+
# Extract values from the record
|
|
401
|
+
schema_info = {
|
|
402
|
+
'database_name': record[0].get('stringValue'),
|
|
403
|
+
'schema_name': record[1].get('stringValue'),
|
|
404
|
+
'schema_owner': record[2].get('longValue'),
|
|
405
|
+
'schema_type': record[3].get('stringValue'),
|
|
406
|
+
'schema_acl': record[4].get('stringValue'),
|
|
407
|
+
'source_database': record[5].get('stringValue'),
|
|
408
|
+
'schema_option': record[6].get('stringValue'),
|
|
409
|
+
}
|
|
410
|
+
schemas.append(schema_info)
|
|
411
|
+
|
|
412
|
+
logger.info(
|
|
413
|
+
f'Found {len(schemas)} schemas in database {schema_database_name} in cluster {cluster_identifier}'
|
|
414
|
+
)
|
|
415
|
+
return schemas
|
|
416
|
+
|
|
417
|
+
except Exception as e:
|
|
418
|
+
logger.error(
|
|
419
|
+
f'Error discovering schemas in database {schema_database_name} in cluster {cluster_identifier}: {str(e)}'
|
|
420
|
+
)
|
|
421
|
+
raise
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
async def discover_tables(
|
|
425
|
+
cluster_identifier: str, table_database_name: str, table_schema_name: str
|
|
426
|
+
) -> list[dict]:
|
|
427
|
+
"""Discover tables in a Redshift schema using the Data API.
|
|
428
|
+
|
|
429
|
+
Args:
|
|
430
|
+
cluster_identifier: The cluster identifier to query.
|
|
431
|
+
table_database_name: The database name to filter tables for. Also used to connect to.
|
|
432
|
+
table_schema_name: The schema name to filter tables for.
|
|
433
|
+
|
|
434
|
+
Returns:
|
|
435
|
+
List of table information dictionaries.
|
|
436
|
+
"""
|
|
437
|
+
try:
|
|
438
|
+
logger.info(
|
|
439
|
+
f'Discovering tables in schema {table_schema_name} in database {table_database_name} in cluster {cluster_identifier}'
|
|
440
|
+
)
|
|
441
|
+
|
|
442
|
+
# Execute the query using the common function
|
|
443
|
+
results_response, _ = await execute_statement(
|
|
444
|
+
cluster_identifier=cluster_identifier,
|
|
445
|
+
database_name=table_database_name,
|
|
446
|
+
sql=SVV_ALL_TABLES_QUERY.format(
|
|
447
|
+
quote_literal_string(table_database_name), quote_literal_string(table_schema_name)
|
|
448
|
+
),
|
|
449
|
+
)
|
|
450
|
+
|
|
451
|
+
tables = []
|
|
452
|
+
records = results_response.get('Records', [])
|
|
453
|
+
|
|
454
|
+
for record in records:
|
|
455
|
+
# Extract values from the record
|
|
456
|
+
table_info = {
|
|
457
|
+
'database_name': record[0].get('stringValue'),
|
|
458
|
+
'schema_name': record[1].get('stringValue'),
|
|
459
|
+
'table_name': record[2].get('stringValue'),
|
|
460
|
+
'table_acl': record[3].get('stringValue'),
|
|
461
|
+
'table_type': record[4].get('stringValue'),
|
|
462
|
+
'remarks': record[5].get('stringValue'),
|
|
463
|
+
}
|
|
464
|
+
tables.append(table_info)
|
|
465
|
+
|
|
466
|
+
logger.info(
|
|
467
|
+
f'Found {len(tables)} tables in schema {table_schema_name} in database {table_database_name} in cluster {cluster_identifier}'
|
|
468
|
+
)
|
|
469
|
+
return tables
|
|
470
|
+
|
|
471
|
+
except Exception as e:
|
|
472
|
+
logger.error(
|
|
473
|
+
f'Error discovering tables in schema {table_schema_name} in database {table_database_name} in cluster {cluster_identifier}: {str(e)}'
|
|
474
|
+
)
|
|
475
|
+
raise
|
|
476
|
+
|
|
477
|
+
|
|
478
|
+
async def discover_columns(
|
|
479
|
+
cluster_identifier: str,
|
|
480
|
+
column_database_name: str,
|
|
481
|
+
column_schema_name: str,
|
|
482
|
+
column_table_name: str,
|
|
483
|
+
) -> list[dict]:
|
|
484
|
+
"""Discover columns in a Redshift table using the Data API.
|
|
485
|
+
|
|
486
|
+
Args:
|
|
487
|
+
cluster_identifier: The cluster identifier to query.
|
|
488
|
+
column_database_name: The database name to filter columns for. Also used to connect to.
|
|
489
|
+
column_schema_name: The schema name to filter columns for.
|
|
490
|
+
column_table_name: The table name to filter columns for.
|
|
491
|
+
|
|
492
|
+
Returns:
|
|
493
|
+
List of column information dictionaries.
|
|
494
|
+
"""
|
|
495
|
+
try:
|
|
496
|
+
logger.info(
|
|
497
|
+
f'Discovering columns in table {column_table_name} in schema {column_schema_name} in database {column_database_name} in cluster {cluster_identifier}'
|
|
498
|
+
)
|
|
499
|
+
|
|
500
|
+
# Execute the query using the common function
|
|
501
|
+
results_response, _ = await execute_statement(
|
|
502
|
+
cluster_identifier=cluster_identifier,
|
|
503
|
+
database_name=column_database_name,
|
|
504
|
+
sql=SVV_ALL_COLUMNS_QUERY.format(
|
|
505
|
+
quote_literal_string(column_database_name),
|
|
506
|
+
quote_literal_string(column_schema_name),
|
|
507
|
+
quote_literal_string(column_table_name),
|
|
508
|
+
),
|
|
509
|
+
)
|
|
510
|
+
|
|
511
|
+
columns = []
|
|
512
|
+
records = results_response.get('Records', [])
|
|
513
|
+
|
|
514
|
+
for record in records:
|
|
515
|
+
# Extract values from the record
|
|
516
|
+
column_info = {
|
|
517
|
+
'database_name': record[0].get('stringValue'),
|
|
518
|
+
'schema_name': record[1].get('stringValue'),
|
|
519
|
+
'table_name': record[2].get('stringValue'),
|
|
520
|
+
'column_name': record[3].get('stringValue'),
|
|
521
|
+
'ordinal_position': record[4].get('longValue'),
|
|
522
|
+
'column_default': record[5].get('stringValue'),
|
|
523
|
+
'is_nullable': record[6].get('stringValue'),
|
|
524
|
+
'data_type': record[7].get('stringValue'),
|
|
525
|
+
'character_maximum_length': record[8].get('longValue'),
|
|
526
|
+
'numeric_precision': record[9].get('longValue'),
|
|
527
|
+
'numeric_scale': record[10].get('longValue'),
|
|
528
|
+
'remarks': record[11].get('stringValue'),
|
|
529
|
+
}
|
|
530
|
+
columns.append(column_info)
|
|
531
|
+
|
|
532
|
+
logger.info(
|
|
533
|
+
f'Found {len(columns)} columns in table {column_table_name} in schema {column_schema_name} in database {column_database_name} in cluster {cluster_identifier}'
|
|
534
|
+
)
|
|
535
|
+
return columns
|
|
536
|
+
|
|
537
|
+
except Exception as e:
|
|
538
|
+
logger.error(
|
|
539
|
+
f'Error discovering columns in table {column_table_name} in schema {column_schema_name} in database {column_database_name} in cluster {cluster_identifier}: {str(e)}'
|
|
540
|
+
)
|
|
541
|
+
raise
|
|
542
|
+
|
|
543
|
+
|
|
544
|
+
async def execute_query(cluster_identifier: str, database_name: str, sql: str) -> dict:
|
|
545
|
+
"""Execute a SQL query against a Redshift cluster using the Data API.
|
|
546
|
+
|
|
547
|
+
Args:
|
|
548
|
+
cluster_identifier: The cluster identifier to query.
|
|
549
|
+
database_name: The database to execute the query against.
|
|
550
|
+
sql: The SQL statement to execute.
|
|
551
|
+
|
|
552
|
+
Returns:
|
|
553
|
+
Dictionary with query results including columns, rows, and metadata.
|
|
554
|
+
"""
|
|
555
|
+
try:
|
|
556
|
+
logger.info(f'Executing query on cluster {cluster_identifier} in database {database_name}')
|
|
557
|
+
logger.debug(f'SQL: {sql}')
|
|
558
|
+
|
|
559
|
+
# Record start time for execution time calculation
|
|
560
|
+
import time
|
|
561
|
+
|
|
562
|
+
start_time = time.time()
|
|
563
|
+
|
|
564
|
+
# Execute the query using the common function
|
|
565
|
+
results_response, query_id = await execute_statement(
|
|
566
|
+
cluster_identifier=cluster_identifier, database_name=database_name, sql=sql
|
|
567
|
+
)
|
|
568
|
+
|
|
569
|
+
# Calculate execution time
|
|
570
|
+
end_time = time.time()
|
|
571
|
+
execution_time_ms = int((end_time - start_time) * 1000)
|
|
572
|
+
|
|
573
|
+
# Extract column names
|
|
574
|
+
columns = []
|
|
575
|
+
column_metadata = results_response.get('ColumnMetadata', [])
|
|
576
|
+
for col_meta in column_metadata:
|
|
577
|
+
columns.append(col_meta.get('name'))
|
|
578
|
+
|
|
579
|
+
# Extract rows
|
|
580
|
+
rows = []
|
|
581
|
+
records = results_response.get('Records', [])
|
|
582
|
+
|
|
583
|
+
for record in records:
|
|
584
|
+
row = []
|
|
585
|
+
for field in record:
|
|
586
|
+
# Extract the actual value from the field based on its type
|
|
587
|
+
if 'stringValue' in field:
|
|
588
|
+
row.append(field['stringValue'])
|
|
589
|
+
elif 'longValue' in field:
|
|
590
|
+
row.append(field['longValue'])
|
|
591
|
+
elif 'doubleValue' in field:
|
|
592
|
+
row.append(field['doubleValue'])
|
|
593
|
+
elif 'booleanValue' in field:
|
|
594
|
+
row.append(field['booleanValue'])
|
|
595
|
+
elif 'isNull' in field and field['isNull']:
|
|
596
|
+
row.append(None)
|
|
597
|
+
else:
|
|
598
|
+
# Fallback for unknown field types
|
|
599
|
+
row.append(str(field))
|
|
600
|
+
rows.append(row)
|
|
601
|
+
|
|
602
|
+
query_result = {
|
|
603
|
+
'columns': columns,
|
|
604
|
+
'rows': rows,
|
|
605
|
+
'row_count': len(rows),
|
|
606
|
+
'execution_time_ms': execution_time_ms,
|
|
607
|
+
'query_id': query_id,
|
|
608
|
+
}
|
|
609
|
+
|
|
610
|
+
logger.info(
|
|
611
|
+
f'Query executed successfully: {query_id}, returned {len(rows)} rows in {execution_time_ms}ms'
|
|
612
|
+
)
|
|
613
|
+
return query_result
|
|
614
|
+
|
|
615
|
+
except Exception as e:
|
|
616
|
+
logger.error(f'Error executing query on cluster {cluster_identifier}: {str(e)}')
|
|
617
|
+
raise
|
|
618
|
+
|
|
619
|
+
|
|
620
|
+
# Global client manager instance
|
|
621
|
+
client_manager = RedshiftClientManager(
|
|
622
|
+
config=Config(
|
|
623
|
+
connect_timeout=CLIENT_TIMEOUT,
|
|
624
|
+
read_timeout=CLIENT_TIMEOUT,
|
|
625
|
+
retries={'max_attempts': 3, 'mode': 'adaptive'},
|
|
626
|
+
user_agent_extra=f'awslabs/mcp/redshift-mcp-server/{__version__}',
|
|
627
|
+
),
|
|
628
|
+
aws_region=os.environ.get('AWS_REGION', DEFAULT_AWS_REGION),
|
|
629
|
+
aws_profile=os.environ.get('AWS_PROFILE'),
|
|
630
|
+
)
|