agentrun-mem0ai 0.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. agentrun_mem0/__init__.py +6 -0
  2. agentrun_mem0/client/__init__.py +0 -0
  3. agentrun_mem0/client/main.py +1747 -0
  4. agentrun_mem0/client/project.py +931 -0
  5. agentrun_mem0/client/utils.py +115 -0
  6. agentrun_mem0/configs/__init__.py +0 -0
  7. agentrun_mem0/configs/base.py +90 -0
  8. agentrun_mem0/configs/embeddings/__init__.py +0 -0
  9. agentrun_mem0/configs/embeddings/base.py +110 -0
  10. agentrun_mem0/configs/enums.py +7 -0
  11. agentrun_mem0/configs/llms/__init__.py +0 -0
  12. agentrun_mem0/configs/llms/anthropic.py +56 -0
  13. agentrun_mem0/configs/llms/aws_bedrock.py +192 -0
  14. agentrun_mem0/configs/llms/azure.py +57 -0
  15. agentrun_mem0/configs/llms/base.py +62 -0
  16. agentrun_mem0/configs/llms/deepseek.py +56 -0
  17. agentrun_mem0/configs/llms/lmstudio.py +59 -0
  18. agentrun_mem0/configs/llms/ollama.py +56 -0
  19. agentrun_mem0/configs/llms/openai.py +79 -0
  20. agentrun_mem0/configs/llms/vllm.py +56 -0
  21. agentrun_mem0/configs/prompts.py +459 -0
  22. agentrun_mem0/configs/rerankers/__init__.py +0 -0
  23. agentrun_mem0/configs/rerankers/base.py +17 -0
  24. agentrun_mem0/configs/rerankers/cohere.py +15 -0
  25. agentrun_mem0/configs/rerankers/config.py +12 -0
  26. agentrun_mem0/configs/rerankers/huggingface.py +17 -0
  27. agentrun_mem0/configs/rerankers/llm.py +48 -0
  28. agentrun_mem0/configs/rerankers/sentence_transformer.py +16 -0
  29. agentrun_mem0/configs/rerankers/zero_entropy.py +28 -0
  30. agentrun_mem0/configs/vector_stores/__init__.py +0 -0
  31. agentrun_mem0/configs/vector_stores/alibabacloud_mysql.py +64 -0
  32. agentrun_mem0/configs/vector_stores/aliyun_tablestore.py +32 -0
  33. agentrun_mem0/configs/vector_stores/azure_ai_search.py +57 -0
  34. agentrun_mem0/configs/vector_stores/azure_mysql.py +84 -0
  35. agentrun_mem0/configs/vector_stores/baidu.py +27 -0
  36. agentrun_mem0/configs/vector_stores/chroma.py +58 -0
  37. agentrun_mem0/configs/vector_stores/databricks.py +61 -0
  38. agentrun_mem0/configs/vector_stores/elasticsearch.py +65 -0
  39. agentrun_mem0/configs/vector_stores/faiss.py +37 -0
  40. agentrun_mem0/configs/vector_stores/langchain.py +30 -0
  41. agentrun_mem0/configs/vector_stores/milvus.py +42 -0
  42. agentrun_mem0/configs/vector_stores/mongodb.py +25 -0
  43. agentrun_mem0/configs/vector_stores/neptune.py +27 -0
  44. agentrun_mem0/configs/vector_stores/opensearch.py +41 -0
  45. agentrun_mem0/configs/vector_stores/pgvector.py +52 -0
  46. agentrun_mem0/configs/vector_stores/pinecone.py +55 -0
  47. agentrun_mem0/configs/vector_stores/qdrant.py +47 -0
  48. agentrun_mem0/configs/vector_stores/redis.py +24 -0
  49. agentrun_mem0/configs/vector_stores/s3_vectors.py +28 -0
  50. agentrun_mem0/configs/vector_stores/supabase.py +44 -0
  51. agentrun_mem0/configs/vector_stores/upstash_vector.py +34 -0
  52. agentrun_mem0/configs/vector_stores/valkey.py +15 -0
  53. agentrun_mem0/configs/vector_stores/vertex_ai_vector_search.py +28 -0
  54. agentrun_mem0/configs/vector_stores/weaviate.py +41 -0
  55. agentrun_mem0/embeddings/__init__.py +0 -0
  56. agentrun_mem0/embeddings/aws_bedrock.py +100 -0
  57. agentrun_mem0/embeddings/azure_openai.py +55 -0
  58. agentrun_mem0/embeddings/base.py +31 -0
  59. agentrun_mem0/embeddings/configs.py +30 -0
  60. agentrun_mem0/embeddings/gemini.py +39 -0
  61. agentrun_mem0/embeddings/huggingface.py +44 -0
  62. agentrun_mem0/embeddings/langchain.py +35 -0
  63. agentrun_mem0/embeddings/lmstudio.py +29 -0
  64. agentrun_mem0/embeddings/mock.py +11 -0
  65. agentrun_mem0/embeddings/ollama.py +53 -0
  66. agentrun_mem0/embeddings/openai.py +49 -0
  67. agentrun_mem0/embeddings/together.py +31 -0
  68. agentrun_mem0/embeddings/vertexai.py +64 -0
  69. agentrun_mem0/exceptions.py +503 -0
  70. agentrun_mem0/graphs/__init__.py +0 -0
  71. agentrun_mem0/graphs/configs.py +105 -0
  72. agentrun_mem0/graphs/neptune/__init__.py +0 -0
  73. agentrun_mem0/graphs/neptune/base.py +497 -0
  74. agentrun_mem0/graphs/neptune/neptunedb.py +511 -0
  75. agentrun_mem0/graphs/neptune/neptunegraph.py +474 -0
  76. agentrun_mem0/graphs/tools.py +371 -0
  77. agentrun_mem0/graphs/utils.py +97 -0
  78. agentrun_mem0/llms/__init__.py +0 -0
  79. agentrun_mem0/llms/anthropic.py +87 -0
  80. agentrun_mem0/llms/aws_bedrock.py +665 -0
  81. agentrun_mem0/llms/azure_openai.py +141 -0
  82. agentrun_mem0/llms/azure_openai_structured.py +91 -0
  83. agentrun_mem0/llms/base.py +131 -0
  84. agentrun_mem0/llms/configs.py +34 -0
  85. agentrun_mem0/llms/deepseek.py +107 -0
  86. agentrun_mem0/llms/gemini.py +201 -0
  87. agentrun_mem0/llms/groq.py +88 -0
  88. agentrun_mem0/llms/langchain.py +94 -0
  89. agentrun_mem0/llms/litellm.py +87 -0
  90. agentrun_mem0/llms/lmstudio.py +114 -0
  91. agentrun_mem0/llms/ollama.py +117 -0
  92. agentrun_mem0/llms/openai.py +147 -0
  93. agentrun_mem0/llms/openai_structured.py +52 -0
  94. agentrun_mem0/llms/sarvam.py +89 -0
  95. agentrun_mem0/llms/together.py +88 -0
  96. agentrun_mem0/llms/vllm.py +107 -0
  97. agentrun_mem0/llms/xai.py +52 -0
  98. agentrun_mem0/memory/__init__.py +0 -0
  99. agentrun_mem0/memory/base.py +63 -0
  100. agentrun_mem0/memory/graph_memory.py +698 -0
  101. agentrun_mem0/memory/kuzu_memory.py +713 -0
  102. agentrun_mem0/memory/main.py +2229 -0
  103. agentrun_mem0/memory/memgraph_memory.py +689 -0
  104. agentrun_mem0/memory/setup.py +56 -0
  105. agentrun_mem0/memory/storage.py +218 -0
  106. agentrun_mem0/memory/telemetry.py +90 -0
  107. agentrun_mem0/memory/utils.py +208 -0
  108. agentrun_mem0/proxy/__init__.py +0 -0
  109. agentrun_mem0/proxy/main.py +189 -0
  110. agentrun_mem0/reranker/__init__.py +9 -0
  111. agentrun_mem0/reranker/base.py +20 -0
  112. agentrun_mem0/reranker/cohere_reranker.py +85 -0
  113. agentrun_mem0/reranker/huggingface_reranker.py +147 -0
  114. agentrun_mem0/reranker/llm_reranker.py +142 -0
  115. agentrun_mem0/reranker/sentence_transformer_reranker.py +107 -0
  116. agentrun_mem0/reranker/zero_entropy_reranker.py +96 -0
  117. agentrun_mem0/utils/factory.py +283 -0
  118. agentrun_mem0/utils/gcp_auth.py +167 -0
  119. agentrun_mem0/vector_stores/__init__.py +0 -0
  120. agentrun_mem0/vector_stores/alibabacloud_mysql.py +547 -0
  121. agentrun_mem0/vector_stores/aliyun_tablestore.py +252 -0
  122. agentrun_mem0/vector_stores/azure_ai_search.py +396 -0
  123. agentrun_mem0/vector_stores/azure_mysql.py +463 -0
  124. agentrun_mem0/vector_stores/baidu.py +368 -0
  125. agentrun_mem0/vector_stores/base.py +58 -0
  126. agentrun_mem0/vector_stores/chroma.py +332 -0
  127. agentrun_mem0/vector_stores/configs.py +67 -0
  128. agentrun_mem0/vector_stores/databricks.py +761 -0
  129. agentrun_mem0/vector_stores/elasticsearch.py +237 -0
  130. agentrun_mem0/vector_stores/faiss.py +479 -0
  131. agentrun_mem0/vector_stores/langchain.py +180 -0
  132. agentrun_mem0/vector_stores/milvus.py +250 -0
  133. agentrun_mem0/vector_stores/mongodb.py +310 -0
  134. agentrun_mem0/vector_stores/neptune_analytics.py +467 -0
  135. agentrun_mem0/vector_stores/opensearch.py +292 -0
  136. agentrun_mem0/vector_stores/pgvector.py +404 -0
  137. agentrun_mem0/vector_stores/pinecone.py +382 -0
  138. agentrun_mem0/vector_stores/qdrant.py +270 -0
  139. agentrun_mem0/vector_stores/redis.py +295 -0
  140. agentrun_mem0/vector_stores/s3_vectors.py +176 -0
  141. agentrun_mem0/vector_stores/supabase.py +237 -0
  142. agentrun_mem0/vector_stores/upstash_vector.py +293 -0
  143. agentrun_mem0/vector_stores/valkey.py +824 -0
  144. agentrun_mem0/vector_stores/vertex_ai_vector_search.py +635 -0
  145. agentrun_mem0/vector_stores/weaviate.py +343 -0
  146. agentrun_mem0ai-0.0.11.data/data/README.md +205 -0
  147. agentrun_mem0ai-0.0.11.dist-info/METADATA +277 -0
  148. agentrun_mem0ai-0.0.11.dist-info/RECORD +150 -0
  149. agentrun_mem0ai-0.0.11.dist-info/WHEEL +4 -0
  150. agentrun_mem0ai-0.0.11.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,252 @@
