mdb-engine 0.2.1__py3-none-any.whl → 0.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mdb_engine/__init__.py +7 -1
- mdb_engine/auth/README.md +6 -0
- mdb_engine/auth/audit.py +40 -40
- mdb_engine/auth/base.py +3 -3
- mdb_engine/auth/casbin_factory.py +6 -6
- mdb_engine/auth/config_defaults.py +5 -5
- mdb_engine/auth/config_helpers.py +12 -12
- mdb_engine/auth/cookie_utils.py +9 -9
- mdb_engine/auth/csrf.py +9 -8
- mdb_engine/auth/decorators.py +7 -6
- mdb_engine/auth/dependencies.py +22 -21
- mdb_engine/auth/integration.py +9 -9
- mdb_engine/auth/jwt.py +9 -9
- mdb_engine/auth/middleware.py +4 -3
- mdb_engine/auth/oso_factory.py +6 -6
- mdb_engine/auth/provider.py +4 -4
- mdb_engine/auth/rate_limiter.py +12 -11
- mdb_engine/auth/restrictions.py +16 -15
- mdb_engine/auth/session_manager.py +11 -13
- mdb_engine/auth/shared_middleware.py +344 -132
- mdb_engine/auth/shared_users.py +20 -20
- mdb_engine/auth/token_lifecycle.py +10 -12
- mdb_engine/auth/token_store.py +4 -5
- mdb_engine/auth/users.py +51 -52
- mdb_engine/auth/utils.py +29 -33
- mdb_engine/cli/commands/generate.py +6 -6
- mdb_engine/cli/utils.py +4 -4
- mdb_engine/config.py +6 -7
- mdb_engine/core/app_registration.py +12 -12
- mdb_engine/core/app_secrets.py +1 -2
- mdb_engine/core/connection.py +3 -4
- mdb_engine/core/encryption.py +1 -2
- mdb_engine/core/engine.py +43 -44
- mdb_engine/core/manifest.py +80 -58
- mdb_engine/core/ray_integration.py +10 -9
- mdb_engine/core/seeding.py +3 -3
- mdb_engine/core/service_initialization.py +10 -9
- mdb_engine/core/types.py +40 -40
- mdb_engine/database/abstraction.py +15 -16
- mdb_engine/database/connection.py +40 -12
- mdb_engine/database/query_validator.py +8 -8
- mdb_engine/database/resource_limiter.py +7 -7
- mdb_engine/database/scoped_wrapper.py +51 -58
- mdb_engine/dependencies.py +14 -13
- mdb_engine/di/container.py +12 -13
- mdb_engine/di/providers.py +14 -13
- mdb_engine/di/scopes.py +5 -5
- mdb_engine/embeddings/dependencies.py +2 -2
- mdb_engine/embeddings/service.py +67 -50
- mdb_engine/exceptions.py +20 -20
- mdb_engine/indexes/helpers.py +11 -11
- mdb_engine/indexes/manager.py +9 -9
- mdb_engine/memory/README.md +93 -2
- mdb_engine/memory/service.py +361 -1109
- mdb_engine/observability/health.py +10 -9
- mdb_engine/observability/logging.py +10 -10
- mdb_engine/observability/metrics.py +8 -7
- mdb_engine/repositories/base.py +25 -25
- mdb_engine/repositories/mongo.py +17 -17
- mdb_engine/repositories/unit_of_work.py +6 -6
- mdb_engine/routing/websockets.py +19 -18
- mdb_engine/utils/__init__.py +3 -1
- mdb_engine/utils/mongo.py +117 -0
- {mdb_engine-0.2.1.dist-info → mdb_engine-0.2.4.dist-info}/METADATA +88 -13
- mdb_engine-0.2.4.dist-info/RECORD +97 -0
- {mdb_engine-0.2.1.dist-info → mdb_engine-0.2.4.dist-info}/WHEEL +1 -1
- mdb_engine-0.2.1.dist-info/RECORD +0 -96
- {mdb_engine-0.2.1.dist-info → mdb_engine-0.2.4.dist-info}/entry_points.txt +0 -0
- {mdb_engine-0.2.1.dist-info → mdb_engine-0.2.4.dist-info}/licenses/LICENSE +0 -0
- {mdb_engine-0.2.1.dist-info → mdb_engine-0.2.4.dist-info}/top_level.txt +0 -0
|
@@ -6,10 +6,11 @@ Provides health check functions for monitoring system status.
|
|
|
6
6
|
|
|
7
7
|
import asyncio
|
|
8
8
|
import logging
|
|
9
|
+
from collections.abc import Callable
|
|
9
10
|
from dataclasses import dataclass, field
|
|
10
11
|
from datetime import datetime
|
|
11
12
|
from enum import Enum
|
|
12
|
-
from typing import Any
|
|
13
|
+
from typing import Any
|
|
13
14
|
|
|
14
15
|
from pymongo.errors import (
|
|
15
16
|
ConnectionFailure,
|
|
@@ -36,10 +37,10 @@ class HealthCheckResult:
|
|
|
36
37
|
name: str
|
|
37
38
|
status: HealthStatus
|
|
38
39
|
message: str
|
|
39
|
-
details:
|
|
40
|
+
details: dict[str, Any] | None = None
|
|
40
41
|
timestamp: datetime = field(default_factory=datetime.now)
|
|
41
42
|
|
|
42
|
-
def to_dict(self) ->
|
|
43
|
+
def to_dict(self) -> dict[str, Any]:
|
|
43
44
|
"""Convert to dictionary."""
|
|
44
45
|
return {
|
|
45
46
|
"name": self.name,
|
|
@@ -57,7 +58,7 @@ class HealthChecker:
|
|
|
57
58
|
|
|
58
59
|
def __init__(self):
|
|
59
60
|
"""Initialize the health checker."""
|
|
60
|
-
self._checks:
|
|
61
|
+
self._checks: list[callable] = []
|
|
61
62
|
|
|
62
63
|
def register_check(self, check_func: Callable) -> None:
|
|
63
64
|
"""
|
|
@@ -68,14 +69,14 @@ class HealthChecker:
|
|
|
68
69
|
"""
|
|
69
70
|
self._checks.append(check_func)
|
|
70
71
|
|
|
71
|
-
async def check_all(self) ->
|
|
72
|
+
async def check_all(self) -> dict[str, Any]:
|
|
72
73
|
"""
|
|
73
74
|
Run all registered health checks.
|
|
74
75
|
|
|
75
76
|
Returns:
|
|
76
77
|
Dictionary with overall status and individual check results
|
|
77
78
|
"""
|
|
78
|
-
results:
|
|
79
|
+
results: list[HealthCheckResult] = []
|
|
79
80
|
|
|
80
81
|
for check_func in self._checks:
|
|
81
82
|
try:
|
|
@@ -117,7 +118,7 @@ class HealthChecker:
|
|
|
117
118
|
|
|
118
119
|
|
|
119
120
|
async def check_mongodb_health(
|
|
120
|
-
mongo_client:
|
|
121
|
+
mongo_client: Any | None, timeout_seconds: float = 5.0
|
|
121
122
|
) -> HealthCheckResult:
|
|
122
123
|
"""
|
|
123
124
|
Check MongoDB connection health.
|
|
@@ -166,7 +167,7 @@ async def check_mongodb_health(
|
|
|
166
167
|
)
|
|
167
168
|
|
|
168
169
|
|
|
169
|
-
async def check_engine_health(engine:
|
|
170
|
+
async def check_engine_health(engine: Any | None) -> HealthCheckResult:
|
|
170
171
|
"""
|
|
171
172
|
Check MongoDB Engine health.
|
|
172
173
|
|
|
@@ -205,7 +206,7 @@ async def check_engine_health(engine: Optional[Any]) -> HealthCheckResult:
|
|
|
205
206
|
|
|
206
207
|
|
|
207
208
|
async def check_pool_health(
|
|
208
|
-
get_pool_metrics_func:
|
|
209
|
+
get_pool_metrics_func: Callable[[], Any] | None = None,
|
|
209
210
|
) -> HealthCheckResult:
|
|
210
211
|
"""
|
|
211
212
|
Check connection pool health.
|
|
@@ -8,25 +8,25 @@ import contextvars
|
|
|
8
8
|
import logging
|
|
9
9
|
import uuid
|
|
10
10
|
from datetime import datetime
|
|
11
|
-
from typing import Any
|
|
11
|
+
from typing import Any
|
|
12
12
|
|
|
13
13
|
# Context variable for correlation ID
|
|
14
|
-
_correlation_id: contextvars.ContextVar[
|
|
14
|
+
_correlation_id: contextvars.ContextVar[str | None] = contextvars.ContextVar(
|
|
15
15
|
"correlation_id", default=None
|
|
16
16
|
)
|
|
17
17
|
|
|
18
18
|
# Context variable for app context
|
|
19
|
-
_app_context: contextvars.ContextVar[
|
|
19
|
+
_app_context: contextvars.ContextVar[dict[str, Any] | None] = contextvars.ContextVar(
|
|
20
20
|
"app_context", default=None
|
|
21
21
|
)
|
|
22
22
|
|
|
23
23
|
|
|
24
|
-
def get_correlation_id() ->
|
|
24
|
+
def get_correlation_id() -> str | None:
|
|
25
25
|
"""Get the current correlation ID from context."""
|
|
26
26
|
return _correlation_id.get()
|
|
27
27
|
|
|
28
28
|
|
|
29
|
-
def set_correlation_id(correlation_id:
|
|
29
|
+
def set_correlation_id(correlation_id: str | None = None) -> str:
|
|
30
30
|
"""
|
|
31
31
|
Set a correlation ID in the current context.
|
|
32
32
|
|
|
@@ -47,7 +47,7 @@ def clear_correlation_id() -> None:
|
|
|
47
47
|
_correlation_id.set(None)
|
|
48
48
|
|
|
49
49
|
|
|
50
|
-
def set_app_context(app_slug:
|
|
50
|
+
def set_app_context(app_slug: str | None = None, **kwargs: Any) -> None:
|
|
51
51
|
"""
|
|
52
52
|
Set app context for logging.
|
|
53
53
|
|
|
@@ -64,14 +64,14 @@ def clear_app_context() -> None:
|
|
|
64
64
|
_app_context.set(None)
|
|
65
65
|
|
|
66
66
|
|
|
67
|
-
def get_logging_context() ->
|
|
67
|
+
def get_logging_context() -> dict[str, Any]:
|
|
68
68
|
"""
|
|
69
69
|
Get current logging context (correlation ID and app context).
|
|
70
70
|
|
|
71
71
|
Returns:
|
|
72
72
|
Dictionary with context information
|
|
73
73
|
"""
|
|
74
|
-
context:
|
|
74
|
+
context: dict[str, Any] = {
|
|
75
75
|
"timestamp": datetime.now().isoformat(),
|
|
76
76
|
}
|
|
77
77
|
|
|
@@ -91,7 +91,7 @@ class ContextualLoggerAdapter(logging.LoggerAdapter):
|
|
|
91
91
|
Logger adapter that automatically adds context to log records.
|
|
92
92
|
"""
|
|
93
93
|
|
|
94
|
-
def process(self, msg: str, kwargs:
|
|
94
|
+
def process(self, msg: str, kwargs: dict[str, Any]) -> tuple[str, dict[str, Any]]:
|
|
95
95
|
"""Add context to log records."""
|
|
96
96
|
# Get base context
|
|
97
97
|
context = get_logging_context()
|
|
@@ -124,7 +124,7 @@ def log_operation(
|
|
|
124
124
|
operation: str,
|
|
125
125
|
level: int = logging.INFO,
|
|
126
126
|
success: bool = True,
|
|
127
|
-
duration_ms:
|
|
127
|
+
duration_ms: float | None = None,
|
|
128
128
|
**context: Any,
|
|
129
129
|
) -> None:
|
|
130
130
|
"""
|
|
@@ -9,9 +9,10 @@ import logging
|
|
|
9
9
|
import threading
|
|
10
10
|
import time
|
|
11
11
|
from collections import OrderedDict
|
|
12
|
+
from collections.abc import Callable
|
|
12
13
|
from dataclasses import dataclass
|
|
13
14
|
from datetime import datetime
|
|
14
|
-
from typing import Any
|
|
15
|
+
from typing import Any
|
|
15
16
|
|
|
16
17
|
logger = logging.getLogger(__name__)
|
|
17
18
|
|
|
@@ -26,7 +27,7 @@ class OperationMetrics:
|
|
|
26
27
|
min_duration_ms: float = float("inf")
|
|
27
28
|
max_duration_ms: float = 0.0
|
|
28
29
|
error_count: int = 0
|
|
29
|
-
last_execution:
|
|
30
|
+
last_execution: datetime | None = None
|
|
30
31
|
|
|
31
32
|
@property
|
|
32
33
|
def avg_duration_ms(self) -> float:
|
|
@@ -48,7 +49,7 @@ class OperationMetrics:
|
|
|
48
49
|
self.error_count += 1
|
|
49
50
|
self.last_execution = datetime.now()
|
|
50
51
|
|
|
51
|
-
def to_dict(self) ->
|
|
52
|
+
def to_dict(self) -> dict[str, Any]:
|
|
52
53
|
"""Convert metrics to dictionary."""
|
|
53
54
|
return {
|
|
54
55
|
"operation": self.operation_name,
|
|
@@ -121,7 +122,7 @@ class MetricsCollector:
|
|
|
121
122
|
|
|
122
123
|
self._metrics[key].record(duration_ms, success)
|
|
123
124
|
|
|
124
|
-
def get_metrics(self, operation_name:
|
|
125
|
+
def get_metrics(self, operation_name: str | None = None) -> dict[str, Any]:
|
|
125
126
|
"""
|
|
126
127
|
Get metrics for operations.
|
|
127
128
|
|
|
@@ -154,7 +155,7 @@ class MetricsCollector:
|
|
|
154
155
|
"total_operations": total_operations,
|
|
155
156
|
}
|
|
156
157
|
|
|
157
|
-
def get_summary(self) ->
|
|
158
|
+
def get_summary(self) -> dict[str, Any]:
|
|
158
159
|
"""
|
|
159
160
|
Get a summary of all metrics.
|
|
160
161
|
|
|
@@ -170,7 +171,7 @@ class MetricsCollector:
|
|
|
170
171
|
}
|
|
171
172
|
|
|
172
173
|
# Aggregate by base operation name (without tags)
|
|
173
|
-
aggregated:
|
|
174
|
+
aggregated: dict[str, OperationMetrics] = {}
|
|
174
175
|
|
|
175
176
|
for _key, metric in self._metrics.items():
|
|
176
177
|
base_name = metric.operation_name
|
|
@@ -217,7 +218,7 @@ class MetricsCollector:
|
|
|
217
218
|
|
|
218
219
|
|
|
219
220
|
# Global metrics collector instance
|
|
220
|
-
_metrics_collector:
|
|
221
|
+
_metrics_collector: MetricsCollector | None = None
|
|
221
222
|
|
|
222
223
|
|
|
223
224
|
def get_metrics_collector() -> MetricsCollector:
|
mdb_engine/repositories/base.py
CHANGED
|
@@ -8,7 +8,7 @@ This allows domain services to work with any data store implementation.
|
|
|
8
8
|
from abc import ABC, abstractmethod
|
|
9
9
|
from dataclasses import dataclass, field
|
|
10
10
|
from datetime import datetime
|
|
11
|
-
from typing import Any,
|
|
11
|
+
from typing import Any, Generic, TypeVar
|
|
12
12
|
|
|
13
13
|
from bson import ObjectId
|
|
14
14
|
|
|
@@ -28,11 +28,11 @@ class Entity:
|
|
|
28
28
|
role: str = "user"
|
|
29
29
|
"""
|
|
30
30
|
|
|
31
|
-
id:
|
|
32
|
-
created_at:
|
|
33
|
-
updated_at:
|
|
31
|
+
id: str | None = None
|
|
32
|
+
created_at: datetime | None = field(default=None)
|
|
33
|
+
updated_at: datetime | None = field(default=None)
|
|
34
34
|
|
|
35
|
-
def to_dict(self) ->
|
|
35
|
+
def to_dict(self) -> dict[str, Any]:
|
|
36
36
|
"""Convert entity to dictionary for storage."""
|
|
37
37
|
data = {}
|
|
38
38
|
for key, value in self.__dict__.items():
|
|
@@ -48,7 +48,7 @@ class Entity:
|
|
|
48
48
|
return data
|
|
49
49
|
|
|
50
50
|
@classmethod
|
|
51
|
-
def from_dict(cls, data:
|
|
51
|
+
def from_dict(cls, data: dict[str, Any]) -> "Entity":
|
|
52
52
|
"""Create entity from dictionary (e.g., from database)."""
|
|
53
53
|
if data is None:
|
|
54
54
|
return None
|
|
@@ -88,7 +88,7 @@ class Repository(ABC, Generic[T]):
|
|
|
88
88
|
"""
|
|
89
89
|
|
|
90
90
|
@abstractmethod
|
|
91
|
-
async def get(self, id: str) ->
|
|
91
|
+
async def get(self, id: str) -> T | None:
|
|
92
92
|
"""
|
|
93
93
|
Get a single entity by ID.
|
|
94
94
|
|
|
@@ -103,11 +103,11 @@ class Repository(ABC, Generic[T]):
|
|
|
103
103
|
@abstractmethod
|
|
104
104
|
async def find(
|
|
105
105
|
self,
|
|
106
|
-
filter:
|
|
106
|
+
filter: dict[str, Any] | None = None,
|
|
107
107
|
skip: int = 0,
|
|
108
108
|
limit: int = 100,
|
|
109
|
-
sort:
|
|
110
|
-
) ->
|
|
109
|
+
sort: list[tuple] | None = None,
|
|
110
|
+
) -> list[T]:
|
|
111
111
|
"""
|
|
112
112
|
Find entities matching a filter.
|
|
113
113
|
|
|
@@ -125,8 +125,8 @@ class Repository(ABC, Generic[T]):
|
|
|
125
125
|
@abstractmethod
|
|
126
126
|
async def find_one(
|
|
127
127
|
self,
|
|
128
|
-
filter:
|
|
129
|
-
) ->
|
|
128
|
+
filter: dict[str, Any],
|
|
129
|
+
) -> T | None:
|
|
130
130
|
"""
|
|
131
131
|
Find a single entity matching a filter.
|
|
132
132
|
|
|
@@ -152,7 +152,7 @@ class Repository(ABC, Generic[T]):
|
|
|
152
152
|
pass
|
|
153
153
|
|
|
154
154
|
@abstractmethod
|
|
155
|
-
async def add_many(self, entities:
|
|
155
|
+
async def add_many(self, entities: list[T]) -> list[str]:
|
|
156
156
|
"""
|
|
157
157
|
Add multiple entities.
|
|
158
158
|
|
|
@@ -179,7 +179,7 @@ class Repository(ABC, Generic[T]):
|
|
|
179
179
|
pass
|
|
180
180
|
|
|
181
181
|
@abstractmethod
|
|
182
|
-
async def update_fields(self, id: str, fields:
|
|
182
|
+
async def update_fields(self, id: str, fields: dict[str, Any]) -> bool:
|
|
183
183
|
"""
|
|
184
184
|
Update specific fields of an entity.
|
|
185
185
|
|
|
@@ -206,7 +206,7 @@ class Repository(ABC, Generic[T]):
|
|
|
206
206
|
pass
|
|
207
207
|
|
|
208
208
|
@abstractmethod
|
|
209
|
-
async def count(self, filter:
|
|
209
|
+
async def count(self, filter: dict[str, Any] | None = None) -> int:
|
|
210
210
|
"""
|
|
211
211
|
Count entities matching a filter.
|
|
212
212
|
|
|
@@ -242,10 +242,10 @@ class InMemoryRepository(Repository[T]):
|
|
|
242
242
|
|
|
243
243
|
def __init__(self, entity_class: type):
|
|
244
244
|
self._entity_class = entity_class
|
|
245
|
-
self._storage:
|
|
245
|
+
self._storage: dict[str, dict[str, Any]] = {}
|
|
246
246
|
self._counter = 0
|
|
247
247
|
|
|
248
|
-
async def get(self, id: str) ->
|
|
248
|
+
async def get(self, id: str) -> T | None:
|
|
249
249
|
data = self._storage.get(id)
|
|
250
250
|
if data is None:
|
|
251
251
|
return None
|
|
@@ -253,11 +253,11 @@ class InMemoryRepository(Repository[T]):
|
|
|
253
253
|
|
|
254
254
|
async def find(
|
|
255
255
|
self,
|
|
256
|
-
filter:
|
|
256
|
+
filter: dict[str, Any] | None = None,
|
|
257
257
|
skip: int = 0,
|
|
258
258
|
limit: int = 100,
|
|
259
|
-
sort:
|
|
260
|
-
) ->
|
|
259
|
+
sort: list[tuple] | None = None,
|
|
260
|
+
) -> list[T]:
|
|
261
261
|
results = []
|
|
262
262
|
for data in self._storage.values():
|
|
263
263
|
if filter is None or self._matches_filter(data, filter):
|
|
@@ -266,7 +266,7 @@ class InMemoryRepository(Repository[T]):
|
|
|
266
266
|
# Apply skip and limit
|
|
267
267
|
return results[skip : skip + limit]
|
|
268
268
|
|
|
269
|
-
async def find_one(self, filter:
|
|
269
|
+
async def find_one(self, filter: dict[str, Any]) -> T | None:
|
|
270
270
|
results = await self.find(filter, limit=1)
|
|
271
271
|
return results[0] if results else None
|
|
272
272
|
|
|
@@ -278,7 +278,7 @@ class InMemoryRepository(Repository[T]):
|
|
|
278
278
|
self._storage[id] = entity.to_dict()
|
|
279
279
|
return id
|
|
280
280
|
|
|
281
|
-
async def add_many(self, entities:
|
|
281
|
+
async def add_many(self, entities: list[T]) -> list[str]:
|
|
282
282
|
return [await self.add(e) for e in entities]
|
|
283
283
|
|
|
284
284
|
async def update(self, id: str, entity: T) -> bool:
|
|
@@ -289,7 +289,7 @@ class InMemoryRepository(Repository[T]):
|
|
|
289
289
|
self._storage[id] = entity.to_dict()
|
|
290
290
|
return True
|
|
291
291
|
|
|
292
|
-
async def update_fields(self, id: str, fields:
|
|
292
|
+
async def update_fields(self, id: str, fields: dict[str, Any]) -> bool:
|
|
293
293
|
if id not in self._storage:
|
|
294
294
|
return False
|
|
295
295
|
self._storage[id].update(fields)
|
|
@@ -302,7 +302,7 @@ class InMemoryRepository(Repository[T]):
|
|
|
302
302
|
del self._storage[id]
|
|
303
303
|
return True
|
|
304
304
|
|
|
305
|
-
async def count(self, filter:
|
|
305
|
+
async def count(self, filter: dict[str, Any] | None = None) -> int:
|
|
306
306
|
if filter is None:
|
|
307
307
|
return len(self._storage)
|
|
308
308
|
return len(await self.find(filter, limit=999999))
|
|
@@ -310,7 +310,7 @@ class InMemoryRepository(Repository[T]):
|
|
|
310
310
|
async def exists(self, id: str) -> bool:
|
|
311
311
|
return id in self._storage
|
|
312
312
|
|
|
313
|
-
def _matches_filter(self, data:
|
|
313
|
+
def _matches_filter(self, data: dict[str, Any], filter: dict[str, Any]) -> bool:
|
|
314
314
|
"""Simple filter matching for testing."""
|
|
315
315
|
for key, value in filter.items():
|
|
316
316
|
if key not in data:
|
mdb_engine/repositories/mongo.py
CHANGED
|
@@ -7,7 +7,7 @@ This provides automatic app scoping and security features.
|
|
|
7
7
|
|
|
8
8
|
import logging
|
|
9
9
|
from datetime import datetime
|
|
10
|
-
from typing import Any,
|
|
10
|
+
from typing import Any, Generic, TypeVar
|
|
11
11
|
|
|
12
12
|
from bson import ObjectId
|
|
13
13
|
|
|
@@ -41,7 +41,7 @@ class MongoRepository(Repository[T], Generic[T]):
|
|
|
41
41
|
def __init__(
|
|
42
42
|
self,
|
|
43
43
|
collection: Any, # ScopedCollectionWrapper - avoid import cycle
|
|
44
|
-
entity_class:
|
|
44
|
+
entity_class: type[T],
|
|
45
45
|
):
|
|
46
46
|
"""
|
|
47
47
|
Initialize the MongoDB repository.
|
|
@@ -53,20 +53,20 @@ class MongoRepository(Repository[T], Generic[T]):
|
|
|
53
53
|
self._collection = collection
|
|
54
54
|
self._entity_class = entity_class
|
|
55
55
|
|
|
56
|
-
def _to_entity(self, doc:
|
|
56
|
+
def _to_entity(self, doc: dict[str, Any] | None) -> T | None:
|
|
57
57
|
"""Convert a MongoDB document to an entity."""
|
|
58
58
|
if doc is None:
|
|
59
59
|
return None
|
|
60
60
|
return self._entity_class.from_dict(doc)
|
|
61
61
|
|
|
62
|
-
def _to_document(self, entity: T, include_id: bool = False) ->
|
|
62
|
+
def _to_document(self, entity: T, include_id: bool = False) -> dict[str, Any]:
|
|
63
63
|
"""Convert an entity to a MongoDB document."""
|
|
64
64
|
doc = entity.to_dict()
|
|
65
65
|
if not include_id and "_id" in doc:
|
|
66
66
|
del doc["_id"]
|
|
67
67
|
return doc
|
|
68
68
|
|
|
69
|
-
async def get(self, id: str) ->
|
|
69
|
+
async def get(self, id: str) -> T | None:
|
|
70
70
|
"""Get entity by ID."""
|
|
71
71
|
try:
|
|
72
72
|
object_id = ObjectId(id) if ObjectId.is_valid(id) else id
|
|
@@ -78,11 +78,11 @@ class MongoRepository(Repository[T], Generic[T]):
|
|
|
78
78
|
|
|
79
79
|
async def find(
|
|
80
80
|
self,
|
|
81
|
-
filter:
|
|
81
|
+
filter: dict[str, Any] | None = None,
|
|
82
82
|
skip: int = 0,
|
|
83
83
|
limit: int = 100,
|
|
84
|
-
sort:
|
|
85
|
-
) ->
|
|
84
|
+
sort: list[tuple] | None = None,
|
|
85
|
+
) -> list[T]:
|
|
86
86
|
"""Find entities matching a filter."""
|
|
87
87
|
cursor = self._collection.find(filter or {})
|
|
88
88
|
|
|
@@ -96,7 +96,7 @@ class MongoRepository(Repository[T], Generic[T]):
|
|
|
96
96
|
docs = await cursor.to_list(length=limit)
|
|
97
97
|
return [self._to_entity(doc) for doc in docs]
|
|
98
98
|
|
|
99
|
-
async def find_one(self, filter:
|
|
99
|
+
async def find_one(self, filter: dict[str, Any]) -> T | None:
|
|
100
100
|
"""Find a single entity matching a filter."""
|
|
101
101
|
doc = await self._collection.find_one(filter)
|
|
102
102
|
return self._to_entity(doc)
|
|
@@ -112,7 +112,7 @@ class MongoRepository(Repository[T], Generic[T]):
|
|
|
112
112
|
logger.debug(f"Added {self._entity_class.__name__} with id={entity.id}")
|
|
113
113
|
return entity.id
|
|
114
114
|
|
|
115
|
-
async def add_many(self, entities:
|
|
115
|
+
async def add_many(self, entities: list[T]) -> list[str]:
|
|
116
116
|
"""Add multiple entities and return their IDs."""
|
|
117
117
|
now = datetime.utcnow()
|
|
118
118
|
docs = []
|
|
@@ -124,7 +124,7 @@ class MongoRepository(Repository[T], Generic[T]):
|
|
|
124
124
|
result = await self._collection.insert_many(docs)
|
|
125
125
|
ids = [str(id) for id in result.inserted_ids]
|
|
126
126
|
|
|
127
|
-
for entity, id in zip(entities, ids):
|
|
127
|
+
for entity, id in zip(entities, ids, strict=False):
|
|
128
128
|
entity.id = id
|
|
129
129
|
|
|
130
130
|
logger.debug(f"Added {len(ids)} {self._entity_class.__name__} entities")
|
|
@@ -144,7 +144,7 @@ class MongoRepository(Repository[T], Generic[T]):
|
|
|
144
144
|
|
|
145
145
|
return result.modified_count > 0
|
|
146
146
|
|
|
147
|
-
async def update_fields(self, id: str, fields:
|
|
147
|
+
async def update_fields(self, id: str, fields: dict[str, Any]) -> bool:
|
|
148
148
|
"""Update specific fields of an entity."""
|
|
149
149
|
try:
|
|
150
150
|
object_id = ObjectId(id) if ObjectId.is_valid(id) else id
|
|
@@ -167,7 +167,7 @@ class MongoRepository(Repository[T], Generic[T]):
|
|
|
167
167
|
result = await self._collection.delete_one({"_id": object_id})
|
|
168
168
|
return result.deleted_count > 0
|
|
169
169
|
|
|
170
|
-
async def count(self, filter:
|
|
170
|
+
async def count(self, filter: dict[str, Any] | None = None) -> int:
|
|
171
171
|
"""Count entities matching a filter."""
|
|
172
172
|
return await self._collection.count_documents(filter or {})
|
|
173
173
|
|
|
@@ -183,7 +183,7 @@ class MongoRepository(Repository[T], Generic[T]):
|
|
|
183
183
|
|
|
184
184
|
# Additional MongoDB-specific methods
|
|
185
185
|
|
|
186
|
-
async def aggregate(self, pipeline:
|
|
186
|
+
async def aggregate(self, pipeline: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
187
187
|
"""
|
|
188
188
|
Run an aggregation pipeline.
|
|
189
189
|
|
|
@@ -198,8 +198,8 @@ class MongoRepository(Repository[T], Generic[T]):
|
|
|
198
198
|
|
|
199
199
|
async def update_many(
|
|
200
200
|
self,
|
|
201
|
-
filter:
|
|
202
|
-
update:
|
|
201
|
+
filter: dict[str, Any],
|
|
202
|
+
update: dict[str, Any],
|
|
203
203
|
) -> int:
|
|
204
204
|
"""
|
|
205
205
|
Update multiple documents matching a filter.
|
|
@@ -219,7 +219,7 @@ class MongoRepository(Repository[T], Generic[T]):
|
|
|
219
219
|
result = await self._collection.update_many(filter, update)
|
|
220
220
|
return result.modified_count
|
|
221
221
|
|
|
222
|
-
async def delete_many(self, filter:
|
|
222
|
+
async def delete_many(self, filter: dict[str, Any]) -> int:
|
|
223
223
|
"""
|
|
224
224
|
Delete multiple documents matching a filter.
|
|
225
225
|
|
|
@@ -6,7 +6,7 @@ The UnitOfWork acts as a factory for repositories and manages their lifecycle.
|
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
8
|
import logging
|
|
9
|
-
from typing import Any,
|
|
9
|
+
from typing import Any, Generic, TypeVar
|
|
10
10
|
|
|
11
11
|
from .base import Entity, Repository
|
|
12
12
|
from .mongo import MongoRepository
|
|
@@ -49,7 +49,7 @@ class UnitOfWork:
|
|
|
49
49
|
def __init__(
|
|
50
50
|
self,
|
|
51
51
|
db: Any, # ScopedMongoWrapper - avoid import cycle
|
|
52
|
-
entity_registry:
|
|
52
|
+
entity_registry: dict[str, type[Entity]] | None = None,
|
|
53
53
|
):
|
|
54
54
|
"""
|
|
55
55
|
Initialize the Unit of Work.
|
|
@@ -59,10 +59,10 @@ class UnitOfWork:
|
|
|
59
59
|
entity_registry: Optional mapping of collection names to entity classes
|
|
60
60
|
"""
|
|
61
61
|
self._db = db
|
|
62
|
-
self._repositories:
|
|
63
|
-
self._entity_registry:
|
|
62
|
+
self._repositories: dict[str, Repository] = {}
|
|
63
|
+
self._entity_registry: dict[str, type[Entity]] = entity_registry or {}
|
|
64
64
|
|
|
65
|
-
def register_entity(self, collection_name: str, entity_class:
|
|
65
|
+
def register_entity(self, collection_name: str, entity_class: type[Entity]) -> None:
|
|
66
66
|
"""
|
|
67
67
|
Register an entity class for a collection.
|
|
68
68
|
|
|
@@ -77,7 +77,7 @@ class UnitOfWork:
|
|
|
77
77
|
def repository(
|
|
78
78
|
self,
|
|
79
79
|
name: str,
|
|
80
|
-
entity_class:
|
|
80
|
+
entity_class: type[T] | None = None,
|
|
81
81
|
) -> Repository[T]:
|
|
82
82
|
"""
|
|
83
83
|
Get or create a repository for a collection.
|