awslabs.documentdb-mcp-server 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- awslabs/documentdb_mcp_server/__init__.py +14 -0
- awslabs/documentdb_mcp_server/analytic_tools.py +375 -0
- awslabs/documentdb_mcp_server/config.py +30 -0
- awslabs/documentdb_mcp_server/connection_tools.py +223 -0
- awslabs/documentdb_mcp_server/db_management_tools.py +176 -0
- awslabs/documentdb_mcp_server/query_tools.py +121 -0
- awslabs/documentdb_mcp_server/server.py +159 -0
- awslabs/documentdb_mcp_server/write_tools.py +202 -0
- awslabs_documentdb_mcp_server-0.0.1.dist-info/METADATA +202 -0
- awslabs_documentdb_mcp_server-0.0.1.dist-info/RECORD +14 -0
- awslabs_documentdb_mcp_server-0.0.1.dist-info/WHEEL +4 -0
- awslabs_documentdb_mcp_server-0.0.1.dist-info/entry_points.txt +2 -0
- awslabs_documentdb_mcp_server-0.0.1.dist-info/licenses/LICENSE +175 -0
- awslabs_documentdb_mcp_server-0.0.1.dist-info/licenses/NOTICE +2 -0
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
|
|
4
|
+
# with the License. A copy of the License is located at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES
|
|
9
|
+
# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions
|
|
10
|
+
# and limitations under the License.
|
|
11
|
+
|
|
12
|
+
"""AWS Labs DocumentDB MCP Server package."""
|
|
13
|
+
|
|
14
|
+
__version__ = '1.0.0'
|
|
@@ -0,0 +1,375 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
|
|
4
|
+
# with the License. A copy of the License is located at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES
|
|
9
|
+
# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions
|
|
10
|
+
# and limitations under the License.
|
|
11
|
+
|
|
12
|
+
"""Analytic tools for DocumentDB MCP Server."""
|
|
13
|
+
|
|
14
|
+
from awslabs.documentdb_mcp_server.connection_tools import DocumentDBConnection
|
|
15
|
+
from loguru import logger
|
|
16
|
+
from pydantic import Field
|
|
17
|
+
from typing import Annotated, Any, Dict, List, Optional
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
async def count_documents(
|
|
21
|
+
connection_id: Annotated[
|
|
22
|
+
str, Field(description='The connection ID returned by the connect tool')
|
|
23
|
+
],
|
|
24
|
+
database: Annotated[str, Field(description='Name of the database')],
|
|
25
|
+
collection: Annotated[str, Field(description='Name of the collection')],
|
|
26
|
+
filter: Annotated[
|
|
27
|
+
Optional[Dict[str, Any]], Field(description='Query filter to count specific documents')
|
|
28
|
+
] = None,
|
|
29
|
+
) -> Dict[str, Any]:
|
|
30
|
+
"""Count documents in a DocumentDB collection.
|
|
31
|
+
|
|
32
|
+
This tool counts the number of documents in a collection that match the provided filter.
|
|
33
|
+
If no filter is provided, it counts all documents.
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
Dict[str, Any]: Count result
|
|
37
|
+
"""
|
|
38
|
+
try:
|
|
39
|
+
# Get connection
|
|
40
|
+
if connection_id not in DocumentDBConnection._connections:
|
|
41
|
+
raise ValueError(f'Connection ID {connection_id} not found. You must connect first.')
|
|
42
|
+
|
|
43
|
+
connection_info = DocumentDBConnection._connections[connection_id]
|
|
44
|
+
client = connection_info.client
|
|
45
|
+
|
|
46
|
+
db = client[database]
|
|
47
|
+
coll = db[collection]
|
|
48
|
+
|
|
49
|
+
# Use empty filter if none provided
|
|
50
|
+
if filter is None:
|
|
51
|
+
filter = {}
|
|
52
|
+
|
|
53
|
+
count = coll.count_documents(filter)
|
|
54
|
+
|
|
55
|
+
logger.info(f"Counted {count} documents in '{database}.{collection}'")
|
|
56
|
+
return {'count': count, 'database': database, 'collection': collection, 'filter': filter}
|
|
57
|
+
except ValueError as e:
|
|
58
|
+
logger.error(f'Connection error: {str(e)}')
|
|
59
|
+
raise ValueError(str(e))
|
|
60
|
+
except Exception as e:
|
|
61
|
+
logger.error(f'Error counting documents: {str(e)}')
|
|
62
|
+
raise ValueError(f'Failed to count documents: {str(e)}')
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
async def get_database_stats(
|
|
66
|
+
connection_id: Annotated[
|
|
67
|
+
str, Field(description='The connection ID returned by the connect tool')
|
|
68
|
+
],
|
|
69
|
+
database: Annotated[str, Field(description='Name of the database')],
|
|
70
|
+
) -> Dict[str, Any]:
|
|
71
|
+
"""Get statistics about a DocumentDB database.
|
|
72
|
+
|
|
73
|
+
This tool retrieves statistics about the specified database,
|
|
74
|
+
including storage information and collection data.
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
Dict[str, Any]: Database statistics
|
|
78
|
+
"""
|
|
79
|
+
try:
|
|
80
|
+
# Get connection
|
|
81
|
+
if connection_id not in DocumentDBConnection._connections:
|
|
82
|
+
raise ValueError(f'Connection ID {connection_id} not found. You must connect first.')
|
|
83
|
+
|
|
84
|
+
connection_info = DocumentDBConnection._connections[connection_id]
|
|
85
|
+
client = connection_info.client
|
|
86
|
+
|
|
87
|
+
db = client[database]
|
|
88
|
+
|
|
89
|
+
# Get database stats
|
|
90
|
+
stats = db.command('dbStats')
|
|
91
|
+
|
|
92
|
+
logger.info(f"Retrieved database statistics for '{database}'")
|
|
93
|
+
return {'stats': stats, 'database': database}
|
|
94
|
+
except ValueError as e:
|
|
95
|
+
logger.error(f'Connection error: {str(e)}')
|
|
96
|
+
raise ValueError(str(e))
|
|
97
|
+
except Exception as e:
|
|
98
|
+
logger.error(f'Error retrieving database statistics: {str(e)}')
|
|
99
|
+
raise ValueError(f'Failed to get database statistics: {str(e)}')
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
async def get_collection_stats(
|
|
103
|
+
connection_id: Annotated[
|
|
104
|
+
str, Field(description='The connection ID returned by the connect tool')
|
|
105
|
+
],
|
|
106
|
+
database: Annotated[str, Field(description='Name of the database')],
|
|
107
|
+
collection: Annotated[str, Field(description='Name of the collection')],
|
|
108
|
+
) -> Dict[str, Any]:
|
|
109
|
+
"""Get statistics about a DocumentDB collection.
|
|
110
|
+
|
|
111
|
+
This tool retrieves detailed statistics about the specified collection,
|
|
112
|
+
including size, document count, and storage information.
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
Dict[str, Any]: Collection statistics
|
|
116
|
+
"""
|
|
117
|
+
try:
|
|
118
|
+
# Get connection
|
|
119
|
+
if connection_id not in DocumentDBConnection._connections:
|
|
120
|
+
raise ValueError(f'Connection ID {connection_id} not found. You must connect first.')
|
|
121
|
+
|
|
122
|
+
connection_info = DocumentDBConnection._connections[connection_id]
|
|
123
|
+
client = connection_info.client
|
|
124
|
+
|
|
125
|
+
db = client[database]
|
|
126
|
+
|
|
127
|
+
# Get collection stats
|
|
128
|
+
stats = db.command('collStats', collection)
|
|
129
|
+
|
|
130
|
+
logger.info(f"Retrieved collection statistics for '{database}.{collection}'")
|
|
131
|
+
return {'stats': stats, 'database': database, 'collection': collection}
|
|
132
|
+
except ValueError as e:
|
|
133
|
+
logger.error(f'Connection error: {str(e)}')
|
|
134
|
+
raise ValueError(str(e))
|
|
135
|
+
except Exception as e:
|
|
136
|
+
logger.error(f'Error retrieving collection statistics: {str(e)}')
|
|
137
|
+
raise ValueError(f'Failed to get collection statistics: {str(e)}')
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def get_field_type(docs, path):
|
|
141
|
+
"""Helper function to determine the data type of a field across documents."""
|
|
142
|
+
parts = path.split('.')
|
|
143
|
+
types = set()
|
|
144
|
+
|
|
145
|
+
for doc in docs:
|
|
146
|
+
value = doc
|
|
147
|
+
try:
|
|
148
|
+
for part in parts:
|
|
149
|
+
if '[' in part:
|
|
150
|
+
# Handle array indexing
|
|
151
|
+
array_part = part.split('[')[0]
|
|
152
|
+
if array_part in value:
|
|
153
|
+
value = value[array_part]
|
|
154
|
+
# Try to get array item
|
|
155
|
+
if isinstance(value, list) and len(value) > 0:
|
|
156
|
+
index = int(part.split('[')[1].split(']')[0])
|
|
157
|
+
if len(value) > index:
|
|
158
|
+
value = value[index]
|
|
159
|
+
else:
|
|
160
|
+
value = None
|
|
161
|
+
break
|
|
162
|
+
else:
|
|
163
|
+
value = None
|
|
164
|
+
break
|
|
165
|
+
else:
|
|
166
|
+
value = None
|
|
167
|
+
break
|
|
168
|
+
else:
|
|
169
|
+
if isinstance(value, dict) and part in value:
|
|
170
|
+
value = value[part]
|
|
171
|
+
else:
|
|
172
|
+
value = None
|
|
173
|
+
break
|
|
174
|
+
|
|
175
|
+
if value is not None:
|
|
176
|
+
value_type = type(value).__name__
|
|
177
|
+
if value_type == 'dict':
|
|
178
|
+
types.add('object')
|
|
179
|
+
elif value_type == 'list':
|
|
180
|
+
types.add('array')
|
|
181
|
+
else:
|
|
182
|
+
types.add(value_type)
|
|
183
|
+
except (ValueError, IndexError, KeyError, TypeError, AttributeError) as e:
|
|
184
|
+
logger.warning(f'Error processing document: {doc}. Error: {e}')
|
|
185
|
+
continue
|
|
186
|
+
|
|
187
|
+
if not types:
|
|
188
|
+
return 'null'
|
|
189
|
+
elif len(types) == 1:
|
|
190
|
+
return next(iter(types))
|
|
191
|
+
else:
|
|
192
|
+
return list(types)
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
async def analyze_schema(
|
|
196
|
+
connection_id: Annotated[
|
|
197
|
+
str, Field(description='The connection ID returned by the connect tool')
|
|
198
|
+
],
|
|
199
|
+
database: Annotated[str, Field(description='Name of the database')],
|
|
200
|
+
collection: Annotated[str, Field(description='Name of the collection to analyze')],
|
|
201
|
+
sample_size: Annotated[
|
|
202
|
+
int, Field(description='Number of documents to sample (default: 100)')
|
|
203
|
+
] = 100,
|
|
204
|
+
) -> Dict[str, Any]:
|
|
205
|
+
"""Analyze the schema of a collection by sampling documents.
|
|
206
|
+
|
|
207
|
+
This tool samples documents from a collection and provides information about
|
|
208
|
+
the document structure and field coverage across the sampled documents.
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
211
|
+
Dict[str, Any]: Schema analysis results including field coverage
|
|
212
|
+
"""
|
|
213
|
+
try:
|
|
214
|
+
# Get connection
|
|
215
|
+
if connection_id not in DocumentDBConnection._connections:
|
|
216
|
+
raise ValueError(f'Connection ID {connection_id} not found. You must connect first.')
|
|
217
|
+
|
|
218
|
+
connection_info = DocumentDBConnection._connections[connection_id]
|
|
219
|
+
client = connection_info.client
|
|
220
|
+
|
|
221
|
+
db = client[database]
|
|
222
|
+
coll = db[collection]
|
|
223
|
+
|
|
224
|
+
# Count total documents to adjust sample size if needed
|
|
225
|
+
total_docs = coll.count_documents({})
|
|
226
|
+
actual_sample_size = min(sample_size, total_docs)
|
|
227
|
+
|
|
228
|
+
if actual_sample_size == 0:
|
|
229
|
+
return {
|
|
230
|
+
'error': 'Collection is empty',
|
|
231
|
+
'field_coverage': {},
|
|
232
|
+
'total_documents': 0,
|
|
233
|
+
'sampled_documents': 0,
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
# Sample documents (using aggregation with $sample stage)
|
|
237
|
+
sample_pipeline = [{'$sample': {'size': actual_sample_size}}]
|
|
238
|
+
sampled_docs = list(coll.aggregate(sample_pipeline))
|
|
239
|
+
|
|
240
|
+
# Analyze schema and calculate field coverage
|
|
241
|
+
field_paths = set()
|
|
242
|
+
field_counts = {}
|
|
243
|
+
|
|
244
|
+
def extract_paths(obj, prefix=''):
|
|
245
|
+
if isinstance(obj, dict):
|
|
246
|
+
for key, value in obj.items():
|
|
247
|
+
if key == '_id':
|
|
248
|
+
continue # Skip _id field
|
|
249
|
+
|
|
250
|
+
path = f'{prefix}.{key}' if prefix else key
|
|
251
|
+
field_paths.add(path)
|
|
252
|
+
|
|
253
|
+
if path not in field_counts:
|
|
254
|
+
field_counts[path] = 0
|
|
255
|
+
field_counts[path] += 1
|
|
256
|
+
|
|
257
|
+
extract_paths(value, path)
|
|
258
|
+
elif isinstance(obj, list) and len(obj) > 0:
|
|
259
|
+
# For arrays, we'll only analyze the first item to avoid complexity
|
|
260
|
+
extract_paths(obj[0], f'{prefix}[0]')
|
|
261
|
+
|
|
262
|
+
for doc in sampled_docs:
|
|
263
|
+
extract_paths(doc)
|
|
264
|
+
|
|
265
|
+
# Calculate coverage percentages
|
|
266
|
+
coverage = {}
|
|
267
|
+
for path, count in field_counts.items():
|
|
268
|
+
coverage[path] = {
|
|
269
|
+
'count': count,
|
|
270
|
+
'percentage': round((count / actual_sample_size) * 100, 2),
|
|
271
|
+
'data_type': get_field_type(sampled_docs, path),
|
|
272
|
+
}
|
|
273
|
+
|
|
274
|
+
logger.info(
|
|
275
|
+
f"Analyzed schema for '{database}.{collection}' with {actual_sample_size} documents"
|
|
276
|
+
)
|
|
277
|
+
return {
|
|
278
|
+
'field_coverage': coverage,
|
|
279
|
+
'total_documents': total_docs,
|
|
280
|
+
'sampled_documents': actual_sample_size,
|
|
281
|
+
'database': database,
|
|
282
|
+
'collection': collection,
|
|
283
|
+
}
|
|
284
|
+
except ValueError as e:
|
|
285
|
+
logger.error(f'Connection error: {str(e)}')
|
|
286
|
+
raise ValueError(str(e))
|
|
287
|
+
except Exception as e:
|
|
288
|
+
logger.error(f'Error analyzing schema: {str(e)}')
|
|
289
|
+
raise ValueError(f'Failed to analyze collection schema: {str(e)}')
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
async def explain_operation(
|
|
293
|
+
connection_id: Annotated[
|
|
294
|
+
str, Field(description='The connection ID returned by the connect tool')
|
|
295
|
+
],
|
|
296
|
+
database: Annotated[str, Field(description='Name of the database')],
|
|
297
|
+
collection: Annotated[str, Field(description='Name of the collection')],
|
|
298
|
+
operation_type: Annotated[
|
|
299
|
+
str, Field(description='Type of operation to explain (find, aggregate)')
|
|
300
|
+
],
|
|
301
|
+
query: Annotated[
|
|
302
|
+
Optional[Dict[str, Any]], Field(description='Query for find operations')
|
|
303
|
+
] = None,
|
|
304
|
+
pipeline: Annotated[
|
|
305
|
+
Optional[List[Dict[str, Any]]],
|
|
306
|
+
Field(description='Pipeline for DocumentDB aggregation operations'),
|
|
307
|
+
] = None,
|
|
308
|
+
verbosity: Annotated[
|
|
309
|
+
str, Field(description='Explanation verbosity level (queryPlanner, executionStats)')
|
|
310
|
+
] = 'queryPlanner',
|
|
311
|
+
) -> Dict[str, Any]:
|
|
312
|
+
"""Get an explanation of how an operation will be executed.
|
|
313
|
+
|
|
314
|
+
This tool returns the execution plan for a query or aggregation operation,
|
|
315
|
+
helping you understand how DocumentDB will process your operations.
|
|
316
|
+
|
|
317
|
+
Returns:
|
|
318
|
+
Dict[str, Any]: Operation explanation
|
|
319
|
+
"""
|
|
320
|
+
try:
|
|
321
|
+
# Get connection
|
|
322
|
+
if connection_id not in DocumentDBConnection._connections:
|
|
323
|
+
raise ValueError(f'Connection ID {connection_id} not found. You must connect first.')
|
|
324
|
+
|
|
325
|
+
connection_info = DocumentDBConnection._connections[connection_id]
|
|
326
|
+
client = connection_info.client
|
|
327
|
+
|
|
328
|
+
db = client[database]
|
|
329
|
+
# Get collection but no need to store in variable since we use db.command directly
|
|
330
|
+
db[collection] # Validate collection exists
|
|
331
|
+
|
|
332
|
+
# Validate operation type
|
|
333
|
+
operation_type = operation_type.lower()
|
|
334
|
+
if operation_type not in ['find', 'aggregate']:
|
|
335
|
+
raise ValueError('Operation type must be one of: find, aggregate')
|
|
336
|
+
|
|
337
|
+
# Validate verbosity
|
|
338
|
+
verbosity_lower = verbosity.lower()
|
|
339
|
+
if verbosity_lower not in ['queryplanner', 'executionstats']:
|
|
340
|
+
verbosity = 'queryPlanner' # Default to queryPlanner if invalid
|
|
341
|
+
|
|
342
|
+
# Get explanation based on operation type
|
|
343
|
+
if operation_type == 'find':
|
|
344
|
+
if not query:
|
|
345
|
+
query = {}
|
|
346
|
+
|
|
347
|
+
explanation = db.command(
|
|
348
|
+
{'explain': {'find': collection, 'filter': query}, 'verbosity': verbosity}
|
|
349
|
+
)
|
|
350
|
+
logger.info(f"Explained find operation on '{database}.{collection}'")
|
|
351
|
+
|
|
352
|
+
else: # aggregate
|
|
353
|
+
if not pipeline:
|
|
354
|
+
raise ValueError('Pipeline is required for aggregate operations')
|
|
355
|
+
|
|
356
|
+
explanation = db.command(
|
|
357
|
+
{
|
|
358
|
+
'explain': {'aggregate': collection, 'pipeline': pipeline, 'cursor': {}},
|
|
359
|
+
'verbosity': verbosity,
|
|
360
|
+
}
|
|
361
|
+
)
|
|
362
|
+
logger.info(f"Explained aggregate operation on '{database}.{collection}'")
|
|
363
|
+
|
|
364
|
+
return {
|
|
365
|
+
'explanation': explanation,
|
|
366
|
+
'operation_type': operation_type,
|
|
367
|
+
'database': database,
|
|
368
|
+
'collection': collection,
|
|
369
|
+
}
|
|
370
|
+
except ValueError as e:
|
|
371
|
+
logger.error(f'Connection error or invalid parameters: {str(e)}')
|
|
372
|
+
raise ValueError(str(e))
|
|
373
|
+
except Exception as e:
|
|
374
|
+
logger.error(f'Error explaining operation: {str(e)}')
|
|
375
|
+
raise ValueError(f'Failed to explain operation: {str(e)}')
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
|
|
4
|
+
# with the License. A copy of the License is located at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES
|
|
9
|
+
# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions
|
|
10
|
+
# and limitations under the License.
|
|
11
|
+
|
|
12
|
+
"""Configuration settings for DocumentDB MCP Server."""
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ServerConfig:
|
|
16
|
+
"""Configuration class for DocumentDB MCP Server.
|
|
17
|
+
|
|
18
|
+
This class contains configuration options that control the server's behavior.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self):
|
|
22
|
+
"""Initialize server configuration with default values.
|
|
23
|
+
|
|
24
|
+
By default, the server starts in read-only mode for safety.
|
|
25
|
+
"""
|
|
26
|
+
self.read_only_mode = True
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
# Singleton instance
|
|
30
|
+
serverConfig = ServerConfig()
|
|
@@ -0,0 +1,223 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
|
|
4
|
+
# with the License. A copy of the License is located at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES
|
|
9
|
+
# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions
|
|
10
|
+
# and limitations under the License.
|
|
11
|
+
|
|
12
|
+
"""Connection management tools for DocumentDB MCP Server."""
|
|
13
|
+
|
|
14
|
+
import uuid
|
|
15
|
+
from datetime import datetime, timedelta
|
|
16
|
+
from loguru import logger
|
|
17
|
+
from pydantic import Field
|
|
18
|
+
from pymongo import MongoClient
|
|
19
|
+
from pymongo.errors import ConnectionFailure, OperationFailure
|
|
20
|
+
from typing import Annotated, Any, Dict
|
|
21
|
+
from urllib.parse import parse_qs, urlparse
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class ConnectionInfo:
|
|
25
|
+
"""Stores information about a DocumentDB connection."""
|
|
26
|
+
|
|
27
|
+
def __init__(self, connection_string: str, client: MongoClient):
|
|
28
|
+
"""Initialize a ConnectionInfo object.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
connection_string: The connection string used to connect to DocumentDB
|
|
32
|
+
client: The MongoDB client instance connected to DocumentDB
|
|
33
|
+
"""
|
|
34
|
+
self.connection_string = connection_string
|
|
35
|
+
self.client = client
|
|
36
|
+
self.connection_id = str(uuid.uuid4())
|
|
37
|
+
self.last_used = datetime.now()
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class DocumentDBConnection:
|
|
41
|
+
"""Manages connections to DocumentDB."""
|
|
42
|
+
|
|
43
|
+
# Connection pool mapped by connection_id
|
|
44
|
+
_connections = {}
|
|
45
|
+
|
|
46
|
+
# Idle timeout in minutes (connections unused for this long will be closed)
|
|
47
|
+
_idle_timeout = 30
|
|
48
|
+
|
|
49
|
+
@classmethod
|
|
50
|
+
def create_connection(cls, connection_string: str) -> ConnectionInfo:
|
|
51
|
+
"""Create a new connection to DocumentDB.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
connection_string: DocumentDB connection string
|
|
55
|
+
Example: "mongodb://username:password@docdb-cluster.cluster-xyz.us-west-2.docdb.amazonaws.com:27017/?tls=true&tlsCAFile=global-bundle.pem&retryWrites=false" # pragma: allowlist secret
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
ConnectionInfo containing the connection ID and client
|
|
59
|
+
"""
|
|
60
|
+
logger.info('Creating new DocumentDB connection')
|
|
61
|
+
DocumentDBConnection.validate_retry_writes_false(connection_string)
|
|
62
|
+
client = MongoClient(connection_string)
|
|
63
|
+
|
|
64
|
+
# Test connection
|
|
65
|
+
try:
|
|
66
|
+
client.admin.command('ping')
|
|
67
|
+
logger.info('Connected successfully to DocumentDB')
|
|
68
|
+
except (ConnectionFailure, OperationFailure) as e:
|
|
69
|
+
logger.error(f'Failed to connect to DocumentDB: {str(e)}')
|
|
70
|
+
raise
|
|
71
|
+
|
|
72
|
+
# Store connection info
|
|
73
|
+
connection_info = ConnectionInfo(connection_string, client)
|
|
74
|
+
cls._connections[connection_info.connection_id] = connection_info
|
|
75
|
+
|
|
76
|
+
return connection_info
|
|
77
|
+
|
|
78
|
+
@classmethod
|
|
79
|
+
def get_connection(cls, connection_id: str) -> MongoClient:
|
|
80
|
+
"""Get an existing connection by ID.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
connection_id: The connection ID returned by create_connection
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
An active pymongo client connected to DocumentDB
|
|
87
|
+
|
|
88
|
+
Raises:
|
|
89
|
+
ValueError: If the connection ID is not found
|
|
90
|
+
"""
|
|
91
|
+
if connection_id not in cls._connections:
|
|
92
|
+
raise ValueError(f'Connection ID {connection_id} not found. You must connect first.')
|
|
93
|
+
|
|
94
|
+
# Update last used timestamp
|
|
95
|
+
connection_info = cls._connections[connection_id]
|
|
96
|
+
connection_info.last_used = datetime.now()
|
|
97
|
+
|
|
98
|
+
return connection_info.client
|
|
99
|
+
|
|
100
|
+
@classmethod
|
|
101
|
+
def close_connection(cls, connection_id: str) -> None:
|
|
102
|
+
"""Close a specific connection by ID.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
connection_id: The connection ID to close
|
|
106
|
+
|
|
107
|
+
Raises:
|
|
108
|
+
ValueError: If the connection ID is not found
|
|
109
|
+
"""
|
|
110
|
+
if connection_id not in cls._connections:
|
|
111
|
+
raise ValueError(f'Connection ID {connection_id} not found')
|
|
112
|
+
|
|
113
|
+
logger.info(f'Closing DocumentDB connection {connection_id}')
|
|
114
|
+
connection_info = cls._connections[connection_id]
|
|
115
|
+
connection_info.client.close()
|
|
116
|
+
del cls._connections[connection_id]
|
|
117
|
+
|
|
118
|
+
@classmethod
|
|
119
|
+
def close_idle_connections(cls) -> None:
|
|
120
|
+
"""Close connections that have been idle for longer than the timeout."""
|
|
121
|
+
now = datetime.now()
|
|
122
|
+
idle_threshold = now - timedelta(minutes=cls._idle_timeout)
|
|
123
|
+
|
|
124
|
+
idle_connections = [
|
|
125
|
+
conn_id
|
|
126
|
+
for conn_id, info in cls._connections.items()
|
|
127
|
+
if info.last_used < idle_threshold
|
|
128
|
+
]
|
|
129
|
+
|
|
130
|
+
for conn_id in idle_connections:
|
|
131
|
+
logger.info(f'Closing idle DocumentDB connection {conn_id}')
|
|
132
|
+
cls._connections[conn_id].client.close()
|
|
133
|
+
del cls._connections[conn_id]
|
|
134
|
+
|
|
135
|
+
@classmethod
|
|
136
|
+
def close_all_connections(cls) -> None:
|
|
137
|
+
"""Close all open connections."""
|
|
138
|
+
for conn_id, conn_info in list(cls._connections.items()):
|
|
139
|
+
logger.info(f'Closing DocumentDB connection {conn_id}')
|
|
140
|
+
conn_info.client.close()
|
|
141
|
+
cls._connections.clear()
|
|
142
|
+
|
|
143
|
+
@staticmethod
|
|
144
|
+
def validate_retry_writes_false(conn_str: str) -> None:
|
|
145
|
+
"""Validate that retryWrites=false is specified in the connection string.
|
|
146
|
+
|
|
147
|
+
DocumentDB requires retryWrites=false to be set in the connection string.
|
|
148
|
+
This method ensures this setting is present to avoid potential data consistency issues.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
conn_str: The connection string to validate
|
|
152
|
+
|
|
153
|
+
Raises:
|
|
154
|
+
ValueError: If retryWrites is missing or set to a value other than 'false'
|
|
155
|
+
"""
|
|
156
|
+
parsed = urlparse(conn_str)
|
|
157
|
+
query_params = parse_qs(parsed.query)
|
|
158
|
+
|
|
159
|
+
retry_value = query_params.get('retryWrites', [None])[0]
|
|
160
|
+
|
|
161
|
+
if retry_value is None:
|
|
162
|
+
raise ValueError("Connection string is missing 'retryWrites=false'.")
|
|
163
|
+
|
|
164
|
+
if retry_value.lower() != 'false':
|
|
165
|
+
raise ValueError(f"Invalid retryWrites value: '{retry_value}'. Expected 'false'.")
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
async def connect(
|
|
169
|
+
connection_string: Annotated[
|
|
170
|
+
str,
|
|
171
|
+
Field(
|
|
172
|
+
description='DocumentDB connection string. Example: "mongodb://user:pass@docdb-cluster.cluster-xyz.us-west-2.docdb.amazonaws.com:27017/?tls=true&tlsCAFile=global-bundle.pem"' # pragma: allowlist secret
|
|
173
|
+
),
|
|
174
|
+
],
|
|
175
|
+
) -> Dict[str, Any]:
|
|
176
|
+
"""Connect to an AWS DocumentDB cluster.
|
|
177
|
+
|
|
178
|
+
This tool establishes and validates a connection to DocumentDB.
|
|
179
|
+
The returned connection_id can be used with other tools instead of providing
|
|
180
|
+
the full connection string each time.
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
Dict[str, Any]: Connection details including connection_id and available databases
|
|
184
|
+
"""
|
|
185
|
+
try:
|
|
186
|
+
# Create connection and get connection info
|
|
187
|
+
connection_info = DocumentDBConnection.create_connection(connection_string)
|
|
188
|
+
client = connection_info.client
|
|
189
|
+
|
|
190
|
+
# List available databases
|
|
191
|
+
databases = client.list_database_names()
|
|
192
|
+
|
|
193
|
+
return {
|
|
194
|
+
'connection_id': connection_info.connection_id,
|
|
195
|
+
'message': 'Successfully connected to DocumentDB',
|
|
196
|
+
'databases': databases,
|
|
197
|
+
}
|
|
198
|
+
except Exception as e:
|
|
199
|
+
logger.error(f'Error connecting to DocumentDB: {str(e)}')
|
|
200
|
+
raise ValueError(f'Failed to connect to DocumentDB: {str(e)}')
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
async def disconnect(
|
|
204
|
+
connection_id: Annotated[
|
|
205
|
+
str, Field(description='The connection ID returned by the connect tool')
|
|
206
|
+
],
|
|
207
|
+
) -> Dict[str, Any]:
|
|
208
|
+
"""Close a connection to DocumentDB.
|
|
209
|
+
|
|
210
|
+
This tool closes a previously established connection to DocumentDB.
|
|
211
|
+
|
|
212
|
+
Returns:
|
|
213
|
+
Dict[str, Any]: Confirmation of successful disconnection
|
|
214
|
+
"""
|
|
215
|
+
try:
|
|
216
|
+
DocumentDBConnection.close_connection(connection_id)
|
|
217
|
+
return {'success': True, 'message': f'Successfully closed connection {connection_id}'}
|
|
218
|
+
except ValueError as e:
|
|
219
|
+
logger.error(f'Error disconnecting from DocumentDB: {str(e)}')
|
|
220
|
+
raise ValueError(str(e))
|
|
221
|
+
except Exception as e:
|
|
222
|
+
logger.error(f'Error disconnecting from DocumentDB: {str(e)}')
|
|
223
|
+
raise ValueError(f'Failed to disconnect from DocumentDB: {str(e)}')
|