mdb-engine 0.1.6__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mdb_engine/__init__.py +104 -11
- mdb_engine/auth/ARCHITECTURE.md +112 -0
- mdb_engine/auth/README.md +648 -11
- mdb_engine/auth/__init__.py +136 -29
- mdb_engine/auth/audit.py +592 -0
- mdb_engine/auth/base.py +252 -0
- mdb_engine/auth/casbin_factory.py +264 -69
- mdb_engine/auth/config_helpers.py +7 -6
- mdb_engine/auth/cookie_utils.py +3 -7
- mdb_engine/auth/csrf.py +373 -0
- mdb_engine/auth/decorators.py +3 -10
- mdb_engine/auth/dependencies.py +47 -50
- mdb_engine/auth/helpers.py +3 -3
- mdb_engine/auth/integration.py +53 -80
- mdb_engine/auth/jwt.py +2 -6
- mdb_engine/auth/middleware.py +77 -34
- mdb_engine/auth/oso_factory.py +18 -38
- mdb_engine/auth/provider.py +270 -171
- mdb_engine/auth/rate_limiter.py +504 -0
- mdb_engine/auth/restrictions.py +8 -24
- mdb_engine/auth/session_manager.py +14 -29
- mdb_engine/auth/shared_middleware.py +600 -0
- mdb_engine/auth/shared_users.py +759 -0
- mdb_engine/auth/token_store.py +14 -28
- mdb_engine/auth/users.py +54 -113
- mdb_engine/auth/utils.py +213 -15
- mdb_engine/cli/commands/generate.py +545 -9
- mdb_engine/cli/commands/validate.py +3 -7
- mdb_engine/cli/utils.py +3 -3
- mdb_engine/config.py +7 -21
- mdb_engine/constants.py +65 -0
- mdb_engine/core/README.md +117 -6
- mdb_engine/core/__init__.py +39 -7
- mdb_engine/core/app_registration.py +22 -41
- mdb_engine/core/app_secrets.py +290 -0
- mdb_engine/core/connection.py +18 -9
- mdb_engine/core/encryption.py +223 -0
- mdb_engine/core/engine.py +1057 -93
- mdb_engine/core/index_management.py +12 -16
- mdb_engine/core/manifest.py +459 -150
- mdb_engine/core/ray_integration.py +435 -0
- mdb_engine/core/seeding.py +10 -18
- mdb_engine/core/service_initialization.py +12 -23
- mdb_engine/core/types.py +2 -5
- mdb_engine/database/README.md +140 -17
- mdb_engine/database/__init__.py +17 -6
- mdb_engine/database/abstraction.py +25 -37
- mdb_engine/database/connection.py +11 -18
- mdb_engine/database/query_validator.py +367 -0
- mdb_engine/database/resource_limiter.py +204 -0
- mdb_engine/database/scoped_wrapper.py +713 -196
- mdb_engine/dependencies.py +426 -0
- mdb_engine/di/__init__.py +34 -0
- mdb_engine/di/container.py +248 -0
- mdb_engine/di/providers.py +205 -0
- mdb_engine/di/scopes.py +139 -0
- mdb_engine/embeddings/README.md +54 -24
- mdb_engine/embeddings/__init__.py +31 -24
- mdb_engine/embeddings/dependencies.py +37 -154
- mdb_engine/embeddings/service.py +11 -25
- mdb_engine/exceptions.py +92 -0
- mdb_engine/indexes/README.md +30 -13
- mdb_engine/indexes/__init__.py +1 -0
- mdb_engine/indexes/helpers.py +1 -1
- mdb_engine/indexes/manager.py +50 -114
- mdb_engine/memory/README.md +2 -2
- mdb_engine/memory/__init__.py +1 -2
- mdb_engine/memory/service.py +30 -87
- mdb_engine/observability/README.md +4 -2
- mdb_engine/observability/__init__.py +26 -9
- mdb_engine/observability/health.py +8 -9
- mdb_engine/observability/metrics.py +32 -12
- mdb_engine/repositories/__init__.py +34 -0
- mdb_engine/repositories/base.py +325 -0
- mdb_engine/repositories/mongo.py +233 -0
- mdb_engine/repositories/unit_of_work.py +166 -0
- mdb_engine/routing/README.md +1 -1
- mdb_engine/routing/__init__.py +1 -3
- mdb_engine/routing/websockets.py +25 -60
- mdb_engine-0.2.0.dist-info/METADATA +313 -0
- mdb_engine-0.2.0.dist-info/RECORD +96 -0
- mdb_engine-0.1.6.dist-info/METADATA +0 -213
- mdb_engine-0.1.6.dist-info/RECORD +0 -75
- {mdb_engine-0.1.6.dist-info → mdb_engine-0.2.0.dist-info}/WHEEL +0 -0
- {mdb_engine-0.1.6.dist-info → mdb_engine-0.2.0.dist-info}/entry_points.txt +0 -0
- {mdb_engine-0.1.6.dist-info → mdb_engine-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {mdb_engine-0.1.6.dist-info → mdb_engine-0.2.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,367 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Query validation for MongoDB Engine.
|
|
3
|
+
|
|
4
|
+
This module provides comprehensive query validation to prevent NoSQL injection,
|
|
5
|
+
block dangerous operators, and enforce query complexity limits.
|
|
6
|
+
|
|
7
|
+
Security Features:
|
|
8
|
+
- Blocks dangerous MongoDB operators ($where, $eval, $function, $accumulator)
|
|
9
|
+
- Prevents deeply nested queries
|
|
10
|
+
- Limits regex complexity to prevent ReDoS attacks
|
|
11
|
+
- Validates aggregation pipelines
|
|
12
|
+
- Prevents NoSQL injection patterns
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
import logging
|
|
16
|
+
import re
|
|
17
|
+
from typing import Any, Dict, List, Optional, Set
|
|
18
|
+
|
|
19
|
+
from ..constants import (
|
|
20
|
+
DANGEROUS_OPERATORS,
|
|
21
|
+
MAX_PIPELINE_STAGES,
|
|
22
|
+
MAX_QUERY_DEPTH,
|
|
23
|
+
MAX_REGEX_COMPLEXITY,
|
|
24
|
+
MAX_REGEX_LENGTH,
|
|
25
|
+
MAX_SORT_FIELDS,
|
|
26
|
+
)
|
|
27
|
+
from ..exceptions import QueryValidationError
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class QueryValidator:
|
|
33
|
+
"""
|
|
34
|
+
Validates MongoDB queries for security and safety.
|
|
35
|
+
|
|
36
|
+
This class provides comprehensive validation to prevent:
|
|
37
|
+
- NoSQL injection attacks
|
|
38
|
+
- Dangerous operator usage
|
|
39
|
+
- Resource exhaustion via complex queries
|
|
40
|
+
- ReDoS attacks via complex regex patterns
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
max_depth: int = MAX_QUERY_DEPTH,
|
|
46
|
+
max_pipeline_stages: int = MAX_PIPELINE_STAGES,
|
|
47
|
+
max_regex_length: int = MAX_REGEX_LENGTH,
|
|
48
|
+
max_regex_complexity: int = MAX_REGEX_COMPLEXITY,
|
|
49
|
+
dangerous_operators: Optional[Set[str]] = None,
|
|
50
|
+
):
|
|
51
|
+
"""
|
|
52
|
+
Initialize the query validator.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
max_depth: Maximum nesting depth for queries
|
|
56
|
+
max_pipeline_stages: Maximum stages in aggregation pipelines
|
|
57
|
+
max_regex_length: Maximum length for regex patterns
|
|
58
|
+
max_regex_complexity: Maximum complexity score for regex patterns
|
|
59
|
+
dangerous_operators: Set of dangerous operators to block
|
|
60
|
+
(defaults to DANGEROUS_OPERATORS)
|
|
61
|
+
"""
|
|
62
|
+
self.max_depth = max_depth
|
|
63
|
+
self.max_pipeline_stages = max_pipeline_stages
|
|
64
|
+
self.max_regex_length = max_regex_length
|
|
65
|
+
self.max_regex_complexity = max_regex_complexity
|
|
66
|
+
# Merge custom dangerous operators with defaults
|
|
67
|
+
if dangerous_operators is not None:
|
|
68
|
+
# Convert DANGEROUS_OPERATORS tuple to set for union operation
|
|
69
|
+
default_ops = (
|
|
70
|
+
set(DANGEROUS_OPERATORS)
|
|
71
|
+
if isinstance(DANGEROUS_OPERATORS, tuple)
|
|
72
|
+
else DANGEROUS_OPERATORS
|
|
73
|
+
)
|
|
74
|
+
self.dangerous_operators = default_ops | set(dangerous_operators)
|
|
75
|
+
else:
|
|
76
|
+
# Convert tuple to set for consistency
|
|
77
|
+
self.dangerous_operators = (
|
|
78
|
+
set(DANGEROUS_OPERATORS)
|
|
79
|
+
if isinstance(DANGEROUS_OPERATORS, tuple)
|
|
80
|
+
else DANGEROUS_OPERATORS
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
def validate_filter(self, filter: Optional[Dict[str, Any]], path: str = "") -> None:
|
|
84
|
+
"""
|
|
85
|
+
Validate a MongoDB query filter.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
filter: The query filter to validate
|
|
89
|
+
path: JSON path for error reporting (used recursively)
|
|
90
|
+
|
|
91
|
+
Raises:
|
|
92
|
+
QueryValidationError: If the filter contains dangerous operators or exceeds limits
|
|
93
|
+
"""
|
|
94
|
+
if not filter:
|
|
95
|
+
return
|
|
96
|
+
|
|
97
|
+
if not isinstance(filter, dict):
|
|
98
|
+
raise QueryValidationError(
|
|
99
|
+
f"Query filter must be a dictionary, got {type(filter).__name__}",
|
|
100
|
+
query_type="filter",
|
|
101
|
+
path=path,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
# Check for dangerous operators and validate depth
|
|
105
|
+
self._check_dangerous_operators(filter, path)
|
|
106
|
+
self._check_query_depth(filter, path, depth=0)
|
|
107
|
+
|
|
108
|
+
def validate_pipeline(self, pipeline: List[Dict[str, Any]]) -> None:
|
|
109
|
+
"""
|
|
110
|
+
Validate an aggregation pipeline.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
pipeline: The aggregation pipeline to validate
|
|
114
|
+
|
|
115
|
+
Raises:
|
|
116
|
+
QueryValidationError: If the pipeline exceeds limits or contains dangerous operators
|
|
117
|
+
"""
|
|
118
|
+
if not pipeline:
|
|
119
|
+
return
|
|
120
|
+
|
|
121
|
+
if not isinstance(pipeline, list):
|
|
122
|
+
raise QueryValidationError(
|
|
123
|
+
f"Aggregation pipeline must be a list, got {type(pipeline).__name__}",
|
|
124
|
+
query_type="pipeline",
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
# Check pipeline length
|
|
128
|
+
if len(pipeline) > self.max_pipeline_stages:
|
|
129
|
+
raise QueryValidationError(
|
|
130
|
+
f"Aggregation pipeline exceeds maximum stages: "
|
|
131
|
+
f"{len(pipeline)} > {self.max_pipeline_stages}",
|
|
132
|
+
query_type="pipeline",
|
|
133
|
+
context={
|
|
134
|
+
"stages": len(pipeline),
|
|
135
|
+
"max_stages": self.max_pipeline_stages,
|
|
136
|
+
},
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
# Validate each stage
|
|
140
|
+
for idx, stage in enumerate(pipeline):
|
|
141
|
+
if not isinstance(stage, dict):
|
|
142
|
+
raise QueryValidationError(
|
|
143
|
+
f"Pipeline stage {idx} must be a dictionary, got {type(stage).__name__}",
|
|
144
|
+
query_type="pipeline",
|
|
145
|
+
path=f"$[{idx}]",
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
# Check for dangerous operators in each stage
|
|
149
|
+
stage_path = f"$[{idx}]"
|
|
150
|
+
self._check_dangerous_operators(stage, stage_path)
|
|
151
|
+
self._check_query_depth(stage, stage_path, depth=0)
|
|
152
|
+
|
|
153
|
+
def validate_regex(self, pattern: str, path: str = "") -> None:
|
|
154
|
+
"""
|
|
155
|
+
Validate a regex pattern to prevent ReDoS attacks.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
pattern: The regex pattern to validate
|
|
159
|
+
path: JSON path for error reporting
|
|
160
|
+
|
|
161
|
+
Raises:
|
|
162
|
+
QueryValidationError: If the regex pattern is too complex or long
|
|
163
|
+
"""
|
|
164
|
+
if not isinstance(pattern, str):
|
|
165
|
+
return # Not a regex pattern
|
|
166
|
+
|
|
167
|
+
# Check length
|
|
168
|
+
if len(pattern) > self.max_regex_length:
|
|
169
|
+
raise QueryValidationError(
|
|
170
|
+
f"Regex pattern exceeds maximum length: "
|
|
171
|
+
f"{len(pattern)} > {self.max_regex_length}",
|
|
172
|
+
query_type="regex",
|
|
173
|
+
path=path,
|
|
174
|
+
context={
|
|
175
|
+
"length": len(pattern),
|
|
176
|
+
"max_length": self.max_regex_length,
|
|
177
|
+
},
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
# Check complexity (simple heuristic: count quantifiers and alternations)
|
|
181
|
+
complexity = self._calculate_regex_complexity(pattern)
|
|
182
|
+
if complexity > self.max_regex_complexity:
|
|
183
|
+
raise QueryValidationError(
|
|
184
|
+
f"Regex pattern exceeds maximum complexity: "
|
|
185
|
+
f"{complexity} > {self.max_regex_complexity}",
|
|
186
|
+
query_type="regex",
|
|
187
|
+
path=path,
|
|
188
|
+
context={
|
|
189
|
+
"complexity": complexity,
|
|
190
|
+
"max_complexity": self.max_regex_complexity,
|
|
191
|
+
},
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
# Try to compile the regex to catch syntax errors early
|
|
195
|
+
try:
|
|
196
|
+
re.compile(pattern)
|
|
197
|
+
except re.error as e:
|
|
198
|
+
raise QueryValidationError(
|
|
199
|
+
f"Invalid regex pattern: {e}",
|
|
200
|
+
query_type="regex",
|
|
201
|
+
path=path,
|
|
202
|
+
) from e
|
|
203
|
+
|
|
204
|
+
def validate_sort(self, sort: Optional[Any]) -> None:
|
|
205
|
+
"""
|
|
206
|
+
Validate a sort specification.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
sort: The sort specification to validate
|
|
210
|
+
|
|
211
|
+
Raises:
|
|
212
|
+
QueryValidationError: If the sort specification exceeds limits
|
|
213
|
+
"""
|
|
214
|
+
if not sort:
|
|
215
|
+
return
|
|
216
|
+
|
|
217
|
+
# Count sort fields
|
|
218
|
+
sort_fields = self._extract_sort_fields(sort)
|
|
219
|
+
if len(sort_fields) > MAX_SORT_FIELDS:
|
|
220
|
+
raise QueryValidationError(
|
|
221
|
+
f"Sort specification exceeds maximum fields: "
|
|
222
|
+
f"{len(sort_fields)} > {MAX_SORT_FIELDS}",
|
|
223
|
+
query_type="sort",
|
|
224
|
+
context={
|
|
225
|
+
"fields": len(sort_fields),
|
|
226
|
+
"max_fields": MAX_SORT_FIELDS,
|
|
227
|
+
},
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
def _check_dangerous_operators(
|
|
231
|
+
self, query: Dict[str, Any], path: str = "", depth: int = 0
|
|
232
|
+
) -> None:
|
|
233
|
+
"""
|
|
234
|
+
Recursively check for dangerous operators in a query.
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
query: The query dictionary to check
|
|
238
|
+
path: Current JSON path for error reporting
|
|
239
|
+
depth: Current nesting depth
|
|
240
|
+
|
|
241
|
+
Raises:
|
|
242
|
+
QueryValidationError: If a dangerous operator is found
|
|
243
|
+
"""
|
|
244
|
+
if depth > self.max_depth:
|
|
245
|
+
raise QueryValidationError(
|
|
246
|
+
f"Query exceeds maximum nesting depth: {depth} > {self.max_depth}",
|
|
247
|
+
query_type="filter",
|
|
248
|
+
path=path,
|
|
249
|
+
context={"depth": depth, "max_depth": self.max_depth},
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
for key, value in query.items():
|
|
253
|
+
current_path = f"{path}.{key}" if path else key
|
|
254
|
+
|
|
255
|
+
# Check if key is a dangerous operator
|
|
256
|
+
if key in self.dangerous_operators:
|
|
257
|
+
logger.warning(
|
|
258
|
+
f"Security: Dangerous operator '{key}' detected in query "
|
|
259
|
+
f"at path '{current_path}'"
|
|
260
|
+
)
|
|
261
|
+
raise QueryValidationError(
|
|
262
|
+
f"Dangerous operator '{key}' is not allowed for security reasons. "
|
|
263
|
+
f"Found at path: {current_path}",
|
|
264
|
+
query_type="filter",
|
|
265
|
+
operator=key,
|
|
266
|
+
path=current_path,
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
# Recursively check nested dictionaries
|
|
270
|
+
if isinstance(value, dict):
|
|
271
|
+
# Check for $regex operator and validate pattern
|
|
272
|
+
if "$regex" in value:
|
|
273
|
+
regex_pattern = value["$regex"]
|
|
274
|
+
if isinstance(regex_pattern, str):
|
|
275
|
+
self.validate_regex(regex_pattern, f"{current_path}.$regex")
|
|
276
|
+
self._check_dangerous_operators(value, current_path, depth + 1)
|
|
277
|
+
elif isinstance(value, list):
|
|
278
|
+
# Check list elements
|
|
279
|
+
for idx, item in enumerate(value):
|
|
280
|
+
if isinstance(item, dict):
|
|
281
|
+
item_path = f"{current_path}[{idx}]"
|
|
282
|
+
# Check for $regex in list items
|
|
283
|
+
if "$regex" in item and isinstance(item["$regex"], str):
|
|
284
|
+
self.validate_regex(item["$regex"], f"{item_path}.$regex")
|
|
285
|
+
self._check_dangerous_operators(item, item_path, depth + 1)
|
|
286
|
+
elif isinstance(value, str) and key == "$regex":
|
|
287
|
+
# Direct $regex value (less common but possible)
|
|
288
|
+
self.validate_regex(value, current_path)
|
|
289
|
+
|
|
290
|
+
def _check_query_depth(self, query: Dict[str, Any], path: str = "", depth: int = 0) -> None:
|
|
291
|
+
"""
|
|
292
|
+
Check query nesting depth.
|
|
293
|
+
|
|
294
|
+
Args:
|
|
295
|
+
query: The query dictionary to check
|
|
296
|
+
path: Current JSON path for error reporting
|
|
297
|
+
depth: Current nesting depth
|
|
298
|
+
|
|
299
|
+
Raises:
|
|
300
|
+
QueryValidationError: If query depth exceeds maximum
|
|
301
|
+
"""
|
|
302
|
+
if depth > self.max_depth:
|
|
303
|
+
raise QueryValidationError(
|
|
304
|
+
f"Query exceeds maximum nesting depth: {depth} > {self.max_depth}",
|
|
305
|
+
query_type="filter",
|
|
306
|
+
path=path,
|
|
307
|
+
context={"depth": depth, "max_depth": self.max_depth},
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
# Recursively check nested dictionaries
|
|
311
|
+
for key, value in query.items():
|
|
312
|
+
current_path = f"{path}.{key}" if path else key
|
|
313
|
+
|
|
314
|
+
if isinstance(value, dict):
|
|
315
|
+
self._check_query_depth(value, current_path, depth + 1)
|
|
316
|
+
elif isinstance(value, list):
|
|
317
|
+
for idx, item in enumerate(value):
|
|
318
|
+
if isinstance(item, dict):
|
|
319
|
+
item_path = f"{current_path}[{idx}]"
|
|
320
|
+
self._check_query_depth(item, item_path, depth + 1)
|
|
321
|
+
|
|
322
|
+
def _calculate_regex_complexity(self, pattern: str) -> int:
|
|
323
|
+
"""
|
|
324
|
+
Calculate a complexity score for a regex pattern.
|
|
325
|
+
|
|
326
|
+
This is a simple heuristic to detect potentially dangerous regex patterns
|
|
327
|
+
that could cause ReDoS attacks.
|
|
328
|
+
|
|
329
|
+
Args:
|
|
330
|
+
pattern: The regex pattern
|
|
331
|
+
|
|
332
|
+
Returns:
|
|
333
|
+
Complexity score (higher = more complex)
|
|
334
|
+
"""
|
|
335
|
+
complexity = 0
|
|
336
|
+
|
|
337
|
+
# Count quantifiers (can cause backtracking)
|
|
338
|
+
complexity += len(re.findall(r"[*+?{]", pattern))
|
|
339
|
+
|
|
340
|
+
# Count alternations (can cause exponential growth)
|
|
341
|
+
complexity += len(re.findall(r"\|", pattern))
|
|
342
|
+
|
|
343
|
+
# Count nested groups (can cause deep backtracking)
|
|
344
|
+
complexity += len(re.findall(r"\([^)]*\([^)]*\)", pattern))
|
|
345
|
+
|
|
346
|
+
# Count lookahead/lookbehind (can be expensive)
|
|
347
|
+
complexity += len(re.findall(r"\(\?[=!<>]", pattern))
|
|
348
|
+
|
|
349
|
+
return complexity
|
|
350
|
+
|
|
351
|
+
def _extract_sort_fields(self, sort: Any) -> List[str]:
|
|
352
|
+
"""
|
|
353
|
+
Extract field names from a sort specification.
|
|
354
|
+
|
|
355
|
+
Args:
|
|
356
|
+
sort: Sort specification (list of tuples, dict, or single tuple)
|
|
357
|
+
|
|
358
|
+
Returns:
|
|
359
|
+
List of field names
|
|
360
|
+
"""
|
|
361
|
+
if isinstance(sort, list):
|
|
362
|
+
return [field for field, _ in sort if isinstance(field, str)]
|
|
363
|
+
elif isinstance(sort, dict):
|
|
364
|
+
return list(sort.keys())
|
|
365
|
+
elif isinstance(sort, tuple) and len(sort) == 2:
|
|
366
|
+
return [sort[0]] if isinstance(sort[0], str) else []
|
|
367
|
+
return []
|
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Resource limiting for MongoDB Engine.
|
|
3
|
+
|
|
4
|
+
This module provides resource limit enforcement to prevent resource exhaustion
|
|
5
|
+
and ensure fair resource usage across applications.
|
|
6
|
+
|
|
7
|
+
Features:
|
|
8
|
+
- Query timeout enforcement
|
|
9
|
+
- Result size limits
|
|
10
|
+
- Document size validation
|
|
11
|
+
- Connection limit tracking
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import logging
|
|
15
|
+
from typing import Any, Dict, Optional
|
|
16
|
+
|
|
17
|
+
from bson import encode as bson_encode
|
|
18
|
+
from bson.errors import InvalidDocument
|
|
19
|
+
|
|
20
|
+
from ..constants import (
|
|
21
|
+
DEFAULT_MAX_TIME_MS,
|
|
22
|
+
MAX_CURSOR_BATCH_SIZE,
|
|
23
|
+
MAX_DOCUMENT_SIZE,
|
|
24
|
+
MAX_QUERY_RESULT_SIZE,
|
|
25
|
+
MAX_QUERY_TIME_MS,
|
|
26
|
+
)
|
|
27
|
+
from ..exceptions import ResourceLimitExceeded
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class ResourceLimiter:
|
|
33
|
+
"""
|
|
34
|
+
Enforces resource limits on MongoDB operations.
|
|
35
|
+
|
|
36
|
+
This class provides resource limit enforcement to prevent:
|
|
37
|
+
- Query timeouts
|
|
38
|
+
- Excessive result sizes
|
|
39
|
+
- Oversized documents
|
|
40
|
+
- Resource exhaustion
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
default_timeout_ms: int = DEFAULT_MAX_TIME_MS,
|
|
46
|
+
max_timeout_ms: int = MAX_QUERY_TIME_MS,
|
|
47
|
+
max_result_size: int = MAX_QUERY_RESULT_SIZE,
|
|
48
|
+
max_batch_size: int = MAX_CURSOR_BATCH_SIZE,
|
|
49
|
+
max_document_size: int = MAX_DOCUMENT_SIZE,
|
|
50
|
+
):
|
|
51
|
+
"""
|
|
52
|
+
Initialize the resource limiter.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
default_timeout_ms: Default query timeout in milliseconds
|
|
56
|
+
max_timeout_ms: Maximum allowed query timeout in milliseconds
|
|
57
|
+
max_result_size: Maximum number of documents in a result set
|
|
58
|
+
max_batch_size: Maximum batch size for cursor operations
|
|
59
|
+
max_document_size: Maximum document size in bytes
|
|
60
|
+
"""
|
|
61
|
+
self.default_timeout_ms = default_timeout_ms
|
|
62
|
+
self.max_timeout_ms = max_timeout_ms
|
|
63
|
+
self.max_result_size = max_result_size
|
|
64
|
+
self.max_batch_size = max_batch_size
|
|
65
|
+
self.max_document_size = max_document_size
|
|
66
|
+
|
|
67
|
+
def enforce_query_timeout(
|
|
68
|
+
self, kwargs: Dict[str, Any], default_timeout: Optional[int] = None
|
|
69
|
+
) -> Dict[str, Any]:
|
|
70
|
+
"""
|
|
71
|
+
Enforce query timeout by adding maxTimeMS if not present.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
kwargs: Query keyword arguments
|
|
75
|
+
default_timeout: Default timeout to use (defaults to self.default_timeout_ms)
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
Updated kwargs with maxTimeMS added if needed
|
|
79
|
+
"""
|
|
80
|
+
kwargs = dict(kwargs) # Create a copy to avoid mutating original
|
|
81
|
+
|
|
82
|
+
default = default_timeout if default_timeout is not None else self.default_timeout_ms
|
|
83
|
+
|
|
84
|
+
# Check if maxTimeMS is already set
|
|
85
|
+
if "maxTimeMS" in kwargs:
|
|
86
|
+
user_timeout = kwargs["maxTimeMS"]
|
|
87
|
+
# Validate user-provided timeout doesn't exceed maximum
|
|
88
|
+
if user_timeout > self.max_timeout_ms:
|
|
89
|
+
logger.warning(
|
|
90
|
+
f"Query timeout {user_timeout}ms exceeds maximum {self.max_timeout_ms}ms. "
|
|
91
|
+
f"Capping to {self.max_timeout_ms}ms"
|
|
92
|
+
)
|
|
93
|
+
kwargs["maxTimeMS"] = self.max_timeout_ms
|
|
94
|
+
else:
|
|
95
|
+
# Add default timeout
|
|
96
|
+
kwargs["maxTimeMS"] = default
|
|
97
|
+
|
|
98
|
+
return kwargs
|
|
99
|
+
|
|
100
|
+
def enforce_result_limit(self, limit: Optional[int], max_limit: Optional[int] = None) -> int:
|
|
101
|
+
"""
|
|
102
|
+
Enforce maximum result limit.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
limit: Requested limit (None means no limit)
|
|
106
|
+
max_limit: Maximum allowed limit (defaults to self.max_result_size)
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
Enforced limit value (capped to maximum if needed)
|
|
110
|
+
"""
|
|
111
|
+
max_allowed = max_limit if max_limit is not None else self.max_result_size
|
|
112
|
+
|
|
113
|
+
if limit is None:
|
|
114
|
+
# No limit requested, return max allowed
|
|
115
|
+
return max_allowed
|
|
116
|
+
|
|
117
|
+
if limit > max_allowed:
|
|
118
|
+
logger.warning(
|
|
119
|
+
f"Result limit {limit} exceeds maximum {max_allowed}. " f"Capping to {max_allowed}"
|
|
120
|
+
)
|
|
121
|
+
return max_allowed
|
|
122
|
+
|
|
123
|
+
return limit
|
|
124
|
+
|
|
125
|
+
def enforce_batch_size(self, batch_size: Optional[int], max_batch: Optional[int] = None) -> int:
|
|
126
|
+
"""
|
|
127
|
+
Enforce maximum batch size for cursor operations.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
batch_size: Requested batch size (None means use default)
|
|
131
|
+
max_batch: Maximum allowed batch size (defaults to self.max_batch_size)
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
Enforced batch size
|
|
135
|
+
"""
|
|
136
|
+
max_allowed = max_batch if max_batch is not None else self.max_batch_size
|
|
137
|
+
|
|
138
|
+
if batch_size is None:
|
|
139
|
+
return max_allowed
|
|
140
|
+
|
|
141
|
+
if batch_size > max_allowed:
|
|
142
|
+
logger.warning(
|
|
143
|
+
f"Batch size {batch_size} exceeds maximum {max_allowed}. "
|
|
144
|
+
f"Capping to {max_allowed}"
|
|
145
|
+
)
|
|
146
|
+
return max_allowed
|
|
147
|
+
|
|
148
|
+
return batch_size
|
|
149
|
+
|
|
150
|
+
def validate_document_size(self, document: Dict[str, Any]) -> None:
|
|
151
|
+
"""
|
|
152
|
+
Validate that a document doesn't exceed size limits.
|
|
153
|
+
|
|
154
|
+
Uses actual BSON encoding for accurate size calculation.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
document: Document to validate
|
|
158
|
+
|
|
159
|
+
Raises:
|
|
160
|
+
ResourceLimitExceeded: If document exceeds size limit
|
|
161
|
+
"""
|
|
162
|
+
try:
|
|
163
|
+
# Use actual BSON encoding for accurate size
|
|
164
|
+
bson_bytes = bson_encode(document)
|
|
165
|
+
actual_size = len(bson_bytes)
|
|
166
|
+
|
|
167
|
+
if actual_size > self.max_document_size:
|
|
168
|
+
raise ResourceLimitExceeded(
|
|
169
|
+
f"Document size {actual_size} bytes exceeds maximum "
|
|
170
|
+
f"{self.max_document_size} bytes",
|
|
171
|
+
limit_type="document_size",
|
|
172
|
+
limit_value=self.max_document_size,
|
|
173
|
+
actual_value=actual_size,
|
|
174
|
+
)
|
|
175
|
+
except ResourceLimitExceeded:
|
|
176
|
+
# Re-raise our validation exceptions immediately
|
|
177
|
+
raise
|
|
178
|
+
except InvalidDocument as e:
|
|
179
|
+
# If BSON encoding fails, log warning but don't fail
|
|
180
|
+
# MongoDB will catch this anyway during actual insert
|
|
181
|
+
logger.warning(f"Could not encode document as BSON for size validation: {e}")
|
|
182
|
+
|
|
183
|
+
def validate_documents_size(self, documents: list[Dict[str, Any]]) -> None:
|
|
184
|
+
"""
|
|
185
|
+
Validate that multiple documents don't exceed size limits.
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
documents: List of documents to validate
|
|
189
|
+
|
|
190
|
+
Raises:
|
|
191
|
+
ResourceLimitExceeded: If any document exceeds size limit
|
|
192
|
+
"""
|
|
193
|
+
for idx, doc in enumerate(documents):
|
|
194
|
+
try:
|
|
195
|
+
self.validate_document_size(doc)
|
|
196
|
+
except ResourceLimitExceeded as e:
|
|
197
|
+
# Add document index to error context
|
|
198
|
+
raise ResourceLimitExceeded(
|
|
199
|
+
f"{e.message} (document index: {idx})",
|
|
200
|
+
limit_type=e.limit_type,
|
|
201
|
+
limit_value=e.limit_value,
|
|
202
|
+
actual_value=e.actual_value,
|
|
203
|
+
context={**(e.context or {}), "document_index": idx},
|
|
204
|
+
) from e
|