mdb-engine 0.2.0__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. mdb_engine/__init__.py +1 -1
  2. mdb_engine/auth/audit.py +40 -40
  3. mdb_engine/auth/base.py +3 -3
  4. mdb_engine/auth/casbin_factory.py +6 -6
  5. mdb_engine/auth/config_defaults.py +5 -5
  6. mdb_engine/auth/config_helpers.py +12 -12
  7. mdb_engine/auth/cookie_utils.py +9 -9
  8. mdb_engine/auth/csrf.py +9 -8
  9. mdb_engine/auth/decorators.py +7 -6
  10. mdb_engine/auth/dependencies.py +22 -21
  11. mdb_engine/auth/integration.py +9 -9
  12. mdb_engine/auth/jwt.py +9 -9
  13. mdb_engine/auth/middleware.py +4 -3
  14. mdb_engine/auth/oso_factory.py +6 -6
  15. mdb_engine/auth/provider.py +4 -4
  16. mdb_engine/auth/rate_limiter.py +12 -11
  17. mdb_engine/auth/restrictions.py +16 -15
  18. mdb_engine/auth/session_manager.py +11 -13
  19. mdb_engine/auth/shared_middleware.py +16 -15
  20. mdb_engine/auth/shared_users.py +20 -20
  21. mdb_engine/auth/token_lifecycle.py +10 -12
  22. mdb_engine/auth/token_store.py +4 -5
  23. mdb_engine/auth/users.py +51 -52
  24. mdb_engine/auth/utils.py +29 -33
  25. mdb_engine/cli/commands/generate.py +6 -6
  26. mdb_engine/cli/utils.py +4 -4
  27. mdb_engine/config.py +6 -7
  28. mdb_engine/core/app_registration.py +12 -12
  29. mdb_engine/core/app_secrets.py +1 -2
  30. mdb_engine/core/connection.py +3 -4
  31. mdb_engine/core/encryption.py +1 -2
  32. mdb_engine/core/engine.py +43 -44
  33. mdb_engine/core/manifest.py +59 -58
  34. mdb_engine/core/ray_integration.py +10 -9
  35. mdb_engine/core/seeding.py +3 -3
  36. mdb_engine/core/service_initialization.py +10 -9
  37. mdb_engine/core/types.py +40 -40
  38. mdb_engine/database/abstraction.py +15 -16
  39. mdb_engine/database/connection.py +40 -12
  40. mdb_engine/database/query_validator.py +8 -8
  41. mdb_engine/database/resource_limiter.py +7 -7
  42. mdb_engine/database/scoped_wrapper.py +51 -58
  43. mdb_engine/dependencies.py +14 -13
  44. mdb_engine/di/container.py +12 -13
  45. mdb_engine/di/providers.py +14 -13
  46. mdb_engine/di/scopes.py +5 -5
  47. mdb_engine/embeddings/dependencies.py +2 -2
  48. mdb_engine/embeddings/service.py +31 -43
  49. mdb_engine/exceptions.py +20 -20
  50. mdb_engine/indexes/helpers.py +11 -11
  51. mdb_engine/indexes/manager.py +9 -9
  52. mdb_engine/memory/service.py +30 -30
  53. mdb_engine/observability/health.py +10 -9
  54. mdb_engine/observability/logging.py +10 -10
  55. mdb_engine/observability/metrics.py +8 -7
  56. mdb_engine/repositories/base.py +25 -25
  57. mdb_engine/repositories/mongo.py +17 -17
  58. mdb_engine/repositories/unit_of_work.py +6 -6
  59. mdb_engine/routing/websockets.py +19 -18
  60. {mdb_engine-0.2.0.dist-info → mdb_engine-0.2.3.dist-info}/METADATA +8 -8
  61. mdb_engine-0.2.3.dist-info/RECORD +96 -0
  62. mdb_engine-0.2.0.dist-info/RECORD +0 -96
  63. {mdb_engine-0.2.0.dist-info → mdb_engine-0.2.3.dist-info}/WHEEL +0 -0
  64. {mdb_engine-0.2.0.dist-info → mdb_engine-0.2.3.dist-info}/entry_points.txt +0 -0
  65. {mdb_engine-0.2.0.dist-info → mdb_engine-0.2.3.dist-info}/licenses/LICENSE +0 -0
  66. {mdb_engine-0.2.0.dist-info → mdb_engine-0.2.3.dist-info}/top_level.txt +0 -0