1
+ import json
2
+ import logging
3
+
4
+ from agentrun_mem0.vector_stores.base import VectorStoreBase
5
+ from typing import Any, Optional, Dict
6
+
7
+ import tablestore
8
+ from tablestore_for_agent_memory.knowledge.knowledge_store import KnowledgeStore
9
+ from tablestore_for_agent_memory.base.base_knowledge_store import Document
10
+ from tablestore_for_agent_memory.base.filter import Filters
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ class OutputData:
15
+ def __init__(self, document: Document, score=None, metadata_name='payload'):
16
+ self._metadata_name = metadata_name
17
+ self.id: Optional[str] = document.document_id # memory id
18
+ self.score: Optional[float] = score # distance
19
+ self.payload: Optional[Dict] = self._metadata2payload(document.metadata) # metadata
20
+ self.payload['data'] = document.text
21
+
22
+ def _metadata2payload(self, metadata):
23
+ return json.loads(metadata[f'{self._metadata_name}_source'])
24
+
25
+ metric_str2metric_type_dict = {
26
+ "VM_EUCLIDEAN": tablestore.VectorMetricType.VM_EUCLIDEAN,
27
+ "VM_COSINE": tablestore.VectorMetricType.VM_COSINE,
28
+ "VM_DOT_PRODUCT": tablestore.VectorMetricType.VM_DOT_PRODUCT,
29
+ }
30
+
31
+ class AliyunTableStore(VectorStoreBase):
32
+ def __init__(
33
+ self,
34
+ endpoint: str,
35
+ instance_name: str,
36
+ access_key_id: str,
37
+ access_key_secret: str,
38
+ vector_dimension: int,
39
+ sts_token: Optional[str] = None,
40
+ collection_name: str = "mem0",
41
+ search_index_name: str = "mem0_search_index",
42
+ text_field: str = "text",
43
+ embedding_field: str = "embedding",
44
+ vector_metric_type: str = "VM_COSINE",
45
+ **kwargs: Any,
46
+ ):
47
+ self._tablestore_client = tablestore.OTSClient(
48
+ end_point=endpoint,
49
+ access_key_id=access_key_id,
50
+ access_key_secret=access_key_secret,
51
+ instance_name=instance_name,
52
+ sts_token=None if sts_token == "" else sts_token,
53
+ retry_policy=tablestore.WriteRetryPolicy(),
54
+ )
55
+
56
+ self._vector_dimension = vector_dimension
57
+ self._collection_name = collection_name
58
+ self._search_index_name = search_index_name
59
+ self._metadata_name = 'payload'
60
+ self._key_value_hyphen = '='
61
+ self._search_index_schema = [
62
+ tablestore.FieldSchema(
63
+ self._metadata_name,
64
+ tablestore.FieldType.KEYWORD,
65
+ index=True,
66
+ is_array=True,
67
+ enable_sort_and_agg=True,
68
+ ),
69
+ tablestore.FieldSchema(
70
+ f'{self._metadata_name}_source',
71
+ tablestore.FieldType.KEYWORD,
72
+ index=False,
73
+ is_array=False,
74
+ enable_sort_and_agg=False,
75
+ )
76
+ ]
77
+ self._text_field = text_field
78
+ self._embedding_field = embedding_field
79
+ self._vector_metric_type = metric_str2metric_type_dict[vector_metric_type]
80
+
81
+ self._knowledge_store = KnowledgeStore(
82
+ tablestore_client=self._tablestore_client,
83
+ vector_dimension=self._vector_dimension,
84
+ enable_multi_tenant=False,
85
+ table_name=self._collection_name,
86
+ search_index_name=self._search_index_name,
87
+ search_index_schema=self._search_index_schema,
88
+ text_field=self._text_field,
89
+ embedding_field=self._embedding_field,
90
+ vector_metric_type=self._vector_metric_type,
91
+ **kwargs,
92
+ )
93
+
94
+ self.create_col(**kwargs)
95
+
96
+ def create_col(self, **kwargs: Any):
97
+ """Create a new collection."""
98
+ if self._collection_name in self.list_cols():
99
+ logger.warning(f"tablestore table:[{self._collection_name}] already exists")
100
+ return
101
+ self._knowledge_store.init_table()
102
+
103
+ def _payload2metadata(self, payload: Dict):
104
+ payload_ = json.dumps([f'{key}{self._key_value_hyphen}{value}' for key, value in payload.items()], ensure_ascii=False)
105
+ return {
106
+ self._metadata_name: payload_,
107
+ f'{self._metadata_name}_source': json.dumps(payload, ensure_ascii=False),
108
+ }
109
+
110
+ def insert(self, vectors: list, payloads: list = None, ids: list = None):
111
+ """Insert vectors into a collection."""
112
+ payloads_ = payloads if payloads is not None else []
113
+ documents = []
114
+
115
+ for id, vector, payload in zip(ids, vectors, payloads_):
116
+ payload_ = payload.copy() if payload is not None else {}
117
+ documents.append(
118
+ Document(
119
+ document_id=id,
120
+ text=payload_.pop('data')
121
+ if 'data' in payload_.keys()
122
+ else None,
123
+ embedding=vector,
124
+ metadata=self._payload2metadata(payload_),
125
+ )
126
+ )
127
+
128
+ for document in documents:
129
+ self._knowledge_store.put_document(document)
130
+
131
+ def _create_filter(self, filters: dict):
132
+ """Create filters from dict (format of mem0 filters)"""
133
+ if filters is None:
134
+ return None
135
+
136
+ if len(filters.keys()) == 1:
137
+ meta_key, meta_value = tuple(filters.items())[0]
138
+ return Filters.eq(self._metadata_name, f'{meta_key}{self._key_value_hyphen}{meta_value}')
139
+
140
+ return Filters.logical_and(
141
+ [
142
+ Filters.eq(self._metadata_name, f'{meta_key}{self._key_value_hyphen}{meta_value}')
143
+ for meta_key, meta_value in filters.items()
144
+ ]
145
+ )
146
+
147
+ def search(self, query, vectors, limit=5, filters=None):
148
+ """Search for similar vectors."""
149
+ response = self._knowledge_store.vector_search(
150
+ query_vector=vectors,
151
+ top_k=limit,
152
+ metadata_filter=self._create_filter(filters),
153
+ )
154
+ return [
155
+ OutputData(
156
+ document=hit.document,
157
+ score=hit.score,
158
+ metadata_name=self._metadata_name,
159
+ )
160
+ for hit in response.hits
161
+ ]
162
+
163
+ def delete(self, vector_id):
164
+ """Delete a vector by ID."""
165
+ self._knowledge_store.delete_document(document_id=vector_id)
166
+
167
+ def update(self, vector_id, vector=None, payload=None):
168
+ """Update a vector and its payload."""
169
+ payload_ = payload.copy() if payload is not None else {}
170
+ document_for_update = Document(
171
+ document_id=vector_id,
172
+ text=payload_.pop('data')
173
+ if 'data' in payload_.keys()
174
+ else None,
175
+ embedding=vector,
176
+ metadata=self._payload2metadata(payload_),
177
+ )
178
+ self._knowledge_store.update_document(document_for_update)
179
+
180
+ def get(self, vector_id):
181
+ """Retrieve a vector by ID."""
182
+ document = self._knowledge_store.get_document(document_id=vector_id)
183
+ return OutputData(
184
+ document=document,
185
+ metadata_name=self._metadata_name,
186
+ )
187
+
188
+ def list_cols(self):
189
+ """List all collections."""
190
+ return self._tablestore_client.list_table()
191
+
192
+ def delete_col(self):
193
+ """Delete a collection."""
194
+ self._tablestore_client.delete_search_index(table_name=self._collection_name, index_name=self._search_index_name)
195
+ self._tablestore_client.delete_table(table_name=self._collection_name)
196
+
197
+ def col_info(self):
198
+ """Get information about a collection."""
199
+ self._tablestore_client.describe_table(table_name=self._collection_name)
200
+
201
+ def list(self, filters=None, limit=100):
202
+ """List all memories."""
203
+ return [
204
+ [
205
+ OutputData(
206
+ document=hit.document,
207
+ metadata_name=self._metadata_name,
208
+ )
209
+ for hit in self._knowledge_store.search_documents(metadata_filter=self._create_filter(filters), limit=limit).hits
210
+ ]
211
+ ]
212
+
213
+ def list_paginated(self, filters=None, limit=100, next_token=None):
214
+ """List memories with pagination support.
215
+
216
+ Args:
217
+ filters: Optional filters to apply
218
+ limit: Maximum number of memories to return (max 1000)
219
+ next_token: Token for pagination, pass the token from previous response to get next page
220
+
221
+ Returns:
222
+ dict: {
223
+ "memories": list of OutputData objects,
224
+ "next_token": token for next page (None if no more pages),
225
+ "has_more": boolean indicating if there are more pages
226
+ }
227
+ """
228
+ response = self._knowledge_store.search_documents(
229
+ metadata_filter=self._create_filter(filters),
230
+ limit=min(limit, 1000), # 最大 1000
231
+ next_token=next_token,
232
+ )
233
+
234
+ memories = [
235
+ OutputData(
236
+ document=hit.document,
237
+ metadata_name=self._metadata_name,
238
+ )
239
+ for hit in response.hits
240
+ ]
241
+
242
+ return {
243
+ "memories": memories,
244
+ "next_token": response.next_token,
245
+ "has_more": response.next_token is not None,
246
+ }
247
+
248
+ def reset(self):
249
+ """Reset by delete the collection and recreate it."""
250
+ logger.warning(f"Resetting table {self._collection_name}...")
251
+ self.delete_col()
252
+ self.create_col()
@@ -0,0 +1,396 @@
1
+ import json
2
+ import logging
3
+ import re
4
+ from typing import List, Optional
5
+
6
+ from pydantic import BaseModel
7
+
8
+ from agentrun_mem0.memory.utils import extract_json
9
+ from agentrun_mem0.vector_stores.base import VectorStoreBase
10
+
11
+ try:
12
+ from azure.core.credentials import AzureKeyCredential
13
+ from azure.core.exceptions import ResourceNotFoundError
14
+ from azure.identity import DefaultAzureCredential
15
+ from azure.search.documents import SearchClient
16
+ from azure.search.documents.indexes import SearchIndexClient
17
+ from azure.search.documents.indexes.models import (
18
+ BinaryQuantizationCompression,
19
+ HnswAlgorithmConfiguration,
20
+ ScalarQuantizationCompression,
21
+ SearchField,
22
+ SearchFieldDataType,
23
+ SearchIndex,
24
+ SimpleField,
25
+ VectorSearch,
26
+ VectorSearchProfile,
27
+ )
28
+ from azure.search.documents.models import VectorizedQuery
29
+ except ImportError:
30
+ raise ImportError(
31
+ "The 'azure-search-documents' library is required. Please install it using 'pip install azure-search-documents==11.5.2'."
32
+ )
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ class OutputData(BaseModel):
38
+ id: Optional[str]
39
+ score: Optional[float]
40
+ payload: Optional[dict]
41
+
42
+
43
+ class AzureAISearch(VectorStoreBase):
44
+ def __init__(
45
+ self,
46
+ service_name,
47
+ collection_name,
48
+ api_key,
49
+ embedding_model_dims,
50
+ compression_type: Optional[str] = None,
51
+ use_float16: bool = False,
52
+ hybrid_search: bool = False,
53
+ vector_filter_mode: Optional[str] = None,
54
+ ):
55
+ """
56
+ Initialize the Azure AI Search vector store.
57
+
58
+ Args:
59
+ service_name (str): Azure AI Search service name.
60
+ collection_name (str): Index name.
61
+ api_key (str): API key for the Azure AI Search service.
62
+ embedding_model_dims (int): Dimension of the embedding vector.
63
+ compression_type (Optional[str]): Specifies the type of quantization to use.
64
+ Allowed values are None (no quantization), "scalar", or "binary".
65
+ use_float16 (bool): Whether to store vectors in half precision (Edm.Half) or full precision (Edm.Single).
66
+ (Note: This flag is preserved from the initial implementation per feedback.)
67
+ hybrid_search (bool): Whether to use hybrid search. Default is False.
68
+ vector_filter_mode (Optional[str]): Mode for vector filtering. Default is "preFilter".
69
+ """
70
+ self.service_name = service_name
71
+ self.api_key = api_key
72
+ self.index_name = collection_name
73
+ self.collection_name = collection_name
74
+ self.embedding_model_dims = embedding_model_dims
75
+ # If compression_type is None, treat it as "none".
76
+ self.compression_type = (compression_type or "none").lower()
77
+ self.use_float16 = use_float16
78
+ self.hybrid_search = hybrid_search
79
+ self.vector_filter_mode = vector_filter_mode
80
+
81
+ # If the API key is not provided or is a placeholder, use DefaultAzureCredential.
82
+ if self.api_key is None or self.api_key == "" or self.api_key == "your-api-key":
83
+ credential = DefaultAzureCredential()
84
+ self.api_key = None
85
+ else:
86
+ credential = AzureKeyCredential(self.api_key)
87
+
88
+ self.search_client = SearchClient(
89
+ endpoint=f"https://{service_name}.search.windows.net",
90
+ index_name=self.index_name,
91
+ credential=credential,
92
+ )
93
+ self.index_client = SearchIndexClient(
94
+ endpoint=f"https://{service_name}.search.windows.net",
95
+ credential=credential,
96
+ )
97
+
98
+ self.search_client._client._config.user_agent_policy.add_user_agent("mem0")
99
+ self.index_client._client._config.user_agent_policy.add_user_agent("mem0")
100
+
101
+ collections = self.list_cols()
102
+ if collection_name not in collections:
103
+ self.create_col()
104
+
105
+ def create_col(self):
106
+ """Create a new index in Azure AI Search."""
107
+ # Determine vector type based on use_float16 setting.
108
+ if self.use_float16:
109
+ vector_type = "Collection(Edm.Half)"
110
+ else:
111
+ vector_type = "Collection(Edm.Single)"
112
+
113
+ # Configure compression settings based on the specified compression_type.
114
+ compression_configurations = []
115
+ compression_name = None
116
+ if self.compression_type == "scalar":
117
+ compression_name = "myCompression"
118
+ # For SQ, rescoring defaults to True and oversampling defaults to 4.
119
+ compression_configurations = [
120
+ ScalarQuantizationCompression(
121
+ compression_name=compression_name
122
+ # rescoring defaults to True and oversampling defaults to 4
123
+ )
124
+ ]
125
+ elif self.compression_type == "binary":
126
+ compression_name = "myCompression"
127
+ # For BQ, rescoring defaults to True and oversampling defaults to 10.
128
+ compression_configurations = [
129
+ BinaryQuantizationCompression(
130
+ compression_name=compression_name
131
+ # rescoring defaults to True and oversampling defaults to 10
132
+ )
133
+ ]
134
+ # If no compression is desired, compression_configurations remains empty.
135
+ fields = [
136
+ SimpleField(name="id", type=SearchFieldDataType.String, key=True),
137
+ SimpleField(name="user_id", type=SearchFieldDataType.String, filterable=True),
138
+ SimpleField(name="run_id", type=SearchFieldDataType.String, filterable=True),
139
+ SimpleField(name="agent_id", type=SearchFieldDataType.String, filterable=True),
140
+ SearchField(
141
+ name="vector",
142
+ type=vector_type,
143
+ searchable=True,
144
+ vector_search_dimensions=self.embedding_model_dims,
145
+ vector_search_profile_name="my-vector-config",
146
+ ),
147
+ SearchField(name="payload", type=SearchFieldDataType.String, searchable=True),
148
+ ]
149
+
150
+ vector_search = VectorSearch(
151
+ profiles=[
152
+ VectorSearchProfile(
153
+ name="my-vector-config",
154
+ algorithm_configuration_name="my-algorithms-config",
155
+ compression_name=compression_name if self.compression_type != "none" else None,
156
+ )
157
+ ],
158
+ algorithms=[HnswAlgorithmConfiguration(name="my-algorithms-config")],
159
+ compressions=compression_configurations,
160
+ )
161
+ index = SearchIndex(name=self.index_name, fields=fields, vector_search=vector_search)
162
+ self.index_client.create_or_update_index(index)
163
+
164
+ def _generate_document(self, vector, payload, id):
165
+ document = {"id": id, "vector": vector, "payload": json.dumps(payload)}
166
+ # Extract additional fields if they exist.
167
+ for field in ["user_id", "run_id", "agent_id"]:
168
+ if field in payload:
169
+ document[field] = payload[field]
170
+ return document
171
+
172
+ # Note: Explicit "insert" calls may later be decoupled from memory management decisions.
173
+ def insert(self, vectors, payloads=None, ids=None):
174
+ """
175
+ Insert vectors into the index.
176
+
177
+ Args:
178
+ vectors (List[List[float]]): List of vectors to insert.
179
+ payloads (List[Dict], optional): List of payloads corresponding to vectors.
180
+ ids (List[str], optional): List of IDs corresponding to vectors.
181
+ """
182
+ logger.info(f"Inserting {len(vectors)} vectors into index {self.index_name}")
183
+ documents = [
184
+ self._generate_document(vector, payload, id) for id, vector, payload in zip(ids, vectors, payloads)
185
+ ]
186
+ response = self.search_client.upload_documents(documents)
187
+ for doc in response:
188
+ if not hasattr(doc, "status_code") and doc.get("status_code") != 201:
189
+ raise Exception(f"Insert failed for document {doc.get('id')}: {doc}")
190
+ return response
191
+
192
+ def _sanitize_key(self, key: str) -> str:
193
+ return re.sub(r"[^\w]", "", key)
194
+
195
+ def _build_filter_expression(self, filters):
196
+ filter_conditions = []
197
+ for key, value in filters.items():
198
+ safe_key = self._sanitize_key(key)
199
+ if isinstance(value, str):
200
+ safe_value = value.replace("'", "''")
201
+ condition = f"{safe_key} eq '{safe_value}'"
202
+ else:
203
+ condition = f"{safe_key} eq {value}"
204
+ filter_conditions.append(condition)
205
+ filter_expression = " and ".join(filter_conditions)
206
+ return filter_expression
207
+
208
+ def search(self, query, vectors, limit=5, filters=None):
209
+ """
210
+ Search for similar vectors.
211
+
212
+ Args:
213
+ query (str): Query.
214
+ vectors (List[float]): Query vector.
215
+ limit (int, optional): Number of results to return. Defaults to 5.
216
+ filters (Dict, optional): Filters to apply to the search. Defaults to None.
217
+
218
+ Returns:
219
+ List[OutputData]: Search results.
220
+ """
221
+ filter_expression = None
222
+ if filters:
223
+ filter_expression = self._build_filter_expression(filters)
224
+
225
+ vector_query = VectorizedQuery(vector=vectors, k_nearest_neighbors=limit, fields="vector")
226
+ if self.hybrid_search:
227
+ search_results = self.search_client.search(
228
+ search_text=query,
229
+ vector_queries=[vector_query],
230
+ filter=filter_expression,
231
+ top=limit,
232
+ vector_filter_mode=self.vector_filter_mode,
233
+ search_fields=["payload"],
234
+ )
235
+ else:
236
+ search_results = self.search_client.search(
237
+ vector_queries=[vector_query],
238
+ filter=filter_expression,
239
+ top=limit,
240
+ vector_filter_mode=self.vector_filter_mode,
241
+ )
242
+
243
+ results = []
244
+ for result in search_results:
245
+ payload = json.loads(extract_json(result["payload"]))
246
+ results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))
247
+ return results
248
+
249
+ def delete(self, vector_id):
250
+ """
251
+ Delete a vector by ID.
252
+
253
+ Args:
254
+ vector_id (str): ID of the vector to delete.
255
+ """
256
+ response = self.search_client.delete_documents(documents=[{"id": vector_id}])
257
+ for doc in response:
258
+ if not hasattr(doc, "status_code") and doc.get("status_code") != 200:
259
+ raise Exception(f"Delete failed for document {vector_id}: {doc}")
260
+ logger.info(f"Deleted document with ID '{vector_id}' from index '{self.index_name}'.")
261
+ return response
262
+
263
+ def update(self, vector_id, vector=None, payload=None):
264
+ """
265
+ Update a vector and its payload.
266
+
267
+ Args:
268
+ vector_id (str): ID of the vector to update.
269
+ vector (List[float], optional): Updated vector.
270
+ payload (Dict, optional): Updated payload.
271
+ """
272
+ document = {"id": vector_id}
273
+ if vector:
274
+ document["vector"] = vector
275
+ if payload:
276
+ json_payload = json.dumps(payload)
277
+ document["payload"] = json_payload
278
+ for field in ["user_id", "run_id", "agent_id"]:
279
+ document[field] = payload.get(field)
280
+ response = self.search_client.merge_or_upload_documents(documents=[document])
281
+ for doc in response:
282
+ if not hasattr(doc, "status_code") and doc.get("status_code") != 200:
283
+ raise Exception(f"Update failed for document {vector_id}: {doc}")
284
+ return response
285
+
286
+ def get(self, vector_id) -> OutputData:
287
+ """
288
+ Retrieve a vector by ID.
289
+
290
+ Args:
291
+ vector_id (str): ID of the vector to retrieve.
292
+
293
+ Returns:
294
+ OutputData: Retrieved vector.
295
+ """
296
+ try:
297
+ result = self.search_client.get_document(key=vector_id)
298
+ except ResourceNotFoundError:
299
+ return None
300
+ payload = json.loads(extract_json(result["payload"]))
301
+ return OutputData(id=result["id"], score=None, payload=payload)
302
+
303
+ def list_cols(self) -> List[str]:
304
+ """
305
+ List all collections (indexes).
306
+
307
+ Returns:
308
+ List[str]: List of index names.
309
+ """
310
+ try:
311
+ names = self.index_client.list_index_names()
312
+ except AttributeError:
313
+ names = [index.name for index in self.index_client.list_indexes()]
314
+ return names
315
+
316
+ def delete_col(self):
317
+ """Delete the index."""
318
+ self.index_client.delete_index(self.index_name)
319
+
320
+ def col_info(self):
321
+ """
322
+ Get information about the index.
323
+
324
+ Returns:
325
+ dict: Index information.
326
+ """
327
+ index = self.index_client.get_index(self.index_name)
328
+ return {"name": index.name, "fields": index.fields}
329
+
330
+ def list(self, filters=None, limit=100):
331
+ """
332
+ List all vectors in the index.
333
+
334
+ Args:
335
+ filters (dict, optional): Filters to apply to the list.
336
+ limit (int, optional): Number of vectors to return. Defaults to 100.
337
+
338
+ Returns:
339
+ List[OutputData]: List of vectors.
340
+ """
341
+ filter_expression = None
342
+ if filters:
343
+ filter_expression = self._build_filter_expression(filters)
344
+
345
+ search_results = self.search_client.search(search_text="*", filter=filter_expression, top=limit)
346
+ results = []
347
+ for result in search_results:
348
+ payload = json.loads(extract_json(result["payload"]))
349
+ results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))
350
+ return [results]
351
+
352
+ def __del__(self):
353
+ """Close the search client when the object is deleted."""
354
+ self.search_client.close()
355
+ self.index_client.close()
356
+
357
+ def reset(self):
358
+ """Reset the index by deleting and recreating it."""
359
+ logger.warning(f"Resetting index {self.index_name}...")
360
+
361
+ try:
362
+ # Close the existing clients
363
+ self.search_client.close()
364
+ self.index_client.close()
365
+
366
+ # Delete the collection
367
+ self.delete_col()
368
+
369
+ # If the API key is not provided or is a placeholder, use DefaultAzureCredential.
370
+ if self.api_key is None or self.api_key == "" or self.api_key == "your-api-key":
371
+ credential = DefaultAzureCredential()
372
+ self.api_key = None
373
+ else:
374
+ credential = AzureKeyCredential(self.api_key)
375
+
376
+ # Reinitialize the clients
377
+ service_endpoint = f"https://{self.service_name}.search.windows.net"
378
+ self.search_client = SearchClient(
379
+ endpoint=service_endpoint,
380
+ index_name=self.index_name,
381
+ credential=credential,
382
+ )
383
+ self.index_client = SearchIndexClient(
384
+ endpoint=service_endpoint,
385
+ credential=credential,
386
+ )
387
+
388
+ # Add user agent
389
+ self.search_client._client._config.user_agent_policy.add_user_agent("mem0")
390
+ self.index_client._client._config.user_agent_policy.add_user_agent("mem0")
391
+
392
+ # Create the collection
393
+ self.create_col()
394
+ except Exception as e:
395
+ logger.error(f"Error resetting index {self.index_name}: {e}")
396
+ raise