agentrun-mem0ai 0.0.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentrun_mem0/__init__.py +6 -0
- agentrun_mem0/client/__init__.py +0 -0
- agentrun_mem0/client/main.py +1747 -0
- agentrun_mem0/client/project.py +931 -0
- agentrun_mem0/client/utils.py +115 -0
- agentrun_mem0/configs/__init__.py +0 -0
- agentrun_mem0/configs/base.py +90 -0
- agentrun_mem0/configs/embeddings/__init__.py +0 -0
- agentrun_mem0/configs/embeddings/base.py +110 -0
- agentrun_mem0/configs/enums.py +7 -0
- agentrun_mem0/configs/llms/__init__.py +0 -0
- agentrun_mem0/configs/llms/anthropic.py +56 -0
- agentrun_mem0/configs/llms/aws_bedrock.py +192 -0
- agentrun_mem0/configs/llms/azure.py +57 -0
- agentrun_mem0/configs/llms/base.py +62 -0
- agentrun_mem0/configs/llms/deepseek.py +56 -0
- agentrun_mem0/configs/llms/lmstudio.py +59 -0
- agentrun_mem0/configs/llms/ollama.py +56 -0
- agentrun_mem0/configs/llms/openai.py +79 -0
- agentrun_mem0/configs/llms/vllm.py +56 -0
- agentrun_mem0/configs/prompts.py +459 -0
- agentrun_mem0/configs/rerankers/__init__.py +0 -0
- agentrun_mem0/configs/rerankers/base.py +17 -0
- agentrun_mem0/configs/rerankers/cohere.py +15 -0
- agentrun_mem0/configs/rerankers/config.py +12 -0
- agentrun_mem0/configs/rerankers/huggingface.py +17 -0
- agentrun_mem0/configs/rerankers/llm.py +48 -0
- agentrun_mem0/configs/rerankers/sentence_transformer.py +16 -0
- agentrun_mem0/configs/rerankers/zero_entropy.py +28 -0
- agentrun_mem0/configs/vector_stores/__init__.py +0 -0
- agentrun_mem0/configs/vector_stores/alibabacloud_mysql.py +64 -0
- agentrun_mem0/configs/vector_stores/aliyun_tablestore.py +32 -0
- agentrun_mem0/configs/vector_stores/azure_ai_search.py +57 -0
- agentrun_mem0/configs/vector_stores/azure_mysql.py +84 -0
- agentrun_mem0/configs/vector_stores/baidu.py +27 -0
- agentrun_mem0/configs/vector_stores/chroma.py +58 -0
- agentrun_mem0/configs/vector_stores/databricks.py +61 -0
- agentrun_mem0/configs/vector_stores/elasticsearch.py +65 -0
- agentrun_mem0/configs/vector_stores/faiss.py +37 -0
- agentrun_mem0/configs/vector_stores/langchain.py +30 -0
- agentrun_mem0/configs/vector_stores/milvus.py +42 -0
- agentrun_mem0/configs/vector_stores/mongodb.py +25 -0
- agentrun_mem0/configs/vector_stores/neptune.py +27 -0
- agentrun_mem0/configs/vector_stores/opensearch.py +41 -0
- agentrun_mem0/configs/vector_stores/pgvector.py +52 -0
- agentrun_mem0/configs/vector_stores/pinecone.py +55 -0
- agentrun_mem0/configs/vector_stores/qdrant.py +47 -0
- agentrun_mem0/configs/vector_stores/redis.py +24 -0
- agentrun_mem0/configs/vector_stores/s3_vectors.py +28 -0
- agentrun_mem0/configs/vector_stores/supabase.py +44 -0
- agentrun_mem0/configs/vector_stores/upstash_vector.py +34 -0
- agentrun_mem0/configs/vector_stores/valkey.py +15 -0
- agentrun_mem0/configs/vector_stores/vertex_ai_vector_search.py +28 -0
- agentrun_mem0/configs/vector_stores/weaviate.py +41 -0
- agentrun_mem0/embeddings/__init__.py +0 -0
- agentrun_mem0/embeddings/aws_bedrock.py +100 -0
- agentrun_mem0/embeddings/azure_openai.py +55 -0
- agentrun_mem0/embeddings/base.py +31 -0
- agentrun_mem0/embeddings/configs.py +30 -0
- agentrun_mem0/embeddings/gemini.py +39 -0
- agentrun_mem0/embeddings/huggingface.py +44 -0
- agentrun_mem0/embeddings/langchain.py +35 -0
- agentrun_mem0/embeddings/lmstudio.py +29 -0
- agentrun_mem0/embeddings/mock.py +11 -0
- agentrun_mem0/embeddings/ollama.py +53 -0
- agentrun_mem0/embeddings/openai.py +49 -0
- agentrun_mem0/embeddings/together.py +31 -0
- agentrun_mem0/embeddings/vertexai.py +64 -0
- agentrun_mem0/exceptions.py +503 -0
- agentrun_mem0/graphs/__init__.py +0 -0
- agentrun_mem0/graphs/configs.py +105 -0
- agentrun_mem0/graphs/neptune/__init__.py +0 -0
- agentrun_mem0/graphs/neptune/base.py +497 -0
- agentrun_mem0/graphs/neptune/neptunedb.py +511 -0
- agentrun_mem0/graphs/neptune/neptunegraph.py +474 -0
- agentrun_mem0/graphs/tools.py +371 -0
- agentrun_mem0/graphs/utils.py +97 -0
- agentrun_mem0/llms/__init__.py +0 -0
- agentrun_mem0/llms/anthropic.py +87 -0
- agentrun_mem0/llms/aws_bedrock.py +665 -0
- agentrun_mem0/llms/azure_openai.py +141 -0
- agentrun_mem0/llms/azure_openai_structured.py +91 -0
- agentrun_mem0/llms/base.py +131 -0
- agentrun_mem0/llms/configs.py +34 -0
- agentrun_mem0/llms/deepseek.py +107 -0
- agentrun_mem0/llms/gemini.py +201 -0
- agentrun_mem0/llms/groq.py +88 -0
- agentrun_mem0/llms/langchain.py +94 -0
- agentrun_mem0/llms/litellm.py +87 -0
- agentrun_mem0/llms/lmstudio.py +114 -0
- agentrun_mem0/llms/ollama.py +117 -0
- agentrun_mem0/llms/openai.py +147 -0
- agentrun_mem0/llms/openai_structured.py +52 -0
- agentrun_mem0/llms/sarvam.py +89 -0
- agentrun_mem0/llms/together.py +88 -0
- agentrun_mem0/llms/vllm.py +107 -0
- agentrun_mem0/llms/xai.py +52 -0
- agentrun_mem0/memory/__init__.py +0 -0
- agentrun_mem0/memory/base.py +63 -0
- agentrun_mem0/memory/graph_memory.py +698 -0
- agentrun_mem0/memory/kuzu_memory.py +713 -0
- agentrun_mem0/memory/main.py +2229 -0
- agentrun_mem0/memory/memgraph_memory.py +689 -0
- agentrun_mem0/memory/setup.py +56 -0
- agentrun_mem0/memory/storage.py +218 -0
- agentrun_mem0/memory/telemetry.py +90 -0
- agentrun_mem0/memory/utils.py +208 -0
- agentrun_mem0/proxy/__init__.py +0 -0
- agentrun_mem0/proxy/main.py +189 -0
- agentrun_mem0/reranker/__init__.py +9 -0
- agentrun_mem0/reranker/base.py +20 -0
- agentrun_mem0/reranker/cohere_reranker.py +85 -0
- agentrun_mem0/reranker/huggingface_reranker.py +147 -0
- agentrun_mem0/reranker/llm_reranker.py +142 -0
- agentrun_mem0/reranker/sentence_transformer_reranker.py +107 -0
- agentrun_mem0/reranker/zero_entropy_reranker.py +96 -0
- agentrun_mem0/utils/factory.py +283 -0
- agentrun_mem0/utils/gcp_auth.py +167 -0
- agentrun_mem0/vector_stores/__init__.py +0 -0
- agentrun_mem0/vector_stores/alibabacloud_mysql.py +547 -0
- agentrun_mem0/vector_stores/aliyun_tablestore.py +252 -0
- agentrun_mem0/vector_stores/azure_ai_search.py +396 -0
- agentrun_mem0/vector_stores/azure_mysql.py +463 -0
- agentrun_mem0/vector_stores/baidu.py +368 -0
- agentrun_mem0/vector_stores/base.py +58 -0
- agentrun_mem0/vector_stores/chroma.py +332 -0
- agentrun_mem0/vector_stores/configs.py +67 -0
- agentrun_mem0/vector_stores/databricks.py +761 -0
- agentrun_mem0/vector_stores/elasticsearch.py +237 -0
- agentrun_mem0/vector_stores/faiss.py +479 -0
- agentrun_mem0/vector_stores/langchain.py +180 -0
- agentrun_mem0/vector_stores/milvus.py +250 -0
- agentrun_mem0/vector_stores/mongodb.py +310 -0
- agentrun_mem0/vector_stores/neptune_analytics.py +467 -0
- agentrun_mem0/vector_stores/opensearch.py +292 -0
- agentrun_mem0/vector_stores/pgvector.py +404 -0
- agentrun_mem0/vector_stores/pinecone.py +382 -0
- agentrun_mem0/vector_stores/qdrant.py +270 -0
- agentrun_mem0/vector_stores/redis.py +295 -0
- agentrun_mem0/vector_stores/s3_vectors.py +176 -0
- agentrun_mem0/vector_stores/supabase.py +237 -0
- agentrun_mem0/vector_stores/upstash_vector.py +293 -0
- agentrun_mem0/vector_stores/valkey.py +824 -0
- agentrun_mem0/vector_stores/vertex_ai_vector_search.py +635 -0
- agentrun_mem0/vector_stores/weaviate.py +343 -0
- agentrun_mem0ai-0.0.11.data/data/README.md +205 -0
- agentrun_mem0ai-0.0.11.dist-info/METADATA +277 -0
- agentrun_mem0ai-0.0.11.dist-info/RECORD +150 -0
- agentrun_mem0ai-0.0.11.dist-info/WHEEL +4 -0
- agentrun_mem0ai-0.0.11.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,295 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from functools import reduce
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import pytz
|
|
8
|
+
import redis
|
|
9
|
+
from redis.commands.search.query import Query
|
|
10
|
+
from redisvl.index import SearchIndex
|
|
11
|
+
from redisvl.query import VectorQuery
|
|
12
|
+
from redisvl.query.filter import Tag
|
|
13
|
+
|
|
14
|
+
from agentrun_mem0.memory.utils import extract_json
|
|
15
|
+
from agentrun_mem0.vector_stores.base import VectorStoreBase
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
# TODO: Improve as these are not the best fields for the Redis's perspective. Might do away with them.
|
|
20
|
+
DEFAULT_FIELDS = [
|
|
21
|
+
{"name": "memory_id", "type": "tag"},
|
|
22
|
+
{"name": "hash", "type": "tag"},
|
|
23
|
+
{"name": "agent_id", "type": "tag"},
|
|
24
|
+
{"name": "run_id", "type": "tag"},
|
|
25
|
+
{"name": "user_id", "type": "tag"},
|
|
26
|
+
{"name": "memory", "type": "text"},
|
|
27
|
+
{"name": "metadata", "type": "text"},
|
|
28
|
+
# TODO: Although it is numeric but also accepts string
|
|
29
|
+
{"name": "created_at", "type": "numeric"},
|
|
30
|
+
{"name": "updated_at", "type": "numeric"},
|
|
31
|
+
{
|
|
32
|
+
"name": "embedding",
|
|
33
|
+
"type": "vector",
|
|
34
|
+
"attrs": {"distance_metric": "cosine", "algorithm": "flat", "datatype": "float32"},
|
|
35
|
+
},
|
|
36
|
+
]
|
|
37
|
+
|
|
38
|
+
excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"}
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class MemoryResult:
|
|
42
|
+
def __init__(self, id: str, payload: dict, score: float = None):
|
|
43
|
+
self.id = id
|
|
44
|
+
self.payload = payload
|
|
45
|
+
self.score = score
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class RedisDB(VectorStoreBase):
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
redis_url: str,
|
|
52
|
+
collection_name: str,
|
|
53
|
+
embedding_model_dims: int,
|
|
54
|
+
):
|
|
55
|
+
"""
|
|
56
|
+
Initialize the Redis vector store.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
redis_url (str): Redis URL.
|
|
60
|
+
collection_name (str): Collection name.
|
|
61
|
+
embedding_model_dims (int): Embedding model dimensions.
|
|
62
|
+
"""
|
|
63
|
+
self.embedding_model_dims = embedding_model_dims
|
|
64
|
+
index_schema = {
|
|
65
|
+
"name": collection_name,
|
|
66
|
+
"prefix": f"mem0:{collection_name}",
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
fields = DEFAULT_FIELDS.copy()
|
|
70
|
+
fields[-1]["attrs"]["dims"] = embedding_model_dims
|
|
71
|
+
|
|
72
|
+
self.schema = {"index": index_schema, "fields": fields}
|
|
73
|
+
|
|
74
|
+
self.client = redis.Redis.from_url(redis_url)
|
|
75
|
+
self.index = SearchIndex.from_dict(self.schema)
|
|
76
|
+
self.index.set_client(self.client)
|
|
77
|
+
self.index.create(overwrite=True)
|
|
78
|
+
|
|
79
|
+
def create_col(self, name=None, vector_size=None, distance=None):
|
|
80
|
+
"""
|
|
81
|
+
Create a new collection (index) in Redis.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
name (str, optional): Name for the collection. Defaults to None, which uses the current collection_name.
|
|
85
|
+
vector_size (int, optional): Size of the vector embeddings. Defaults to None, which uses the current embedding_model_dims.
|
|
86
|
+
distance (str, optional): Distance metric to use. Defaults to None, which uses 'cosine'.
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
The created index object.
|
|
90
|
+
"""
|
|
91
|
+
# Use provided parameters or fall back to instance attributes
|
|
92
|
+
collection_name = name or self.schema["index"]["name"]
|
|
93
|
+
embedding_dims = vector_size or self.embedding_model_dims
|
|
94
|
+
distance_metric = distance or "cosine"
|
|
95
|
+
|
|
96
|
+
# Create a new schema with the specified parameters
|
|
97
|
+
index_schema = {
|
|
98
|
+
"name": collection_name,
|
|
99
|
+
"prefix": f"mem0:{collection_name}",
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
# Copy the default fields and update the vector field with the specified dimensions
|
|
103
|
+
fields = DEFAULT_FIELDS.copy()
|
|
104
|
+
fields[-1]["attrs"]["dims"] = embedding_dims
|
|
105
|
+
fields[-1]["attrs"]["distance_metric"] = distance_metric
|
|
106
|
+
|
|
107
|
+
# Create the schema
|
|
108
|
+
schema = {"index": index_schema, "fields": fields}
|
|
109
|
+
|
|
110
|
+
# Create the index
|
|
111
|
+
index = SearchIndex.from_dict(schema)
|
|
112
|
+
index.set_client(self.client)
|
|
113
|
+
index.create(overwrite=True)
|
|
114
|
+
|
|
115
|
+
# Update instance attributes if creating a new collection
|
|
116
|
+
if name:
|
|
117
|
+
self.schema = schema
|
|
118
|
+
self.index = index
|
|
119
|
+
|
|
120
|
+
return index
|
|
121
|
+
|
|
122
|
+
def insert(self, vectors: list, payloads: list = None, ids: list = None):
|
|
123
|
+
data = []
|
|
124
|
+
for vector, payload, id in zip(vectors, payloads, ids):
|
|
125
|
+
# Start with required fields
|
|
126
|
+
entry = {
|
|
127
|
+
"memory_id": id,
|
|
128
|
+
"hash": payload["hash"],
|
|
129
|
+
"memory": payload["data"],
|
|
130
|
+
"created_at": int(datetime.fromisoformat(payload["created_at"]).timestamp()),
|
|
131
|
+
"embedding": np.array(vector, dtype=np.float32).tobytes(),
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
# Conditionally add optional fields
|
|
135
|
+
for field in ["agent_id", "run_id", "user_id"]:
|
|
136
|
+
if field in payload:
|
|
137
|
+
entry[field] = payload[field]
|
|
138
|
+
|
|
139
|
+
# Add metadata excluding specific keys
|
|
140
|
+
entry["metadata"] = json.dumps({k: v for k, v in payload.items() if k not in excluded_keys})
|
|
141
|
+
|
|
142
|
+
data.append(entry)
|
|
143
|
+
self.index.load(data, id_field="memory_id")
|
|
144
|
+
|
|
145
|
+
def search(self, query: str, vectors: list, limit: int = 5, filters: dict = None):
|
|
146
|
+
conditions = [Tag(key) == value for key, value in filters.items() if value is not None]
|
|
147
|
+
filter = reduce(lambda x, y: x & y, conditions)
|
|
148
|
+
|
|
149
|
+
v = VectorQuery(
|
|
150
|
+
vector=np.array(vectors, dtype=np.float32).tobytes(),
|
|
151
|
+
vector_field_name="embedding",
|
|
152
|
+
return_fields=["memory_id", "hash", "agent_id", "run_id", "user_id", "memory", "metadata", "created_at"],
|
|
153
|
+
filter_expression=filter,
|
|
154
|
+
num_results=limit,
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
results = self.index.query(v)
|
|
158
|
+
|
|
159
|
+
return [
|
|
160
|
+
MemoryResult(
|
|
161
|
+
id=result["memory_id"],
|
|
162
|
+
score=result["vector_distance"],
|
|
163
|
+
payload={
|
|
164
|
+
"hash": result["hash"],
|
|
165
|
+
"data": result["memory"],
|
|
166
|
+
"created_at": datetime.fromtimestamp(
|
|
167
|
+
int(result["created_at"]), tz=pytz.timezone("US/Pacific")
|
|
168
|
+
).isoformat(timespec="microseconds"),
|
|
169
|
+
**(
|
|
170
|
+
{
|
|
171
|
+
"updated_at": datetime.fromtimestamp(
|
|
172
|
+
int(result["updated_at"]), tz=pytz.timezone("US/Pacific")
|
|
173
|
+
).isoformat(timespec="microseconds")
|
|
174
|
+
}
|
|
175
|
+
if "updated_at" in result
|
|
176
|
+
else {}
|
|
177
|
+
),
|
|
178
|
+
**{field: result[field] for field in ["agent_id", "run_id", "user_id"] if field in result},
|
|
179
|
+
**{k: v for k, v in json.loads(extract_json(result["metadata"])).items()},
|
|
180
|
+
},
|
|
181
|
+
)
|
|
182
|
+
for result in results
|
|
183
|
+
]
|
|
184
|
+
|
|
185
|
+
def delete(self, vector_id):
|
|
186
|
+
self.index.drop_keys(f"{self.schema['index']['prefix']}:{vector_id}")
|
|
187
|
+
|
|
188
|
+
def update(self, vector_id=None, vector=None, payload=None):
|
|
189
|
+
data = {
|
|
190
|
+
"memory_id": vector_id,
|
|
191
|
+
"hash": payload["hash"],
|
|
192
|
+
"memory": payload["data"],
|
|
193
|
+
"created_at": int(datetime.fromisoformat(payload["created_at"]).timestamp()),
|
|
194
|
+
"updated_at": int(datetime.fromisoformat(payload["updated_at"]).timestamp()),
|
|
195
|
+
"embedding": np.array(vector, dtype=np.float32).tobytes(),
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
for field in ["agent_id", "run_id", "user_id"]:
|
|
199
|
+
if field in payload:
|
|
200
|
+
data[field] = payload[field]
|
|
201
|
+
|
|
202
|
+
data["metadata"] = json.dumps({k: v for k, v in payload.items() if k not in excluded_keys})
|
|
203
|
+
self.index.load(data=[data], keys=[f"{self.schema['index']['prefix']}:{vector_id}"], id_field="memory_id")
|
|
204
|
+
|
|
205
|
+
def get(self, vector_id):
|
|
206
|
+
result = self.index.fetch(vector_id)
|
|
207
|
+
payload = {
|
|
208
|
+
"hash": result["hash"],
|
|
209
|
+
"data": result["memory"],
|
|
210
|
+
"created_at": datetime.fromtimestamp(int(result["created_at"]), tz=pytz.timezone("US/Pacific")).isoformat(
|
|
211
|
+
timespec="microseconds"
|
|
212
|
+
),
|
|
213
|
+
**(
|
|
214
|
+
{
|
|
215
|
+
"updated_at": datetime.fromtimestamp(
|
|
216
|
+
int(result["updated_at"]), tz=pytz.timezone("US/Pacific")
|
|
217
|
+
).isoformat(timespec="microseconds")
|
|
218
|
+
}
|
|
219
|
+
if "updated_at" in result
|
|
220
|
+
else {}
|
|
221
|
+
),
|
|
222
|
+
**{field: result[field] for field in ["agent_id", "run_id", "user_id"] if field in result},
|
|
223
|
+
**{k: v for k, v in json.loads(extract_json(result["metadata"])).items()},
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
return MemoryResult(id=result["memory_id"], payload=payload)
|
|
227
|
+
|
|
228
|
+
def list_cols(self):
|
|
229
|
+
return self.index.listall()
|
|
230
|
+
|
|
231
|
+
def delete_col(self):
|
|
232
|
+
self.index.delete()
|
|
233
|
+
|
|
234
|
+
def col_info(self, name):
|
|
235
|
+
return self.index.info()
|
|
236
|
+
|
|
237
|
+
def reset(self):
|
|
238
|
+
"""
|
|
239
|
+
Reset the index by deleting and recreating it.
|
|
240
|
+
"""
|
|
241
|
+
collection_name = self.schema["index"]["name"]
|
|
242
|
+
logger.warning(f"Resetting index {collection_name}...")
|
|
243
|
+
self.delete_col()
|
|
244
|
+
|
|
245
|
+
self.index = SearchIndex.from_dict(self.schema)
|
|
246
|
+
self.index.set_client(self.client)
|
|
247
|
+
self.index.create(overwrite=True)
|
|
248
|
+
|
|
249
|
+
# or use
|
|
250
|
+
# self.create_col(collection_name, self.embedding_model_dims)
|
|
251
|
+
|
|
252
|
+
# Recreate the index with the same parameters
|
|
253
|
+
self.create_col(collection_name, self.embedding_model_dims)
|
|
254
|
+
|
|
255
|
+
def list(self, filters: dict = None, limit: int = None) -> list:
|
|
256
|
+
"""
|
|
257
|
+
List all recent created memories from the vector store.
|
|
258
|
+
"""
|
|
259
|
+
conditions = [Tag(key) == value for key, value in filters.items() if value is not None]
|
|
260
|
+
filter = reduce(lambda x, y: x & y, conditions)
|
|
261
|
+
query = Query(str(filter)).sort_by("created_at", asc=False)
|
|
262
|
+
if limit is not None:
|
|
263
|
+
query = Query(str(filter)).sort_by("created_at", asc=False).paging(0, limit)
|
|
264
|
+
|
|
265
|
+
results = self.index.search(query)
|
|
266
|
+
return [
|
|
267
|
+
[
|
|
268
|
+
MemoryResult(
|
|
269
|
+
id=result["memory_id"],
|
|
270
|
+
payload={
|
|
271
|
+
"hash": result["hash"],
|
|
272
|
+
"data": result["memory"],
|
|
273
|
+
"created_at": datetime.fromtimestamp(
|
|
274
|
+
int(result["created_at"]), tz=pytz.timezone("US/Pacific")
|
|
275
|
+
).isoformat(timespec="microseconds"),
|
|
276
|
+
**(
|
|
277
|
+
{
|
|
278
|
+
"updated_at": datetime.fromtimestamp(
|
|
279
|
+
int(result["updated_at"]), tz=pytz.timezone("US/Pacific")
|
|
280
|
+
).isoformat(timespec="microseconds")
|
|
281
|
+
}
|
|
282
|
+
if result.__dict__.get("updated_at")
|
|
283
|
+
else {}
|
|
284
|
+
),
|
|
285
|
+
**{
|
|
286
|
+
field: result[field]
|
|
287
|
+
for field in ["agent_id", "run_id", "user_id"]
|
|
288
|
+
if field in result.__dict__
|
|
289
|
+
},
|
|
290
|
+
**{k: v for k, v in json.loads(extract_json(result["metadata"])).items()},
|
|
291
|
+
},
|
|
292
|
+
)
|
|
293
|
+
for result in results.docs
|
|
294
|
+
]
|
|
295
|
+
]
|
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
from typing import Dict, List, Optional
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel
|
|
6
|
+
|
|
7
|
+
from agentrun_mem0.vector_stores.base import VectorStoreBase
|
|
8
|
+
|
|
9
|
+
try:
|
|
10
|
+
import boto3
|
|
11
|
+
from botocore.exceptions import ClientError
|
|
12
|
+
except ImportError:
|
|
13
|
+
raise ImportError("The 'boto3' library is required. Please install it using 'pip install boto3'.")
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class OutputData(BaseModel):
|
|
19
|
+
id: Optional[str]
|
|
20
|
+
score: Optional[float]
|
|
21
|
+
payload: Optional[Dict]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class S3Vectors(VectorStoreBase):
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
vector_bucket_name: str,
|
|
28
|
+
collection_name: str,
|
|
29
|
+
embedding_model_dims: int,
|
|
30
|
+
distance_metric: str = "cosine",
|
|
31
|
+
region_name: Optional[str] = None,
|
|
32
|
+
):
|
|
33
|
+
self.client = boto3.client("s3vectors", region_name=region_name)
|
|
34
|
+
self.vector_bucket_name = vector_bucket_name
|
|
35
|
+
self.collection_name = collection_name
|
|
36
|
+
self.embedding_model_dims = embedding_model_dims
|
|
37
|
+
self.distance_metric = distance_metric
|
|
38
|
+
|
|
39
|
+
self._ensure_bucket_exists()
|
|
40
|
+
self.create_col(self.collection_name, self.embedding_model_dims, self.distance_metric)
|
|
41
|
+
|
|
42
|
+
def _ensure_bucket_exists(self):
|
|
43
|
+
try:
|
|
44
|
+
self.client.get_vector_bucket(vectorBucketName=self.vector_bucket_name)
|
|
45
|
+
logger.info(f"Vector bucket '{self.vector_bucket_name}' already exists.")
|
|
46
|
+
except ClientError as e:
|
|
47
|
+
if e.response["Error"]["Code"] == "NotFoundException":
|
|
48
|
+
logger.info(f"Vector bucket '{self.vector_bucket_name}' not found. Creating it.")
|
|
49
|
+
self.client.create_vector_bucket(vectorBucketName=self.vector_bucket_name)
|
|
50
|
+
logger.info(f"Vector bucket '{self.vector_bucket_name}' created.")
|
|
51
|
+
else:
|
|
52
|
+
raise
|
|
53
|
+
|
|
54
|
+
def create_col(self, name, vector_size, distance="cosine"):
|
|
55
|
+
try:
|
|
56
|
+
self.client.get_index(vectorBucketName=self.vector_bucket_name, indexName=name)
|
|
57
|
+
logger.info(f"Index '{name}' already exists in bucket '{self.vector_bucket_name}'.")
|
|
58
|
+
except ClientError as e:
|
|
59
|
+
if e.response["Error"]["Code"] == "NotFoundException":
|
|
60
|
+
logger.info(f"Index '{name}' not found in bucket '{self.vector_bucket_name}'. Creating it.")
|
|
61
|
+
self.client.create_index(
|
|
62
|
+
vectorBucketName=self.vector_bucket_name,
|
|
63
|
+
indexName=name,
|
|
64
|
+
dataType="float32",
|
|
65
|
+
dimension=vector_size,
|
|
66
|
+
distanceMetric=distance,
|
|
67
|
+
)
|
|
68
|
+
logger.info(f"Index '{name}' created.")
|
|
69
|
+
else:
|
|
70
|
+
raise
|
|
71
|
+
|
|
72
|
+
def _parse_output(self, vectors: List[Dict]) -> List[OutputData]:
|
|
73
|
+
results = []
|
|
74
|
+
for v in vectors:
|
|
75
|
+
payload = v.get("metadata", {})
|
|
76
|
+
# Boto3 might return metadata as a JSON string
|
|
77
|
+
if isinstance(payload, str):
|
|
78
|
+
try:
|
|
79
|
+
payload = json.loads(payload)
|
|
80
|
+
except json.JSONDecodeError:
|
|
81
|
+
logger.warning(f"Failed to parse metadata for key {v.get('key')}")
|
|
82
|
+
payload = {}
|
|
83
|
+
results.append(OutputData(id=v.get("key"), score=v.get("distance"), payload=payload))
|
|
84
|
+
return results
|
|
85
|
+
|
|
86
|
+
def insert(self, vectors, payloads=None, ids=None):
|
|
87
|
+
vectors_to_put = []
|
|
88
|
+
for i, vec in enumerate(vectors):
|
|
89
|
+
vectors_to_put.append(
|
|
90
|
+
{
|
|
91
|
+
"key": ids[i],
|
|
92
|
+
"data": {"float32": vec},
|
|
93
|
+
"metadata": payloads[i] if payloads else {},
|
|
94
|
+
}
|
|
95
|
+
)
|
|
96
|
+
self.client.put_vectors(
|
|
97
|
+
vectorBucketName=self.vector_bucket_name,
|
|
98
|
+
indexName=self.collection_name,
|
|
99
|
+
vectors=vectors_to_put,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
def search(self, query, vectors, limit=5, filters=None):
|
|
103
|
+
params = {
|
|
104
|
+
"vectorBucketName": self.vector_bucket_name,
|
|
105
|
+
"indexName": self.collection_name,
|
|
106
|
+
"queryVector": {"float32": vectors},
|
|
107
|
+
"topK": limit,
|
|
108
|
+
"returnMetadata": True,
|
|
109
|
+
"returnDistance": True,
|
|
110
|
+
}
|
|
111
|
+
if filters:
|
|
112
|
+
params["filter"] = filters
|
|
113
|
+
|
|
114
|
+
response = self.client.query_vectors(**params)
|
|
115
|
+
return self._parse_output(response.get("vectors", []))
|
|
116
|
+
|
|
117
|
+
def delete(self, vector_id):
|
|
118
|
+
self.client.delete_vectors(
|
|
119
|
+
vectorBucketName=self.vector_bucket_name,
|
|
120
|
+
indexName=self.collection_name,
|
|
121
|
+
keys=[vector_id],
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
def update(self, vector_id, vector=None, payload=None):
|
|
125
|
+
# S3 Vectors uses put_vectors for updates (overwrite)
|
|
126
|
+
self.insert(vectors=[vector], payloads=[payload], ids=[vector_id])
|
|
127
|
+
|
|
128
|
+
def get(self, vector_id) -> Optional[OutputData]:
|
|
129
|
+
response = self.client.get_vectors(
|
|
130
|
+
vectorBucketName=self.vector_bucket_name,
|
|
131
|
+
indexName=self.collection_name,
|
|
132
|
+
keys=[vector_id],
|
|
133
|
+
returnData=False,
|
|
134
|
+
returnMetadata=True,
|
|
135
|
+
)
|
|
136
|
+
vectors = response.get("vectors", [])
|
|
137
|
+
if not vectors:
|
|
138
|
+
return None
|
|
139
|
+
return self._parse_output(vectors)[0]
|
|
140
|
+
|
|
141
|
+
def list_cols(self):
|
|
142
|
+
response = self.client.list_indexes(vectorBucketName=self.vector_bucket_name)
|
|
143
|
+
return [idx["indexName"] for idx in response.get("indexes", [])]
|
|
144
|
+
|
|
145
|
+
def delete_col(self):
|
|
146
|
+
self.client.delete_index(vectorBucketName=self.vector_bucket_name, indexName=self.collection_name)
|
|
147
|
+
|
|
148
|
+
def col_info(self):
|
|
149
|
+
response = self.client.get_index(vectorBucketName=self.vector_bucket_name, indexName=self.collection_name)
|
|
150
|
+
return response.get("index", {})
|
|
151
|
+
|
|
152
|
+
def list(self, filters=None, limit=None):
|
|
153
|
+
# Note: list_vectors does not support metadata filtering.
|
|
154
|
+
if filters:
|
|
155
|
+
logger.warning("S3 Vectors `list` does not support metadata filtering. Ignoring filters.")
|
|
156
|
+
|
|
157
|
+
params = {
|
|
158
|
+
"vectorBucketName": self.vector_bucket_name,
|
|
159
|
+
"indexName": self.collection_name,
|
|
160
|
+
"returnData": False,
|
|
161
|
+
"returnMetadata": True,
|
|
162
|
+
}
|
|
163
|
+
if limit:
|
|
164
|
+
params["maxResults"] = limit
|
|
165
|
+
|
|
166
|
+
paginator = self.client.get_paginator("list_vectors")
|
|
167
|
+
pages = paginator.paginate(**params)
|
|
168
|
+
all_vectors = []
|
|
169
|
+
for page in pages:
|
|
170
|
+
all_vectors.extend(page.get("vectors", []))
|
|
171
|
+
return [self._parse_output(all_vectors)]
|
|
172
|
+
|
|
173
|
+
def reset(self):
|
|
174
|
+
logger.warning(f"Resetting index {self.collection_name}...")
|
|
175
|
+
self.delete_col()
|
|
176
|
+
self.create_col(self.collection_name, self.embedding_model_dims, self.distance_metric)
|