@@ -6,10 +6,11 @@ Provides health check functions for monitoring system status.
6
6
 
7
7
  import asyncio
8
8
  import logging
9
+ from collections.abc import Callable
9
10
  from dataclasses import dataclass, field
10
11
  from datetime import datetime
11
12
  from enum import Enum
12
- from typing import Any, Callable, Dict, List, Optional
13
+ from typing import Any
13
14
 
14
15
  from pymongo.errors import (
15
16
  ConnectionFailure,
@@ -36,10 +37,10 @@ class HealthCheckResult:
36
37
  name: str
37
38
  status: HealthStatus
38
39
  message: str
39
- details: Optional[Dict[str, Any]] = None
40
+ details: dict[str, Any] | None = None
40
41
  timestamp: datetime = field(default_factory=datetime.now)
41
42
 
42
- def to_dict(self) -> Dict[str, Any]:
43
+ def to_dict(self) -> dict[str, Any]:
43
44
  """Convert to dictionary."""
44
45
  return {
45
46
  "name": self.name,
@@ -57,7 +58,7 @@ class HealthChecker:
57
58
 
58
59
  def __init__(self):
59
60
  """Initialize the health checker."""
60
- self._checks: List[callable] = []
61
+ self._checks: list[callable] = []
61
62
 
62
63
  def register_check(self, check_func: Callable) -> None:
63
64
  """
@@ -68,14 +69,14 @@ class HealthChecker:
68
69
  """
69
70
  self._checks.append(check_func)
70
71
 
71
- async def check_all(self) -> Dict[str, Any]:
72
+ async def check_all(self) -> dict[str, Any]:
72
73
  """
73
74
  Run all registered health checks.
74
75
 
75
76
  Returns:
76
77
  Dictionary with overall status and individual check results
77
78
  """
78
- results: List[HealthCheckResult] = []
79
+ results: list[HealthCheckResult] = []
79
80
 
80
81
  for check_func in self._checks:
81
82
  try:
@@ -117,7 +118,7 @@ class HealthChecker:
117
118
 
118
119
 
119
120
  async def check_mongodb_health(
120
- mongo_client: Optional[Any], timeout_seconds: float = 5.0
121
+ mongo_client: Any | None, timeout_seconds: float = 5.0
121
122
  ) -> HealthCheckResult:
122
123
  """
123
124
  Check MongoDB connection health.
@@ -166,7 +167,7 @@ async def check_mongodb_health(
166
167
  )
167
168
 
168
169
 
169
- async def check_engine_health(engine: Optional[Any]) -> HealthCheckResult:
170
+ async def check_engine_health(engine: Any | None) -> HealthCheckResult:
170
171
  """
171
172
  Check MongoDB Engine health.
172
173
 
@@ -205,7 +206,7 @@ async def check_engine_health(engine: Optional[Any]) -> HealthCheckResult:
205
206
 
206
207
 
207
208
  async def check_pool_health(
208
- get_pool_metrics_func: Optional[Callable[[], Any]] = None,
209
+ get_pool_metrics_func: Callable[[], Any] | None = None,
209
210
  ) -> HealthCheckResult:
210
211
  """
211
212
  Check connection pool health.
@@ -8,25 +8,25 @@ import contextvars
8
8
  import logging
9
9
  import uuid
10
10
  from datetime import datetime
11
- from typing import Any, Dict, Optional
11
+ from typing import Any
12
12
 
13
13
  # Context variable for correlation ID
14
- _correlation_id: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar(
14
+ _correlation_id: contextvars.ContextVar[str | None] = contextvars.ContextVar(
15
15
  "correlation_id", default=None
16
16
  )
17
17
 
18
18
  # Context variable for app context
19
- _app_context: contextvars.ContextVar[Optional[Dict[str, Any]]] = contextvars.ContextVar(
19
+ _app_context: contextvars.ContextVar[dict[str, Any] | None] = contextvars.ContextVar(
20
20
  "app_context", default=None
21
21
  )
22
22
 
23
23
 
24
- def get_correlation_id() -> Optional[str]:
24
+ def get_correlation_id() -> str | None:
25
25
  """Get the current correlation ID from context."""
26
26
  return _correlation_id.get()
27
27
 
28
28
 
29
- def set_correlation_id(correlation_id: Optional[str] = None) -> str:
29
+ def set_correlation_id(correlation_id: str | None = None) -> str:
30
30
  """
31
31
  Set a correlation ID in the current context.
32
32
 
@@ -47,7 +47,7 @@ def clear_correlation_id() -> None:
47
47
  _correlation_id.set(None)
48
48
 
49
49
 
50
- def set_app_context(app_slug: Optional[str] = None, **kwargs: Any) -> None:
50
+ def set_app_context(app_slug: str | None = None, **kwargs: Any) -> None:
51
51
  """
52
52
  Set app context for logging.
53
53
 
@@ -64,14 +64,14 @@ def clear_app_context() -> None:
64
64
  _app_context.set(None)
65
65
 
66
66
 
67
- def get_logging_context() -> Dict[str, Any]:
67
+ def get_logging_context() -> dict[str, Any]:
68
68
  """
69
69
  Get current logging context (correlation ID and app context).
70
70
 
71
71
  Returns:
72
72
  Dictionary with context information
73
73
  """
74
- context: Dict[str, Any] = {
74
+ context: dict[str, Any] = {
75
75
  "timestamp": datetime.now().isoformat(),
76
76
  }
77
77
 
@@ -91,7 +91,7 @@ class ContextualLoggerAdapter(logging.LoggerAdapter):
91
91
  Logger adapter that automatically adds context to log records.
92
92
  """
93
93
 
94
- def process(self, msg: str, kwargs: Dict[str, Any]) -> tuple[str, Dict[str, Any]]:
94
+ def process(self, msg: str, kwargs: dict[str, Any]) -> tuple[str, dict[str, Any]]:
95
95
  """Add context to log records."""
96
96
  # Get base context
97
97
  context = get_logging_context()
@@ -124,7 +124,7 @@ def log_operation(
124
124
  operation: str,
125
125
  level: int = logging.INFO,
126
126
  success: bool = True,
127
- duration_ms: Optional[float] = None,
127
+ duration_ms: float | None = None,
128
128
  **context: Any,
129
129
  ) -> None:
130
130
  """
@@ -9,9 +9,10 @@ import logging
9
9
  import threading
10
10
  import time
11
11
  from collections import OrderedDict
12
+ from collections.abc import Callable
12
13
  from dataclasses import dataclass
13
14
  from datetime import datetime
14
- from typing import Any, Callable, Dict, Optional
15
+ from typing import Any
15
16
 
16
17
  logger = logging.getLogger(__name__)
17
18
 
@@ -26,7 +27,7 @@ class OperationMetrics:
26
27
  min_duration_ms: float = float("inf")
27
28
  max_duration_ms: float = 0.0
28
29
  error_count: int = 0
29
- last_execution: Optional[datetime] = None
30
+ last_execution: datetime | None = None
30
31
 
31
32
  @property
32
33
  def avg_duration_ms(self) -> float:
@@ -48,7 +49,7 @@ class OperationMetrics:
48
49
  self.error_count += 1
49
50
  self.last_execution = datetime.now()
50
51
 
51
- def to_dict(self) -> Dict[str, Any]:
52
+ def to_dict(self) -> dict[str, Any]:
52
53
  """Convert metrics to dictionary."""
53
54
  return {
54
55
  "operation": self.operation_name,
@@ -121,7 +122,7 @@ class MetricsCollector:
121
122
 
122
123
  self._metrics[key].record(duration_ms, success)
123
124
 
124
- def get_metrics(self, operation_name: Optional[str] = None) -> Dict[str, Any]:
125
+ def get_metrics(self, operation_name: str | None = None) -> dict[str, Any]:
125
126
  """
126
127
  Get metrics for operations.
127
128
 
@@ -154,7 +155,7 @@ class MetricsCollector:
154
155
  "total_operations": total_operations,
155
156
  }
156
157
 
157
- def get_summary(self) -> Dict[str, Any]:
158
+ def get_summary(self) -> dict[str, Any]:
158
159
  """
159
160
  Get a summary of all metrics.
160
161
 
@@ -170,7 +171,7 @@ class MetricsCollector:
170
171
  }
171
172
 
172
173
  # Aggregate by base operation name (without tags)
173
- aggregated: Dict[str, OperationMetrics] = {}
174
+ aggregated: dict[str, OperationMetrics] = {}
174
175
 
175
176
  for _key, metric in self._metrics.items():
176
177
  base_name = metric.operation_name
@@ -217,7 +218,7 @@ class MetricsCollector:
217
218
 
218
219
 
219
220
  # Global metrics collector instance
220
- _metrics_collector: Optional[MetricsCollector] = None
221
+ _metrics_collector: MetricsCollector | None = None
221
222
 
222
223
 
223
224
  def get_metrics_collector() -> MetricsCollector:
@@ -8,7 +8,7 @@ This allows domain services to work with any data store implementation.
8
8
  from abc import ABC, abstractmethod
9
9
  from dataclasses import dataclass, field
10
10
  from datetime import datetime
11
- from typing import Any, Dict, Generic, List, Optional, TypeVar
11
+ from typing import Any, Generic, TypeVar
12
12
 
13
13
  from bson import ObjectId
14
14
 
@@ -28,11 +28,11 @@ class Entity:
28
28
  role: str = "user"
29
29
  """
30
30
 
31
- id: Optional[str] = None
32
- created_at: Optional[datetime] = field(default=None)
33
- updated_at: Optional[datetime] = field(default=None)
31
+ id: str | None = None
32
+ created_at: datetime | None = field(default=None)
33
+ updated_at: datetime | None = field(default=None)
34
34
 
35
- def to_dict(self) -> Dict[str, Any]:
35
+ def to_dict(self) -> dict[str, Any]:
36
36
  """Convert entity to dictionary for storage."""
37
37
  data = {}
38
38
  for key, value in self.__dict__.items():
@@ -48,7 +48,7 @@ class Entity:
48
48
  return data
49
49
 
50
50
  @classmethod
51
- def from_dict(cls, data: Dict[str, Any]) -> "Entity":
51
+ def from_dict(cls, data: dict[str, Any]) -> "Entity":
52
52
  """Create entity from dictionary (e.g., from database)."""
53
53
  if data is None:
54
54
  return None
@@ -88,7 +88,7 @@ class Repository(ABC, Generic[T]):
88
88
  """
89
89
 
90
90
  @abstractmethod
91
- async def get(self, id: str) -> Optional[T]:
91
+ async def get(self, id: str) -> T | None:
92
92
  """
93
93
  Get a single entity by ID.
94
94
 
@@ -103,11 +103,11 @@ class Repository(ABC, Generic[T]):
103
103
  @abstractmethod
104
104
  async def find(
105
105
  self,
106
- filter: Optional[Dict[str, Any]] = None,
106
+ filter: dict[str, Any] | None = None,
107
107
  skip: int = 0,
108
108
  limit: int = 100,
109
- sort: Optional[List[tuple]] = None,
110
- ) -> List[T]:
109
+ sort: list[tuple] | None = None,
110
+ ) -> list[T]:
111
111
  """
112
112
  Find entities matching a filter.
113
113
 
@@ -125,8 +125,8 @@ class Repository(ABC, Generic[T]):
125
125
  @abstractmethod
126
126
  async def find_one(
127
127
  self,
128
- filter: Dict[str, Any],
129
- ) -> Optional[T]:
128
+ filter: dict[str, Any],
129
+ ) -> T | None:
130
130
  """
