mdb-engine 0.2.0__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. mdb_engine/__init__.py +1 -1
  2. mdb_engine/auth/audit.py +40 -40
  3. mdb_engine/auth/base.py +3 -3
  4. mdb_engine/auth/casbin_factory.py +6 -6
  5. mdb_engine/auth/config_defaults.py +5 -5
  6. mdb_engine/auth/config_helpers.py +12 -12
  7. mdb_engine/auth/cookie_utils.py +9 -9
  8. mdb_engine/auth/csrf.py +9 -8
  9. mdb_engine/auth/decorators.py +7 -6
  10. mdb_engine/auth/dependencies.py +22 -21
  11. mdb_engine/auth/integration.py +9 -9
  12. mdb_engine/auth/jwt.py +9 -9
  13. mdb_engine/auth/middleware.py +4 -3
  14. mdb_engine/auth/oso_factory.py +6 -6
  15. mdb_engine/auth/provider.py +4 -4
  16. mdb_engine/auth/rate_limiter.py +12 -11
  17. mdb_engine/auth/restrictions.py +16 -15
  18. mdb_engine/auth/session_manager.py +11 -13
  19. mdb_engine/auth/shared_middleware.py +16 -15
  20. mdb_engine/auth/shared_users.py +20 -20
  21. mdb_engine/auth/token_lifecycle.py +10 -12
  22. mdb_engine/auth/token_store.py +4 -5
  23. mdb_engine/auth/users.py +51 -52
  24. mdb_engine/auth/utils.py +29 -33
  25. mdb_engine/cli/commands/generate.py +6 -6
  26. mdb_engine/cli/utils.py +4 -4
  27. mdb_engine/config.py +6 -7
  28. mdb_engine/core/app_registration.py +12 -12
  29. mdb_engine/core/app_secrets.py +1 -2
  30. mdb_engine/core/connection.py +3 -4
  31. mdb_engine/core/encryption.py +1 -2
  32. mdb_engine/core/engine.py +43 -44
  33. mdb_engine/core/manifest.py +59 -58
  34. mdb_engine/core/ray_integration.py +10 -9
  35. mdb_engine/core/seeding.py +3 -3
  36. mdb_engine/core/service_initialization.py +10 -9
  37. mdb_engine/core/types.py +40 -40
  38. mdb_engine/database/abstraction.py +15 -16
  39. mdb_engine/database/connection.py +40 -12
  40. mdb_engine/database/query_validator.py +8 -8
  41. mdb_engine/database/resource_limiter.py +7 -7
  42. mdb_engine/database/scoped_wrapper.py +51 -58
  43. mdb_engine/dependencies.py +14 -13
  44. mdb_engine/di/container.py +12 -13
  45. mdb_engine/di/providers.py +14 -13
  46. mdb_engine/di/scopes.py +5 -5
  47. mdb_engine/embeddings/dependencies.py +2 -2
  48. mdb_engine/embeddings/service.py +31 -43
  49. mdb_engine/exceptions.py +20 -20
  50. mdb_engine/indexes/helpers.py +11 -11
  51. mdb_engine/indexes/manager.py +9 -9
  52. mdb_engine/memory/service.py +30 -30
  53. mdb_engine/observability/health.py +10 -9
  54. mdb_engine/observability/logging.py +10 -10
  55. mdb_engine/observability/metrics.py +8 -7
  56. mdb_engine/repositories/base.py +25 -25
  57. mdb_engine/repositories/mongo.py +17 -17
  58. mdb_engine/repositories/unit_of_work.py +6 -6
  59. mdb_engine/routing/websockets.py +19 -18
  60. {mdb_engine-0.2.0.dist-info → mdb_engine-0.2.3.dist-info}/METADATA +8 -8
  61. mdb_engine-0.2.3.dist-info/RECORD +96 -0
  62. mdb_engine-0.2.0.dist-info/RECORD +0 -96
  63. {mdb_engine-0.2.0.dist-info → mdb_engine-0.2.3.dist-info}/WHEEL +0 -0
  64. {mdb_engine-0.2.0.dist-info → mdb_engine-0.2.3.dist-info}/entry_points.txt +0 -0
  65. {mdb_engine-0.2.0.dist-info → mdb_engine-0.2.3.dist-info}/licenses/LICENSE +0 -0
  66. {mdb_engine-0.2.0.dist-info → mdb_engine-0.2.3.dist-info}/top_level.txt +0 -0
@@ -28,17 +28,12 @@ import asyncio
28
28
  import logging
29
29
  import re
30
30
  import time
