mdb-engine 0.1.6__py3-none-any.whl → 0.4.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mdb_engine/__init__.py +116 -11
- mdb_engine/auth/ARCHITECTURE.md +112 -0
- mdb_engine/auth/README.md +654 -11
- mdb_engine/auth/__init__.py +136 -29
- mdb_engine/auth/audit.py +592 -0
- mdb_engine/auth/base.py +252 -0
- mdb_engine/auth/casbin_factory.py +265 -70
- mdb_engine/auth/config_defaults.py +5 -5
- mdb_engine/auth/config_helpers.py +19 -18
- mdb_engine/auth/cookie_utils.py +12 -16
- mdb_engine/auth/csrf.py +483 -0
- mdb_engine/auth/decorators.py +10 -16
- mdb_engine/auth/dependencies.py +69 -71
- mdb_engine/auth/helpers.py +3 -3
- mdb_engine/auth/integration.py +61 -88
- mdb_engine/auth/jwt.py +11 -15
- mdb_engine/auth/middleware.py +79 -35
- mdb_engine/auth/oso_factory.py +21 -41
- mdb_engine/auth/provider.py +270 -171
- mdb_engine/auth/rate_limiter.py +505 -0
- mdb_engine/auth/restrictions.py +21 -36
- mdb_engine/auth/session_manager.py +24 -41
- mdb_engine/auth/shared_middleware.py +977 -0
- mdb_engine/auth/shared_users.py +775 -0
- mdb_engine/auth/token_lifecycle.py +10 -12
- mdb_engine/auth/token_store.py +17 -32
- mdb_engine/auth/users.py +99 -159
- mdb_engine/auth/utils.py +236 -42
- mdb_engine/cli/commands/generate.py +546 -10
- mdb_engine/cli/commands/validate.py +3 -7
- mdb_engine/cli/utils.py +7 -7
- mdb_engine/config.py +13 -28
- mdb_engine/constants.py +65 -0
- mdb_engine/core/README.md +117 -6
- mdb_engine/core/__init__.py +39 -7
- mdb_engine/core/app_registration.py +31 -50
- mdb_engine/core/app_secrets.py +289 -0
- mdb_engine/core/connection.py +20 -12
- mdb_engine/core/encryption.py +222 -0
- mdb_engine/core/engine.py +2862 -115
- mdb_engine/core/index_management.py +12 -16
- mdb_engine/core/manifest.py +628 -204
- mdb_engine/core/ray_integration.py +436 -0
- mdb_engine/core/seeding.py +13 -21
- mdb_engine/core/service_initialization.py +20 -30
- mdb_engine/core/types.py +40 -43
- mdb_engine/database/README.md +140 -17
- mdb_engine/database/__init__.py +17 -6
- mdb_engine/database/abstraction.py +37 -50
- mdb_engine/database/connection.py +51 -30
- mdb_engine/database/query_validator.py +367 -0
- mdb_engine/database/resource_limiter.py +204 -0
- mdb_engine/database/scoped_wrapper.py +747 -237
- mdb_engine/dependencies.py +427 -0
- mdb_engine/di/__init__.py +34 -0
- mdb_engine/di/container.py +247 -0
- mdb_engine/di/providers.py +206 -0
- mdb_engine/di/scopes.py +139 -0
- mdb_engine/embeddings/README.md +54 -24
- mdb_engine/embeddings/__init__.py +31 -24
- mdb_engine/embeddings/dependencies.py +38 -155
- mdb_engine/embeddings/service.py +78 -75
- mdb_engine/exceptions.py +104 -12
- mdb_engine/indexes/README.md +30 -13
- mdb_engine/indexes/__init__.py +1 -0
- mdb_engine/indexes/helpers.py +11 -11
- mdb_engine/indexes/manager.py +59 -123
- mdb_engine/memory/README.md +95 -4
- mdb_engine/memory/__init__.py +1 -2
- mdb_engine/memory/service.py +363 -1168
- mdb_engine/observability/README.md +4 -2
- mdb_engine/observability/__init__.py +26 -9
- mdb_engine/observability/health.py +17 -17
- mdb_engine/observability/logging.py +10 -10
- mdb_engine/observability/metrics.py +40 -19
- mdb_engine/repositories/__init__.py +34 -0
- mdb_engine/repositories/base.py +325 -0
- mdb_engine/repositories/mongo.py +233 -0
- mdb_engine/repositories/unit_of_work.py +166 -0
- mdb_engine/routing/README.md +1 -1
- mdb_engine/routing/__init__.py +1 -3
- mdb_engine/routing/websockets.py +41 -75
- mdb_engine/utils/__init__.py +3 -1
- mdb_engine/utils/mongo.py +117 -0
- mdb_engine-0.4.12.dist-info/METADATA +492 -0
- mdb_engine-0.4.12.dist-info/RECORD +97 -0
- {mdb_engine-0.1.6.dist-info → mdb_engine-0.4.12.dist-info}/WHEEL +1 -1
- mdb_engine-0.1.6.dist-info/METADATA +0 -213
- mdb_engine-0.1.6.dist-info/RECORD +0 -75
- {mdb_engine-0.1.6.dist-info → mdb_engine-0.4.12.dist-info}/entry_points.txt +0 -0
- {mdb_engine-0.1.6.dist-info → mdb_engine-0.4.12.dist-info}/licenses/LICENSE +0 -0
- {mdb_engine-0.1.6.dist-info → mdb_engine-0.4.12.dist-info}/top_level.txt +0 -0
|
@@ -23,16 +23,20 @@ Usage:
|
|
|
23
23
|
import logging
|
|
24
24
|
import os
|
|
25
25
|
import threading
|
|
26
|
-
from typing import Any
|
|
26
|
+
from typing import Any
|
|
27
27
|
|
|
28
28
|
from motor.motor_asyncio import AsyncIOMotorClient
|
|
29
|
-
from pymongo.errors import (
|
|
30
|
-
|
|
29
|
+
from pymongo.errors import (
|
|
30
|
+
ConnectionFailure,
|
|
31
|
+
InvalidOperation,
|
|
32
|
+
OperationFailure,
|
|
33
|
+
ServerSelectionTimeoutError,
|
|
34
|
+
)
|
|
31
35
|
|
|
32
36
|
logger = logging.getLogger(__name__)
|
|
33
37
|
|
|
34
38
|
# Global singleton instance
|
|
35
|
-
_shared_client:
|
|
39
|
+
_shared_client: AsyncIOMotorClient | None = None
|
|
36
40
|
# Use threading.Lock for cross-thread safety in multi-threaded environments
|
|
37
41
|
# asyncio.Lock isn't sufficient for thread-safe initialization
|
|
38
42
|
_init_lock = threading.Lock()
|
|
@@ -40,8 +44,8 @@ _init_lock = threading.Lock()
|
|
|
40
44
|
|
|
41
45
|
def get_shared_mongo_client(
|
|
42
46
|
mongo_uri: str,
|
|
43
|
-
max_pool_size:
|
|
44
|
-
min_pool_size:
|
|
47
|
+
max_pool_size: int | None = None,
|
|
48
|
+
min_pool_size: int | None = None,
|
|
45
49
|
server_selection_timeout_ms: int = 5000,
|
|
46
50
|
max_idle_time_ms: int = 45000,
|
|
47
51
|
retry_writes: bool = True,
|
|
@@ -87,10 +91,7 @@ def get_shared_mongo_client(
|
|
|
87
91
|
# Verify client is still connected
|
|
88
92
|
try:
|
|
89
93
|
# Non-blocking check - if client was closed, it will be None or invalid
|
|
90
|
-
if (
|
|
91
|
-
hasattr(_shared_client, "_topology")
|
|
92
|
-
and _shared_client._topology is not None
|
|
93
|
-
):
|
|
94
|
+
if hasattr(_shared_client, "_topology") and _shared_client._topology is not None:
|
|
94
95
|
return _shared_client
|
|
95
96
|
except (AttributeError, RuntimeError):
|
|
96
97
|
# Client was closed or invalid, reset and recreate
|
|
@@ -103,10 +104,7 @@ def get_shared_mongo_client(
|
|
|
103
104
|
# Double-check pattern: another thread may have initialized while we waited
|
|
104
105
|
if _shared_client is not None:
|
|
105
106
|
try:
|
|
106
|
-
if (
|
|
107
|
-
hasattr(_shared_client, "_topology")
|
|
108
|
-
and _shared_client._topology is not None
|
|
109
|
-
):
|
|
107
|
+
if hasattr(_shared_client, "_topology") and _shared_client._topology is not None:
|
|
110
108
|
return _shared_client
|
|
111
109
|
except (AttributeError, RuntimeError):
|
|
112
110
|
# Client was closed or invalid, reset and recreate
|
|
@@ -180,7 +178,7 @@ async def verify_shared_client() -> bool:
|
|
|
180
178
|
OperationFailure,
|
|
181
179
|
InvalidOperation,
|
|
182
180
|
) as e:
|
|
183
|
-
logger.
|
|
181
|
+
logger.exception(f"Shared MongoDB client verification failed: {e}")
|
|
184
182
|
return False
|
|
185
183
|
|
|
186
184
|
|
|
@@ -205,8 +203,8 @@ def register_client_for_metrics(client: AsyncIOMotorClient) -> None:
|
|
|
205
203
|
|
|
206
204
|
|
|
207
205
|
async def get_pool_metrics(
|
|
208
|
-
client:
|
|
209
|
-
) ->
|
|
206
|
+
client: AsyncIOMotorClient | None = None,
|
|
207
|
+
) -> dict[str, Any]:
|
|
210
208
|
"""
|
|
211
209
|
Gets connection pool metrics for monitoring.
|
|
212
210
|
Returns information about pool size, active connections, etc.
|
|
@@ -236,10 +234,7 @@ async def get_pool_metrics(
|
|
|
236
234
|
for registered_client in _registered_clients:
|
|
237
235
|
try:
|
|
238
236
|
# Verify client is still valid
|
|
239
|
-
if (
|
|
240
|
-
hasattr(registered_client, "_topology")
|
|
241
|
-
and registered_client._topology is not None
|
|
242
|
-
):
|
|
237
|
+
if hasattr(registered_client, "_topology") and registered_client._topology is not None:
|
|
243
238
|
return await _get_client_pool_metrics(registered_client)
|
|
244
239
|
except (AttributeError, RuntimeError):
|
|
245
240
|
# Type 2: Recoverable - if this client is invalid, try next one
|
|
@@ -252,7 +247,7 @@ async def get_pool_metrics(
|
|
|
252
247
|
}
|
|
253
248
|
|
|
254
249
|
|
|
255
|
-
async def _get_client_pool_metrics(client: AsyncIOMotorClient) ->
|
|
250
|
+
async def _get_client_pool_metrics(client: AsyncIOMotorClient) -> dict[str, Any]:
|
|
256
251
|
"""
|
|
257
252
|
Internal helper to get pool metrics from a specific client.
|
|
258
253
|
|
|
@@ -304,10 +299,32 @@ async def _get_client_pool_metrics(client: AsyncIOMotorClient) -> Dict[str, Any]
|
|
|
304
299
|
|
|
305
300
|
try:
|
|
306
301
|
server_status = await client.admin.command("serverStatus")
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
302
|
+
if not isinstance(server_status, dict):
|
|
303
|
+
# Mock or invalid response - skip connection metrics
|
|
304
|
+
current_connections = None
|
|
305
|
+
available_connections = None
|
|
306
|
+
total_created = None
|
|
307
|
+
else:
|
|
308
|
+
connections = server_status.get("connections", {})
|
|
309
|
+
if not isinstance(connections, dict):
|
|
310
|
+
# Mock or invalid response - skip connection metrics
|
|
311
|
+
current_connections = None
|
|
312
|
+
available_connections = None
|
|
313
|
+
total_created = None
|
|
314
|
+
else:
|
|
315
|
+
# Get values, ensuring they're numeric (not MagicMocks)
|
|
316
|
+
current_raw = connections.get("current", 0)
|
|
317
|
+
available_raw = connections.get("available", 0)
|
|
318
|
+
total_raw = connections.get("totalCreated", 0)
|
|
319
|
+
|
|
320
|
+
# Only use if actually numeric
|
|
321
|
+
current_connections = (
|
|
322
|
+
int(current_raw) if isinstance(current_raw, int | float) else None
|
|
323
|
+
)
|
|
324
|
+
available_connections = (
|
|
325
|
+
int(available_raw) if isinstance(available_raw, int | float) else None
|
|
326
|
+
)
|
|
327
|
+
total_created = int(total_raw) if isinstance(total_raw, int | float) else None
|
|
311
328
|
except (
|
|
312
329
|
OperationFailure,
|
|
313
330
|
ConnectionFailure,
|
|
@@ -335,12 +352,16 @@ async def _get_client_pool_metrics(client: AsyncIOMotorClient) -> Dict[str, Any]
|
|
|
335
352
|
metrics["total_connections_created"] = total_created
|
|
336
353
|
|
|
337
354
|
# Calculate pool usage if we have max_pool_size and current connections
|
|
338
|
-
|
|
355
|
+
# Ensure both are numeric (not MagicMock or other types)
|
|
356
|
+
if (
|
|
357
|
+
max_pool_size
|
|
358
|
+
and current_connections is not None
|
|
359
|
+
and isinstance(max_pool_size, int | float)
|
|
360
|
+
and isinstance(current_connections, int | float)
|
|
361
|
+
):
|
|
339
362
|
usage_percent = (current_connections / max_pool_size) * 100
|
|
340
363
|
metrics["pool_usage_percent"] = round(usage_percent, 2)
|
|
341
|
-
metrics["active_connections"] =
|
|
342
|
-
current_connections # Alias for compatibility
|
|
343
|
-
)
|
|
364
|
+
metrics["active_connections"] = current_connections # Alias for compatibility
|
|
344
365
|
|
|
345
366
|
# Warn if pool usage is high
|
|
346
367
|
if usage_percent > 80:
|
|
@@ -0,0 +1,367 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Query validation for MongoDB Engine.
|
|
3
|
+
|
|
4
|
+
This module provides comprehensive query validation to prevent NoSQL injection,
|
|
5
|
+
block dangerous operators, and enforce query complexity limits.
|
|
6
|
+
|
|
7
|
+
Security Features:
|
|
8
|
+
- Blocks dangerous MongoDB operators ($where, $eval, $function, $accumulator)
|
|
9
|
+
- Prevents deeply nested queries
|
|
10
|
+
- Limits regex complexity to prevent ReDoS attacks
|
|
11
|
+
- Validates aggregation pipelines
|
|
12
|
+
- Prevents NoSQL injection patterns
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
import logging
|
|
16
|
+
import re
|
|
17
|
+
from typing import Any
|
|
18
|
+
|
|
19
|
+
from ..constants import (
|
|
20
|
+
DANGEROUS_OPERATORS,
|
|
21
|
+
MAX_PIPELINE_STAGES,
|
|
22
|
+
MAX_QUERY_DEPTH,
|
|
23
|
+
MAX_REGEX_COMPLEXITY,
|
|
24
|
+
MAX_REGEX_LENGTH,
|
|
25
|
+
MAX_SORT_FIELDS,
|
|
26
|
+
)
|
|
27
|
+
from ..exceptions import QueryValidationError
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class QueryValidator:
|
|
33
|
+
"""
|
|
34
|
+
Validates MongoDB queries for security and safety.
|
|
35
|
+
|
|
36
|
+
This class provides comprehensive validation to prevent:
|
|
37
|
+
- NoSQL injection attacks
|
|
38
|
+
- Dangerous operator usage
|
|
39
|
+
- Resource exhaustion via complex queries
|
|
40
|
+
- ReDoS attacks via complex regex patterns
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
max_depth: int = MAX_QUERY_DEPTH,
|
|
46
|
+
max_pipeline_stages: int = MAX_PIPELINE_STAGES,
|
|
47
|
+
max_regex_length: int = MAX_REGEX_LENGTH,
|
|
48
|
+
max_regex_complexity: int = MAX_REGEX_COMPLEXITY,
|
|
49
|
+
dangerous_operators: set[str] | None = None,
|
|
50
|
+
):
|
|
51
|
+
"""
|
|
52
|
+
Initialize the query validator.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
max_depth: Maximum nesting depth for queries
|
|
56
|
+
max_pipeline_stages: Maximum stages in aggregation pipelines
|
|
57
|
+
max_regex_length: Maximum length for regex patterns
|
|
58
|
+
max_regex_complexity: Maximum complexity score for regex patterns
|
|
59
|
+
dangerous_operators: Set of dangerous operators to block
|
|
60
|
+
(defaults to DANGEROUS_OPERATORS)
|
|
61
|
+
"""
|
|
62
|
+
self.max_depth = max_depth
|
|
63
|
+
self.max_pipeline_stages = max_pipeline_stages
|
|
64
|
+
self.max_regex_length = max_regex_length
|
|
65
|
+
self.max_regex_complexity = max_regex_complexity
|
|
66
|
+
# Merge custom dangerous operators with defaults
|
|
67
|
+
if dangerous_operators is not None:
|
|
68
|
+
# Convert DANGEROUS_OPERATORS tuple to set for union operation
|
|
69
|
+
default_ops = (
|
|
70
|
+
set(DANGEROUS_OPERATORS)
|
|
71
|
+
if isinstance(DANGEROUS_OPERATORS, tuple)
|
|
72
|
+
else DANGEROUS_OPERATORS
|
|
73
|
+
)
|
|
74
|
+
self.dangerous_operators = default_ops | set(dangerous_operators)
|
|
75
|
+
else:
|
|
76
|
+
# Convert tuple to set for consistency
|
|
77
|
+
self.dangerous_operators = (
|
|
78
|
+
set(DANGEROUS_OPERATORS)
|
|
79
|
+
if isinstance(DANGEROUS_OPERATORS, tuple)
|
|
80
|
+
else DANGEROUS_OPERATORS
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
def validate_filter(self, filter: dict[str, Any] | None, path: str = "") -> None:
|
|
84
|
+
"""
|
|
85
|
+
Validate a MongoDB query filter.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
filter: The query filter to validate
|
|
89
|
+
path: JSON path for error reporting (used recursively)
|
|
90
|
+
|
|
91
|
+
Raises:
|
|
92
|
+
QueryValidationError: If the filter contains dangerous operators or exceeds limits
|
|
93
|
+
"""
|
|
94
|
+
if not filter:
|
|
95
|
+
return
|
|
96
|
+
|
|
97
|
+
if not isinstance(filter, dict):
|
|
98
|
+
raise QueryValidationError(
|
|
99
|
+
f"Query filter must be a dictionary, got {type(filter).__name__}",
|
|
100
|
+
query_type="filter",
|
|
101
|
+
path=path,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
# Check for dangerous operators and validate depth
|
|
105
|
+
self._check_dangerous_operators(filter, path)
|
|
106
|
+
self._check_query_depth(filter, path, depth=0)
|
|
107
|
+
|
|
108
|
+
def validate_pipeline(self, pipeline: list[dict[str, Any]]) -> None:
|
|
109
|
+
"""
|
|
110
|
+
Validate an aggregation pipeline.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
pipeline: The aggregation pipeline to validate
|
|
114
|
+
|
|
115
|
+
Raises:
|
|
116
|
+
QueryValidationError: If the pipeline exceeds limits or contains dangerous operators
|
|
117
|
+
"""
|
|
118
|
+
if not pipeline:
|
|
119
|
+
return
|
|
120
|
+
|
|
121
|
+
if not isinstance(pipeline, list):
|
|
122
|
+
raise QueryValidationError(
|
|
123
|
+
f"Aggregation pipeline must be a list, got {type(pipeline).__name__}",
|
|
124
|
+
query_type="pipeline",
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
# Check pipeline length
|
|
128
|
+
if len(pipeline) > self.max_pipeline_stages:
|
|
129
|
+
raise QueryValidationError(
|
|
130
|
+
f"Aggregation pipeline exceeds maximum stages: "
|
|
131
|
+
f"{len(pipeline)} > {self.max_pipeline_stages}",
|
|
132
|
+
query_type="pipeline",
|
|
133
|
+
context={
|
|
134
|
+
"stages": len(pipeline),
|
|
135
|
+
"max_stages": self.max_pipeline_stages,
|
|
136
|
+
},
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
# Validate each stage
|
|
140
|
+
for idx, stage in enumerate(pipeline):
|
|
141
|
+
if not isinstance(stage, dict):
|
|
142
|
+
raise QueryValidationError(
|
|
143
|
+
f"Pipeline stage {idx} must be a dictionary, got {type(stage).__name__}",
|
|
144
|
+
query_type="pipeline",
|
|
145
|
+
path=f"$[{idx}]",
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
# Check for dangerous operators in each stage
|
|
149
|
+
stage_path = f"$[{idx}]"
|
|
150
|
+
self._check_dangerous_operators(stage, stage_path)
|
|
151
|
+
self._check_query_depth(stage, stage_path, depth=0)
|
|
152
|
+
|
|
153
|
+
def validate_regex(self, pattern: str, path: str = "") -> None:
|
|
154
|
+
"""
|
|
155
|
+
Validate a regex pattern to prevent ReDoS attacks.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
pattern: The regex pattern to validate
|
|
159
|
+
path: JSON path for error reporting
|
|
160
|
+
|
|
161
|
+
Raises:
|
|
162
|
+
QueryValidationError: If the regex pattern is too complex or long
|
|
163
|
+
"""
|
|
164
|
+
if not isinstance(pattern, str):
|
|
165
|
+
return # Not a regex pattern
|
|
166
|
+
|
|
167
|
+
# Check length
|
|
168
|
+
if len(pattern) > self.max_regex_length:
|
|
169
|
+
raise QueryValidationError(
|
|
170
|
+
f"Regex pattern exceeds maximum length: "
|
|
171
|
+
f"{len(pattern)} > {self.max_regex_length}",
|
|
172
|
+
query_type="regex",
|
|
173
|
+
path=path,
|
|
174
|
+
context={
|
|
175
|
+
"length": len(pattern),
|
|
176
|
+
"max_length": self.max_regex_length,
|
|
177
|
+
},
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
# Check complexity (simple heuristic: count quantifiers and alternations)
|
|
181
|
+
complexity = self._calculate_regex_complexity(pattern)
|
|
182
|
+
if complexity > self.max_regex_complexity:
|
|
183
|
+
raise QueryValidationError(
|
|
184
|
+
f"Regex pattern exceeds maximum complexity: "
|
|
185
|
+
f"{complexity} > {self.max_regex_complexity}",
|
|
186
|
+
query_type="regex",
|
|
187
|
+
path=path,
|
|
188
|
+
context={
|
|
189
|
+
"complexity": complexity,
|
|
190
|
+
"max_complexity": self.max_regex_complexity,
|
|
191
|
+
},
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
# Try to compile the regex to catch syntax errors early
|
|
195
|
+
try:
|
|
196
|
+
re.compile(pattern)
|
|
197
|
+
except re.error as e:
|
|
198
|
+
raise QueryValidationError(
|
|
199
|
+
f"Invalid regex pattern: {e}",
|
|
200
|
+
query_type="regex",
|
|
201
|
+
path=path,
|
|
202
|
+
) from e
|
|
203
|
+
|
|
204
|
+
def validate_sort(self, sort: Any | None) -> None:
|
|
205
|
+
"""
|
|
206
|
+
Validate a sort specification.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
sort: The sort specification to validate
|
|
210
|
+
|
|
211
|
+
Raises:
|
|
212
|
+
QueryValidationError: If the sort specification exceeds limits
|
|
213
|
+
"""
|
|
214
|
+
if not sort:
|
|
215
|
+
return
|
|
216
|
+
|
|
217
|
+
# Count sort fields
|
|
218
|
+
sort_fields = self._extract_sort_fields(sort)
|
|
219
|
+
if len(sort_fields) > MAX_SORT_FIELDS:
|
|
220
|
+
raise QueryValidationError(
|
|
221
|
+
f"Sort specification exceeds maximum fields: "
|
|
222
|
+
f"{len(sort_fields)} > {MAX_SORT_FIELDS}",
|
|
223
|
+
query_type="sort",
|
|
224
|
+
context={
|
|
225
|
+
"fields": len(sort_fields),
|
|
226
|
+
"max_fields": MAX_SORT_FIELDS,
|
|
227
|
+
},
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
def _check_dangerous_operators(
|
|
231
|
+
self, query: dict[str, Any], path: str = "", depth: int = 0
|
|
232
|
+
) -> None:
|
|
233
|
+
"""
|
|
234
|
+
Recursively check for dangerous operators in a query.
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
query: The query dictionary to check
|
|
238
|
+
path: Current JSON path for error reporting
|
|
239
|
+
depth: Current nesting depth
|
|
240
|
+
|
|
241
|
+
Raises:
|
|
242
|
+
QueryValidationError: If a dangerous operator is found
|
|
243
|
+
"""
|
|
244
|
+
if depth > self.max_depth:
|
|
245
|
+
raise QueryValidationError(
|
|
246
|
+
f"Query exceeds maximum nesting depth: {depth} > {self.max_depth}",
|
|
247
|
+
query_type="filter",
|
|
248
|
+
path=path,
|
|
249
|
+
context={"depth": depth, "max_depth": self.max_depth},
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
for key, value in query.items():
|
|
253
|
+
current_path = f"{path}.{key}" if path else key
|
|
254
|
+
|
|
255
|
+
# Check if key is a dangerous operator
|
|
256
|
+
if key in self.dangerous_operators:
|
|
257
|
+
logger.warning(
|
|
258
|
+
f"Security: Dangerous operator '{key}' detected in query "
|
|
259
|
+
f"at path '{current_path}'"
|
|
260
|
+
)
|
|
261
|
+
raise QueryValidationError(
|
|
262
|
+
f"Dangerous operator '{key}' is not allowed for security reasons. "
|
|
263
|
+
f"Found at path: {current_path}",
|
|
264
|
+
query_type="filter",
|
|
265
|
+
operator=key,
|
|
266
|
+
path=current_path,
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
# Recursively check nested dictionaries
|
|
270
|
+
if isinstance(value, dict):
|
|
271
|
+
# Check for $regex operator and validate pattern
|
|
272
|
+
if "$regex" in value:
|
|
273
|
+
regex_pattern = value["$regex"]
|
|
274
|
+
if isinstance(regex_pattern, str):
|
|
275
|
+
self.validate_regex(regex_pattern, f"{current_path}.$regex")
|
|
276
|
+
self._check_dangerous_operators(value, current_path, depth + 1)
|
|
277
|
+
elif isinstance(value, list):
|
|
278
|
+
# Check list elements
|
|
279
|
+
for idx, item in enumerate(value):
|
|
280
|
+
if isinstance(item, dict):
|
|
281
|
+
item_path = f"{current_path}[{idx}]"
|
|
282
|
+
# Check for $regex in list items
|
|
283
|
+
if "$regex" in item and isinstance(item["$regex"], str):
|
|
284
|
+
self.validate_regex(item["$regex"], f"{item_path}.$regex")
|
|
285
|
+
self._check_dangerous_operators(item, item_path, depth + 1)
|
|
286
|
+
elif isinstance(value, str) and key == "$regex":
|
|
287
|
+
# Direct $regex value (less common but possible)
|
|
288
|
+
self.validate_regex(value, current_path)
|
|
289
|
+
|
|
290
|
+
def _check_query_depth(self, query: dict[str, Any], path: str = "", depth: int = 0) -> None:
|
|
291
|
+
"""
|
|
292
|
+
Check query nesting depth.
|
|
293
|
+
|
|
294
|
+
Args:
|
|
295
|
+
query: The query dictionary to check
|
|
296
|
+
path: Current JSON path for error reporting
|
|
297
|
+
depth: Current nesting depth
|
|
298
|
+
|
|
299
|
+
Raises:
|
|
300
|
+
QueryValidationError: If query depth exceeds maximum
|
|
301
|
+
"""
|
|
302
|
+
if depth > self.max_depth:
|
|
303
|
+
raise QueryValidationError(
|
|
304
|
+
f"Query exceeds maximum nesting depth: {depth} > {self.max_depth}",
|
|
305
|
+
query_type="filter",
|
|
306
|
+
path=path,
|
|
307
|
+
context={"depth": depth, "max_depth": self.max_depth},
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
# Recursively check nested dictionaries
|
|
311
|
+
for key, value in query.items():
|
|
312
|
+
current_path = f"{path}.{key}" if path else key
|
|
313
|
+
|
|
314
|
+
if isinstance(value, dict):
|
|
315
|
+
self._check_query_depth(value, current_path, depth + 1)
|
|
316
|
+
elif isinstance(value, list):
|
|
317
|
+
for idx, item in enumerate(value):
|
|
318
|
+
if isinstance(item, dict):
|
|
319
|
+
item_path = f"{current_path}[{idx}]"
|
|
320
|
+
self._check_query_depth(item, item_path, depth + 1)
|
|
321
|
+
|
|
322
|
+
def _calculate_regex_complexity(self, pattern: str) -> int:
|
|
323
|
+
"""
|
|
324
|
+
Calculate a complexity score for a regex pattern.
|
|
325
|
+
|
|
326
|
+
This is a simple heuristic to detect potentially dangerous regex patterns
|
|
327
|
+
that could cause ReDoS attacks.
|
|
328
|
+
|
|
329
|
+
Args:
|
|
330
|
+
pattern: The regex pattern
|
|
331
|
+
|
|
332
|
+
Returns:
|
|
333
|
+
Complexity score (higher = more complex)
|
|
334
|
+
"""
|
|
335
|
+
complexity = 0
|
|
336
|
+
|
|
337
|
+
# Count quantifiers (can cause backtracking)
|
|
338
|
+
complexity += len(re.findall(r"[*+?{]", pattern))
|
|
339
|
+
|
|
340
|
+
# Count alternations (can cause exponential growth)
|
|
341
|
+
complexity += len(re.findall(r"\|", pattern))
|
|
342
|
+
|
|
343
|
+
# Count nested groups (can cause deep backtracking)
|
|
344
|
+
complexity += len(re.findall(r"\([^)]*\([^)]*\)", pattern))
|
|
345
|
+
|
|
346
|
+
# Count lookahead/lookbehind (can be expensive)
|
|
347
|
+
complexity += len(re.findall(r"\(\?[=!<>]", pattern))
|
|
348
|
+
|
|
349
|
+
return complexity
|
|
350
|
+
|
|
351
|
+
def _extract_sort_fields(self, sort: Any) -> list[str]:
|
|
352
|
+
"""
|
|
353
|
+
Extract field names from a sort specification.
|
|
354
|
+
|
|
355
|
+
Args:
|
|
356
|
+
sort: Sort specification (list of tuples, dict, or single tuple)
|
|
357
|
+
|
|
358
|
+
Returns:
|
|
359
|
+
List of field names
|
|
360
|
+
"""
|
|
361
|
+
if isinstance(sort, list):
|
|
362
|
+
return [field for field, _ in sort if isinstance(field, str)]
|
|
363
|
+
elif isinstance(sort, dict):
|
|
364
|
+
return list(sort.keys())
|
|
365
|
+
elif isinstance(sort, tuple) and len(sort) == 2:
|
|
366
|
+
return [sort[0]] if isinstance(sort[0], str) else []
|
|
367
|
+
return []
|