131
131
  Find a single entity matching a filter.
132
132
 
@@ -152,7 +152,7 @@ class Repository(ABC, Generic[T]):
152
152
  pass
153
153
 
154
154
  @abstractmethod
155
- async def add_many(self, entities: List[T]) -> List[str]:
155
+ async def add_many(self, entities: list[T]) -> list[str]:
156
156
  """
157
157
  Add multiple entities.
158
158
 
@@ -179,7 +179,7 @@ class Repository(ABC, Generic[T]):
179
179
  pass
180
180
 
181
181
  @abstractmethod
182
- async def update_fields(self, id: str, fields: Dict[str, Any]) -> bool:
182
+ async def update_fields(self, id: str, fields: dict[str, Any]) -> bool:
183
183
  """
184
184
  Update specific fields of an entity.
185
185
 
@@ -206,7 +206,7 @@ class Repository(ABC, Generic[T]):
206
206
  pass
207
207
 
208
208
  @abstractmethod
209
- async def count(self, filter: Optional[Dict[str, Any]] = None) -> int:
209
+ async def count(self, filter: dict[str, Any] | None = None) -> int:
210
210
  """
211
211
  Count entities matching a filter.
212
212
 
@@ -242,10 +242,10 @@ class InMemoryRepository(Repository[T]):
242
242
 
243
243
  def __init__(self, entity_class: type):
244
244
  self._entity_class = entity_class
245
- self._storage: Dict[str, Dict[str, Any]] = {}
245
+ self._storage: dict[str, dict[str, Any]] = {}
246
246
  self._counter = 0
247
247
 
248
- async def get(self, id: str) -> Optional[T]:
248
+ async def get(self, id: str) -> T | None:
249
249
  data = self._storage.get(id)
250
250
  if data is None:
251
251
  return None
@@ -253,11 +253,11 @@ class InMemoryRepository(Repository[T]):
253
253
 
254
254
  async def find(
255
255
  self,
256
- filter: Optional[Dict[str, Any]] = None,
256
+ filter: dict[str, Any] | None = None,
257
257
  skip: int = 0,
258
258
  limit: int = 100,
259
- sort: Optional[List[tuple]] = None,
260
- ) -> List[T]:
259
+ sort: list[tuple] | None = None,
260
+ ) -> list[T]:
261
261
  results = []
262
262
  for data in self._storage.values():
263
263
  if filter is None or self._matches_filter(data, filter):
@@ -266,7 +266,7 @@ class InMemoryRepository(Repository[T]):
266
266
  # Apply skip and limit
267
267
  return results[skip : skip + limit]
268
268
 
269
- async def find_one(self, filter: Dict[str, Any]) -> Optional[T]:
269
+ async def find_one(self, filter: dict[str, Any]) -> T | None:
270
270
  results = await self.find(filter, limit=1)
271
271
  return results[0] if results else None
272
272
 
@@ -278,7 +278,7 @@ class InMemoryRepository(Repository[T]):
278
278
  self._storage[id] = entity.to_dict()
279
279
  return id
280
280
 
281
- async def add_many(self, entities: List[T]) -> List[str]:
281
+ async def add_many(self, entities: list[T]) -> list[str]:
282
282
  return [await self.add(e) for e in entities]
283
283
 
284
284
  async def update(self, id: str, entity: T) -> bool:
@@ -289,7 +289,7 @@ class InMemoryRepository(Repository[T]):
289
289
  self._storage[id] = entity.to_dict()
290
290
  return True
291
291
 
292
- async def update_fields(self, id: str, fields: Dict[str, Any]) -> bool:
292
+ async def update_fields(self, id: str, fields: dict[str, Any]) -> bool:
293
293
  if id not in self._storage:
294
294
  return False
295
295
  self._storage[id].update(fields)
@@ -302,7 +302,7 @@ class InMemoryRepository(Repository[T]):
302
302
  del self._storage[id]
303
303
  return True
304
304
 
305
- async def count(self, filter: Optional[Dict[str, Any]] = None) -> int:
305
+ async def count(self, filter: dict[str, Any] | None = None) -> int:
306
306
  if filter is None:
307
307
  return len(self._storage)
308
308
  return len(await self.find(filter, limit=999999))
@@ -310,7 +310,7 @@ class InMemoryRepository(Repository[T]):
310
310
  async def exists(self, id: str) -> bool:
311
311
  return id in self._storage
312
312
 
313
- def _matches_filter(self, data: Dict[str, Any], filter: Dict[str, Any]) -> bool:
313
+ def _matches_filter(self, data: dict[str, Any], filter: dict[str, Any]) -> bool:
314
314
  """Simple filter matching for testing."""
315
315
  for key, value in filter.items():
316
316
  if key not in data:
@@ -7,7 +7,7 @@ This provides automatic app scoping and security features.
7
7
 
8
8
  import logging
9
9
  from datetime import datetime
10
- from typing import Any, Dict, Generic, List, Optional, Type, TypeVar
10
+ from typing import Any, Generic, TypeVar
11
11
 
12
12
  from bson import ObjectId
13
13
 
@@ -41,7 +41,7 @@ class MongoRepository(Repository[T], Generic[T]):
41
41
  def __init__(
42
42
  self,
43
43
  collection: Any, # ScopedCollectionWrapper - avoid import cycle
44
- entity_class: Type[T],
44
+ entity_class: type[T],
45
45
  ):
46
46
  """