31
+ from collections.abc import Coroutine, Mapping
31
32
  from typing import (
32
33
  TYPE_CHECKING,
33
34
  Any,
34
35
  ClassVar,
35
- Coroutine,
36
- Dict,
37
- List,
38
- Mapping,
39
36
  Optional,
40
- Tuple,
41
- Union,
42
37
  )
43
38
 
44
39
  if TYPE_CHECKING:
@@ -98,7 +93,7 @@ GEO2DSPHERE = "2dsphere"
98
93
 
99
94
 
100
95
  # --- HELPER FUNCTION FOR MANAGED TASK CREATION ---
101
- def _create_managed_task(coro: Coroutine[Any, Any, Any], task_name: Optional[str] = None) -> None:
96
+ def _create_managed_task(coro: Coroutine[Any, Any, Any], task_name: str | None = None) -> None:
102
97
  """
103
98
  Creates a background task using asyncio.create_task().
104
99
 
@@ -205,7 +200,7 @@ def _validate_collection_name(name: str, allow_prefixed: bool = False) -> None:
205
200
  )
206
201
 
207
202
 
208
- def _extract_app_slug_from_prefixed_name(prefixed_name: str) -> Optional[str]:
203
+ def _extract_app_slug_from_prefixed_name(prefixed_name: str) -> str | None:
209
204
  """
210
205
  Extract app slug from a prefixed collection name.
211
206
 
@@ -335,11 +330,11 @@ class AsyncAtlasIndexManager:
335
330
 
336
331
  def _check_definition_changed(
337
332
  self,
338
- definition: Dict[str, Any],
339
- latest_def: Dict[str, Any],
333
+ definition: dict[str, Any],
334
+ latest_def: dict[str, Any],
340
335
  index_type: str,
341
336
  name: str,
342
- ) -> Tuple[bool, str]:
337
+ ) -> tuple[bool, str]:
343
338
  """Check if index definition has changed."""
344
339
  definition_changed = False
345
340
  change_reason = ""
@@ -362,8 +357,8 @@ class AsyncAtlasIndexManager:
362
357
 
363
358
  async def _handle_existing_index(
364
359
  self,
365
- existing_index: Dict[str, Any],
366
- definition: Dict[str, Any],
360
+ existing_index: dict[str, Any],
361
+ definition: dict[str, Any],
367
362
  index_type: str,
368
363
  name: str,
369
364
  ) -> bool:
@@ -402,7 +397,7 @@ class AsyncAtlasIndexManager:
402
397
  return False # Will wait below
403
398
 
404
399
  async def _create_new_search_index(
405
- self, name: str, definition: Dict[str, Any], index_type: str
400
+ self, name: str, definition: dict[str, Any], index_type: str
406
401
  ) -> None:
407
402
  """Create a new search index."""
408
403
  try:
@@ -422,7 +417,7 @@ class AsyncAtlasIndexManager:
422
417
  async def create_search_index(
423
418
  self,
424
419
  name: str,
425
- definition: Dict[str, Any],
420
+ definition: dict[str, Any],
426
421
  index_type: str = "search",
427
422
  wait_for_ready: bool = True,
428
423
  timeout: int = DEFAULT_SEARCH_TIMEOUT,
@@ -472,7 +467,7 @@ class AsyncAtlasIndexManager:
472
467
  context={"index_name": name, "operation": "create_search_index"},
473
468
  ) from e
474
469
 
475
- async def get_search_index(self, name: str) -> Optional[Dict[str, Any]]:
470
+ async def get_search_index(self, name: str) -> dict[str, Any] | None:
476
471
  """
477
472
  Retrieves the definition and status of a single search index by name
478
473
  using the $listSearchIndexes aggregation stage.
@@ -496,7 +491,7 @@ class AsyncAtlasIndexManager:
496
491
  context={"index_name": name, "operation": "get_search_index"},
497
492
  ) from e
498
493
 
499
- async def list_search_indexes(self) -> List[Dict[str, Any]]:
494
+ async def list_search_indexes(self) -> list[dict[str, Any]]:
500
495
  """Lists all Atlas Search indexes for the collection."""
501
496
  try:
502
497
  return await self._collection.list_search_indexes().to_list(None)
@@ -552,7 +547,7 @@ class AsyncAtlasIndexManager:
552
547
  async def update_search_index(
553
548
  self,
554
549
  name: str,
555
- definition: Dict[str, Any],
550
+ definition: dict[str, Any],
556
551
  wait_for_ready: bool = True,
557
552
  timeout: int = DEFAULT_SEARCH_TIMEOUT,
558
553
  ) -> bool:
@@ -675,7 +670,7 @@ class AsyncAtlasIndexManager:
675
670
  # consistent async API with the search index methods.
676
671
 
677
672
  async def create_index( # noqa: C901
678
- self, keys: Union[str, List[Tuple[str, Union[int, str]]]], **kwargs: Any
673
+ self, keys: str | list[tuple[str, int | str]], **kwargs: Any
679
674
  ) -> str:
680
675
  """
