agentrun-mem0ai 0.0.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentrun_mem0/__init__.py +6 -0
- agentrun_mem0/client/__init__.py +0 -0
- agentrun_mem0/client/main.py +1747 -0
- agentrun_mem0/client/project.py +931 -0
- agentrun_mem0/client/utils.py +115 -0
- agentrun_mem0/configs/__init__.py +0 -0
- agentrun_mem0/configs/base.py +90 -0
- agentrun_mem0/configs/embeddings/__init__.py +0 -0
- agentrun_mem0/configs/embeddings/base.py +110 -0
- agentrun_mem0/configs/enums.py +7 -0
- agentrun_mem0/configs/llms/__init__.py +0 -0
- agentrun_mem0/configs/llms/anthropic.py +56 -0
- agentrun_mem0/configs/llms/aws_bedrock.py +192 -0
- agentrun_mem0/configs/llms/azure.py +57 -0
- agentrun_mem0/configs/llms/base.py +62 -0
- agentrun_mem0/configs/llms/deepseek.py +56 -0
- agentrun_mem0/configs/llms/lmstudio.py +59 -0
- agentrun_mem0/configs/llms/ollama.py +56 -0
- agentrun_mem0/configs/llms/openai.py +79 -0
- agentrun_mem0/configs/llms/vllm.py +56 -0
- agentrun_mem0/configs/prompts.py +459 -0
- agentrun_mem0/configs/rerankers/__init__.py +0 -0
- agentrun_mem0/configs/rerankers/base.py +17 -0
- agentrun_mem0/configs/rerankers/cohere.py +15 -0
- agentrun_mem0/configs/rerankers/config.py +12 -0
- agentrun_mem0/configs/rerankers/huggingface.py +17 -0
- agentrun_mem0/configs/rerankers/llm.py +48 -0
- agentrun_mem0/configs/rerankers/sentence_transformer.py +16 -0
- agentrun_mem0/configs/rerankers/zero_entropy.py +28 -0
- agentrun_mem0/configs/vector_stores/__init__.py +0 -0
- agentrun_mem0/configs/vector_stores/alibabacloud_mysql.py +64 -0
- agentrun_mem0/configs/vector_stores/aliyun_tablestore.py +32 -0
- agentrun_mem0/configs/vector_stores/azure_ai_search.py +57 -0
- agentrun_mem0/configs/vector_stores/azure_mysql.py +84 -0
- agentrun_mem0/configs/vector_stores/baidu.py +27 -0
- agentrun_mem0/configs/vector_stores/chroma.py +58 -0
- agentrun_mem0/configs/vector_stores/databricks.py +61 -0
- agentrun_mem0/configs/vector_stores/elasticsearch.py +65 -0
- agentrun_mem0/configs/vector_stores/faiss.py +37 -0
- agentrun_mem0/configs/vector_stores/langchain.py +30 -0
- agentrun_mem0/configs/vector_stores/milvus.py +42 -0
- agentrun_mem0/configs/vector_stores/mongodb.py +25 -0
- agentrun_mem0/configs/vector_stores/neptune.py +27 -0
- agentrun_mem0/configs/vector_stores/opensearch.py +41 -0
- agentrun_mem0/configs/vector_stores/pgvector.py +52 -0
- agentrun_mem0/configs/vector_stores/pinecone.py +55 -0
- agentrun_mem0/configs/vector_stores/qdrant.py +47 -0
- agentrun_mem0/configs/vector_stores/redis.py +24 -0
- agentrun_mem0/configs/vector_stores/s3_vectors.py +28 -0
- agentrun_mem0/configs/vector_stores/supabase.py +44 -0
- agentrun_mem0/configs/vector_stores/upstash_vector.py +34 -0
- agentrun_mem0/configs/vector_stores/valkey.py +15 -0
- agentrun_mem0/configs/vector_stores/vertex_ai_vector_search.py +28 -0
- agentrun_mem0/configs/vector_stores/weaviate.py +41 -0
- agentrun_mem0/embeddings/__init__.py +0 -0
- agentrun_mem0/embeddings/aws_bedrock.py +100 -0
- agentrun_mem0/embeddings/azure_openai.py +55 -0
- agentrun_mem0/embeddings/base.py +31 -0
- agentrun_mem0/embeddings/configs.py +30 -0
- agentrun_mem0/embeddings/gemini.py +39 -0
- agentrun_mem0/embeddings/huggingface.py +44 -0
- agentrun_mem0/embeddings/langchain.py +35 -0
- agentrun_mem0/embeddings/lmstudio.py +29 -0
- agentrun_mem0/embeddings/mock.py +11 -0
- agentrun_mem0/embeddings/ollama.py +53 -0
- agentrun_mem0/embeddings/openai.py +49 -0
- agentrun_mem0/embeddings/together.py +31 -0
- agentrun_mem0/embeddings/vertexai.py +64 -0
- agentrun_mem0/exceptions.py +503 -0
- agentrun_mem0/graphs/__init__.py +0 -0
- agentrun_mem0/graphs/configs.py +105 -0
- agentrun_mem0/graphs/neptune/__init__.py +0 -0
- agentrun_mem0/graphs/neptune/base.py +497 -0
- agentrun_mem0/graphs/neptune/neptunedb.py +511 -0
- agentrun_mem0/graphs/neptune/neptunegraph.py +474 -0
- agentrun_mem0/graphs/tools.py +371 -0
- agentrun_mem0/graphs/utils.py +97 -0
- agentrun_mem0/llms/__init__.py +0 -0
- agentrun_mem0/llms/anthropic.py +87 -0
- agentrun_mem0/llms/aws_bedrock.py +665 -0
- agentrun_mem0/llms/azure_openai.py +141 -0
- agentrun_mem0/llms/azure_openai_structured.py +91 -0
- agentrun_mem0/llms/base.py +131 -0
- agentrun_mem0/llms/configs.py +34 -0
- agentrun_mem0/llms/deepseek.py +107 -0
- agentrun_mem0/llms/gemini.py +201 -0
- agentrun_mem0/llms/groq.py +88 -0
- agentrun_mem0/llms/langchain.py +94 -0
- agentrun_mem0/llms/litellm.py +87 -0
- agentrun_mem0/llms/lmstudio.py +114 -0
- agentrun_mem0/llms/ollama.py +117 -0
- agentrun_mem0/llms/openai.py +147 -0
- agentrun_mem0/llms/openai_structured.py +52 -0
- agentrun_mem0/llms/sarvam.py +89 -0
- agentrun_mem0/llms/together.py +88 -0
- agentrun_mem0/llms/vllm.py +107 -0
- agentrun_mem0/llms/xai.py +52 -0
- agentrun_mem0/memory/__init__.py +0 -0
- agentrun_mem0/memory/base.py +63 -0
- agentrun_mem0/memory/graph_memory.py +698 -0
- agentrun_mem0/memory/kuzu_memory.py +713 -0
- agentrun_mem0/memory/main.py +2229 -0
- agentrun_mem0/memory/memgraph_memory.py +689 -0
- agentrun_mem0/memory/setup.py +56 -0
- agentrun_mem0/memory/storage.py +218 -0
- agentrun_mem0/memory/telemetry.py +90 -0
- agentrun_mem0/memory/utils.py +208 -0
- agentrun_mem0/proxy/__init__.py +0 -0
- agentrun_mem0/proxy/main.py +189 -0
- agentrun_mem0/reranker/__init__.py +9 -0
- agentrun_mem0/reranker/base.py +20 -0
- agentrun_mem0/reranker/cohere_reranker.py +85 -0
- agentrun_mem0/reranker/huggingface_reranker.py +147 -0
- agentrun_mem0/reranker/llm_reranker.py +142 -0
- agentrun_mem0/reranker/sentence_transformer_reranker.py +107 -0
- agentrun_mem0/reranker/zero_entropy_reranker.py +96 -0
- agentrun_mem0/utils/factory.py +283 -0
- agentrun_mem0/utils/gcp_auth.py +167 -0
- agentrun_mem0/vector_stores/__init__.py +0 -0
- agentrun_mem0/vector_stores/alibabacloud_mysql.py +547 -0
- agentrun_mem0/vector_stores/aliyun_tablestore.py +252 -0
- agentrun_mem0/vector_stores/azure_ai_search.py +396 -0
- agentrun_mem0/vector_stores/azure_mysql.py +463 -0
- agentrun_mem0/vector_stores/baidu.py +368 -0
- agentrun_mem0/vector_stores/base.py +58 -0
- agentrun_mem0/vector_stores/chroma.py +332 -0
- agentrun_mem0/vector_stores/configs.py +67 -0
- agentrun_mem0/vector_stores/databricks.py +761 -0
- agentrun_mem0/vector_stores/elasticsearch.py +237 -0
- agentrun_mem0/vector_stores/faiss.py +479 -0
- agentrun_mem0/vector_stores/langchain.py +180 -0
- agentrun_mem0/vector_stores/milvus.py +250 -0
- agentrun_mem0/vector_stores/mongodb.py +310 -0
- agentrun_mem0/vector_stores/neptune_analytics.py +467 -0
- agentrun_mem0/vector_stores/opensearch.py +292 -0
- agentrun_mem0/vector_stores/pgvector.py +404 -0
- agentrun_mem0/vector_stores/pinecone.py +382 -0
- agentrun_mem0/vector_stores/qdrant.py +270 -0
- agentrun_mem0/vector_stores/redis.py +295 -0
- agentrun_mem0/vector_stores/s3_vectors.py +176 -0
- agentrun_mem0/vector_stores/supabase.py +237 -0
- agentrun_mem0/vector_stores/upstash_vector.py +293 -0
- agentrun_mem0/vector_stores/valkey.py +824 -0
- agentrun_mem0/vector_stores/vertex_ai_vector_search.py +635 -0
- agentrun_mem0/vector_stores/weaviate.py +343 -0
- agentrun_mem0ai-0.0.11.data/data/README.md +205 -0
- agentrun_mem0ai-0.0.11.dist-info/METADATA +277 -0
- agentrun_mem0ai-0.0.11.dist-info/RECORD +150 -0
- agentrun_mem0ai-0.0.11.dist-info/WHEEL +4 -0
- agentrun_mem0ai-0.0.11.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
|
|
4
|
+
from agentrun_mem0.vector_stores.base import VectorStoreBase
|
|
5
|
+
from typing import Any, Optional, Dict
|
|
6
|
+
|
|
7
|
+
import tablestore
|
|
8
|
+
from tablestore_for_agent_memory.knowledge.knowledge_store import KnowledgeStore
|
|
9
|
+
from tablestore_for_agent_memory.base.base_knowledge_store import Document
|
|
10
|
+
from tablestore_for_agent_memory.base.filter import Filters
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
class OutputData:
|
|
15
|
+
def __init__(self, document: Document, score=None, metadata_name='payload'):
|
|
16
|
+
self._metadata_name = metadata_name
|
|
17
|
+
self.id: Optional[str] = document.document_id # memory id
|
|
18
|
+
self.score: Optional[float] = score # distance
|
|
19
|
+
self.payload: Optional[Dict] = self._metadata2payload(document.metadata) # metadata
|
|
20
|
+
self.payload['data'] = document.text
|
|
21
|
+
|
|
22
|
+
def _metadata2payload(self, metadata):
|
|
23
|
+
return json.loads(metadata[f'{self._metadata_name}_source'])
|
|
24
|
+
|
|
25
|
+
metric_str2metric_type_dict = {
|
|
26
|
+
"VM_EUCLIDEAN": tablestore.VectorMetricType.VM_EUCLIDEAN,
|
|
27
|
+
"VM_COSINE": tablestore.VectorMetricType.VM_COSINE,
|
|
28
|
+
"VM_DOT_PRODUCT": tablestore.VectorMetricType.VM_DOT_PRODUCT,
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
class AliyunTableStore(VectorStoreBase):
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
endpoint: str,
|
|
35
|
+
instance_name: str,
|
|
36
|
+
access_key_id: str,
|
|
37
|
+
access_key_secret: str,
|
|
38
|
+
vector_dimension: int,
|
|
39
|
+
sts_token: Optional[str] = None,
|
|
40
|
+
collection_name: str = "mem0",
|
|
41
|
+
search_index_name: str = "mem0_search_index",
|
|
42
|
+
text_field: str = "text",
|
|
43
|
+
embedding_field: str = "embedding",
|
|
44
|
+
vector_metric_type: str = "VM_COSINE",
|
|
45
|
+
**kwargs: Any,
|
|
46
|
+
):
|
|
47
|
+
self._tablestore_client = tablestore.OTSClient(
|
|
48
|
+
end_point=endpoint,
|
|
49
|
+
access_key_id=access_key_id,
|
|
50
|
+
access_key_secret=access_key_secret,
|
|
51
|
+
instance_name=instance_name,
|
|
52
|
+
sts_token=None if sts_token == "" else sts_token,
|
|
53
|
+
retry_policy=tablestore.WriteRetryPolicy(),
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
self._vector_dimension = vector_dimension
|
|
57
|
+
self._collection_name = collection_name
|
|
58
|
+
self._search_index_name = search_index_name
|
|
59
|
+
self._metadata_name = 'payload'
|
|
60
|
+
self._key_value_hyphen = '='
|
|
61
|
+
self._search_index_schema = [
|
|
62
|
+
tablestore.FieldSchema(
|
|
63
|
+
self._metadata_name,
|
|
64
|
+
tablestore.FieldType.KEYWORD,
|
|
65
|
+
index=True,
|
|
66
|
+
is_array=True,
|
|
67
|
+
enable_sort_and_agg=True,
|
|
68
|
+
),
|
|
69
|
+
tablestore.FieldSchema(
|
|
70
|
+
f'{self._metadata_name}_source',
|
|
71
|
+
tablestore.FieldType.KEYWORD,
|
|
72
|
+
index=False,
|
|
73
|
+
is_array=False,
|
|
74
|
+
enable_sort_and_agg=False,
|
|
75
|
+
)
|
|
76
|
+
]
|
|
77
|
+
self._text_field = text_field
|
|
78
|
+
self._embedding_field = embedding_field
|
|
79
|
+
self._vector_metric_type = metric_str2metric_type_dict[vector_metric_type]
|
|
80
|
+
|
|
81
|
+
self._knowledge_store = KnowledgeStore(
|
|
82
|
+
tablestore_client=self._tablestore_client,
|
|
83
|
+
vector_dimension=self._vector_dimension,
|
|
84
|
+
enable_multi_tenant=False,
|
|
85
|
+
table_name=self._collection_name,
|
|
86
|
+
search_index_name=self._search_index_name,
|
|
87
|
+
search_index_schema=self._search_index_schema,
|
|
88
|
+
text_field=self._text_field,
|
|
89
|
+
embedding_field=self._embedding_field,
|
|
90
|
+
vector_metric_type=self._vector_metric_type,
|
|
91
|
+
**kwargs,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
self.create_col(**kwargs)
|
|
95
|
+
|
|
96
|
+
def create_col(self, **kwargs: Any):
|
|
97
|
+
"""Create a new collection."""
|
|
98
|
+
if self._collection_name in self.list_cols():
|
|
99
|
+
logger.warning(f"tablestore table:[{self._collection_name}] already exists")
|
|
100
|
+
return
|
|
101
|
+
self._knowledge_store.init_table()
|
|
102
|
+
|
|
103
|
+
def _payload2metadata(self, payload: Dict):
|
|
104
|
+
payload_ = json.dumps([f'{key}{self._key_value_hyphen}{value}' for key, value in payload.items()], ensure_ascii=False)
|
|
105
|
+
return {
|
|
106
|
+
self._metadata_name: payload_,
|
|
107
|
+
f'{self._metadata_name}_source': json.dumps(payload, ensure_ascii=False),
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
def insert(self, vectors: list, payloads: list = None, ids: list = None):
|
|
111
|
+
"""Insert vectors into a collection."""
|
|
112
|
+
payloads_ = payloads if payloads is not None else []
|
|
113
|
+
documents = []
|
|
114
|
+
|
|
115
|
+
for id, vector, payload in zip(ids, vectors, payloads_):
|
|
116
|
+
payload_ = payload.copy() if payload is not None else {}
|
|
117
|
+
documents.append(
|
|
118
|
+
Document(
|
|
119
|
+
document_id=id,
|
|
120
|
+
text=payload_.pop('data')
|
|
121
|
+
if 'data' in payload_.keys()
|
|
122
|
+
else None,
|
|
123
|
+
embedding=vector,
|
|
124
|
+
metadata=self._payload2metadata(payload_),
|
|
125
|
+
)
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
for document in documents:
|
|
129
|
+
self._knowledge_store.put_document(document)
|
|
130
|
+
|
|
131
|
+
def _create_filter(self, filters: dict):
|
|
132
|
+
"""Create filters from dict (format of mem0 filters)"""
|
|
133
|
+
if filters is None:
|
|
134
|
+
return None
|
|
135
|
+
|
|
136
|
+
if len(filters.keys()) == 1:
|
|
137
|
+
meta_key, meta_value = tuple(filters.items())[0]
|
|
138
|
+
return Filters.eq(self._metadata_name, f'{meta_key}{self._key_value_hyphen}{meta_value}')
|
|
139
|
+
|
|
140
|
+
return Filters.logical_and(
|
|
141
|
+
[
|
|
142
|
+
Filters.eq(self._metadata_name, f'{meta_key}{self._key_value_hyphen}{meta_value}')
|
|
143
|
+
for meta_key, meta_value in filters.items()
|
|
144
|
+
]
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
def search(self, query, vectors, limit=5, filters=None):
|
|
148
|
+
"""Search for similar vectors."""
|
|
149
|
+
response = self._knowledge_store.vector_search(
|
|
150
|
+
query_vector=vectors,
|
|
151
|
+
top_k=limit,
|
|
152
|
+
metadata_filter=self._create_filter(filters),
|
|
153
|
+
)
|
|
154
|
+
return [
|
|
155
|
+
OutputData(
|
|
156
|
+
document=hit.document,
|
|
157
|
+
score=hit.score,
|
|
158
|
+
metadata_name=self._metadata_name,
|
|
159
|
+
)
|
|
160
|
+
for hit in response.hits
|
|
161
|
+
]
|
|
162
|
+
|
|
163
|
+
def delete(self, vector_id):
|
|
164
|
+
"""Delete a vector by ID."""
|
|
165
|
+
self._knowledge_store.delete_document(document_id=vector_id)
|
|
166
|
+
|
|
167
|
+
def update(self, vector_id, vector=None, payload=None):
|
|
168
|
+
"""Update a vector and its payload."""
|
|
169
|
+
payload_ = payload.copy() if payload is not None else {}
|
|
170
|
+
document_for_update = Document(
|
|
171
|
+
document_id=vector_id,
|
|
172
|
+
text=payload_.pop('data')
|
|
173
|
+
if 'data' in payload_.keys()
|
|
174
|
+
else None,
|
|
175
|
+
embedding=vector,
|
|
176
|
+
metadata=self._payload2metadata(payload_),
|
|
177
|
+
)
|
|
178
|
+
self._knowledge_store.update_document(document_for_update)
|
|
179
|
+
|
|
180
|
+
def get(self, vector_id):
|
|
181
|
+
"""Retrieve a vector by ID."""
|
|
182
|
+
document = self._knowledge_store.get_document(document_id=vector_id)
|
|
183
|
+
return OutputData(
|
|
184
|
+
document=document,
|
|
185
|
+
metadata_name=self._metadata_name,
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
def list_cols(self):
|
|
189
|
+
"""List all collections."""
|
|
190
|
+
return self._tablestore_client.list_table()
|
|
191
|
+
|
|
192
|
+
def delete_col(self):
|
|
193
|
+
"""Delete a collection."""
|
|
194
|
+
self._tablestore_client.delete_search_index(table_name=self._collection_name, index_name=self._search_index_name)
|
|
195
|
+
self._tablestore_client.delete_table(table_name=self._collection_name)
|
|
196
|
+
|
|
197
|
+
def col_info(self):
|
|
198
|
+
"""Get information about a collection."""
|
|
199
|
+
self._tablestore_client.describe_table(table_name=self._collection_name)
|
|
200
|
+
|
|
201
|
+
def list(self, filters=None, limit=100):
|
|
202
|
+
"""List all memories."""
|
|
203
|
+
return [
|
|
204
|
+
[
|
|
205
|
+
OutputData(
|
|
206
|
+
document=hit.document,
|
|
207
|
+
metadata_name=self._metadata_name,
|
|
208
|
+
)
|
|
209
|
+
for hit in self._knowledge_store.search_documents(metadata_filter=self._create_filter(filters), limit=limit).hits
|
|
210
|
+
]
|
|
211
|
+
]
|
|
212
|
+
|
|
213
|
+
def list_paginated(self, filters=None, limit=100, next_token=None):
|
|
214
|
+
"""List memories with pagination support.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
filters: Optional filters to apply
|
|
218
|
+
limit: Maximum number of memories to return (max 1000)
|
|
219
|
+
next_token: Token for pagination, pass the token from previous response to get next page
|
|
220
|
+
|
|
221
|
+
Returns:
|
|
222
|
+
dict: {
|
|
223
|
+
"memories": list of OutputData objects,
|
|
224
|
+
"next_token": token for next page (None if no more pages),
|
|
225
|
+
"has_more": boolean indicating if there are more pages
|
|
226
|
+
}
|
|
227
|
+
"""
|
|
228
|
+
response = self._knowledge_store.search_documents(
|
|
229
|
+
metadata_filter=self._create_filter(filters),
|
|
230
|
+
limit=min(limit, 1000), # 最大 1000
|
|
231
|
+
next_token=next_token,
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
memories = [
|
|
235
|
+
OutputData(
|
|
236
|
+
document=hit.document,
|
|
237
|
+
metadata_name=self._metadata_name,
|
|
238
|
+
)
|
|
239
|
+
for hit in response.hits
|
|
240
|
+
]
|
|
241
|
+
|
|
242
|
+
return {
|
|
243
|
+
"memories": memories,
|
|
244
|
+
"next_token": response.next_token,
|
|
245
|
+
"has_more": response.next_token is not None,
|
|
246
|
+
}
|
|
247
|
+
|
|
248
|
+
def reset(self):
|
|
249
|
+
"""Reset by delete the collection and recreate it."""
|
|
250
|
+
logger.warning(f"Resetting table {self._collection_name}...")
|
|
251
|
+
self.delete_col()
|
|
252
|
+
self.create_col()
|
|
@@ -0,0 +1,396 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
import re
|
|
4
|
+
from typing import List, Optional
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel
|
|
7
|
+
|
|
8
|
+
from agentrun_mem0.memory.utils import extract_json
|
|
9
|
+
from agentrun_mem0.vector_stores.base import VectorStoreBase
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
from azure.core.credentials import AzureKeyCredential
|
|
13
|
+
from azure.core.exceptions import ResourceNotFoundError
|
|
14
|
+
from azure.identity import DefaultAzureCredential
|
|
15
|
+
from azure.search.documents import SearchClient
|
|
16
|
+
from azure.search.documents.indexes import SearchIndexClient
|
|
17
|
+
from azure.search.documents.indexes.models import (
|
|
18
|
+
BinaryQuantizationCompression,
|
|
19
|
+
HnswAlgorithmConfiguration,
|
|
20
|
+
ScalarQuantizationCompression,
|
|
21
|
+
SearchField,
|
|
22
|
+
SearchFieldDataType,
|
|
23
|
+
SearchIndex,
|
|
24
|
+
SimpleField,
|
|
25
|
+
VectorSearch,
|
|
26
|
+
VectorSearchProfile,
|
|
27
|
+
)
|
|
28
|
+
from azure.search.documents.models import VectorizedQuery
|
|
29
|
+
except ImportError:
|
|
30
|
+
raise ImportError(
|
|
31
|
+
"The 'azure-search-documents' library is required. Please install it using 'pip install azure-search-documents==11.5.2'."
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
logger = logging.getLogger(__name__)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class OutputData(BaseModel):
|
|
38
|
+
id: Optional[str]
|
|
39
|
+
score: Optional[float]
|
|
40
|
+
payload: Optional[dict]
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class AzureAISearch(VectorStoreBase):
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
service_name,
|
|
47
|
+
collection_name,
|
|
48
|
+
api_key,
|
|
49
|
+
embedding_model_dims,
|
|
50
|
+
compression_type: Optional[str] = None,
|
|
51
|
+
use_float16: bool = False,
|
|
52
|
+
hybrid_search: bool = False,
|
|
53
|
+
vector_filter_mode: Optional[str] = None,
|
|
54
|
+
):
|
|
55
|
+
"""
|
|
56
|
+
Initialize the Azure AI Search vector store.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
service_name (str): Azure AI Search service name.
|
|
60
|
+
collection_name (str): Index name.
|
|
61
|
+
api_key (str): API key for the Azure AI Search service.
|
|
62
|
+
embedding_model_dims (int): Dimension of the embedding vector.
|
|
63
|
+
compression_type (Optional[str]): Specifies the type of quantization to use.
|
|
64
|
+
Allowed values are None (no quantization), "scalar", or "binary".
|
|
65
|
+
use_float16 (bool): Whether to store vectors in half precision (Edm.Half) or full precision (Edm.Single).
|
|
66
|
+
(Note: This flag is preserved from the initial implementation per feedback.)
|
|
67
|
+
hybrid_search (bool): Whether to use hybrid search. Default is False.
|
|
68
|
+
vector_filter_mode (Optional[str]): Mode for vector filtering. Default is "preFilter".
|
|
69
|
+
"""
|
|
70
|
+
self.service_name = service_name
|
|
71
|
+
self.api_key = api_key
|
|
72
|
+
self.index_name = collection_name
|
|
73
|
+
self.collection_name = collection_name
|
|
74
|
+
self.embedding_model_dims = embedding_model_dims
|
|
75
|
+
# If compression_type is None, treat it as "none".
|
|
76
|
+
self.compression_type = (compression_type or "none").lower()
|
|
77
|
+
self.use_float16 = use_float16
|
|
78
|
+
self.hybrid_search = hybrid_search
|
|
79
|
+
self.vector_filter_mode = vector_filter_mode
|
|
80
|
+
|
|
81
|
+
# If the API key is not provided or is a placeholder, use DefaultAzureCredential.
|
|
82
|
+
if self.api_key is None or self.api_key == "" or self.api_key == "your-api-key":
|
|
83
|
+
credential = DefaultAzureCredential()
|
|
84
|
+
self.api_key = None
|
|
85
|
+
else:
|
|
86
|
+
credential = AzureKeyCredential(self.api_key)
|
|
87
|
+
|
|
88
|
+
self.search_client = SearchClient(
|
|
89
|
+
endpoint=f"https://{service_name}.search.windows.net",
|
|
90
|
+
index_name=self.index_name,
|
|
91
|
+
credential=credential,
|
|
92
|
+
)
|
|
93
|
+
self.index_client = SearchIndexClient(
|
|
94
|
+
endpoint=f"https://{service_name}.search.windows.net",
|
|
95
|
+
credential=credential,
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
self.search_client._client._config.user_agent_policy.add_user_agent("mem0")
|
|
99
|
+
self.index_client._client._config.user_agent_policy.add_user_agent("mem0")
|
|
100
|
+
|
|
101
|
+
collections = self.list_cols()
|
|
102
|
+
if collection_name not in collections:
|
|
103
|
+
self.create_col()
|
|
104
|
+
|
|
105
|
+
def create_col(self):
|
|
106
|
+
"""Create a new index in Azure AI Search."""
|
|
107
|
+
# Determine vector type based on use_float16 setting.
|
|
108
|
+
if self.use_float16:
|
|
109
|
+
vector_type = "Collection(Edm.Half)"
|
|
110
|
+
else:
|
|
111
|
+
vector_type = "Collection(Edm.Single)"
|
|
112
|
+
|
|
113
|
+
# Configure compression settings based on the specified compression_type.
|
|
114
|
+
compression_configurations = []
|
|
115
|
+
compression_name = None
|
|
116
|
+
if self.compression_type == "scalar":
|
|
117
|
+
compression_name = "myCompression"
|
|
118
|
+
# For SQ, rescoring defaults to True and oversampling defaults to 4.
|
|
119
|
+
compression_configurations = [
|
|
120
|
+
ScalarQuantizationCompression(
|
|
121
|
+
compression_name=compression_name
|
|
122
|
+
# rescoring defaults to True and oversampling defaults to 4
|
|
123
|
+
)
|
|
124
|
+
]
|
|
125
|
+
elif self.compression_type == "binary":
|
|
126
|
+
compression_name = "myCompression"
|
|
127
|
+
# For BQ, rescoring defaults to True and oversampling defaults to 10.
|
|
128
|
+
compression_configurations = [
|
|
129
|
+
BinaryQuantizationCompression(
|
|
130
|
+
compression_name=compression_name
|
|
131
|
+
# rescoring defaults to True and oversampling defaults to 10
|
|
132
|
+
)
|
|
133
|
+
]
|
|
134
|
+
# If no compression is desired, compression_configurations remains empty.
|
|
135
|
+
fields = [
|
|
136
|
+
SimpleField(name="id", type=SearchFieldDataType.String, key=True),
|
|
137
|
+
SimpleField(name="user_id", type=SearchFieldDataType.String, filterable=True),
|
|
138
|
+
SimpleField(name="run_id", type=SearchFieldDataType.String, filterable=True),
|
|
139
|
+
SimpleField(name="agent_id", type=SearchFieldDataType.String, filterable=True),
|
|
140
|
+
SearchField(
|
|
141
|
+
name="vector",
|
|
142
|
+
type=vector_type,
|
|
143
|
+
searchable=True,
|
|
144
|
+
vector_search_dimensions=self.embedding_model_dims,
|
|
145
|
+
vector_search_profile_name="my-vector-config",
|
|
146
|
+
),
|
|
147
|
+
SearchField(name="payload", type=SearchFieldDataType.String, searchable=True),
|
|
148
|
+
]
|
|
149
|
+
|
|
150
|
+
vector_search = VectorSearch(
|
|
151
|
+
profiles=[
|
|
152
|
+
VectorSearchProfile(
|
|
153
|
+
name="my-vector-config",
|
|
154
|
+
algorithm_configuration_name="my-algorithms-config",
|
|
155
|
+
compression_name=compression_name if self.compression_type != "none" else None,
|
|
156
|
+
)
|
|
157
|
+
],
|
|
158
|
+
algorithms=[HnswAlgorithmConfiguration(name="my-algorithms-config")],
|
|
159
|
+
compressions=compression_configurations,
|
|
160
|
+
)
|
|
161
|
+
index = SearchIndex(name=self.index_name, fields=fields, vector_search=vector_search)
|
|
162
|
+
self.index_client.create_or_update_index(index)
|
|
163
|
+
|
|
164
|
+
def _generate_document(self, vector, payload, id):
|
|
165
|
+
document = {"id": id, "vector": vector, "payload": json.dumps(payload)}
|
|
166
|
+
# Extract additional fields if they exist.
|
|
167
|
+
for field in ["user_id", "run_id", "agent_id"]:
|
|
168
|
+
if field in payload:
|
|
169
|
+
document[field] = payload[field]
|
|
170
|
+
return document
|
|
171
|
+
|
|
172
|
+
# Note: Explicit "insert" calls may later be decoupled from memory management decisions.
|
|
173
|
+
def insert(self, vectors, payloads=None, ids=None):
|
|
174
|
+
"""
|
|
175
|
+
Insert vectors into the index.
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
vectors (List[List[float]]): List of vectors to insert.
|
|
179
|
+
payloads (List[Dict], optional): List of payloads corresponding to vectors.
|
|
180
|
+
ids (List[str], optional): List of IDs corresponding to vectors.
|
|
181
|
+
"""
|
|
182
|
+
logger.info(f"Inserting {len(vectors)} vectors into index {self.index_name}")
|
|
183
|
+
documents = [
|
|
184
|
+
self._generate_document(vector, payload, id) for id, vector, payload in zip(ids, vectors, payloads)
|
|
185
|
+
]
|
|
186
|
+
response = self.search_client.upload_documents(documents)
|
|
187
|
+
for doc in response:
|
|
188
|
+
if not hasattr(doc, "status_code") and doc.get("status_code") != 201:
|
|
189
|
+
raise Exception(f"Insert failed for document {doc.get('id')}: {doc}")
|
|
190
|
+
return response
|
|
191
|
+
|
|
192
|
+
def _sanitize_key(self, key: str) -> str:
|
|
193
|
+
return re.sub(r"[^\w]", "", key)
|
|
194
|
+
|
|
195
|
+
def _build_filter_expression(self, filters):
|
|
196
|
+
filter_conditions = []
|
|
197
|
+
for key, value in filters.items():
|
|
198
|
+
safe_key = self._sanitize_key(key)
|
|
199
|
+
if isinstance(value, str):
|
|
200
|
+
safe_value = value.replace("'", "''")
|
|
201
|
+
condition = f"{safe_key} eq '{safe_value}'"
|
|
202
|
+
else:
|
|
203
|
+
condition = f"{safe_key} eq {value}"
|
|
204
|
+
filter_conditions.append(condition)
|
|
205
|
+
filter_expression = " and ".join(filter_conditions)
|
|
206
|
+
return filter_expression
|
|
207
|
+
|
|
208
|
+
def search(self, query, vectors, limit=5, filters=None):
|
|
209
|
+
"""
|
|
210
|
+
Search for similar vectors.
|
|
211
|
+
|
|
212
|
+
Args:
|
|
213
|
+
query (str): Query.
|
|
214
|
+
vectors (List[float]): Query vector.
|
|
215
|
+
limit (int, optional): Number of results to return. Defaults to 5.
|
|
216
|
+
filters (Dict, optional): Filters to apply to the search. Defaults to None.
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
219
|
+
List[OutputData]: Search results.
|
|
220
|
+
"""
|
|
221
|
+
filter_expression = None
|
|
222
|
+
if filters:
|
|
223
|
+
filter_expression = self._build_filter_expression(filters)
|
|
224
|
+
|
|
225
|
+
vector_query = VectorizedQuery(vector=vectors, k_nearest_neighbors=limit, fields="vector")
|
|
226
|
+
if self.hybrid_search:
|
|
227
|
+
search_results = self.search_client.search(
|
|
228
|
+
search_text=query,
|
|
229
|
+
vector_queries=[vector_query],
|
|
230
|
+
filter=filter_expression,
|
|
231
|
+
top=limit,
|
|
232
|
+
vector_filter_mode=self.vector_filter_mode,
|
|
233
|
+
search_fields=["payload"],
|
|
234
|
+
)
|
|
235
|
+
else:
|
|
236
|
+
search_results = self.search_client.search(
|
|
237
|
+
vector_queries=[vector_query],
|
|
238
|
+
filter=filter_expression,
|
|
239
|
+
top=limit,
|
|
240
|
+
vector_filter_mode=self.vector_filter_mode,
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
results = []
|
|
244
|
+
for result in search_results:
|
|
245
|
+
payload = json.loads(extract_json(result["payload"]))
|
|
246
|
+
results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))
|
|
247
|
+
return results
|
|
248
|
+
|
|
249
|
+
def delete(self, vector_id):
|
|
250
|
+
"""
|
|
251
|
+
Delete a vector by ID.
|
|
252
|
+
|
|
253
|
+
Args:
|
|
254
|
+
vector_id (str): ID of the vector to delete.
|
|
255
|
+
"""
|
|
256
|
+
response = self.search_client.delete_documents(documents=[{"id": vector_id}])
|
|
257
|
+
for doc in response:
|
|
258
|
+
if not hasattr(doc, "status_code") and doc.get("status_code") != 200:
|
|
259
|
+
raise Exception(f"Delete failed for document {vector_id}: {doc}")
|
|
260
|
+
logger.info(f"Deleted document with ID '{vector_id}' from index '{self.index_name}'.")
|
|
261
|
+
return response
|
|
262
|
+
|
|
263
|
+
def update(self, vector_id, vector=None, payload=None):
|
|
264
|
+
"""
|
|
265
|
+
Update a vector and its payload.
|
|
266
|
+
|
|
267
|
+
Args:
|
|
268
|
+
vector_id (str): ID of the vector to update.
|
|
269
|
+
vector (List[float], optional): Updated vector.
|
|
270
|
+
payload (Dict, optional): Updated payload.
|
|
271
|
+
"""
|
|
272
|
+
document = {"id": vector_id}
|
|
273
|
+
if vector:
|
|
274
|
+
document["vector"] = vector
|
|
275
|
+
if payload:
|
|
276
|
+
json_payload = json.dumps(payload)
|
|
277
|
+
document["payload"] = json_payload
|
|
278
|
+
for field in ["user_id", "run_id", "agent_id"]:
|
|
279
|
+
document[field] = payload.get(field)
|
|
280
|
+
response = self.search_client.merge_or_upload_documents(documents=[document])
|
|
281
|
+
for doc in response:
|
|
282
|
+
if not hasattr(doc, "status_code") and doc.get("status_code") != 200:
|
|
283
|
+
raise Exception(f"Update failed for document {vector_id}: {doc}")
|
|
284
|
+
return response
|
|
285
|
+
|
|
286
|
+
def get(self, vector_id) -> OutputData:
|
|
287
|
+
"""
|
|
288
|
+
Retrieve a vector by ID.
|
|
289
|
+
|
|
290
|
+
Args:
|
|
291
|
+
vector_id (str): ID of the vector to retrieve.
|
|
292
|
+
|
|
293
|
+
Returns:
|
|
294
|
+
OutputData: Retrieved vector.
|
|
295
|
+
"""
|
|
296
|
+
try:
|
|
297
|
+
result = self.search_client.get_document(key=vector_id)
|
|
298
|
+
except ResourceNotFoundError:
|
|
299
|
+
return None
|
|
300
|
+
payload = json.loads(extract_json(result["payload"]))
|
|
301
|
+
return OutputData(id=result["id"], score=None, payload=payload)
|
|
302
|
+
|
|
303
|
+
def list_cols(self) -> List[str]:
|
|
304
|
+
"""
|
|
305
|
+
List all collections (indexes).
|
|
306
|
+
|
|
307
|
+
Returns:
|
|
308
|
+
List[str]: List of index names.
|
|
309
|
+
"""
|
|
310
|
+
try:
|
|
311
|
+
names = self.index_client.list_index_names()
|
|
312
|
+
except AttributeError:
|
|
313
|
+
names = [index.name for index in self.index_client.list_indexes()]
|
|
314
|
+
return names
|
|
315
|
+
|
|
316
|
+
def delete_col(self):
|
|
317
|
+
"""Delete the index."""
|
|
318
|
+
self.index_client.delete_index(self.index_name)
|
|
319
|
+
|
|
320
|
+
def col_info(self):
|
|
321
|
+
"""
|
|
322
|
+
Get information about the index.
|
|
323
|
+
|
|
324
|
+
Returns:
|
|
325
|
+
dict: Index information.
|
|
326
|
+
"""
|
|
327
|
+
index = self.index_client.get_index(self.index_name)
|
|
328
|
+
return {"name": index.name, "fields": index.fields}
|
|
329
|
+
|
|
330
|
+
def list(self, filters=None, limit=100):
|
|
331
|
+
"""
|
|
332
|
+
List all vectors in the index.
|
|
333
|
+
|
|
334
|
+
Args:
|
|
335
|
+
filters (dict, optional): Filters to apply to the list.
|
|
336
|
+
limit (int, optional): Number of vectors to return. Defaults to 100.
|
|
337
|
+
|
|
338
|
+
Returns:
|
|
339
|
+
List[OutputData]: List of vectors.
|
|
340
|
+
"""
|
|
341
|
+
filter_expression = None
|
|
342
|
+
if filters:
|
|
343
|
+
filter_expression = self._build_filter_expression(filters)
|
|
344
|
+
|
|
345
|
+
search_results = self.search_client.search(search_text="*", filter=filter_expression, top=limit)
|
|
346
|
+
results = []
|
|
347
|
+
for result in search_results:
|
|
348
|
+
payload = json.loads(extract_json(result["payload"]))
|
|
349
|
+
results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))
|
|
350
|
+
return [results]
|
|
351
|
+
|
|
352
|
+
def __del__(self):
|
|
353
|
+
"""Close the search client when the object is deleted."""
|
|
354
|
+
self.search_client.close()
|
|
355
|
+
self.index_client.close()
|
|
356
|
+
|
|
357
|
+
def reset(self):
|
|
358
|
+
"""Reset the index by deleting and recreating it."""
|
|
359
|
+
logger.warning(f"Resetting index {self.index_name}...")
|
|
360
|
+
|
|
361
|
+
try:
|
|
362
|
+
# Close the existing clients
|
|
363
|
+
self.search_client.close()
|
|
364
|
+
self.index_client.close()
|
|
365
|
+
|
|
366
|
+
# Delete the collection
|
|
367
|
+
self.delete_col()
|
|
368
|
+
|
|
369
|
+
# If the API key is not provided or is a placeholder, use DefaultAzureCredential.
|
|
370
|
+
if self.api_key is None or self.api_key == "" or self.api_key == "your-api-key":
|
|
371
|
+
credential = DefaultAzureCredential()
|
|
372
|
+
self.api_key = None
|
|
373
|
+
else:
|
|
374
|
+
credential = AzureKeyCredential(self.api_key)
|
|
375
|
+
|
|
376
|
+
# Reinitialize the clients
|
|
377
|
+
service_endpoint = f"https://{self.service_name}.search.windows.net"
|
|
378
|
+
self.search_client = SearchClient(
|
|
379
|
+
endpoint=service_endpoint,
|
|
380
|
+
index_name=self.index_name,
|
|
381
|
+
credential=credential,
|
|
382
|
+
)
|
|
383
|
+
self.index_client = SearchIndexClient(
|
|
384
|
+
endpoint=service_endpoint,
|
|
385
|
+
credential=credential,
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
# Add user agent
|
|
389
|
+
self.search_client._client._config.user_agent_policy.add_user_agent("mem0")
|
|
390
|
+
self.index_client._client._config.user_agent_policy.add_user_agent("mem0")
|
|
391
|
+
|
|
392
|
+
# Create the collection
|
|
393
|
+
self.create_col()
|
|
394
|
+
except Exception as e:
|
|
395
|
+
logger.error(f"Error resetting index {self.index_name}: {e}")
|
|
396
|
+
raise
|