47
47
  Initialize the MongoDB repository.
@@ -53,20 +53,20 @@ class MongoRepository(Repository[T], Generic[T]):
53
53
  self._collection = collection
54
54
  self._entity_class = entity_class
55
55
 
56
- def _to_entity(self, doc: Optional[Dict[str, Any]]) -> Optional[T]:
56
+ def _to_entity(self, doc: dict[str, Any] | None) -> T | None:
57
57
  """Convert a MongoDB document to an entity."""
58
58
  if doc is None:
59
59
  return None
60
60
  return self._entity_class.from_dict(doc)
61
61
 
62
- def _to_document(self, entity: T, include_id: bool = False) -> Dict[str, Any]:
62
+ def _to_document(self, entity: T, include_id: bool = False) -> dict[str, Any]:
63
63
  """Convert an entity to a MongoDB document."""
64
64
  doc = entity.to_dict()
65
65
  if not include_id and "_id" in doc:
66
66
  del doc["_id"]
67
67
  return doc
68
68
 
69
- async def get(self, id: str) -> Optional[T]:
69
+ async def get(self, id: str) -> T | None:
70
70
  """Get entity by ID."""
71
71
  try:
72
72
  object_id = ObjectId(id) if ObjectId.is_valid(id) else id
@@ -78,11 +78,11 @@ class MongoRepository(Repository[T], Generic[T]):
78
78
 