681
676
  Creates a standard (non-search) database index.
@@ -784,8 +779,8 @@ class AsyncAtlasIndexManager:
784
779
 
785
780
  async def create_text_index(
786
781
  self,
787
- fields: List[str],
788
- weights: Optional[Dict[str, int]] = None,
782
+ fields: list[str],
783
+ weights: dict[str, int] | None = None,
789
784
  name: str = "text_index",
790
785
  **kwargs: Any,
791
786
  ) -> str:
@@ -797,7 +792,7 @@ class AsyncAtlasIndexManager:
797
792
  kwargs["name"] = name
798
793
  return await self.create_index(keys, **kwargs)
799
794
 
800
- async def create_geo_index(self, field: str, name: Optional[str] = None, **kwargs: Any) -> str:
795
+ async def create_geo_index(self, field: str, name: str | None = None, **kwargs: Any) -> str:
801
796
  """Helper to create a standard 2dsphere index."""
802
797
  keys = [(field, GEO2DSPHERE)]
803
798
  if name:
@@ -832,7 +827,7 @@ class AsyncAtlasIndexManager:
832
827
  context={"index_name": name, "operation": "drop_index"},
833
828
  ) from e
834
829
 
835
- async def list_indexes(self) -> List[Dict[str, Any]]:
830
+ async def list_indexes(self) -> list[dict[str, Any]]:
836
831
  """Lists all standard (non-search) indexes on the collection."""
837
832
  try:
838
833
  return await self._collection.list_indexes().to_list(None)
@@ -844,7 +839,7 @@ class AsyncAtlasIndexManager:
844
839
  logger.debug("Skipping list_indexes: MongoDB client is closed (likely during shutdown)")
845
840
  return []
846
841
 
847
- async def get_index(self, name: str) -> Optional[Dict[str, Any]]:
842
+ async def get_index(self, name: str) -> dict[str, Any] | None:
848
843
  """Gets a single standard index by name."""
849
844
  indexes = await self.list_indexes()
850
845
  return next((index for index in indexes if index.get("name") == name), None)
@@ -919,17 +914,17 @@ class AutoIndexManager:
919
914
  self._collection = collection
920
915
  self._index_manager = index_manager
921
916
  # Cache of index creation decisions (index_name -> bool)
922
- self._creation_cache: Dict[str, bool] = {}
917
+ self._creation_cache: dict[str, bool] = {}
923
918
  # Async lock to prevent race conditions during index creation
924
919
  self._lock = asyncio.Lock()
925
920
  # Track query patterns to determine which indexes to create
926
- self._query_counts: Dict[str, int] = {}
921
+ self._query_counts: dict[str, int] = {}
927
922
  # Track in-flight index creation tasks to prevent duplicates
928
- self._pending_tasks: Dict[str, asyncio.Task] = {}
923
+ self._pending_tasks: dict[str, asyncio.Task] = {}
929
924
 
930
925
  def _extract_index_fields_from_filter(
931
- self, filter: Optional[Mapping[str, Any]]
932
- ) -> List[Tuple[str, int]]:
926
+ self, filter: Mapping[str, Any] | None
927
+ ) -> list[tuple[str, int]]:
933
928
  """
934
929
  Extracts potential index fields from a MongoDB query filter.
935
930
 
@@ -944,7 +939,7 @@ class AutoIndexManager:
944
939
  if not filter:
945
940
  return []
946
941
 
947
- index_fields: List[Tuple[str, int]] = []
942
+ index_fields: list[tuple[str, int]] = []
948
943
 
949
944
  def analyze_value(value: Any, field_name: str) -> None:
950
945
  """Recursively analyze filter values to extract index candidates."""
@@ -976,8 +971,8 @@ class AutoIndexManager:
976
971
  return list(set(index_fields)) # Remove duplicates
977
972
 
978
973
  def _extract_sort_fields(
979
- self, sort: Optional[Union[List[Tuple[str, int]], Dict[str, int]]]
980
- ) -> List[Tuple[str, int]]:
974
+ self, sort: list[tuple[str, int]] | dict[str, int] | None
975
+ ) -> list[tuple[str, int]]:
981
976
  """
982
977
  Extracts index fields from sort specification.
983
978
 
@@ -993,7 +988,7 @@ class AutoIndexManager:
993
988
  else:
994
989
  return []
995
990
 
