mem0ai-azure-mysql 0.1.115__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mem0/__init__.py +6 -0
- mem0/client/__init__.py +0 -0
- mem0/client/main.py +1535 -0
- mem0/client/project.py +860 -0
- mem0/client/utils.py +29 -0
- mem0/configs/__init__.py +0 -0
- mem0/configs/base.py +90 -0
- mem0/configs/dbs/__init__.py +4 -0
- mem0/configs/dbs/base.py +41 -0
- mem0/configs/dbs/mysql.py +25 -0
- mem0/configs/embeddings/__init__.py +0 -0
- mem0/configs/embeddings/base.py +108 -0
- mem0/configs/enums.py +7 -0
- mem0/configs/llms/__init__.py +0 -0
- mem0/configs/llms/base.py +152 -0
- mem0/configs/prompts.py +333 -0
- mem0/configs/vector_stores/__init__.py +0 -0
- mem0/configs/vector_stores/azure_ai_search.py +59 -0
- mem0/configs/vector_stores/baidu.py +29 -0
- mem0/configs/vector_stores/chroma.py +40 -0
- mem0/configs/vector_stores/elasticsearch.py +47 -0
- mem0/configs/vector_stores/faiss.py +39 -0
- mem0/configs/vector_stores/langchain.py +32 -0
- mem0/configs/vector_stores/milvus.py +43 -0
- mem0/configs/vector_stores/mongodb.py +25 -0
- mem0/configs/vector_stores/opensearch.py +41 -0
- mem0/configs/vector_stores/pgvector.py +37 -0
- mem0/configs/vector_stores/pinecone.py +56 -0
- mem0/configs/vector_stores/qdrant.py +49 -0
- mem0/configs/vector_stores/redis.py +26 -0
- mem0/configs/vector_stores/supabase.py +44 -0
- mem0/configs/vector_stores/upstash_vector.py +36 -0
- mem0/configs/vector_stores/vertex_ai_vector_search.py +27 -0
- mem0/configs/vector_stores/weaviate.py +43 -0
- mem0/dbs/__init__.py +4 -0
- mem0/dbs/base.py +68 -0
- mem0/dbs/configs.py +21 -0
- mem0/dbs/mysql.py +321 -0
- mem0/embeddings/__init__.py +0 -0
- mem0/embeddings/aws_bedrock.py +100 -0
- mem0/embeddings/azure_openai.py +43 -0
- mem0/embeddings/base.py +31 -0
- mem0/embeddings/configs.py +30 -0
- mem0/embeddings/gemini.py +39 -0
- mem0/embeddings/huggingface.py +41 -0
- mem0/embeddings/langchain.py +35 -0
- mem0/embeddings/lmstudio.py +29 -0
- mem0/embeddings/mock.py +11 -0
- mem0/embeddings/ollama.py +53 -0
- mem0/embeddings/openai.py +49 -0
- mem0/embeddings/together.py +31 -0
- mem0/embeddings/vertexai.py +54 -0
- mem0/graphs/__init__.py +0 -0
- mem0/graphs/configs.py +96 -0
- mem0/graphs/neptune/__init__.py +0 -0
- mem0/graphs/neptune/base.py +410 -0
- mem0/graphs/neptune/main.py +372 -0
- mem0/graphs/tools.py +371 -0
- mem0/graphs/utils.py +97 -0
- mem0/llms/__init__.py +0 -0
- mem0/llms/anthropic.py +64 -0
- mem0/llms/aws_bedrock.py +270 -0
- mem0/llms/azure_openai.py +114 -0
- mem0/llms/azure_openai_structured.py +76 -0
- mem0/llms/base.py +32 -0
- mem0/llms/configs.py +34 -0
- mem0/llms/deepseek.py +85 -0
- mem0/llms/gemini.py +201 -0
- mem0/llms/groq.py +88 -0
- mem0/llms/langchain.py +65 -0
- mem0/llms/litellm.py +87 -0
- mem0/llms/lmstudio.py +53 -0
- mem0/llms/ollama.py +94 -0
- mem0/llms/openai.py +124 -0
- mem0/llms/openai_structured.py +52 -0
- mem0/llms/sarvam.py +89 -0
- mem0/llms/together.py +88 -0
- mem0/llms/vllm.py +89 -0
- mem0/llms/xai.py +52 -0
- mem0/memory/__init__.py +0 -0
- mem0/memory/base.py +63 -0
- mem0/memory/graph_memory.py +632 -0
- mem0/memory/main.py +1843 -0
- mem0/memory/memgraph_memory.py +630 -0
- mem0/memory/setup.py +56 -0
- mem0/memory/storage.py +218 -0
- mem0/memory/telemetry.py +90 -0
- mem0/memory/utils.py +133 -0
- mem0/proxy/__init__.py +0 -0
- mem0/proxy/main.py +194 -0
- mem0/utils/factory.py +132 -0
- mem0/vector_stores/__init__.py +0 -0
- mem0/vector_stores/azure_ai_search.py +383 -0
- mem0/vector_stores/baidu.py +368 -0
- mem0/vector_stores/base.py +58 -0
- mem0/vector_stores/chroma.py +229 -0
- mem0/vector_stores/configs.py +60 -0
- mem0/vector_stores/elasticsearch.py +235 -0
- mem0/vector_stores/faiss.py +473 -0
- mem0/vector_stores/langchain.py +179 -0
- mem0/vector_stores/milvus.py +245 -0
- mem0/vector_stores/mongodb.py +293 -0
- mem0/vector_stores/opensearch.py +281 -0
- mem0/vector_stores/pgvector.py +294 -0
- mem0/vector_stores/pinecone.py +373 -0
- mem0/vector_stores/qdrant.py +240 -0
- mem0/vector_stores/redis.py +295 -0
- mem0/vector_stores/supabase.py +237 -0
- mem0/vector_stores/upstash_vector.py +293 -0
- mem0/vector_stores/vertex_ai_vector_search.py +629 -0
- mem0/vector_stores/weaviate.py +316 -0
- mem0ai_azure_mysql-0.1.115.data/data/README.md +169 -0
- mem0ai_azure_mysql-0.1.115.dist-info/METADATA +224 -0
- mem0ai_azure_mysql-0.1.115.dist-info/RECORD +116 -0
- mem0ai_azure_mysql-0.1.115.dist-info/WHEEL +4 -0
- mem0ai_azure_mysql-0.1.115.dist-info/licenses/LICENSE +201 -0
mem0/utils/factory.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
import importlib
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
from mem0.configs.embeddings.base import BaseEmbedderConfig
|
|
5
|
+
from mem0.configs.llms.base import BaseLlmConfig
|
|
6
|
+
from mem0.configs.dbs.mysql import MySQLConfig
|
|
7
|
+
from mem0.embeddings.mock import MockEmbeddings
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def load_class(class_type):
|
|
11
|
+
module_path, class_name = class_type.rsplit(".", 1)
|
|
12
|
+
module = importlib.import_module(module_path)
|
|
13
|
+
return getattr(module, class_name)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class LlmFactory:
|
|
17
|
+
provider_to_class = {
|
|
18
|
+
"ollama": "mem0.llms.ollama.OllamaLLM",
|
|
19
|
+
"openai": "mem0.llms.openai.OpenAILLM",
|
|
20
|
+
"groq": "mem0.llms.groq.GroqLLM",
|
|
21
|
+
"together": "mem0.llms.together.TogetherLLM",
|
|
22
|
+
"aws_bedrock": "mem0.llms.aws_bedrock.AWSBedrockLLM",
|
|
23
|
+
"litellm": "mem0.llms.litellm.LiteLLM",
|
|
24
|
+
"azure_openai": "mem0.llms.azure_openai.AzureOpenAILLM",
|
|
25
|
+
"openai_structured": "mem0.llms.openai_structured.OpenAIStructuredLLM",
|
|
26
|
+
"anthropic": "mem0.llms.anthropic.AnthropicLLM",
|
|
27
|
+
"azure_openai_structured": "mem0.llms.azure_openai_structured.AzureOpenAIStructuredLLM",
|
|
28
|
+
"gemini": "mem0.llms.gemini.GeminiLLM",
|
|
29
|
+
"deepseek": "mem0.llms.deepseek.DeepSeekLLM",
|
|
30
|
+
"xai": "mem0.llms.xai.XAILLM",
|
|
31
|
+
"sarvam": "mem0.llms.sarvam.SarvamLLM",
|
|
32
|
+
"lmstudio": "mem0.llms.lmstudio.LMStudioLLM",
|
|
33
|
+
"vllm": "mem0.llms.vllm.VllmLLM",
|
|
34
|
+
"langchain": "mem0.llms.langchain.LangchainLLM",
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
@classmethod
|
|
38
|
+
def create(cls, provider_name, config):
|
|
39
|
+
class_type = cls.provider_to_class.get(provider_name)
|
|
40
|
+
if class_type:
|
|
41
|
+
llm_instance = load_class(class_type)
|
|
42
|
+
base_config = BaseLlmConfig(**config)
|
|
43
|
+
return llm_instance(base_config)
|
|
44
|
+
else:
|
|
45
|
+
raise ValueError(f"Unsupported Llm provider: {provider_name}")
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class EmbedderFactory:
|
|
49
|
+
provider_to_class = {
|
|
50
|
+
"openai": "mem0.embeddings.openai.OpenAIEmbedding",
|
|
51
|
+
"ollama": "mem0.embeddings.ollama.OllamaEmbedding",
|
|
52
|
+
"huggingface": "mem0.embeddings.huggingface.HuggingFaceEmbedding",
|
|
53
|
+
"azure_openai": "mem0.embeddings.azure_openai.AzureOpenAIEmbedding",
|
|
54
|
+
"gemini": "mem0.embeddings.gemini.GoogleGenAIEmbedding",
|
|
55
|
+
"vertexai": "mem0.embeddings.vertexai.VertexAIEmbedding",
|
|
56
|
+
"together": "mem0.embeddings.together.TogetherEmbedding",
|
|
57
|
+
"lmstudio": "mem0.embeddings.lmstudio.LMStudioEmbedding",
|
|
58
|
+
"langchain": "mem0.embeddings.langchain.LangchainEmbedding",
|
|
59
|
+
"aws_bedrock": "mem0.embeddings.aws_bedrock.AWSBedrockEmbedding",
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
@classmethod
|
|
63
|
+
def create(cls, provider_name, config, vector_config: Optional[dict]):
|
|
64
|
+
if provider_name == "upstash_vector" and vector_config and vector_config.enable_embeddings:
|
|
65
|
+
return MockEmbeddings()
|
|
66
|
+
class_type = cls.provider_to_class.get(provider_name)
|
|
67
|
+
if class_type:
|
|
68
|
+
embedder_instance = load_class(class_type)
|
|
69
|
+
base_config = BaseEmbedderConfig(**config)
|
|
70
|
+
return embedder_instance(base_config)
|
|
71
|
+
else:
|
|
72
|
+
raise ValueError(f"Unsupported Embedder provider: {provider_name}")
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class VectorStoreFactory:
|
|
76
|
+
provider_to_class = {
|
|
77
|
+
"qdrant": "mem0.vector_stores.qdrant.Qdrant",
|
|
78
|
+
"chroma": "mem0.vector_stores.chroma.ChromaDB",
|
|
79
|
+
"pgvector": "mem0.vector_stores.pgvector.PGVector",
|
|
80
|
+
"milvus": "mem0.vector_stores.milvus.MilvusDB",
|
|
81
|
+
"upstash_vector": "mem0.vector_stores.upstash_vector.UpstashVector",
|
|
82
|
+
"azure_ai_search": "mem0.vector_stores.azure_ai_search.AzureAISearch",
|
|
83
|
+
"pinecone": "mem0.vector_stores.pinecone.PineconeDB",
|
|
84
|
+
"mongodb": "mem0.vector_stores.mongodb.MongoDB",
|
|
85
|
+
"redis": "mem0.vector_stores.redis.RedisDB",
|
|
86
|
+
"elasticsearch": "mem0.vector_stores.elasticsearch.ElasticsearchDB",
|
|
87
|
+
"vertex_ai_vector_search": "mem0.vector_stores.vertex_ai_vector_search.GoogleMatchingEngine",
|
|
88
|
+
"opensearch": "mem0.vector_stores.opensearch.OpenSearchDB",
|
|
89
|
+
"supabase": "mem0.vector_stores.supabase.Supabase",
|
|
90
|
+
"weaviate": "mem0.vector_stores.weaviate.Weaviate",
|
|
91
|
+
"faiss": "mem0.vector_stores.faiss.FAISS",
|
|
92
|
+
"langchain": "mem0.vector_stores.langchain.Langchain",
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
@classmethod
|
|
96
|
+
def create(cls, provider_name, config):
|
|
97
|
+
class_type = cls.provider_to_class.get(provider_name)
|
|
98
|
+
if class_type:
|
|
99
|
+
if not isinstance(config, dict):
|
|
100
|
+
config = config.model_dump()
|
|
101
|
+
vector_store_instance = load_class(class_type)
|
|
102
|
+
return vector_store_instance(**config)
|
|
103
|
+
else:
|
|
104
|
+
raise ValueError(f"Unsupported VectorStore provider: {provider_name}")
|
|
105
|
+
|
|
106
|
+
@classmethod
|
|
107
|
+
def reset(cls, instance):
|
|
108
|
+
instance.reset()
|
|
109
|
+
return instance
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class DBFactory:
|
|
113
|
+
provider_to_class = {
|
|
114
|
+
"mysql": "mem0.dbs.mysql.MySQLManager",
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
provider_to_config = {
|
|
118
|
+
"mysql": MySQLConfig,
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
@classmethod
|
|
122
|
+
def create(cls, provider_name, config):
|
|
123
|
+
class_type = cls.provider_to_class.get(provider_name)
|
|
124
|
+
config_class = cls.provider_to_config.get(provider_name)
|
|
125
|
+
if class_type and config_class:
|
|
126
|
+
db_instance = load_class(class_type)
|
|
127
|
+
if not isinstance(config, dict):
|
|
128
|
+
config = config.model_dump()
|
|
129
|
+
db_config = config_class(**config)
|
|
130
|
+
return db_instance(db_config)
|
|
131
|
+
else:
|
|
132
|
+
raise ValueError(f"Unsupported DB provider: {provider_name}")
|
|
File without changes
|
|
@@ -0,0 +1,383 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
import re
|
|
4
|
+
from typing import List, Optional
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel
|
|
7
|
+
from azure.identity import DefaultAzureCredential
|
|
8
|
+
|
|
9
|
+
from mem0.memory.utils import extract_json
|
|
10
|
+
from mem0.vector_stores.base import VectorStoreBase
|
|
11
|
+
|
|
12
|
+
try:
|
|
13
|
+
from azure.core.exceptions import ResourceNotFoundError
|
|
14
|
+
from azure.search.documents import SearchClient
|
|
15
|
+
from azure.search.documents.indexes import SearchIndexClient
|
|
16
|
+
from azure.search.documents.indexes.models import (
|
|
17
|
+
BinaryQuantizationCompression,
|
|
18
|
+
HnswAlgorithmConfiguration,
|
|
19
|
+
ScalarQuantizationCompression,
|
|
20
|
+
SearchField,
|
|
21
|
+
SearchFieldDataType,
|
|
22
|
+
SearchIndex,
|
|
23
|
+
SimpleField,
|
|
24
|
+
VectorSearch,
|
|
25
|
+
VectorSearchProfile,
|
|
26
|
+
)
|
|
27
|
+
from azure.search.documents.models import VectorizedQuery
|
|
28
|
+
except ImportError:
|
|
29
|
+
raise ImportError(
|
|
30
|
+
"The 'azure-search-documents' library is required. Please install it using 'pip install azure-search-documents==11.5.2'."
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
logger = logging.getLogger(__name__)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class OutputData(BaseModel):
|
|
37
|
+
id: Optional[str]
|
|
38
|
+
score: Optional[float]
|
|
39
|
+
payload: Optional[dict]
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class AzureAISearch(VectorStoreBase):
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
service_name,
|
|
46
|
+
collection_name,
|
|
47
|
+
api_key,
|
|
48
|
+
embedding_model_dims,
|
|
49
|
+
compression_type: Optional[str] = None,
|
|
50
|
+
use_float16: bool = False,
|
|
51
|
+
hybrid_search: bool = False,
|
|
52
|
+
vector_filter_mode: Optional[str] = None,
|
|
53
|
+
):
|
|
54
|
+
"""
|
|
55
|
+
Initialize the Azure AI Search vector store.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
service_name (str): Azure AI Search service name.
|
|
59
|
+
collection_name (str): Index name.
|
|
60
|
+
api_key (str): API key for the Azure AI Search service.
|
|
61
|
+
embedding_model_dims (int): Dimension of the embedding vector.
|
|
62
|
+
compression_type (Optional[str]): Specifies the type of quantization to use.
|
|
63
|
+
Allowed values are None (no quantization), "scalar", or "binary".
|
|
64
|
+
use_float16 (bool): Whether to store vectors in half precision (Edm.Half) or full precision (Edm.Single).
|
|
65
|
+
(Note: This flag is preserved from the initial implementation per feedback.)
|
|
66
|
+
hybrid_search (bool): Whether to use hybrid search. Default is False.
|
|
67
|
+
vector_filter_mode (Optional[str]): Mode for vector filtering. Default is "preFilter".
|
|
68
|
+
"""
|
|
69
|
+
self.service_name = service_name
|
|
70
|
+
self.api_key = api_key
|
|
71
|
+
self.index_name = collection_name
|
|
72
|
+
self.collection_name = collection_name
|
|
73
|
+
self.embedding_model_dims = embedding_model_dims
|
|
74
|
+
# If compression_type is None, treat it as "none".
|
|
75
|
+
self.compression_type = (compression_type or "none").lower()
|
|
76
|
+
self.use_float16 = use_float16
|
|
77
|
+
self.hybrid_search = hybrid_search
|
|
78
|
+
self.vector_filter_mode = vector_filter_mode
|
|
79
|
+
|
|
80
|
+
credential = DefaultAzureCredential()
|
|
81
|
+
self.search_client = SearchClient(
|
|
82
|
+
endpoint=f"https://{service_name}.search.windows.net",
|
|
83
|
+
index_name=self.index_name,
|
|
84
|
+
credential=credential,
|
|
85
|
+
)
|
|
86
|
+
self.index_client = SearchIndexClient(
|
|
87
|
+
endpoint=f"https://{service_name}.search.windows.net",
|
|
88
|
+
credential=credential,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
self.search_client._client._config.user_agent_policy.add_user_agent("mem0")
|
|
92
|
+
self.index_client._client._config.user_agent_policy.add_user_agent("mem0")
|
|
93
|
+
|
|
94
|
+
collections = self.list_cols()
|
|
95
|
+
if collection_name not in collections:
|
|
96
|
+
self.create_col()
|
|
97
|
+
|
|
98
|
+
def create_col(self):
|
|
99
|
+
"""Create a new index in Azure AI Search."""
|
|
100
|
+
# Determine vector type based on use_float16 setting.
|
|
101
|
+
if self.use_float16:
|
|
102
|
+
vector_type = "Collection(Edm.Half)"
|
|
103
|
+
else:
|
|
104
|
+
vector_type = "Collection(Edm.Single)"
|
|
105
|
+
|
|
106
|
+
# Configure compression settings based on the specified compression_type.
|
|
107
|
+
compression_configurations = []
|
|
108
|
+
compression_name = None
|
|
109
|
+
if self.compression_type == "scalar":
|
|
110
|
+
compression_name = "myCompression"
|
|
111
|
+
# For SQ, rescoring defaults to True and oversampling defaults to 4.
|
|
112
|
+
compression_configurations = [
|
|
113
|
+
ScalarQuantizationCompression(
|
|
114
|
+
compression_name=compression_name
|
|
115
|
+
# rescoring defaults to True and oversampling defaults to 4
|
|
116
|
+
)
|
|
117
|
+
]
|
|
118
|
+
elif self.compression_type == "binary":
|
|
119
|
+
compression_name = "myCompression"
|
|
120
|
+
# For BQ, rescoring defaults to True and oversampling defaults to 10.
|
|
121
|
+
compression_configurations = [
|
|
122
|
+
BinaryQuantizationCompression(
|
|
123
|
+
compression_name=compression_name
|
|
124
|
+
# rescoring defaults to True and oversampling defaults to 10
|
|
125
|
+
)
|
|
126
|
+
]
|
|
127
|
+
# If no compression is desired, compression_configurations remains empty.
|
|
128
|
+
fields = [
|
|
129
|
+
SimpleField(name="id", type=SearchFieldDataType.String, key=True),
|
|
130
|
+
SimpleField(name="user_id", type=SearchFieldDataType.String, filterable=True),
|
|
131
|
+
SimpleField(name="run_id", type=SearchFieldDataType.String, filterable=True),
|
|
132
|
+
SimpleField(name="agent_id", type=SearchFieldDataType.String, filterable=True),
|
|
133
|
+
SearchField(
|
|
134
|
+
name="vector",
|
|
135
|
+
type=vector_type,
|
|
136
|
+
searchable=True,
|
|
137
|
+
vector_search_dimensions=self.embedding_model_dims,
|
|
138
|
+
vector_search_profile_name="my-vector-config",
|
|
139
|
+
),
|
|
140
|
+
SearchField(name="payload", type=SearchFieldDataType.String, searchable=True),
|
|
141
|
+
]
|
|
142
|
+
|
|
143
|
+
vector_search = VectorSearch(
|
|
144
|
+
profiles=[
|
|
145
|
+
VectorSearchProfile(
|
|
146
|
+
name="my-vector-config",
|
|
147
|
+
algorithm_configuration_name="my-algorithms-config",
|
|
148
|
+
compression_name=compression_name if self.compression_type != "none" else None,
|
|
149
|
+
)
|
|
150
|
+
],
|
|
151
|
+
algorithms=[HnswAlgorithmConfiguration(name="my-algorithms-config")],
|
|
152
|
+
compressions=compression_configurations,
|
|
153
|
+
)
|
|
154
|
+
index = SearchIndex(name=self.index_name, fields=fields, vector_search=vector_search)
|
|
155
|
+
self.index_client.create_or_update_index(index)
|
|
156
|
+
|
|
157
|
+
def _generate_document(self, vector, payload, id):
|
|
158
|
+
document = {"id": id, "vector": vector, "payload": json.dumps(payload)}
|
|
159
|
+
# Extract additional fields if they exist.
|
|
160
|
+
for field in ["user_id", "run_id", "agent_id"]:
|
|
161
|
+
if field in payload:
|
|
162
|
+
document[field] = payload[field]
|
|
163
|
+
return document
|
|
164
|
+
|
|
165
|
+
# Note: Explicit "insert" calls may later be decoupled from memory management decisions.
|
|
166
|
+
def insert(self, vectors, payloads=None, ids=None):
|
|
167
|
+
"""
|
|
168
|
+
Insert vectors into the index.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
vectors (List[List[float]]): List of vectors to insert.
|
|
172
|
+
payloads (List[Dict], optional): List of payloads corresponding to vectors.
|
|
173
|
+
ids (List[str], optional): List of IDs corresponding to vectors.
|
|
174
|
+
"""
|
|
175
|
+
logger.info(f"Inserting {len(vectors)} vectors into index {self.index_name}")
|
|
176
|
+
documents = [
|
|
177
|
+
self._generate_document(vector, payload, id) for id, vector, payload in zip(ids, vectors, payloads)
|
|
178
|
+
]
|
|
179
|
+
response = self.search_client.upload_documents(documents)
|
|
180
|
+
for doc in response:
|
|
181
|
+
if not hasattr(doc, "status_code") and doc.get("status_code") != 201:
|
|
182
|
+
raise Exception(f"Insert failed for document {doc.get('id')}: {doc}")
|
|
183
|
+
return response
|
|
184
|
+
|
|
185
|
+
def _sanitize_key(self, key: str) -> str:
|
|
186
|
+
return re.sub(r"[^\w]", "", key)
|
|
187
|
+
|
|
188
|
+
def _build_filter_expression(self, filters):
|
|
189
|
+
filter_conditions = []
|
|
190
|
+
for key, value in filters.items():
|
|
191
|
+
safe_key = self._sanitize_key(key)
|
|
192
|
+
if isinstance(value, str):
|
|
193
|
+
safe_value = value.replace("'", "''")
|
|
194
|
+
condition = f"{safe_key} eq '{safe_value}'"
|
|
195
|
+
else:
|
|
196
|
+
condition = f"{safe_key} eq {value}"
|
|
197
|
+
filter_conditions.append(condition)
|
|
198
|
+
filter_expression = " and ".join(filter_conditions)
|
|
199
|
+
return filter_expression
|
|
200
|
+
|
|
201
|
+
def search(self, query, vectors, limit=5, filters=None):
|
|
202
|
+
"""
|
|
203
|
+
Search for similar vectors.
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
query (str): Query.
|
|
207
|
+
vectors (List[float]): Query vector.
|
|
208
|
+
limit (int, optional): Number of results to return. Defaults to 5.
|
|
209
|
+
filters (Dict, optional): Filters to apply to the search. Defaults to None.
|
|
210
|
+
|
|
211
|
+
Returns:
|
|
212
|
+
List[OutputData]: Search results.
|
|
213
|
+
"""
|
|
214
|
+
filter_expression = None
|
|
215
|
+
if filters:
|
|
216
|
+
filter_expression = self._build_filter_expression(filters)
|
|
217
|
+
|
|
218
|
+
vector_query = VectorizedQuery(vector=vectors, k_nearest_neighbors=limit, fields="vector")
|
|
219
|
+
if self.hybrid_search:
|
|
220
|
+
search_results = self.search_client.search(
|
|
221
|
+
search_text=query,
|
|
222
|
+
vector_queries=[vector_query],
|
|
223
|
+
filter=filter_expression,
|
|
224
|
+
top=limit,
|
|
225
|
+
vector_filter_mode=self.vector_filter_mode,
|
|
226
|
+
search_fields=["payload"],
|
|
227
|
+
)
|
|
228
|
+
else:
|
|
229
|
+
search_results = self.search_client.search(
|
|
230
|
+
vector_queries=[vector_query],
|
|
231
|
+
filter=filter_expression,
|
|
232
|
+
top=limit,
|
|
233
|
+
vector_filter_mode=self.vector_filter_mode,
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
results = []
|
|
237
|
+
for result in search_results:
|
|
238
|
+
payload = json.loads(extract_json(result["payload"]))
|
|
239
|
+
results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))
|
|
240
|
+
return results
|
|
241
|
+
|
|
242
|
+
def delete(self, vector_id):
|
|
243
|
+
"""
|
|
244
|
+
Delete a vector by ID.
|
|
245
|
+
|
|
246
|
+
Args:
|
|
247
|
+
vector_id (str): ID of the vector to delete.
|
|
248
|
+
"""
|
|
249
|
+
response = self.search_client.delete_documents(documents=[{"id": vector_id}])
|
|
250
|
+
for doc in response:
|
|
251
|
+
if not hasattr(doc, "status_code") and doc.get("status_code") != 200:
|
|
252
|
+
raise Exception(f"Delete failed for document {vector_id}: {doc}")
|
|
253
|
+
logger.info(f"Deleted document with ID '{vector_id}' from index '{self.index_name}'.")
|
|
254
|
+
return response
|
|
255
|
+
|
|
256
|
+
def update(self, vector_id, vector=None, payload=None):
|
|
257
|
+
"""
|
|
258
|
+
Update a vector and its payload.
|
|
259
|
+
|
|
260
|
+
Args:
|
|
261
|
+
vector_id (str): ID of the vector to update.
|
|
262
|
+
vector (List[float], optional): Updated vector.
|
|
263
|
+
payload (Dict, optional): Updated payload.
|
|
264
|
+
"""
|
|
265
|
+
document = {"id": vector_id}
|
|
266
|
+
if vector:
|
|
267
|
+
document["vector"] = vector
|
|
268
|
+
if payload:
|
|
269
|
+
json_payload = json.dumps(payload)
|
|
270
|
+
document["payload"] = json_payload
|
|
271
|
+
for field in ["user_id", "run_id", "agent_id"]:
|
|
272
|
+
document[field] = payload.get(field)
|
|
273
|
+
response = self.search_client.merge_or_upload_documents(documents=[document])
|
|
274
|
+
for doc in response:
|
|
275
|
+
if not hasattr(doc, "status_code") and doc.get("status_code") != 200:
|
|
276
|
+
raise Exception(f"Update failed for document {vector_id}: {doc}")
|
|
277
|
+
return response
|
|
278
|
+
|
|
279
|
+
def get(self, vector_id) -> OutputData:
|
|
280
|
+
"""
|
|
281
|
+
Retrieve a vector by ID.
|
|
282
|
+
|
|
283
|
+
Args:
|
|
284
|
+
vector_id (str): ID of the vector to retrieve.
|
|
285
|
+
|
|
286
|
+
Returns:
|
|
287
|
+
OutputData: Retrieved vector.
|
|
288
|
+
"""
|
|
289
|
+
try:
|
|
290
|
+
result = self.search_client.get_document(key=vector_id)
|
|
291
|
+
except ResourceNotFoundError:
|
|
292
|
+
return None
|
|
293
|
+
payload = json.loads(extract_json(result["payload"]))
|
|
294
|
+
return OutputData(id=result["id"], score=None, payload=payload)
|
|
295
|
+
|
|
296
|
+
def list_cols(self) -> List[str]:
|
|
297
|
+
"""
|
|
298
|
+
List all collections (indexes).
|
|
299
|
+
|
|
300
|
+
Returns:
|
|
301
|
+
List[str]: List of index names.
|
|
302
|
+
"""
|
|
303
|
+
try:
|
|
304
|
+
names = self.index_client.list_index_names()
|
|
305
|
+
except AttributeError:
|
|
306
|
+
names = [index.name for index in self.index_client.list_indexes()]
|
|
307
|
+
return names
|
|
308
|
+
|
|
309
|
+
def delete_col(self):
|
|
310
|
+
"""Delete the index."""
|
|
311
|
+
self.index_client.delete_index(self.index_name)
|
|
312
|
+
|
|
313
|
+
def col_info(self):
|
|
314
|
+
"""
|
|
315
|
+
Get information about the index.
|
|
316
|
+
|
|
317
|
+
Returns:
|
|
318
|
+
dict: Index information.
|
|
319
|
+
"""
|
|
320
|
+
index = self.index_client.get_index(self.index_name)
|
|
321
|
+
return {"name": index.name, "fields": index.fields}
|
|
322
|
+
|
|
323
|
+
def list(self, filters=None, limit=100):
|
|
324
|
+
"""
|
|
325
|
+
List all vectors in the index.
|
|
326
|
+
|
|
327
|
+
Args:
|
|
328
|
+
filters (dict, optional): Filters to apply to the list.
|
|
329
|
+
limit (int, optional): Number of vectors to return. Defaults to 100.
|
|
330
|
+
|
|
331
|
+
Returns:
|
|
332
|
+
List[OutputData]: List of vectors.
|
|
333
|
+
"""
|
|
334
|
+
filter_expression = None
|
|
335
|
+
if filters:
|
|
336
|
+
filter_expression = self._build_filter_expression(filters)
|
|
337
|
+
|
|
338
|
+
search_results = self.search_client.search(search_text="*", filter=filter_expression, top=limit)
|
|
339
|
+
results = []
|
|
340
|
+
for result in search_results:
|
|
341
|
+
payload = json.loads(extract_json(result["payload"]))
|
|
342
|
+
results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))
|
|
343
|
+
return [results]
|
|
344
|
+
|
|
345
|
+
def __del__(self):
|
|
346
|
+
"""Close the search client when the object is deleted."""
|
|
347
|
+
self.search_client.close()
|
|
348
|
+
self.index_client.close()
|
|
349
|
+
|
|
350
|
+
def reset(self):
|
|
351
|
+
"""Reset the index by deleting and recreating it."""
|
|
352
|
+
logger.warning(f"Resetting index {self.index_name}...")
|
|
353
|
+
|
|
354
|
+
try:
|
|
355
|
+
# Close the existing clients
|
|
356
|
+
self.search_client.close()
|
|
357
|
+
self.index_client.close()
|
|
358
|
+
|
|
359
|
+
# Delete the collection
|
|
360
|
+
self.delete_col()
|
|
361
|
+
|
|
362
|
+
# Reinitialize the clients
|
|
363
|
+
credential = DefaultAzureCredential()
|
|
364
|
+
service_endpoint = f"https://{self.service_name}.search.windows.net"
|
|
365
|
+
self.search_client = SearchClient(
|
|
366
|
+
endpoint=service_endpoint,
|
|
367
|
+
index_name=self.index_name,
|
|
368
|
+
credential=credential,
|
|
369
|
+
)
|
|
370
|
+
self.index_client = SearchIndexClient(
|
|
371
|
+
endpoint=service_endpoint,
|
|
372
|
+
credential=credential,
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
# Add user agent
|
|
376
|
+
self.search_client._client._config.user_agent_policy.add_user_agent("mem0")
|
|
377
|
+
self.index_client._client._config.user_agent_policy.add_user_agent("mem0")
|
|
378
|
+
|
|
379
|
+
# Create the collection
|
|
380
|
+
self.create_col()
|
|
381
|
+
except Exception as e:
|
|
382
|
+
logger.error(f"Error resetting index {self.index_name}: {e}")
|
|
383
|
+
raise
|