79
79
  async def find(
80
80
  self,
81
- filter: Optional[Dict[str, Any]] = None,
81
+ filter: dict[str, Any] | None = None,
82
82
  skip: int = 0,
83
83
  limit: int = 100,
84
- sort: Optional[List[tuple]] = None,
85
- ) -> List[T]:
84
+ sort: list[tuple] | None = None,
85
+ ) -> list[T]:
86
86
  """Find entities matching a filter."""
87
87
  cursor = self._collection.find(filter or {})
88
88
 
@@ -96,7 +96,7 @@ class MongoRepository(Repository[T], Generic[T]):
96
96
  docs = await cursor.to_list(length=limit)
97
97
  return [self._to_entity(doc) for doc in docs]
98
98
 
99
- async def find_one(self, filter: Dict[str, Any]) -> Optional[T]:
99
+ async def find_one(self, filter: dict[str, Any]) -> T | None:
100
100
  """Find a single entity matching a filter."""
101
101
  doc = await self._collection.find_one(filter)
102
102
  return self._to_entity(doc)
@@ -112,7 +112,7 @@ class MongoRepository(Repository[T], Generic[T]):
112
112
  logger.debug(f"Added {self._entity_class.__name__} with id={entity.id}")
113
113
  return entity.id
114
114
 