996
- def _generate_index_name(self, fields: List[Tuple[str, int]]) -> str:
991
+ def _generate_index_name(self, fields: list[tuple[str, int]]) -> str:
997
992
  """Generate a human-readable index name from field list."""
998
993
  if not fields:
999
994
  return "auto_idx_empty"
@@ -1006,7 +1001,7 @@ class AutoIndexManager:
1006
1001
  return f"auto_{'_'.join(parts)}"
1007
1002
 
1008
1003
  async def _create_index_safely(
1009
- self, index_name: str, all_fields: List[Tuple[str, int]]
1004
+ self, index_name: str, all_fields: list[tuple[str, int]]
1010
1005
  ) -> None:
1011
1006
  """
1012
1007
  Safely create an index, handling errors gracefully.
@@ -1052,8 +1047,8 @@ class AutoIndexManager:
1052
1047
 
1053
1048
  async def ensure_index_for_query(
1054
1049
  self,
1055
- filter: Optional[Mapping[str, Any]] = None,
1056
- sort: Optional[Union[List[Tuple[str, int]], Dict[str, int]]] = None,
1050
+ filter: Mapping[str, Any] | None = None,
1051
+ sort: list[tuple[str, int]] | dict[str, int] | None = None,
1057
1052
  hint_threshold: int = AUTO_INDEX_HINT_THRESHOLD,
1058
1053
  ) -> None:
1059
1054
  """
@@ -1170,11 +1165,11 @@ class ScopedCollectionWrapper:
1170
1165
  def __init__(
1171
1166
  self,
1172
1167
  real_collection: AsyncIOMotorCollection,
1173
- read_scopes: List[str],
1168
+ read_scopes: list[str],
1174
1169
  write_scope: str,
1175
1170
  auto_index: bool = True,
1176
- query_validator: Optional[QueryValidator] = None,
1177
- resource_limiter: Optional[ResourceLimiter] = None,
1171
+ query_validator: QueryValidator | None = None,
1172
+ resource_limiter: ResourceLimiter | None = None,
1178
1173
  parent_wrapper: Optional["ScopedMongoWrapper"] = None,
1179
1174
  ):
1180
1175
  self._collection = real_collection
@@ -1182,8 +1177,8 @@ class ScopedCollectionWrapper:
1182
1177
  self._write_scope = write_scope
1183
1178
  self._auto_index_enabled = auto_index
1184
1179
  # Lazily instantiated and cached
1185
- self._index_manager: Optional[AsyncAtlasIndexManager] = None
1186
- self._auto_index_manager: Optional[AutoIndexManager] = None
1180
+ self._index_manager: AsyncAtlasIndexManager | None = None
1181
+ self._auto_index_manager: AutoIndexManager | None = None
1187
1182
  # Query security and resource limits
1188
1183
  self._query_validator = query_validator or QueryValidator()
1189
1184
  self._resource_limiter = resource_limiter or ResourceLimiter()
@@ -1211,7 +1206,7 @@ class ScopedCollectionWrapper:
1211
1206
  return self._index_manager
1212
1207
 
1213
1208
  @property
1214
- def auto_index_manager(self) -> Optional[AutoIndexManager]:
1209
+ def auto_index_manager(self) -> AutoIndexManager | None:
1215
1210
  """
1216
1211
  Gets the AutoIndexManager for magical automatic index creation.
1217
1212
 
@@ -1267,7 +1262,7 @@ class ScopedCollectionWrapper:
1267
1262
  )
1268
1263
  super().__setattr__(name, value)
1269
1264
 
1270
- def _inject_read_filter(self, filter: Optional[Mapping[str, Any]] = None) -> Dict[str, Any]:
1265
+ def _inject_read_filter(self, filter: Mapping[str, Any] | None = None) -> dict[str, Any]:
1271
1266
  """
1272
1267
  Combines the user's filter with our mandatory scope filter.
1273
1268
 
@@ -1357,7 +1352,7 @@ class ScopedCollectionWrapper:
1357
1352
  ) from e
1358
1353
 
1359
1354
  async def insert_many(
1360
- self, documents: List[Mapping[str, Any]], *args, **kwargs
1355
+ self, documents: list[Mapping[str, Any]], *args, **kwargs
1361
1356
  ) -> InsertManyResult:
1362
1357
  """
1363
1358
  Injects the app_id into all documents before writing.
@@ -1378,8 +1373,8 @@ class ScopedCollectionWrapper:
1378
1373
  return await self._collection.insert_many(docs_to_insert, *args, **kwargs_for_insert)
1379
1374
 
1380
1375
  async def find_one(
1381
- self, filter: Optional[Mapping[str, Any]] = None, *args, **kwargs
1382
- ) -> Optional[Dict[str, Any]]:
1376
+ self, filter: Mapping[str, Any] | None = None, *args, **kwargs
1377
+ ) -> dict[str, Any] | None:
1383
1378
  """
