mdb-engine 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mdb_engine/auth/audit.py +40 -40
- mdb_engine/auth/base.py +3 -3
- mdb_engine/auth/casbin_factory.py +6 -6
- mdb_engine/auth/config_defaults.py +5 -5
- mdb_engine/auth/config_helpers.py +12 -12
- mdb_engine/auth/cookie_utils.py +9 -9
- mdb_engine/auth/csrf.py +9 -8
- mdb_engine/auth/decorators.py +7 -6
- mdb_engine/auth/dependencies.py +22 -21
- mdb_engine/auth/integration.py +9 -9
- mdb_engine/auth/jwt.py +9 -9
- mdb_engine/auth/middleware.py +4 -3
- mdb_engine/auth/oso_factory.py +6 -6
- mdb_engine/auth/provider.py +4 -4
- mdb_engine/auth/rate_limiter.py +12 -11
- mdb_engine/auth/restrictions.py +16 -15
- mdb_engine/auth/session_manager.py +11 -13
- mdb_engine/auth/shared_middleware.py +16 -15
- mdb_engine/auth/shared_users.py +20 -20
- mdb_engine/auth/token_lifecycle.py +10 -12
- mdb_engine/auth/token_store.py +4 -5
- mdb_engine/auth/users.py +51 -52
- mdb_engine/auth/utils.py +29 -33
- mdb_engine/cli/commands/generate.py +6 -6
- mdb_engine/cli/utils.py +4 -4
- mdb_engine/config.py +6 -7
- mdb_engine/core/app_registration.py +12 -12
- mdb_engine/core/app_secrets.py +1 -2
- mdb_engine/core/connection.py +3 -4
- mdb_engine/core/encryption.py +1 -2
- mdb_engine/core/engine.py +43 -44
- mdb_engine/core/manifest.py +59 -58
- mdb_engine/core/ray_integration.py +10 -9
- mdb_engine/core/seeding.py +3 -3
- mdb_engine/core/service_initialization.py +10 -9
- mdb_engine/core/types.py +40 -40
- mdb_engine/database/abstraction.py +15 -16
- mdb_engine/database/connection.py +40 -12
- mdb_engine/database/query_validator.py +8 -8
- mdb_engine/database/resource_limiter.py +7 -7
- mdb_engine/database/scoped_wrapper.py +51 -58
- mdb_engine/dependencies.py +14 -13
- mdb_engine/di/container.py +12 -13
- mdb_engine/di/providers.py +14 -13
- mdb_engine/di/scopes.py +5 -5
- mdb_engine/embeddings/dependencies.py +2 -2
- mdb_engine/embeddings/service.py +31 -43
- mdb_engine/exceptions.py +20 -20
- mdb_engine/indexes/helpers.py +11 -11
- mdb_engine/indexes/manager.py +9 -9
- mdb_engine/memory/service.py +30 -30
- mdb_engine/observability/health.py +10 -9
- mdb_engine/observability/logging.py +10 -10
- mdb_engine/observability/metrics.py +8 -7
- mdb_engine/repositories/base.py +25 -25
- mdb_engine/repositories/mongo.py +17 -17
- mdb_engine/repositories/unit_of_work.py +6 -6
- mdb_engine/routing/websockets.py +19 -18
- {mdb_engine-0.2.1.dist-info → mdb_engine-0.2.3.dist-info}/METADATA +8 -8
- mdb_engine-0.2.3.dist-info/RECORD +96 -0
- mdb_engine-0.2.1.dist-info/RECORD +0 -96
- {mdb_engine-0.2.1.dist-info → mdb_engine-0.2.3.dist-info}/WHEEL +0 -0
- {mdb_engine-0.2.1.dist-info → mdb_engine-0.2.3.dist-info}/entry_points.txt +0 -0
- {mdb_engine-0.2.1.dist-info → mdb_engine-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {mdb_engine-0.2.1.dist-info → mdb_engine-0.2.3.dist-info}/top_level.txt +0 -0
mdb_engine/embeddings/service.py
CHANGED
|
@@ -23,7 +23,7 @@ import os
|
|
|
23
23
|
import time
|
|
24
24
|
from abc import ABC, abstractmethod
|
|
25
25
|
from datetime import datetime
|
|
26
|
-
from typing import Any
|
|
26
|
+
from typing import Any
|
|
27
27
|
|
|
28
28
|
# Optional OpenAI SDK import
|
|
29
29
|
try:
|
|
@@ -59,9 +59,7 @@ class BaseEmbeddingProvider(ABC):
|
|
|
59
59
|
"""
|
|
60
60
|
|
|
61
61
|
@abstractmethod
|
|
62
|
-
async def embed(
|
|
63
|
-
self, text: Union[str, List[str]], model: Optional[str] = None
|
|
64
|
-
) -> List[List[float]]:
|
|
62
|
+
async def embed(self, text: str | list[str], model: str | None = None) -> list[list[float]]:
|
|
65
63
|
"""
|
|
66
64
|
Generate embeddings for text.
|
|
67
65
|
|
|
@@ -84,7 +82,7 @@ class OpenAIEmbeddingProvider(BaseEmbeddingProvider):
|
|
|
84
82
|
|
|
85
83
|
def __init__(
|
|
86
84
|
self,
|
|
87
|
-
api_key:
|
|
85
|
+
api_key: str | None = None,
|
|
88
86
|
default_model: str = "text-embedding-3-small",
|
|
89
87
|
):
|
|
90
88
|
"""
|
|
@@ -108,9 +106,7 @@ class OpenAIEmbeddingProvider(BaseEmbeddingProvider):
|
|
|
108
106
|
self.client = AsyncOpenAI(api_key=api_key)
|
|
109
107
|
self.default_model = default_model
|
|
110
108
|
|
|
111
|
-
async def embed(
|
|
112
|
-
self, text: Union[str, List[str]], model: Optional[str] = None
|
|
113
|
-
) -> List[List[float]]:
|
|
109
|
+
async def embed(self, text: str | list[str], model: str | None = None) -> list[list[float]]:
|
|
114
110
|
"""Generate embeddings using OpenAI."""
|
|
115
111
|
model = model or self.default_model
|
|
116
112
|
|
|
@@ -149,9 +145,9 @@ class AzureOpenAIEmbeddingProvider(BaseEmbeddingProvider):
|
|
|
149
145
|
|
|
150
146
|
def __init__(
|
|
151
147
|
self,
|
|
152
|
-
api_key:
|
|
153
|
-
endpoint:
|
|
154
|
-
api_version:
|
|
148
|
+
api_key: str | None = None,
|
|
149
|
+
endpoint: str | None = None,
|
|
150
|
+
api_version: str | None = None,
|
|
155
151
|
default_model: str = "text-embedding-3-small",
|
|
156
152
|
):
|
|
157
153
|
"""
|
|
@@ -191,9 +187,7 @@ class AzureOpenAIEmbeddingProvider(BaseEmbeddingProvider):
|
|
|
191
187
|
)
|
|
192
188
|
self.default_model = default_model
|
|
193
189
|
|
|
194
|
-
async def embed(
|
|
195
|
-
self, text: Union[str, List[str]], model: Optional[str] = None
|
|
196
|
-
) -> List[List[float]]:
|
|
190
|
+
async def embed(self, text: str | list[str], model: str | None = None) -> list[list[float]]:
|
|
197
191
|
"""Generate embeddings using Azure OpenAI."""
|
|
198
192
|
model = model or self.default_model
|
|
199
193
|
|
|
@@ -255,8 +249,8 @@ class EmbeddingProvider:
|
|
|
255
249
|
|
|
256
250
|
def __init__(
|
|
257
251
|
self,
|
|
258
|
-
embedding_provider:
|
|
259
|
-
config:
|
|
252
|
+
embedding_provider: BaseEmbeddingProvider | None = None,
|
|
253
|
+
config: dict[str, Any] | None = None,
|
|
260
254
|
):
|
|
261
255
|
"""
|
|
262
256
|
Initialize Embedding Provider.
|
|
@@ -293,9 +287,7 @@ class EmbeddingProvider:
|
|
|
293
287
|
# Store config for potential future use
|
|
294
288
|
self.config = config or {}
|
|
295
289
|
|
|
296
|
-
async def embed(
|
|
297
|
-
self, text: Union[str, List[str]], model: Optional[str] = None
|
|
298
|
-
) -> List[List[float]]:
|
|
290
|
+
async def embed(self, text: str | list[str], model: str | None = None) -> list[list[float]]:
|
|
299
291
|
"""
|
|
300
292
|
Generates vector embeddings for a string or list of strings.
|
|
301
293
|
|
|
@@ -361,10 +353,10 @@ class EmbeddingService:
|
|
|
361
353
|
|
|
362
354
|
def __init__(
|
|
363
355
|
self,
|
|
364
|
-
embedding_provider:
|
|
356
|
+
embedding_provider: EmbeddingProvider | None = None,
|
|
365
357
|
default_max_tokens: int = 1000,
|
|
366
358
|
default_tokenizer_model: str = "gpt-3.5-turbo",
|
|
367
|
-
config:
|
|
359
|
+
config: dict[str, Any] | None = None,
|
|
368
360
|
):
|
|
369
361
|
"""
|
|
370
362
|
Initialize Embedding Service.
|
|
@@ -397,9 +389,7 @@ class EmbeddingService:
|
|
|
397
389
|
self.default_max_tokens = default_max_tokens
|
|
398
390
|
self.default_tokenizer_model = default_tokenizer_model
|
|
399
391
|
|
|
400
|
-
def _create_splitter(
|
|
401
|
-
self, max_tokens: int, tokenizer_model: Optional[str] = None
|
|
402
|
-
) -> TextSplitter:
|
|
392
|
+
def _create_splitter(self, max_tokens: int, tokenizer_model: str | None = None) -> TextSplitter:
|
|
403
393
|
"""
|
|
404
394
|
Create a TextSplitter instance.
|
|
405
395
|
|
|
@@ -419,9 +409,9 @@ class EmbeddingService:
|
|
|
419
409
|
async def chunk_text(
|
|
420
410
|
self,
|
|
421
411
|
text_content: str,
|
|
422
|
-
max_tokens:
|
|
423
|
-
tokenizer_model:
|
|
424
|
-
) ->
|
|
412
|
+
max_tokens: int | None = None,
|
|
413
|
+
tokenizer_model: str | None = None,
|
|
414
|
+
) -> list[str]:
|
|
425
415
|
"""
|
|
426
416
|
Split text into semantic chunks.
|
|
427
417
|
|
|
@@ -455,9 +445,7 @@ class EmbeddingService:
|
|
|
455
445
|
logger.error(f"Error chunking text: {e}", exc_info=True)
|
|
456
446
|
raise EmbeddingServiceError(f"Chunking failed: {str(e)}") from e
|
|
457
447
|
|
|
458
|
-
async def embed_chunks(
|
|
459
|
-
self, chunks: List[str], model: Optional[str] = None
|
|
460
|
-
) -> List[List[float]]:
|
|
448
|
+
async def embed_chunks(self, chunks: list[str], model: str | None = None) -> list[list[float]]:
|
|
461
449
|
"""
|
|
462
450
|
Generate embeddings for text chunks.
|
|
463
451
|
|
|
@@ -498,11 +486,11 @@ class EmbeddingService:
|
|
|
498
486
|
text_content: str,
|
|
499
487
|
source_id: str,
|
|
500
488
|
collection: Any, # MongoDB collection (AppDB Collection or Motor collection)
|
|
501
|
-
max_tokens:
|
|
502
|
-
tokenizer_model:
|
|
503
|
-
embedding_model:
|
|
504
|
-
metadata:
|
|
505
|
-
) ->
|
|
489
|
+
max_tokens: int | None = None,
|
|
490
|
+
tokenizer_model: str | None = None,
|
|
491
|
+
embedding_model: str | None = None,
|
|
492
|
+
metadata: dict[str, Any] | None = None,
|
|
493
|
+
) -> dict[str, Any]:
|
|
506
494
|
"""
|
|
507
495
|
Process text and store chunks with embeddings in MongoDB.
|
|
508
496
|
|
|
@@ -573,7 +561,7 @@ class EmbeddingService:
|
|
|
573
561
|
|
|
574
562
|
# Step 3: Prepare documents for insertion
|
|
575
563
|
documents_to_insert = []
|
|
576
|
-
for i, (chunk_text, vector) in enumerate(zip(chunks, vectors)):
|
|
564
|
+
for i, (chunk_text, vector) in enumerate(zip(chunks, vectors, strict=False)):
|
|
577
565
|
doc = {
|
|
578
566
|
"source_id": source_id,
|
|
579
567
|
"chunk_index": i,
|
|
@@ -626,10 +614,10 @@ class EmbeddingService:
|
|
|
626
614
|
async def process_text(
|
|
627
615
|
self,
|
|
628
616
|
text_content: str,
|
|
629
|
-
max_tokens:
|
|
630
|
-
tokenizer_model:
|
|
631
|
-
embedding_model:
|
|
632
|
-
) ->
|
|
617
|
+
max_tokens: int | None = None,
|
|
618
|
+
tokenizer_model: str | None = None,
|
|
619
|
+
embedding_model: str | None = None,
|
|
620
|
+
) -> list[dict[str, Any]]:
|
|
633
621
|
"""
|
|
634
622
|
Process text and return chunks with embeddings (without storing).
|
|
635
623
|
|
|
@@ -673,7 +661,7 @@ class EmbeddingService:
|
|
|
673
661
|
|
|
674
662
|
# Prepare results
|
|
675
663
|
results = []
|
|
676
|
-
for i, (chunk_text, vector) in enumerate(zip(chunks, vectors)):
|
|
664
|
+
for i, (chunk_text, vector) in enumerate(zip(chunks, vectors, strict=False)):
|
|
677
665
|
results.append(
|
|
678
666
|
{
|
|
679
667
|
"chunk_index": i,
|
|
@@ -692,8 +680,8 @@ class EmbeddingService:
|
|
|
692
680
|
|
|
693
681
|
# Dependency injection helper
|
|
694
682
|
def get_embedding_service(
|
|
695
|
-
embedding_provider:
|
|
696
|
-
config:
|
|
683
|
+
embedding_provider: BaseEmbeddingProvider | None = None,
|
|
684
|
+
config: dict[str, Any] | None = None,
|
|
697
685
|
) -> EmbeddingService:
|
|
698
686
|
"""
|
|
699
687
|
Create EmbeddingService instance with auto-detected or provided embedding provider.
|
mdb_engine/exceptions.py
CHANGED
|
@@ -5,7 +5,7 @@ These exceptions provide more specific error types while maintaining
|
|
|
5
5
|
backward compatibility with RuntimeError.
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
|
-
from typing import Any
|
|
8
|
+
from typing import Any
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
class MongoDBEngineError(RuntimeError):
|
|
@@ -21,7 +21,7 @@ class MongoDBEngineError(RuntimeError):
|
|
|
21
21
|
collection_name, etc.)
|
|
22
22
|
"""
|
|
23
23
|
|
|
24
|
-
def __init__(self, message: str, context:
|
|
24
|
+
def __init__(self, message: str, context: dict[str, Any] | None = None) -> None:
|
|
25
25
|
"""
|
|
26
26
|
Initialize the exception.
|
|
27
27
|
|
|
@@ -58,9 +58,9 @@ class InitializationError(MongoDBEngineError):
|
|
|
58
58
|
def __init__(
|
|
59
59
|
self,
|
|
60
60
|
message: str,
|
|
61
|
-
mongo_uri:
|
|
62
|
-
db_name:
|
|
63
|
-
context:
|
|
61
|
+
mongo_uri: str | None = None,
|
|
62
|
+
db_name: str | None = None,
|
|
63
|
+
context: dict[str, Any] | None = None,
|
|
64
64
|
) -> None:
|
|
65
65
|
"""
|
|
66
66
|
Initialize the initialization error.
|
|
@@ -99,10 +99,10 @@ class ManifestValidationError(MongoDBEngineError):
|
|
|
99
99
|
def __init__(
|
|
100
100
|
self,
|
|
101
101
|
message: str,
|
|
102
|
-
error_paths:
|
|
103
|
-
manifest_slug:
|
|
104
|
-
schema_version:
|
|
105
|
-
context:
|
|
102
|
+
error_paths: list[str] | None = None,
|
|
103
|
+
manifest_slug: str | None = None,
|
|
104
|
+
schema_version: str | None = None,
|
|
105
|
+
context: dict[str, Any] | None = None,
|
|
106
106
|
) -> None:
|
|
107
107
|
"""
|
|
108
108
|
Initialize the manifest validation error.
|
|
@@ -144,9 +144,9 @@ class ConfigurationError(MongoDBEngineError):
|
|
|
144
144
|
def __init__(
|
|
145
145
|
self,
|
|
146
146
|
message: str,
|
|
147
|
-
config_key:
|
|
148
|
-
config_value:
|
|
149
|
-
context:
|
|
147
|
+
config_key: str | None = None,
|
|
148
|
+
config_value: Any | None = None,
|
|
149
|
+
context: dict[str, Any] | None = None,
|
|
150
150
|
) -> None:
|
|
151
151
|
"""
|
|
152
152
|
Initialize the configuration error.
|
|
@@ -185,10 +185,10 @@ class QueryValidationError(MongoDBEngineError):
|
|
|
185
185
|
def __init__(
|
|
186
186
|
self,
|
|
187
187
|
message: str,
|
|
188
|
-
query_type:
|
|
189
|
-
operator:
|
|
190
|
-
path:
|
|
191
|
-
context:
|
|
188
|
+
query_type: str | None = None,
|
|
189
|
+
operator: str | None = None,
|
|
190
|
+
path: str | None = None,
|
|
191
|
+
context: dict[str, Any] | None = None,
|
|
192
192
|
) -> None:
|
|
193
193
|
"""
|
|
194
194
|
Initialize the query validation error.
|
|
@@ -231,10 +231,10 @@ class ResourceLimitExceeded(MongoDBEngineError):
|
|
|
231
231
|
def __init__(
|
|
232
232
|
self,
|
|
233
233
|
message: str,
|
|
234
|
-
limit_type:
|
|
235
|
-
limit_value:
|
|
236
|
-
actual_value:
|
|
237
|
-
context:
|
|
234
|
+
limit_type: str | None = None,
|
|
235
|
+
limit_value: Any | None = None,
|
|
236
|
+
actual_value: Any | None = None,
|
|
237
|
+
context: dict[str, Any] | None = None,
|
|
238
238
|
) -> None:
|
|
239
239
|
"""
|
|
240
240
|
Initialize the resource limit exceeded error.
|
mdb_engine/indexes/helpers.py
CHANGED
|
@@ -6,14 +6,14 @@ in index creation and management.
|
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
8
|
import logging
|
|
9
|
-
from typing import Any
|
|
9
|
+
from typing import Any
|
|
10
10
|
|
|
11
11
|
logger = logging.getLogger(__name__)
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
def normalize_keys(
|
|
15
|
-
keys:
|
|
16
|
-
) ->
|
|
15
|
+
keys: dict[str, Any] | list[tuple[str, Any]],
|
|
16
|
+
) -> list[tuple[str, Any]]:
|
|
17
17
|
"""
|
|
18
18
|
Normalize index keys to a consistent format.
|
|
19
19
|
|
|
@@ -28,7 +28,7 @@ def normalize_keys(
|
|
|
28
28
|
return keys
|
|
29
29
|
|
|
30
30
|
|
|
31
|
-
def keys_to_dict(keys:
|
|
31
|
+
def keys_to_dict(keys: dict[str, Any] | list[tuple[str, Any]]) -> dict[str, Any]:
|
|
32
32
|
"""
|
|
33
33
|
Convert index keys to dictionary format for comparison.
|
|
34
34
|
|
|
@@ -43,7 +43,7 @@ def keys_to_dict(keys: Union[Dict[str, Any], List[Tuple[str, Any]]]) -> Dict[str
|
|
|
43
43
|
return {k: v for k, v in keys}
|
|
44
44
|
|
|
45
45
|
|
|
46
|
-
def is_id_index(keys:
|
|
46
|
+
def is_id_index(keys: dict[str, Any] | list[tuple[str, Any]]) -> bool:
|
|
47
47
|
"""
|
|
48
48
|
Check if index keys target the _id field (which MongoDB creates automatically).
|
|
49
49
|
|
|
@@ -63,10 +63,10 @@ def is_id_index(keys: Union[Dict[str, Any], List[Tuple[str, Any]]]) -> bool:
|
|
|
63
63
|
async def check_and_update_index(
|
|
64
64
|
index_manager: Any,
|
|
65
65
|
index_name: str,
|
|
66
|
-
expected_keys:
|
|
67
|
-
expected_options:
|
|
66
|
+
expected_keys: dict[str, Any] | list[tuple[str, Any]],
|
|
67
|
+
expected_options: dict[str, Any] | None = None,
|
|
68
68
|
log_prefix: str = "",
|
|
69
|
-
) ->
|
|
69
|
+
) -> tuple[bool, dict[str, Any] | None]:
|
|
70
70
|
"""
|
|
71
71
|
Check if an index exists and matches the expected definition.
|
|
72
72
|
|
|
@@ -118,11 +118,11 @@ async def check_and_update_index(
|
|
|
118
118
|
|
|
119
119
|
|
|
120
120
|
def validate_index_definition_basic(
|
|
121
|
-
index_def:
|
|
121
|
+
index_def: dict[str, Any],
|
|
122
122
|
index_name: str,
|
|
123
|
-
required_fields:
|
|
123
|
+
required_fields: list[str],
|
|
124
124
|
log_prefix: str = "",
|
|
125
|
-
) ->
|
|
125
|
+
) -> tuple[bool, str | None]:
|
|
126
126
|
"""
|
|
127
127
|
Basic validation for index definitions.
|
|
128
128
|
|
mdb_engine/indexes/manager.py
CHANGED
|
@@ -8,7 +8,7 @@ This module is part of MDB_ENGINE - MongoDB Engine.
|
|
|
8
8
|
|
|
9
9
|
import json
|
|
10
10
|
import logging
|
|
11
|
-
from typing import Any
|
|
11
|
+
from typing import Any
|
|
12
12
|
|
|
13
13
|
from motor.motor_asyncio import AsyncIOMotorDatabase
|
|
14
14
|
from pymongo.errors import (
|
|
@@ -44,7 +44,7 @@ logger = logging.getLogger(__name__)
|
|
|
44
44
|
|
|
45
45
|
async def _handle_regular_index(
|
|
46
46
|
index_manager: AsyncAtlasIndexManager,
|
|
47
|
-
index_def:
|
|
47
|
+
index_def: dict[str, Any],
|
|
48
48
|
index_name: str,
|
|
49
49
|
log_prefix: str,
|
|
50
50
|
) -> None:
|
|
@@ -156,7 +156,7 @@ async def _handle_regular_index(
|
|
|
156
156
|
|
|
157
157
|
async def _handle_ttl_index(
|
|
158
158
|
index_manager: AsyncAtlasIndexManager,
|
|
159
|
-
index_def:
|
|
159
|
+
index_def: dict[str, Any],
|
|
160
160
|
index_name: str,
|
|
161
161
|
log_prefix: str,
|
|
162
162
|
) -> None:
|
|
@@ -203,7 +203,7 @@ async def _handle_ttl_index(
|
|
|
203
203
|
|
|
204
204
|
async def _handle_partial_index(
|
|
205
205
|
index_manager: AsyncAtlasIndexManager,
|
|
206
|
-
index_def:
|
|
206
|
+
index_def: dict[str, Any],
|
|
207
207
|
index_name: str,
|
|
208
208
|
log_prefix: str,
|
|
209
209
|
) -> None:
|
|
@@ -269,7 +269,7 @@ async def _handle_partial_index(
|
|
|
269
269
|
|
|
270
270
|
async def _handle_text_index(
|
|
271
271
|
index_manager: AsyncAtlasIndexManager,
|
|
272
|
-
index_def:
|
|
272
|
+
index_def: dict[str, Any],
|
|
273
273
|
index_name: str,
|
|
274
274
|
log_prefix: str,
|
|
275
275
|
) -> None:
|
|
@@ -335,7 +335,7 @@ async def _handle_text_index(
|
|
|
335
335
|
|
|
336
336
|
async def _handle_geospatial_index(
|
|
337
337
|
index_manager: AsyncAtlasIndexManager,
|
|
338
|
-
index_def:
|
|
338
|
+
index_def: dict[str, Any],
|
|
339
339
|
index_name: str,
|
|
340
340
|
log_prefix: str,
|
|
341
341
|
) -> None:
|
|
@@ -400,7 +400,7 @@ async def _handle_geospatial_index(
|
|
|
400
400
|
|
|
401
401
|
async def _handle_search_index(
|
|
402
402
|
index_manager: AsyncAtlasIndexManager,
|
|
403
|
-
index_def:
|
|
403
|
+
index_def: dict[str, Any],
|
|
404
404
|
index_name: str,
|
|
405
405
|
index_type: str,
|
|
406
406
|
slug: str,
|
|
@@ -502,7 +502,7 @@ async def _handle_search_index(
|
|
|
502
502
|
|
|
503
503
|
async def _handle_hybrid_index(
|
|
504
504
|
index_manager: AsyncAtlasIndexManager,
|
|
505
|
-
index_def:
|
|
505
|
+
index_def: dict[str, Any],
|
|
506
506
|
index_name: str,
|
|
507
507
|
slug: str,
|
|
508
508
|
log_prefix: str,
|
|
@@ -692,7 +692,7 @@ async def run_index_creation_for_collection(
|
|
|
692
692
|
db: AsyncIOMotorDatabase,
|
|
693
693
|
slug: str,
|
|
694
694
|
collection_name: str,
|
|
695
|
-
index_definitions:
|
|
695
|
+
index_definitions: list[dict[str, Any]],
|
|
696
696
|
):
|
|
697
697
|
"""Create or update indexes for a collection based on index definitions."""
|
|
698
698
|
log_prefix = f"[{slug} -> {collection_name}]"
|
mdb_engine/memory/service.py
CHANGED
|
@@ -9,7 +9,7 @@ mem0 handles embeddings and LLM via environment variables (.env).
|
|
|
9
9
|
import logging
|
|
10
10
|
import os
|
|
11
11
|
import tempfile
|
|
12
|
-
from typing import Any
|
|
12
|
+
from typing import Any
|
|
13
13
|
|
|
14
14
|
# Set MEM0_DIR environment variable early to avoid permission issues
|
|
15
15
|
# mem0 tries to create .mem0 directory at import time, so we set this before any import
|
|
@@ -74,7 +74,7 @@ def _detect_provider_from_env() -> str:
|
|
|
74
74
|
return "openai"
|
|
75
75
|
|
|
76
76
|
|
|
77
|
-
def _detect_embedding_dimensions(model_name: str) ->
|
|
77
|
+
def _detect_embedding_dimensions(model_name: str) -> int | None:
|
|
78
78
|
"""
|
|
79
79
|
Auto-detect embedding dimensions from model name.
|
|
80
80
|
|
|
@@ -123,7 +123,7 @@ class Mem0MemoryServiceError(Exception):
|
|
|
123
123
|
|
|
124
124
|
def _build_vector_store_config(
|
|
125
125
|
db_name: str, collection_name: str, mongo_uri: str, embedding_model_dims: int
|
|
126
|
-
) ->
|
|
126
|
+
) -> dict[str, Any]:
|
|
127
127
|
"""Build vector store configuration for mem0."""
|
|
128
128
|
return {
|
|
129
129
|
"vector_store": {
|
|
@@ -138,7 +138,7 @@ def _build_vector_store_config(
|
|
|
138
138
|
}
|
|
139
139
|
|
|
140
140
|
|
|
141
|
-
def _build_embedder_config(provider: str, embedding_model: str, app_slug: str) ->
|
|
141
|
+
def _build_embedder_config(provider: str, embedding_model: str, app_slug: str) -> dict[str, Any]:
|
|
142
142
|
"""Build embedder configuration for mem0."""
|
|
143
143
|
clean_embedding_model = embedding_model.replace("azure/", "").replace("openai/", "")
|
|
144
144
|
if provider == "azure":
|
|
@@ -190,7 +190,7 @@ def _build_embedder_config(provider: str, embedding_model: str, app_slug: str) -
|
|
|
190
190
|
|
|
191
191
|
def _build_llm_config(
|
|
192
192
|
provider: str, chat_model: str, temperature: float, app_slug: str
|
|
193
|
-
) ->
|
|
193
|
+
) -> dict[str, Any]:
|
|
194
194
|
"""Build LLM configuration for mem0."""
|
|
195
195
|
clean_chat_model = chat_model.replace("azure/", "").replace("openai/", "")
|
|
196
196
|
if provider == "azure":
|
|
@@ -245,7 +245,7 @@ def _build_llm_config(
|
|
|
245
245
|
return config
|
|
246
246
|
|
|
247
247
|
|
|
248
|
-
def _initialize_memory_instance(mem0_config:
|
|
248
|
+
def _initialize_memory_instance(mem0_config: dict[str, Any], app_slug: str) -> tuple:
|
|
249
249
|
"""Initialize Mem0 Memory instance and return (instance, init_method)."""
|
|
250
250
|
logger.debug(
|
|
251
251
|
"Initializing Mem0 Memory with config structure",
|
|
@@ -330,7 +330,7 @@ class Mem0MemoryService:
|
|
|
330
330
|
mongo_uri: str,
|
|
331
331
|
db_name: str,
|
|
332
332
|
app_slug: str,
|
|
333
|
-
config:
|
|
333
|
+
config: dict[str, Any] | None = None,
|
|
334
334
|
):
|
|
335
335
|
"""
|
|
336
336
|
Initialize Mem0 Memory Service.
|
|
@@ -498,11 +498,11 @@ class Mem0MemoryService:
|
|
|
498
498
|
|
|
499
499
|
def add(
|
|
500
500
|
self,
|
|
501
|
-
messages:
|
|
502
|
-
user_id:
|
|
503
|
-
metadata:
|
|
501
|
+
messages: str | list[dict[str, str]],
|
|
502
|
+
user_id: str | None = None,
|
|
503
|
+
metadata: dict[str, Any] | None = None,
|
|
504
504
|
**kwargs,
|
|
505
|
-
) ->
|
|
505
|
+
) -> list[dict[str, Any]]:
|
|
506
506
|
"""
|
|
507
507
|
Add memories from messages or text.
|
|
508
508
|
|
|
@@ -677,13 +677,13 @@ class Mem0MemoryService:
|
|
|
677
677
|
|
|
678
678
|
def get_all(
|
|
679
679
|
self,
|
|
680
|
-
user_id:
|
|
681
|
-
limit:
|
|
680
|
+
user_id: str | None = None,
|
|
681
|
+
limit: int | None = None,
|
|
682
682
|
retry_on_empty: bool = True,
|
|
683
683
|
max_retries: int = 2,
|
|
684
684
|
retry_delay: float = 0.5,
|
|
685
685
|
**kwargs,
|
|
686
|
-
) ->
|
|
686
|
+
) -> list[dict[str, Any]]:
|
|
687
687
|
"""
|
|
688
688
|
Get all memories for a user.
|
|
689
689
|
|
|
@@ -755,7 +755,7 @@ class Mem0MemoryService:
|
|
|
755
755
|
result = self.memory.get_all(
|
|
756
756
|
user_id=str(user_id), limit=limit, **kwargs
|
|
757
757
|
) # Ensure string
|
|
758
|
-
result_length = len(result) if isinstance(result,
|
|
758
|
+
result_length = len(result) if isinstance(result, list | dict) else "N/A"
|
|
759
759
|
logger.debug(
|
|
760
760
|
f"🟢 RESULT RECEIVED: type={type(result).__name__}, "
|
|
761
761
|
f"length={result_length}",
|
|
@@ -764,7 +764,7 @@ class Mem0MemoryService:
|
|
|
764
764
|
"user_id": user_id,
|
|
765
765
|
"result_type": type(result).__name__,
|
|
766
766
|
"result_length": (
|
|
767
|
-
len(result) if isinstance(result,
|
|
767
|
+
len(result) if isinstance(result, list | dict) else 0
|
|
768
768
|
),
|
|
769
769
|
"attempt": attempt + 1,
|
|
770
770
|
},
|
|
@@ -791,7 +791,7 @@ class Mem0MemoryService:
|
|
|
791
791
|
"result_type": str(type(result)),
|
|
792
792
|
"is_dict": isinstance(result, dict),
|
|
793
793
|
"is_list": isinstance(result, list),
|
|
794
|
-
"result_length": (len(result) if isinstance(result,
|
|
794
|
+
"result_length": (len(result) if isinstance(result, list | dict) else 0),
|
|
795
795
|
},
|
|
796
796
|
)
|
|
797
797
|
|
|
@@ -857,12 +857,12 @@ class Mem0MemoryService:
|
|
|
857
857
|
def search(
|
|
858
858
|
self,
|
|
859
859
|
query: str,
|
|
860
|
-
user_id:
|
|
861
|
-
limit:
|
|
862
|
-
metadata:
|
|
863
|
-
filters:
|
|
860
|
+
user_id: str | None = None,
|
|
861
|
+
limit: int | None = None,
|
|
862
|
+
metadata: dict[str, Any] | None = None,
|
|
863
|
+
filters: dict[str, Any] | None = None,
|
|
864
864
|
**kwargs,
|
|
865
|
-
) ->
|
|
865
|
+
) -> list[dict[str, Any]]:
|
|
866
866
|
"""
|
|
867
867
|
Search for relevant memories using semantic search.
|
|
868
868
|
|
|
@@ -970,7 +970,7 @@ class Mem0MemoryService:
|
|
|
970
970
|
)
|
|
971
971
|
raise Mem0MemoryServiceError(f"Failed to search memories: {e}") from e
|
|
972
972
|
|
|
973
|
-
def get(self, memory_id: str, user_id:
|
|
973
|
+
def get(self, memory_id: str, user_id: str | None = None, **kwargs) -> dict[str, Any]:
|
|
974
974
|
"""
|
|
975
975
|
Get a single memory by ID.
|
|
976
976
|
|
|
@@ -1037,11 +1037,11 @@ class Mem0MemoryService:
|
|
|
1037
1037
|
def update(
|
|
1038
1038
|
self,
|
|
1039
1039
|
memory_id: str,
|
|
1040
|
-
data:
|
|
1041
|
-
user_id:
|
|
1042
|
-
metadata:
|
|
1040
|
+
data: str | list[dict[str, str]],
|
|
1041
|
+
user_id: str | None = None,
|
|
1042
|
+
metadata: dict[str, Any] | None = None,
|
|
1043
1043
|
**kwargs,
|
|
1044
|
-
) ->
|
|
1044
|
+
) -> dict[str, Any]:
|
|
1045
1045
|
"""
|
|
1046
1046
|
Update a memory by ID with new data.
|
|
1047
1047
|
|
|
@@ -1120,7 +1120,7 @@ class Mem0MemoryService:
|
|
|
1120
1120
|
)
|
|
1121
1121
|
raise Mem0MemoryServiceError(f"Failed to update memory: {e}") from e
|
|
1122
1122
|
|
|
1123
|
-
def delete(self, memory_id: str, user_id:
|
|
1123
|
+
def delete(self, memory_id: str, user_id: str | None = None, **kwargs) -> bool:
|
|
1124
1124
|
"""
|
|
1125
1125
|
Delete a memory by ID.
|
|
1126
1126
|
|
|
@@ -1169,7 +1169,7 @@ class Mem0MemoryService:
|
|
|
1169
1169
|
)
|
|
1170
1170
|
raise Mem0MemoryServiceError(f"Failed to delete memory: {e}") from e
|
|
1171
1171
|
|
|
1172
|
-
def delete_all(self, user_id:
|
|
1172
|
+
def delete_all(self, user_id: str | None = None, **kwargs) -> bool:
|
|
1173
1173
|
"""
|
|
1174
1174
|
Delete all memories for a user.
|
|
1175
1175
|
|
|
@@ -1205,7 +1205,7 @@ class Mem0MemoryService:
|
|
|
1205
1205
|
|
|
1206
1206
|
|
|
1207
1207
|
def get_memory_service(
|
|
1208
|
-
mongo_uri: str, db_name: str, app_slug: str, config:
|
|
1208
|
+
mongo_uri: str, db_name: str, app_slug: str, config: dict[str, Any] | None = None
|
|
1209
1209
|
) -> Mem0MemoryService:
|
|
1210
1210
|
"""
|
|
1211
1211
|
Get or create a Mem0MemoryService instance (cached).
|