115
- async def add_many(self, entities: List[T]) -> List[str]:
115
+ async def add_many(self, entities: list[T]) -> list[str]:
116
116
  """Add multiple entities and return their IDs."""
117
117
  now = datetime.utcnow()
118
118
  docs = []
@@ -124,7 +124,7 @@ class MongoRepository(Repository[T], Generic[T]):
124
124
  result = await self._collection.insert_many(docs)
125
125
  ids = [str(id) for id in result.inserted_ids]
126
126
 
127
- for entity, id in zip(entities, ids):
127
+ for entity, id in zip(entities, ids, strict=False):
128
128
  entity.id = id
129
129
 
130
130
  logger.debug(f"Added {len(ids)} {self._entity_class.__name__} entities")
@@ -144,7 +144,7 @@ class MongoRepository(Repository[T], Generic[T]):
144
144
 
145
145
  return result.modified_count > 0
146
146
 
147
- async def update_fields(self, id: str, fields: Dict[str, Any]) -> bool:
147
+ async def update_fields(self, id: str, fields: dict[str, Any]) -> bool:
148
148
  """Update specific fields of an entity."""
149
149
  try:
150
150
  object_id = ObjectId(id) if ObjectId.is_valid(id) else id
@@ -167,7 +167,7 @@ class MongoRepository(Repository[T], Generic[T]):
167
167
  result = await self._collection.delete_one({"_id": object_id})