1384
1379
  Applies the read scope to the filter.
1385
1380
  Automatically ensures appropriate indexes exist for the query.
@@ -1437,9 +1432,7 @@ class ScopedCollectionWrapper:
1437
1432
  )
1438
1433
  raise
1439
1434
 
1440
- def find(
1441
- self, filter: Optional[Mapping[str, Any]] = None, *args, **kwargs
1442
- ) -> AsyncIOMotorCursor:
1435
+ def find(self, filter: Mapping[str, Any] | None = None, *args, **kwargs) -> AsyncIOMotorCursor:
1443
1436
  """
1444
1437
  Applies the read scope to the filter.
1445
1438
  Returns an async cursor, just like motor.
@@ -1550,7 +1543,7 @@ class ScopedCollectionWrapper:
1550
1543
  return await self._collection.delete_many(scoped_filter, *args, **kwargs_for_delete)
1551
1544
 
1552
1545
  async def count_documents(
1553
- self, filter: Optional[Mapping[str, Any]] = None, *args, **kwargs
1546
+ self, filter: Mapping[str, Any] | None = None, *args, **kwargs
1554
1547
  ) -> int:
1555
1548
  """
1556
1549
  Applies the read scope to the filter for counting.
@@ -1571,7 +1564,7 @@ class ScopedCollectionWrapper:
1571
1564
  scoped_filter = self._inject_read_filter(filter)
1572
1565
  return await self._collection.count_documents(scoped_filter, *args, **kwargs_for_count)
1573
1566
 
1574
- def aggregate(self, pipeline: List[Dict[str, Any]], *args, **kwargs) -> AsyncIOMotorCursor:
1567
+ def aggregate(self, pipeline: list[dict[str, Any]], *args, **kwargs) -> AsyncIOMotorCursor:
1575
1568
  """
1576
1569
  Injects a scope filter into the pipeline. For normal pipelines, we prepend
1577
1570
  a $match stage. However, if the first stage is $vectorSearch, we embed
@@ -1639,7 +1632,7 @@ class ScopedMongoWrapper:
1639
1632
 
1640
1633
  # Class-level cache for collections that have app_id index checked
1641
1634
  # Key: collection name, Value: boolean (True if index exists, False if check is pending)
1642
- _app_id_index_cache: ClassVar[Dict[str, bool]] = {}
1635
+ _app_id_index_cache: ClassVar[dict[str, bool]] = {}
1643
1636
  # Lock to prevent race conditions when multiple requests try to create the same index
1644
1637
  _app_id_index_lock: ClassVar[asyncio.Lock] = asyncio.Lock()
1645
1638
 
@@ -1661,13 +1654,13 @@ class ScopedMongoWrapper:
1661
1654
  def __init__(
1662
1655
  self,
1663
1656
  real_db: AsyncIOMotorDatabase,
1664
- read_scopes: List[str],
1657
+ read_scopes: list[str],
1665
1658
  write_scope: str,
1666
1659
  auto_index: bool = True,
1667
- query_validator: Optional[QueryValidator] = None,
1668
- resource_limiter: Optional[ResourceLimiter] = None,
1669
- app_slug: Optional[str] = None,
1670
- app_token: Optional[str] = None,
1660
+ query_validator: QueryValidator | None = None,
1661
+ resource_limiter: ResourceLimiter | None = None,
1662
+ app_slug: str | None = None,
1663
+ app_token: str | None = None,
1671
1664
  app_secrets_manager: Optional["AppSecretsManager"] = None,
1672
1665
  ):
1673
1666
  self._db = real_db
@@ -1687,7 +1680,7 @@ class ScopedMongoWrapper:
1687
1680
  self._token_verification_lock = asyncio.Lock()
1688
1681
 
1689
1682
  # Cache for created collection wrappers.
1690
- self._wrapper_cache: Dict[str, ScopedCollectionWrapper] = {}
1683
+ self._wrapper_cache: dict[str, ScopedCollectionWrapper] = {}
1691
1684
 
1692
1685
  async def _verify_token_if_needed(self) -> None:
1693
1686
  """
@@ -18,7 +18,8 @@ Usage:
18
18
 
19
19
  import logging
20
20
  import os
21
- from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, TypeVar, Union
21
+ from collections.abc import Callable
22
+ from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
22
23
 
23
24
  from fastapi import HTTPException, Request
24
25
 
