mdb-engine 0.2.0__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mdb_engine/__init__.py +1 -1
- mdb_engine/auth/audit.py +40 -40
- mdb_engine/auth/base.py +3 -3
- mdb_engine/auth/casbin_factory.py +6 -6
- mdb_engine/auth/config_defaults.py +5 -5
- mdb_engine/auth/config_helpers.py +12 -12
- mdb_engine/auth/cookie_utils.py +9 -9
- mdb_engine/auth/csrf.py +9 -8
- mdb_engine/auth/decorators.py +7 -6
- mdb_engine/auth/dependencies.py +22 -21
- mdb_engine/auth/integration.py +9 -9
- mdb_engine/auth/jwt.py +9 -9
- mdb_engine/auth/middleware.py +4 -3
- mdb_engine/auth/oso_factory.py +6 -6
- mdb_engine/auth/provider.py +4 -4
- mdb_engine/auth/rate_limiter.py +12 -11
- mdb_engine/auth/restrictions.py +16 -15
- mdb_engine/auth/session_manager.py +11 -13
- mdb_engine/auth/shared_middleware.py +16 -15
- mdb_engine/auth/shared_users.py +20 -20
- mdb_engine/auth/token_lifecycle.py +10 -12
- mdb_engine/auth/token_store.py +4 -5
- mdb_engine/auth/users.py +51 -52
- mdb_engine/auth/utils.py +29 -33
- mdb_engine/cli/commands/generate.py +6 -6
- mdb_engine/cli/utils.py +4 -4
- mdb_engine/config.py +6 -7
- mdb_engine/core/app_registration.py +12 -12
- mdb_engine/core/app_secrets.py +1 -2
- mdb_engine/core/connection.py +3 -4
- mdb_engine/core/encryption.py +1 -2
- mdb_engine/core/engine.py +43 -44
- mdb_engine/core/manifest.py +59 -58
- mdb_engine/core/ray_integration.py +10 -9
- mdb_engine/core/seeding.py +3 -3
- mdb_engine/core/service_initialization.py +10 -9
- mdb_engine/core/types.py +40 -40
- mdb_engine/database/abstraction.py +15 -16
- mdb_engine/database/connection.py +40 -12
- mdb_engine/database/query_validator.py +8 -8
- mdb_engine/database/resource_limiter.py +7 -7
- mdb_engine/database/scoped_wrapper.py +51 -58
- mdb_engine/dependencies.py +14 -13
- mdb_engine/di/container.py +12 -13
- mdb_engine/di/providers.py +14 -13
- mdb_engine/di/scopes.py +5 -5
- mdb_engine/embeddings/dependencies.py +2 -2
- mdb_engine/embeddings/service.py +31 -43
- mdb_engine/exceptions.py +20 -20
- mdb_engine/indexes/helpers.py +11 -11
- mdb_engine/indexes/manager.py +9 -9
- mdb_engine/memory/service.py +30 -30
- mdb_engine/observability/health.py +10 -9
- mdb_engine/observability/logging.py +10 -10
- mdb_engine/observability/metrics.py +8 -7
- mdb_engine/repositories/base.py +25 -25
- mdb_engine/repositories/mongo.py +17 -17
- mdb_engine/repositories/unit_of_work.py +6 -6
- mdb_engine/routing/websockets.py +19 -18
- {mdb_engine-0.2.0.dist-info → mdb_engine-0.2.3.dist-info}/METADATA +8 -8
- mdb_engine-0.2.3.dist-info/RECORD +96 -0
- mdb_engine-0.2.0.dist-info/RECORD +0 -96
- {mdb_engine-0.2.0.dist-info → mdb_engine-0.2.3.dist-info}/WHEEL +0 -0
- {mdb_engine-0.2.0.dist-info → mdb_engine-0.2.3.dist-info}/entry_points.txt +0 -0
- {mdb_engine-0.2.0.dist-info → mdb_engine-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {mdb_engine-0.2.0.dist-info → mdb_engine-0.2.3.dist-info}/top_level.txt +0 -0
|
@@ -28,17 +28,12 @@ import asyncio
|
|
|
28
28
|
import logging
|
|
29
29
|
import re
|
|
30
30
|
import time
|
|
31
|
+
from collections.abc import Coroutine, Mapping
|
|
31
32
|
from typing import (
|
|
32
33
|
TYPE_CHECKING,
|
|
33
34
|
Any,
|
|
34
35
|
ClassVar,
|
|
35
|
-
Coroutine,
|
|
36
|
-
Dict,
|
|
37
|
-
List,
|
|
38
|
-
Mapping,
|
|
39
36
|
Optional,
|
|
40
|
-
Tuple,
|
|
41
|
-
Union,
|
|
42
37
|
)
|
|
43
38
|
|
|
44
39
|
if TYPE_CHECKING:
|
|
@@ -98,7 +93,7 @@ GEO2DSPHERE = "2dsphere"
|
|
|
98
93
|
|
|
99
94
|
|
|
100
95
|
# --- HELPER FUNCTION FOR MANAGED TASK CREATION ---
|
|
101
|
-
def _create_managed_task(coro: Coroutine[Any, Any, Any], task_name:
|
|
96
|
+
def _create_managed_task(coro: Coroutine[Any, Any, Any], task_name: str | None = None) -> None:
|
|
102
97
|
"""
|
|
103
98
|
Creates a background task using asyncio.create_task().
|
|
104
99
|
|
|
@@ -205,7 +200,7 @@ def _validate_collection_name(name: str, allow_prefixed: bool = False) -> None:
|
|
|
205
200
|
)
|
|
206
201
|
|
|
207
202
|
|
|
208
|
-
def _extract_app_slug_from_prefixed_name(prefixed_name: str) ->
|
|
203
|
+
def _extract_app_slug_from_prefixed_name(prefixed_name: str) -> str | None:
|
|
209
204
|
"""
|
|
210
205
|
Extract app slug from a prefixed collection name.
|
|
211
206
|
|
|
@@ -335,11 +330,11 @@ class AsyncAtlasIndexManager:
|
|
|
335
330
|
|
|
336
331
|
def _check_definition_changed(
|
|
337
332
|
self,
|
|
338
|
-
definition:
|
|
339
|
-
latest_def:
|
|
333
|
+
definition: dict[str, Any],
|
|
334
|
+
latest_def: dict[str, Any],
|
|
340
335
|
index_type: str,
|
|
341
336
|
name: str,
|
|
342
|
-
) ->
|
|
337
|
+
) -> tuple[bool, str]:
|
|
343
338
|
"""Check if index definition has changed."""
|
|
344
339
|
definition_changed = False
|
|
345
340
|
change_reason = ""
|
|
@@ -362,8 +357,8 @@ class AsyncAtlasIndexManager:
|
|
|
362
357
|
|
|
363
358
|
async def _handle_existing_index(
|
|
364
359
|
self,
|
|
365
|
-
existing_index:
|
|
366
|
-
definition:
|
|
360
|
+
existing_index: dict[str, Any],
|
|
361
|
+
definition: dict[str, Any],
|
|
367
362
|
index_type: str,
|
|
368
363
|
name: str,
|
|
369
364
|
) -> bool:
|
|
@@ -402,7 +397,7 @@ class AsyncAtlasIndexManager:
|
|
|
402
397
|
return False # Will wait below
|
|
403
398
|
|
|
404
399
|
async def _create_new_search_index(
|
|
405
|
-
self, name: str, definition:
|
|
400
|
+
self, name: str, definition: dict[str, Any], index_type: str
|
|
406
401
|
) -> None:
|
|
407
402
|
"""Create a new search index."""
|
|
408
403
|
try:
|
|
@@ -422,7 +417,7 @@ class AsyncAtlasIndexManager:
|
|
|
422
417
|
async def create_search_index(
|
|
423
418
|
self,
|
|
424
419
|
name: str,
|
|
425
|
-
definition:
|
|
420
|
+
definition: dict[str, Any],
|
|
426
421
|
index_type: str = "search",
|
|
427
422
|
wait_for_ready: bool = True,
|
|
428
423
|
timeout: int = DEFAULT_SEARCH_TIMEOUT,
|
|
@@ -472,7 +467,7 @@ class AsyncAtlasIndexManager:
|
|
|
472
467
|
context={"index_name": name, "operation": "create_search_index"},
|
|
473
468
|
) from e
|
|
474
469
|
|
|
475
|
-
async def get_search_index(self, name: str) ->
|
|
470
|
+
async def get_search_index(self, name: str) -> dict[str, Any] | None:
|
|
476
471
|
"""
|
|
477
472
|
Retrieves the definition and status of a single search index by name
|
|
478
473
|
using the $listSearchIndexes aggregation stage.
|
|
@@ -496,7 +491,7 @@ class AsyncAtlasIndexManager:
|
|
|
496
491
|
context={"index_name": name, "operation": "get_search_index"},
|
|
497
492
|
) from e
|
|
498
493
|
|
|
499
|
-
async def list_search_indexes(self) ->
|
|
494
|
+
async def list_search_indexes(self) -> list[dict[str, Any]]:
|
|
500
495
|
"""Lists all Atlas Search indexes for the collection."""
|
|
501
496
|
try:
|
|
502
497
|
return await self._collection.list_search_indexes().to_list(None)
|
|
@@ -552,7 +547,7 @@ class AsyncAtlasIndexManager:
|
|
|
552
547
|
async def update_search_index(
|
|
553
548
|
self,
|
|
554
549
|
name: str,
|
|
555
|
-
definition:
|
|
550
|
+
definition: dict[str, Any],
|
|
556
551
|
wait_for_ready: bool = True,
|
|
557
552
|
timeout: int = DEFAULT_SEARCH_TIMEOUT,
|
|
558
553
|
) -> bool:
|
|
@@ -675,7 +670,7 @@ class AsyncAtlasIndexManager:
|
|
|
675
670
|
# consistent async API with the search index methods.
|
|
676
671
|
|
|
677
672
|
async def create_index( # noqa: C901
|
|
678
|
-
self, keys:
|
|
673
|
+
self, keys: str | list[tuple[str, int | str]], **kwargs: Any
|
|
679
674
|
) -> str:
|
|
680
675
|
"""
|
|
681
676
|
Creates a standard (non-search) database index.
|
|
@@ -784,8 +779,8 @@ class AsyncAtlasIndexManager:
|
|
|
784
779
|
|
|
785
780
|
async def create_text_index(
|
|
786
781
|
self,
|
|
787
|
-
fields:
|
|
788
|
-
weights:
|
|
782
|
+
fields: list[str],
|
|
783
|
+
weights: dict[str, int] | None = None,
|
|
789
784
|
name: str = "text_index",
|
|
790
785
|
**kwargs: Any,
|
|
791
786
|
) -> str:
|
|
@@ -797,7 +792,7 @@ class AsyncAtlasIndexManager:
|
|
|
797
792
|
kwargs["name"] = name
|
|
798
793
|
return await self.create_index(keys, **kwargs)
|
|
799
794
|
|
|
800
|
-
async def create_geo_index(self, field: str, name:
|
|
795
|
+
async def create_geo_index(self, field: str, name: str | None = None, **kwargs: Any) -> str:
|
|
801
796
|
"""Helper to create a standard 2dsphere index."""
|
|
802
797
|
keys = [(field, GEO2DSPHERE)]
|
|
803
798
|
if name:
|
|
@@ -832,7 +827,7 @@ class AsyncAtlasIndexManager:
|
|
|
832
827
|
context={"index_name": name, "operation": "drop_index"},
|
|
833
828
|
) from e
|
|
834
829
|
|
|
835
|
-
async def list_indexes(self) ->
|
|
830
|
+
async def list_indexes(self) -> list[dict[str, Any]]:
|
|
836
831
|
"""Lists all standard (non-search) indexes on the collection."""
|
|
837
832
|
try:
|
|
838
833
|
return await self._collection.list_indexes().to_list(None)
|
|
@@ -844,7 +839,7 @@ class AsyncAtlasIndexManager:
|
|
|
844
839
|
logger.debug("Skipping list_indexes: MongoDB client is closed (likely during shutdown)")
|
|
845
840
|
return []
|
|
846
841
|
|
|
847
|
-
async def get_index(self, name: str) ->
|
|
842
|
+
async def get_index(self, name: str) -> dict[str, Any] | None:
|
|
848
843
|
"""Gets a single standard index by name."""
|
|
849
844
|
indexes = await self.list_indexes()
|
|
850
845
|
return next((index for index in indexes if index.get("name") == name), None)
|
|
@@ -919,17 +914,17 @@ class AutoIndexManager:
|
|
|
919
914
|
self._collection = collection
|
|
920
915
|
self._index_manager = index_manager
|
|
921
916
|
# Cache of index creation decisions (index_name -> bool)
|
|
922
|
-
self._creation_cache:
|
|
917
|
+
self._creation_cache: dict[str, bool] = {}
|
|
923
918
|
# Async lock to prevent race conditions during index creation
|
|
924
919
|
self._lock = asyncio.Lock()
|
|
925
920
|
# Track query patterns to determine which indexes to create
|
|
926
|
-
self._query_counts:
|
|
921
|
+
self._query_counts: dict[str, int] = {}
|
|
927
922
|
# Track in-flight index creation tasks to prevent duplicates
|
|
928
|
-
self._pending_tasks:
|
|
923
|
+
self._pending_tasks: dict[str, asyncio.Task] = {}
|
|
929
924
|
|
|
930
925
|
def _extract_index_fields_from_filter(
|
|
931
|
-
self, filter:
|
|
932
|
-
) ->
|
|
926
|
+
self, filter: Mapping[str, Any] | None
|
|
927
|
+
) -> list[tuple[str, int]]:
|
|
933
928
|
"""
|
|
934
929
|
Extracts potential index fields from a MongoDB query filter.
|
|
935
930
|
|
|
@@ -944,7 +939,7 @@ class AutoIndexManager:
|
|
|
944
939
|
if not filter:
|
|
945
940
|
return []
|
|
946
941
|
|
|
947
|
-
index_fields:
|
|
942
|
+
index_fields: list[tuple[str, int]] = []
|
|
948
943
|
|
|
949
944
|
def analyze_value(value: Any, field_name: str) -> None:
|
|
950
945
|
"""Recursively analyze filter values to extract index candidates."""
|
|
@@ -976,8 +971,8 @@ class AutoIndexManager:
|
|
|
976
971
|
return list(set(index_fields)) # Remove duplicates
|
|
977
972
|
|
|
978
973
|
def _extract_sort_fields(
|
|
979
|
-
self, sort:
|
|
980
|
-
) ->
|
|
974
|
+
self, sort: list[tuple[str, int]] | dict[str, int] | None
|
|
975
|
+
) -> list[tuple[str, int]]:
|
|
981
976
|
"""
|
|
982
977
|
Extracts index fields from sort specification.
|
|
983
978
|
|
|
@@ -993,7 +988,7 @@ class AutoIndexManager:
|
|
|
993
988
|
else:
|
|
994
989
|
return []
|
|
995
990
|
|
|
996
|
-
def _generate_index_name(self, fields:
|
|
991
|
+
def _generate_index_name(self, fields: list[tuple[str, int]]) -> str:
|
|
997
992
|
"""Generate a human-readable index name from field list."""
|
|
998
993
|
if not fields:
|
|
999
994
|
return "auto_idx_empty"
|
|
@@ -1006,7 +1001,7 @@ class AutoIndexManager:
|
|
|
1006
1001
|
return f"auto_{'_'.join(parts)}"
|
|
1007
1002
|
|
|
1008
1003
|
async def _create_index_safely(
|
|
1009
|
-
self, index_name: str, all_fields:
|
|
1004
|
+
self, index_name: str, all_fields: list[tuple[str, int]]
|
|
1010
1005
|
) -> None:
|
|
1011
1006
|
"""
|
|
1012
1007
|
Safely create an index, handling errors gracefully.
|
|
@@ -1052,8 +1047,8 @@ class AutoIndexManager:
|
|
|
1052
1047
|
|
|
1053
1048
|
async def ensure_index_for_query(
|
|
1054
1049
|
self,
|
|
1055
|
-
filter:
|
|
1056
|
-
sort:
|
|
1050
|
+
filter: Mapping[str, Any] | None = None,
|
|
1051
|
+
sort: list[tuple[str, int]] | dict[str, int] | None = None,
|
|
1057
1052
|
hint_threshold: int = AUTO_INDEX_HINT_THRESHOLD,
|
|
1058
1053
|
) -> None:
|
|
1059
1054
|
"""
|
|
@@ -1170,11 +1165,11 @@ class ScopedCollectionWrapper:
|
|
|
1170
1165
|
def __init__(
|
|
1171
1166
|
self,
|
|
1172
1167
|
real_collection: AsyncIOMotorCollection,
|
|
1173
|
-
read_scopes:
|
|
1168
|
+
read_scopes: list[str],
|
|
1174
1169
|
write_scope: str,
|
|
1175
1170
|
auto_index: bool = True,
|
|
1176
|
-
query_validator:
|
|
1177
|
-
resource_limiter:
|
|
1171
|
+
query_validator: QueryValidator | None = None,
|
|
1172
|
+
resource_limiter: ResourceLimiter | None = None,
|
|
1178
1173
|
parent_wrapper: Optional["ScopedMongoWrapper"] = None,
|
|
1179
1174
|
):
|
|
1180
1175
|
self._collection = real_collection
|
|
@@ -1182,8 +1177,8 @@ class ScopedCollectionWrapper:
|
|
|
1182
1177
|
self._write_scope = write_scope
|
|
1183
1178
|
self._auto_index_enabled = auto_index
|
|
1184
1179
|
# Lazily instantiated and cached
|
|
1185
|
-
self._index_manager:
|
|
1186
|
-
self._auto_index_manager:
|
|
1180
|
+
self._index_manager: AsyncAtlasIndexManager | None = None
|
|
1181
|
+
self._auto_index_manager: AutoIndexManager | None = None
|
|
1187
1182
|
# Query security and resource limits
|
|
1188
1183
|
self._query_validator = query_validator or QueryValidator()
|
|
1189
1184
|
self._resource_limiter = resource_limiter or ResourceLimiter()
|
|
@@ -1211,7 +1206,7 @@ class ScopedCollectionWrapper:
|
|
|
1211
1206
|
return self._index_manager
|
|
1212
1207
|
|
|
1213
1208
|
@property
|
|
1214
|
-
def auto_index_manager(self) ->
|
|
1209
|
+
def auto_index_manager(self) -> AutoIndexManager | None:
|
|
1215
1210
|
"""
|
|
1216
1211
|
Gets the AutoIndexManager for magical automatic index creation.
|
|
1217
1212
|
|
|
@@ -1267,7 +1262,7 @@ class ScopedCollectionWrapper:
|
|
|
1267
1262
|
)
|
|
1268
1263
|
super().__setattr__(name, value)
|
|
1269
1264
|
|
|
1270
|
-
def _inject_read_filter(self, filter:
|
|
1265
|
+
def _inject_read_filter(self, filter: Mapping[str, Any] | None = None) -> dict[str, Any]:
|
|
1271
1266
|
"""
|
|
1272
1267
|
Combines the user's filter with our mandatory scope filter.
|
|
1273
1268
|
|
|
@@ -1357,7 +1352,7 @@ class ScopedCollectionWrapper:
|
|
|
1357
1352
|
) from e
|
|
1358
1353
|
|
|
1359
1354
|
async def insert_many(
|
|
1360
|
-
self, documents:
|
|
1355
|
+
self, documents: list[Mapping[str, Any]], *args, **kwargs
|
|
1361
1356
|
) -> InsertManyResult:
|
|
1362
1357
|
"""
|
|
1363
1358
|
Injects the app_id into all documents before writing.
|
|
@@ -1378,8 +1373,8 @@ class ScopedCollectionWrapper:
|
|
|
1378
1373
|
return await self._collection.insert_many(docs_to_insert, *args, **kwargs_for_insert)
|
|
1379
1374
|
|
|
1380
1375
|
async def find_one(
|
|
1381
|
-
self, filter:
|
|
1382
|
-
) ->
|
|
1376
|
+
self, filter: Mapping[str, Any] | None = None, *args, **kwargs
|
|
1377
|
+
) -> dict[str, Any] | None:
|
|
1383
1378
|
"""
|
|
1384
1379
|
Applies the read scope to the filter.
|
|
1385
1380
|
Automatically ensures appropriate indexes exist for the query.
|
|
@@ -1437,9 +1432,7 @@ class ScopedCollectionWrapper:
|
|
|
1437
1432
|
)
|
|
1438
1433
|
raise
|
|
1439
1434
|
|
|
1440
|
-
def find(
|
|
1441
|
-
self, filter: Optional[Mapping[str, Any]] = None, *args, **kwargs
|
|
1442
|
-
) -> AsyncIOMotorCursor:
|
|
1435
|
+
def find(self, filter: Mapping[str, Any] | None = None, *args, **kwargs) -> AsyncIOMotorCursor:
|
|
1443
1436
|
"""
|
|
1444
1437
|
Applies the read scope to the filter.
|
|
1445
1438
|
Returns an async cursor, just like motor.
|
|
@@ -1550,7 +1543,7 @@ class ScopedCollectionWrapper:
|
|
|
1550
1543
|
return await self._collection.delete_many(scoped_filter, *args, **kwargs_for_delete)
|
|
1551
1544
|
|
|
1552
1545
|
async def count_documents(
|
|
1553
|
-
self, filter:
|
|
1546
|
+
self, filter: Mapping[str, Any] | None = None, *args, **kwargs
|
|
1554
1547
|
) -> int:
|
|
1555
1548
|
"""
|
|
1556
1549
|
Applies the read scope to the filter for counting.
|
|
@@ -1571,7 +1564,7 @@ class ScopedCollectionWrapper:
|
|
|
1571
1564
|
scoped_filter = self._inject_read_filter(filter)
|
|
1572
1565
|
return await self._collection.count_documents(scoped_filter, *args, **kwargs_for_count)
|
|
1573
1566
|
|
|
1574
|
-
def aggregate(self, pipeline:
|
|
1567
|
+
def aggregate(self, pipeline: list[dict[str, Any]], *args, **kwargs) -> AsyncIOMotorCursor:
|
|
1575
1568
|
"""
|
|
1576
1569
|
Injects a scope filter into the pipeline. For normal pipelines, we prepend
|
|
1577
1570
|
a $match stage. However, if the first stage is $vectorSearch, we embed
|
|
@@ -1639,7 +1632,7 @@ class ScopedMongoWrapper:
|
|
|
1639
1632
|
|
|
1640
1633
|
# Class-level cache for collections that have app_id index checked
|
|
1641
1634
|
# Key: collection name, Value: boolean (True if index exists, False if check is pending)
|
|
1642
|
-
_app_id_index_cache: ClassVar[
|
|
1635
|
+
_app_id_index_cache: ClassVar[dict[str, bool]] = {}
|
|
1643
1636
|
# Lock to prevent race conditions when multiple requests try to create the same index
|
|
1644
1637
|
_app_id_index_lock: ClassVar[asyncio.Lock] = asyncio.Lock()
|
|
1645
1638
|
|
|
@@ -1661,13 +1654,13 @@ class ScopedMongoWrapper:
|
|
|
1661
1654
|
def __init__(
|
|
1662
1655
|
self,
|
|
1663
1656
|
real_db: AsyncIOMotorDatabase,
|
|
1664
|
-
read_scopes:
|
|
1657
|
+
read_scopes: list[str],
|
|
1665
1658
|
write_scope: str,
|
|
1666
1659
|
auto_index: bool = True,
|
|
1667
|
-
query_validator:
|
|
1668
|
-
resource_limiter:
|
|
1669
|
-
app_slug:
|
|
1670
|
-
app_token:
|
|
1660
|
+
query_validator: QueryValidator | None = None,
|
|
1661
|
+
resource_limiter: ResourceLimiter | None = None,
|
|
1662
|
+
app_slug: str | None = None,
|
|
1663
|
+
app_token: str | None = None,
|
|
1671
1664
|
app_secrets_manager: Optional["AppSecretsManager"] = None,
|
|
1672
1665
|
):
|
|
1673
1666
|
self._db = real_db
|
|
@@ -1687,7 +1680,7 @@ class ScopedMongoWrapper:
|
|
|
1687
1680
|
self._token_verification_lock = asyncio.Lock()
|
|
1688
1681
|
|
|
1689
1682
|
# Cache for created collection wrappers.
|
|
1690
|
-
self._wrapper_cache:
|
|
1683
|
+
self._wrapper_cache: dict[str, ScopedCollectionWrapper] = {}
|
|
1691
1684
|
|
|
1692
1685
|
async def _verify_token_if_needed(self) -> None:
|
|
1693
1686
|
"""
|
mdb_engine/dependencies.py
CHANGED
|
@@ -18,7 +18,8 @@ Usage:
|
|
|
18
18
|
|
|
19
19
|
import logging
|
|
20
20
|
import os
|
|
21
|
-
from
|
|
21
|
+
from collections.abc import Callable
|
|
22
|
+
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
|
|
22
23
|
|
|
23
24
|
from fastapi import HTTPException, Request
|
|
24
25
|
|
|
@@ -62,7 +63,7 @@ async def get_app_slug(request: Request) -> str:
|
|
|
62
63
|
return slug
|
|
63
64
|
|
|
64
65
|
|
|
65
|
-
async def get_app_config(request: Request) ->
|
|
66
|
+
async def get_app_config(request: Request) -> dict[str, Any]:
|
|
66
67
|
"""Get the app's manifest configuration."""
|
|
67
68
|
manifest = getattr(request.app.state, "manifest", None)
|
|
68
69
|
if manifest is None:
|
|
@@ -169,12 +170,12 @@ async def get_authz_provider(request: Request) -> Optional["AuthorizationProvide
|
|
|
169
170
|
return getattr(request.app.state, "authz_provider", None)
|
|
170
171
|
|
|
171
172
|
|
|
172
|
-
async def get_current_user(request: Request) ->
|
|
173
|
+
async def get_current_user(request: Request) -> dict[str, Any] | None:
|
|
173
174
|
"""Get the current authenticated user."""
|
|
174
175
|
return getattr(request.state, "user", None)
|
|
175
176
|
|
|
176
177
|
|
|
177
|
-
async def get_user_roles(request: Request) ->
|
|
178
|
+
async def get_user_roles(request: Request) -> list[str]:
|
|
178
179
|
"""Get the current user's roles."""
|
|
179
180
|
return getattr(request.state, "user_roles", [])
|
|
180
181
|
|
|
@@ -182,7 +183,7 @@ async def get_user_roles(request: Request) -> List[str]:
|
|
|
182
183
|
def require_user():
|
|
183
184
|
"""Dependency that requires authentication."""
|
|
184
185
|
|
|
185
|
-
async def _require_user(request: Request) ->
|
|
186
|
+
async def _require_user(request: Request) -> dict[str, Any]:
|
|
186
187
|
user = await get_current_user(request)
|
|
187
188
|
if not user:
|
|
188
189
|
raise HTTPException(401, "Authentication required")
|
|
@@ -194,7 +195,7 @@ def require_user():
|
|
|
194
195
|
def require_role(*roles: str):
|
|
195
196
|
"""Dependency that requires specific roles."""
|
|
196
197
|
|
|
197
|
-
async def _require_role(request: Request) ->
|
|
198
|
+
async def _require_role(request: Request) -> dict[str, Any]:
|
|
198
199
|
user = await get_current_user(request)
|
|
199
200
|
if not user:
|
|
200
201
|
raise HTTPException(401, "Authentication required")
|
|
@@ -272,7 +273,7 @@ class RequestContext:
|
|
|
272
273
|
return self._uow
|
|
273
274
|
|
|
274
275
|
@property
|
|
275
|
-
def config(self) ->
|
|
276
|
+
def config(self) -> dict[str, Any]:
|
|
276
277
|
"""Get the app's manifest configuration."""
|
|
277
278
|
if self._config is None:
|
|
278
279
|
self._config = getattr(self.request.app.state, "manifest", None)
|
|
@@ -331,14 +332,14 @@ class RequestContext:
|
|
|
331
332
|
return get_llm_model_name()
|
|
332
333
|
|
|
333
334
|
@property
|
|
334
|
-
def user(self) ->
|
|
335
|
+
def user(self) -> dict[str, Any] | None:
|
|
335
336
|
"""Get the current authenticated user."""
|
|
336
337
|
if self._user is None:
|
|
337
338
|
self._user = getattr(self.request.state, "user", None)
|
|
338
339
|
return self._user
|
|
339
340
|
|
|
340
341
|
@property
|
|
341
|
-
def user_roles(self) ->
|
|
342
|
+
def user_roles(self) -> list[str]:
|
|
342
343
|
"""Get the current user's roles."""
|
|
343
344
|
return getattr(self.request.state, "user_roles", [])
|
|
344
345
|
|
|
@@ -349,13 +350,13 @@ class RequestContext:
|
|
|
349
350
|
self._authz = getattr(self.request.app.state, "authz_provider", None)
|
|
350
351
|
return self._authz
|
|
351
352
|
|
|
352
|
-
def require_user(self) ->
|
|
353
|
+
def require_user(self) -> dict[str, Any]:
|
|
353
354
|
"""Require authentication, raising 401 if not authenticated."""
|
|
354
355
|
if not self.user:
|
|
355
356
|
raise HTTPException(401, "Authentication required")
|
|
356
357
|
return self.user
|
|
357
358
|
|
|
358
|
-
def require_role(self, *roles: str) ->
|
|
359
|
+
def require_role(self, *roles: str) -> dict[str, Any]:
|
|
359
360
|
"""Require specific roles, raising 403 if not authorized."""
|
|
360
361
|
user = self.require_user()
|
|
361
362
|
user_roles = set(self.user_roles)
|
|
@@ -365,7 +366,7 @@ class RequestContext:
|
|
|
365
366
|
return user
|
|
366
367
|
|
|
367
368
|
async def check_permission(
|
|
368
|
-
self, resource: str, action: str, subject:
|
|
369
|
+
self, resource: str, action: str, subject: str | None = None
|
|
369
370
|
) -> bool:
|
|
370
371
|
"""Check if current user has permission for an action."""
|
|
371
372
|
if not self.authz:
|
|
@@ -390,7 +391,7 @@ RequestContext.__call__ = staticmethod(_get_request_context)
|
|
|
390
391
|
# =============================================================================
|
|
391
392
|
|
|
392
393
|
|
|
393
|
-
def inject(service_type:
|
|
394
|
+
def inject(service_type: type[T]) -> Callable[..., T]:
|
|
394
395
|
"""Create a dependency that resolves a service from the DI container."""
|
|
395
396
|
|
|
396
397
|
async def _resolve(request: Request) -> T:
|
mdb_engine/di/container.py
CHANGED
|
@@ -5,13 +5,12 @@ A lightweight, FastAPI-native DI container with proper service lifetimes.
|
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
7
|
import logging
|
|
8
|
-
from
|
|
8
|
+
from collections.abc import Callable
|
|
9
|
+
from typing import Any, Optional, TypeVar
|
|
9
10
|
|
|
11
|
+
from .providers import Provider
|
|
10
12
|
from .scopes import Scope
|
|
11
13
|
|
|
12
|
-
if TYPE_CHECKING:
|
|
13
|
-
from .providers import Provider
|
|
14
|
-
|
|
15
14
|
logger = logging.getLogger(__name__)
|
|
16
15
|
|
|
17
16
|
T = TypeVar("T")
|
|
@@ -50,8 +49,8 @@ class Container:
|
|
|
50
49
|
_global_instance: Optional["Container"] = None
|
|
51
50
|
|
|
52
51
|
def __init__(self):
|
|
53
|
-
self._providers:
|
|
54
|
-
self._instances:
|
|
52
|
+
self._providers: dict[type, Provider] = {}
|
|
53
|
+
self._instances: dict[type, Any] = {} # For register_instance
|
|
55
54
|
|
|
56
55
|
@classmethod
|
|
57
56
|
def get_global(cls) -> "Container":
|
|
@@ -72,8 +71,8 @@ class Container:
|
|
|
72
71
|
|
|
73
72
|
def register(
|
|
74
73
|
self,
|
|
75
|
-
service_type:
|
|
76
|
-
implementation:
|
|
74
|
+
service_type: type[T],
|
|
75
|
+
implementation: type[T] | None = None,
|
|
77
76
|
scope: Scope = Scope.SINGLETON,
|
|
78
77
|
) -> "Container":
|
|
79
78
|
"""
|
|
@@ -107,7 +106,7 @@ class Container:
|
|
|
107
106
|
|
|
108
107
|
def register_factory(
|
|
109
108
|
self,
|
|
110
|
-
service_type:
|
|
109
|
+
service_type: type[T],
|
|
111
110
|
factory: Callable[["Container"], T],
|
|
112
111
|
scope: Scope = Scope.SINGLETON,
|
|
113
112
|
) -> "Container":
|
|
@@ -137,7 +136,7 @@ class Container:
|
|
|
137
136
|
logger.debug(f"Registered factory for {service_type.__name__} as {scope.value}")
|
|
138
137
|
return self
|
|
139
138
|
|
|
140
|
-
def register_instance(self, service_type:
|
|
139
|
+
def register_instance(self, service_type: type[T], instance: T) -> "Container":
|
|
141
140
|
"""
|
|
142
141
|
Register an existing instance as a singleton.
|
|
143
142
|
|
|
@@ -154,7 +153,7 @@ class Container:
|
|
|
154
153
|
logger.debug(f"Registered instance for {service_type.__name__}")
|
|
155
154
|
return self
|
|
156
155
|
|
|
157
|
-
def resolve(self, service_type:
|
|
156
|
+
def resolve(self, service_type: type[T]) -> T:
|
|
158
157
|
"""
|
|
159
158
|
Resolve a service instance.
|
|
160
159
|
|
|
@@ -180,7 +179,7 @@ class Container:
|
|
|
180
179
|
|
|
181
180
|
return self._providers[service_type].get(self)
|
|
182
181
|
|
|
183
|
-
def try_resolve(self, service_type:
|
|
182
|
+
def try_resolve(self, service_type: type[T]) -> T | None:
|
|
184
183
|
"""
|
|
185
184
|
Try to resolve a service, returning None if not registered.
|
|
186
185
|
|
|
@@ -218,7 +217,7 @@ class Container:
|
|
|
218
217
|
|
|
219
218
|
|
|
220
219
|
# FastAPI integration helpers
|
|
221
|
-
def inject(service_type:
|
|
220
|
+
def inject(service_type: type[T]) -> T:
|
|
222
221
|
"""
|
|
223
222
|
FastAPI dependency that resolves a service from the global container.
|
|
224
223
|
|
mdb_engine/di/providers.py
CHANGED
|
@@ -8,7 +8,8 @@ according to their configured scope.
|
|
|
8
8
|
import inspect
|
|
9
9
|
import logging
|
|
10
10
|
from abc import ABC, abstractmethod
|
|
11
|
-
from
|
|
11
|
+
from collections.abc import Callable
|
|
12
|
+
from typing import TYPE_CHECKING, Any, Generic, TypeVar
|
|
12
13
|
|
|
13
14
|
from .scopes import Scope, ScopeManager
|
|
14
15
|
|
|
@@ -30,9 +31,9 @@ class Provider(ABC, Generic[T]):
|
|
|
30
31
|
|
|
31
32
|
def __init__(
|
|
32
33
|
self,
|
|
33
|
-
service_type:
|
|
34
|
+
service_type: type[T],
|
|
34
35
|
scope: Scope,
|
|
35
|
-
factory:
|
|
36
|
+
factory: Callable[..., T] | None = None,
|
|
36
37
|
):
|
|
37
38
|
self.service_type = service_type
|
|
38
39
|
self.scope = scope
|
|
@@ -60,7 +61,7 @@ class Provider(ABC, Generic[T]):
|
|
|
60
61
|
"""
|
|
61
62
|
# Get constructor signature
|
|
62
63
|
sig = inspect.signature(self._factory)
|
|
63
|
-
kwargs:
|
|
64
|
+
kwargs: dict[str, Any] = {}
|
|
64
65
|
|
|
65
66
|
for param_name, param in sig.parameters.items():
|
|
66
67
|
if param_name == "self":
|
|
@@ -101,11 +102,11 @@ class SingletonProvider(Provider[T]):
|
|
|
101
102
|
|
|
102
103
|
def __init__(
|
|
103
104
|
self,
|
|
104
|
-
service_type:
|
|
105
|
-
factory:
|
|
105
|
+
service_type: type[T],
|
|
106
|
+
factory: Callable[..., T] | None = None,
|
|
106
107
|
):
|
|
107
108
|
super().__init__(service_type, Scope.SINGLETON, factory)
|
|
108
|
-
self._instance:
|
|
109
|
+
self._instance: T | None = None
|
|
109
110
|
|
|
110
111
|
def get(self, container: "Container") -> T:
|
|
111
112
|
if self._instance is None:
|
|
@@ -127,8 +128,8 @@ class RequestProvider(Provider[T]):
|
|
|
127
128
|
|
|
128
129
|
def __init__(
|
|
129
130
|
self,
|
|
130
|
-
service_type:
|
|
131
|
-
factory:
|
|
131
|
+
service_type: type[T],
|
|
132
|
+
factory: Callable[..., T] | None = None,
|
|
132
133
|
):
|
|
133
134
|
super().__init__(service_type, Scope.REQUEST, factory)
|
|
134
135
|
|
|
@@ -145,8 +146,8 @@ class TransientProvider(Provider[T]):
|
|
|
145
146
|
|
|
146
147
|
def __init__(
|
|
147
148
|
self,
|
|
148
|
-
service_type:
|
|
149
|
-
factory:
|
|
149
|
+
service_type: type[T],
|
|
150
|
+
factory: Callable[..., T] | None = None,
|
|
150
151
|
):
|
|
151
152
|
super().__init__(service_type, Scope.TRANSIENT, factory)
|
|
152
153
|
|
|
@@ -173,13 +174,13 @@ class FactoryProvider(Provider[T]):
|
|
|
173
174
|
|
|
174
175
|
def __init__(
|
|
175
176
|
self,
|
|
176
|
-
service_type:
|
|
177
|
+
service_type: type[T],
|
|
177
178
|
factory: Callable[["Container"], T],
|
|
178
179
|
scope: Scope,
|
|
179
180
|
):
|
|
180
181
|
super().__init__(service_type, scope, None)
|
|
181
182
|
self._custom_factory = factory
|
|
182
|
-
self._singleton_instance:
|
|
183
|
+
self._singleton_instance: T | None = None
|
|
183
184
|
|
|
184
185
|
def get(self, container: "Container") -> T:
|
|
185
186
|
if self.scope == Scope.SINGLETON:
|
mdb_engine/di/scopes.py
CHANGED
|
@@ -10,12 +10,12 @@ Defines service lifetime scopes following enterprise patterns:
|
|
|
10
10
|
import logging
|
|
11
11
|
from contextvars import ContextVar
|
|
12
12
|
from enum import Enum
|
|
13
|
-
from typing import Any
|
|
13
|
+
from typing import Any
|
|
14
14
|
|
|
15
15
|
logger = logging.getLogger(__name__)
|
|
16
16
|
|
|
17
17
|
# Context variable for request-scoped instances
|
|
18
|
-
_request_scope: ContextVar[
|
|
18
|
+
_request_scope: ContextVar[dict[type, Any] | None] = ContextVar("request_scope", default=None)
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
class Scope(Enum):
|
|
@@ -51,13 +51,13 @@ class ScopeManager:
|
|
|
51
51
|
"""
|
|
52
52
|
|
|
53
53
|
@classmethod
|
|
54
|
-
def begin_request(cls) ->
|
|
54
|
+
def begin_request(cls) -> dict[type, Any]:
|
|
55
55
|
"""
|
|
56
56
|
Begin a new request scope.
|
|
57
57
|
|
|
58
58
|
Returns the scope dictionary for manual management if needed.
|
|
59
59
|
"""
|
|
60
|
-
scope_dict:
|
|
60
|
+
scope_dict: dict[type, Any] = {}
|
|
61
61
|
_request_scope.set(scope_dict)
|
|
62
62
|
logger.debug("Request scope started")
|
|
63
63
|
return scope_dict
|
|
@@ -82,7 +82,7 @@ class ScopeManager:
|
|
|
82
82
|
logger.debug("Request scope ended")
|
|
83
83
|
|
|
84
84
|
@classmethod
|
|
85
|
-
def get_request_scope(cls) ->
|
|
85
|
+
def get_request_scope(cls) -> dict[type, Any] | None:
|
|
86
86
|
"""Get the current request scope dictionary."""
|
|
87
87
|
return _request_scope.get()
|
|
88
88
|
|
|
@@ -19,7 +19,7 @@ Usage:
|
|
|
19
19
|
service = get_embedding_service_for_app("my_app", engine)
|
|
20
20
|
"""
|
|
21
21
|
|
|
22
|
-
from typing import TYPE_CHECKING
|
|
22
|
+
from typing import TYPE_CHECKING
|
|
23
23
|
|
|
24
24
|
if TYPE_CHECKING:
|
|
25
25
|
from ..core.engine import MongoDBEngine
|
|
@@ -29,7 +29,7 @@ from .service import EmbeddingService, get_embedding_service
|
|
|
29
29
|
|
|
30
30
|
def get_embedding_service_for_app(
|
|
31
31
|
app_slug: str, engine: "MongoDBEngine"
|
|
32
|
-
) ->
|
|
32
|
+
) -> EmbeddingService | None:
|
|
33
33
|
"""
|
|
34
34
|
Get embedding service for a specific app using the engine instance.
|
|
35
35
|
|