168
168
  return result.deleted_count > 0
169
169
 
170
- async def count(self, filter: Optional[Dict[str, Any]] = None) -> int:
170
+ async def count(self, filter: dict[str, Any] | None = None) -> int:
171
171
  """Count entities matching a filter."""
172
172
  return await self._collection.count_documents(filter or {})
173
173
 
@@ -183,7 +183,7 @@ class MongoRepository(Repository[T], Generic[T]):
183
183
 
184
184
  # Additional MongoDB-specific methods
185
185
 
186
- async def aggregate(self, pipeline: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
186
+ async def aggregate(self, pipeline: list[dict[str, Any]]) -> list[dict[str, Any]]:
187
187
  """
188
188
  Run an aggregation pipeline.
189
189
 
@@ -198,8 +198,8 @@ class MongoRepository(Repository[T], Generic[T]):
198
198
 
199
199
  async def update_many(
200
200
  self,
201
- filter: Dict[str, Any],
202
- update: Dict[str, Any],
201
+ filter: dict[str, Any],
202
+ update: dict[str, Any],
203
203
  ) -> int:
204
204
  """
205
205
  Update multiple documents matching a filter.
@@ -219,7 +219,7 @@ class MongoRepository(Repository[T], Generic[T]):
219
219
  result = await self._collection.update_many(filter, update)
220
220
  return result.modified_count
221
221
 
222
- async def delete_many(self, filter: Dict[str, Any]) -> int:
222
+ async def delete_many(self, filter: dict[str, Any]) -> int:
223
223
  """
224
224
  Delete multiple documents matching a filter.
225
225
 
@@ -6,7 +6,7 @@ The UnitOfWork acts as a factory for repositories and manages their lifecycle.
6
6
  """
7
7
 
8
8
  import logging
9
- from typing import Any, Dict, Generic, Optional, Type, TypeVar
9
+ from typing import Any, Generic, TypeVar
10
10
 
11
11
  from .base import Entity, Repository
12
12
  from .mongo import MongoRepository
@@ -49,7 +49,7 @@ class UnitOfWork:
49
49
  def __init__(
50
50
  self,
51
51
  db: Any, # ScopedMongoWrapper - avoid import cycle
52
- entity_registry: Optional[Dict[str, Type[Entity]]] = None,
52
+ entity_registry: dict[str, type[Entity]] | None = None,
53
53
  ):
54
54
  """
55
55
  Initialize the Unit of Work.
@@ -59,10 +59,10 @@ class UnitOfWork:
59
59
  entity_registry: Optional mapping of collection names to entity classes
60
60
  """
61
61
  self._db = db
62
- self._repositories: Dict[str, Repository] = {}
63
- self._entity_registry: Dict[str, Type[Entity]] = entity_registry or {}
62
+ self._repositories: dict[str, Repository] = {}
63
+ self._entity_registry: dict[str, type[Entity]] = entity_registry or {}
64
64
 
65
- def register_entity(self, collection_name: str, entity_class: Type[Entity]) -> None:
65
+ def register_entity(self, collection_name: str, entity_class: type[Entity]) -> None:
66
66
  """
67
67
  Register an entity class for a collection.
68
68
 
@@ -77,7 +77,7 @@ class UnitOfWork:
77
77
  def repository(
78
78
  self,
79
79
  name: str,
80
- entity_class: Optional[Type[T]] = None,
80
+ entity_class: type[T] | None = None,
81
81
  ) -> Repository[T]:
82
82
  """
83
83
  Get or create a repository for a collection.