@@ -62,7 +63,7 @@ async def get_app_slug(request: Request) -> str:
62
63
  return slug
63
64
 
64
65
 
65
- async def get_app_config(request: Request) -> Dict[str, Any]:
66
+ async def get_app_config(request: Request) -> dict[str, Any]:
66
67
  """Get the app's manifest configuration."""
67
68
  manifest = getattr(request.app.state, "manifest", None)
68
69
  if manifest is None:
@@ -169,12 +170,12 @@ async def get_authz_provider(request: Request) -> Optional["AuthorizationProvide
169
170
  return getattr(request.app.state, "authz_provider", None)
170
171
 
171
172
 
172
- async def get_current_user(request: Request) -> Optional[Dict[str, Any]]:
173
+ async def get_current_user(request: Request) -> dict[str, Any] | None:
173
174
  """Get the current authenticated user."""
174
175
  return getattr(request.state, "user", None)
175
176
 
176
177
 
177
- async def get_user_roles(request: Request) -> List[str]:
178
+ async def get_user_roles(request: Request) -> list[str]:
178
179
  """Get the current user's roles."""
179
180
  return getattr(request.state, "user_roles", [])
180
181
 
@@ -182,7 +183,7 @@ async def get_user_roles(request: Request) -> List[str]:
182
183
  def require_user():
183
184
  """Dependency that requires authentication."""
184
185
 
185
- async def _require_user(request: Request) -> Dict[str, Any]:
186
+ async def _require_user(request: Request) -> dict[str, Any]:
186
187
  user = await get_current_user(request)
187
188
  if not user:
188
189
  raise HTTPException(401, "Authentication required")
@@ -194,7 +195,7 @@ def require_user():
194
195
  def require_role(*roles: str):
195
196
  """Dependency that requires specific roles."""
196
197
 
197
- async def _require_role(request: Request) -> Dict[str, Any]:
198
+ async def _require_role(request: Request) -> dict[str, Any]:
198
199
  user = await get_current_user(request)
199
200
  if not user:
200
201
  raise HTTPException(401, "Authentication required")
@@ -272,7 +273,7 @@ class RequestContext:
272
273
  return self._uow
273
274
 
274
275
  @property
275
- def config(self) -> Dict[str, Any]:
276
+ def config(self) -> dict[str, Any]:
276
277
  """Get the app's manifest configuration."""
277
278
  if self._config is None:
278
279
  self._config = getattr(self.request.app.state, "manifest", None)
@@ -331,14 +332,14 @@ class RequestContext:
331
332
  return get_llm_model_name()
332
333
 
333
334
  @property
334
- def user(self) -> Optional[Dict[str, Any]]:
335
+ def user(self) -> dict[str, Any] | None:
335
336
  """Get the current authenticated user."""
336
337
  if self._user is None:
337
338
  self._user = getattr(self.request.state, "user", None)
338
339
  return self._user
339
340
 
340
341
  @property
341
- def user_roles(self) -> List[str]:
342
+ def user_roles(self) -> list[str]:
342
343
  """Get the current user's roles."""
343
344
  return getattr(self.request.state, "user_roles", [])
344
345
 
@@ -349,13 +350,13 @@ class RequestContext:
349
350
  self._authz = getattr(self.request.app.state, "authz_provider", None)
350
351
  return self._authz
351
352
 
352
- def require_user(self) -> Dict[str, Any]:
353
+ def require_user(self) -> dict[str, Any]:
353
354
  """Require authentication, raising 401 if not authenticated."""
354
355
  if not self.user:
355
356
  raise HTTPException(401, "Authentication required")
356
357
  return self.user
357
358
 
358
- def require_role(self, *roles: str) -> Dict[str, Any]:
359
+ def require_role(self, *roles: str) -> dict[str, Any]:
359
360
  """Require specific roles, raising 403 if not authorized."""
360
361
  user = self.require_user()
361
362
  user_roles = set(self.user_roles)
@@ -365,7 +366,7 @@ class RequestContext:
365
366
  return user
366
367
 
367
368
  async def check_permission(
368
- self, resource: str, action: str, subject: Optional[str] = None
369
+ self, resource: str, action: str, subject: str | None = None
369
370
  ) -> bool:
370
371
  """Check if current user has permission for an action."""
371
372
  if not self.authz:
@@ -390,7 +391,7 @@ RequestContext.__call__ = staticmethod(_get_request_context)
390
391
  # =============================================================================
391
392
 
392
393
 
393
- def inject(service_type: Type[T]) -> Callable[..., T]:
394
+ def inject(service_type: type[T]) -> Callable[..., T]:
394
395
  """Create a dependency that resolves a service from the DI container."""
395
396
 
396
397
  async def _resolve(request: Request) -> T:
@@ -5,13 +5,12 @@ A lightweight, FastAPI-native DI container with proper service lifetimes.
5
5
  """
6
6
 
7
7
  import logging
8
- from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, TypeVar
8
+ from collections.abc import Callable
9
+ from typing import Any, Optional, TypeVar
9
10
 
11
+ from .providers import Provider
10
12
  from .scopes import Scope
11
13
 
12
- if TYPE_CHECKING:
13
- from .providers import Provider
14
-
15
14
  logger = logging.getLogger(__name__)
16
15
 
17
16
  T = TypeVar("T")
@@ -50,8 +49,8 @@ class Container:
50
49
  _global_instance: Optional["Container"] = None
51
50
 
52
51
  def __init__(self):
53
- self._providers: Dict[type, "Provider"] = {}
54
- self._instances: Dict[type, Any] = {} # For register_instance
52
+ self._providers: dict[type, Provider] = {}
53
+ self._instances: dict[type, Any] = {} # For register_instance
55
54
 
56
55
  @classmethod
57
56
  def get_global(cls) -> "Container":
@@ -72,8 +71,8 @@ class Container:
72
71
 
73
72
  def register(
74
73
  self,
75
- service_type: Type[T],
76
- implementation: Optional[Type[T]] = None,
74
+ service_type: type[T],
75
+ implementation: type[T] | None = None,
77
76
  scope: Scope = Scope.SINGLETON,
78
77
  ) -> "Container":
79
78
  """
@@ -107,7 +106,7 @@ class Container:
107
106
 
108
107
  def register_factory(
109
108
  self,
110
- service_type: Type[T],
109
+ service_type: type[T],
111
110
  factory: Callable[["Container"], T],
112
111
  scope: Scope = Scope.SINGLETON,
113
112
  ) -> "Container":
@@ -137,7 +136,7 @@ class Container:
137
136
  logger.debug(f"Registered factory for {service_type.__name__} as {scope.value}")
138
137
  return self
139
138
 
140
- def register_instance(self, service_type: Type[T], instance: T) -> "Container":
139
+ def register_instance(self, service_type: type[T], instance: T) -> "Container":
141
140
  """
142
141
  Register an existing instance as a singleton.
143
142
 
@@ -154,7 +153,7 @@ class Container:
154
153
  logger.debug(f"Registered instance for {service_type.__name__}")
155
154
  return self
156
155
 
157
- def resolve(self, service_type: Type[T]) -> T:
156
+ def resolve(self, service_type: type[T]) -> T:
158
157
  """
159
158
  Resolve a service instance.
160
159
 
@@ -180,7 +179,7 @@ class Container:
180
179
 
181
180
  return self._providers[service_type].get(self)
182
181
 
183
- def try_resolve(self, service_type: Type[T]) -> Optional[T]:
182
+ def try_resolve(self, service_type: type[T]) -> T | None:
184
183
  """
185
184
  Try to resolve a service, returning None if not registered.
186
185
 
@@ -218,7 +217,7 @@ class Container:
218
217
 
219
218
 
220
219
  # FastAPI integration helpers
221
- def inject(service_type: Type[T]) -> T:
220
+ def inject(service_type: type[T]) -> T:
222
221
  """
223
222
  FastAPI dependency that resolves a service from the global container.
224
223
 
@@ -8,7 +8,8 @@ according to their configured scope.
8
8
  import inspect
9
9
  import logging
10
10
  from abc import ABC, abstractmethod
11
- from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, Type, TypeVar
11
+ from collections.abc import Callable
12
+ from typing import TYPE_CHECKING, Any, Generic, TypeVar
12
13
 
13
14
  from .scopes import Scope, ScopeManager
14
15
 
@@ -30,9 +31,9 @@ class Provider(ABC, Generic[T]):
30
31
 
31
32
  def __init__(
32
33
  self,
33
- service_type: Type[T],
34
+ service_type: type[T],
34
35
  scope: Scope,
35
- factory: Optional[Callable[..., T]] = None,
36
+ factory: Callable[..., T] | None = None,
36
37
  ):
37
38
  self.service_type = service_type
38
39
  self.scope = scope
@@ -60,7 +61,7 @@ class Provider(ABC, Generic[T]):
60
61
  """
61
62
  # Get constructor signature
62
63
  sig = inspect.signature(self._factory)
63
- kwargs: Dict[str, Any] = {}
64
+ kwargs: dict[str, Any] = {}
64
65
 
65
66
  for param_name, param in sig.parameters.items():
66
67
  if param_name == "self":
@@ -101,11 +102,11 @@ class SingletonProvider(Provider[T]):
101
102
 
102
103
  def __init__(
103
104
  self,
104
- service_type: Type[T],
105
- factory: Optional[Callable[..., T]] = None,
105
+ service_type: type[T],
106
+ factory: Callable[..., T] | None = None,
106
107
  ):
107
108
  super().__init__(service_type, Scope.SINGLETON, factory)
108
- self._instance: Optional[T] = None
109
+ self._instance: T | None = None
109
110
 
110
111
  def get(self, container: "Container") -> T:
111
112
  if self._instance is None:
@@ -127,8 +128,8 @@ class RequestProvider(Provider[T]):
127
128
 
128
129
  def __init__(
129
130
  self,
130
- service_type: Type[T],
131
- factory: Optional[Callable[..., T]] = None,
131
+ service_type: type[T],
132
+ factory: Callable[..., T] | None = None,
132
133
  ):
133
134
  super().__init__(service_type, Scope.REQUEST, factory)
134
135
 
@@ -145,8 +146,8 @@ class TransientProvider(Provider[T]):
145
146
 
146
147
  def __init__(
147
148
  self,
148
- service_type: Type[T],
149
- factory: Optional[Callable[..., T]] = None,
149
+ service_type: type[T],
150
+ factory: Callable[..., T] | None = None,
150
151
  ):
151
152
  super().__init__(service_type, Scope.TRANSIENT, factory)
152
153
 
@@ -173,13 +174,13 @@ class FactoryProvider(Provider[T]):
173
174
 
174
175
  def __init__(
175
176
  self,
176
- service_type: Type[T],
177
+ service_type: type[T],
177
178
  factory: Callable[["Container"], T],
178
179
  scope: Scope,
179
180
  ):
180
181
  super().__init__(service_type, scope, None)
181
182
  self._custom_factory = factory
182
- self._singleton_instance: Optional[T] = None
183
+ self._singleton_instance: T | None = None
183
184
 
184
185
  def get(self, container: "Container") -> T:
185
186
  if self.scope == Scope.SINGLETON:
mdb_engine/di/scopes.py CHANGED
@@ -10,12 +10,12 @@ Defines service lifetime scopes following enterprise patterns:
10
10
  import logging
11
11
  from contextvars import ContextVar
12
12
  from enum import Enum
13
- from typing import Any, Dict, Optional
13
+ from typing import Any
14
14
 
15
15
  logger = logging.getLogger(__name__)
16
16
 
17
17
  # Context variable for request-scoped instances
18
- _request_scope: ContextVar[Optional[Dict[type, Any]]] = ContextVar("request_scope", default=None)
18
+ _request_scope: ContextVar[dict[type, Any] | None] = ContextVar("request_scope", default=None)
19
19
 
20
20
 
21
21
  class Scope(Enum):
@@ -51,13 +51,13 @@ class ScopeManager:
51
51
  """
52
52
 
53
53
  @classmethod
54
- def begin_request(cls) -> Dict[type, Any]:
54
+ def begin_request(cls) -> dict[type, Any]:
55
55
  """
56
56
  Begin a new request scope.
57
57
 
58
58
  Returns the scope dictionary for manual management if needed.
59
59
  """
60
- scope_dict: Dict[type, Any] = {}
60
+ scope_dict: dict[type, Any] = {}
61
61
  _request_scope.set(scope_dict)
62
62
  logger.debug("Request scope started")
63
63
  return scope_dict
@@ -82,7 +82,7 @@ class ScopeManager:
82
82
  logger.debug("Request scope ended")
83
83
 
84
84
  @classmethod
85
- def get_request_scope(cls) -> Optional[Dict[type, Any]]:
85
+ def get_request_scope(cls) -> dict[type, Any] | None:
86
86
  """Get the current request scope dictionary."""
87
87
  return _request_scope.get()
88
88
 
@@ -19,7 +19,7 @@ Usage:
19
19
  service = get_embedding_service_for_app("my_app", engine)
20
20
  """
21
21
 
22
- from typing import TYPE_CHECKING, Optional
22
+ from typing import TYPE_CHECKING
23
23
 
24
24
  if TYPE_CHECKING:
25
25
  from ..core.engine import MongoDBEngine
@@ -29,7 +29,7 @@ from .service import EmbeddingService, get_embedding_service
29
29
 
30
30
  def get_embedding_service_for_app(
31
31
  app_slug: str, engine: "MongoDBEngine"
32
- ) -> Optional[EmbeddingService]:
32
+ ) -> EmbeddingService | None:
33
33
  """
34
34
  Get embedding service for a specific app using the engine instance.
35
35