agentrun-mem0ai 0.0.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentrun_mem0/__init__.py +6 -0
- agentrun_mem0/client/__init__.py +0 -0
- agentrun_mem0/client/main.py +1747 -0
- agentrun_mem0/client/project.py +931 -0
- agentrun_mem0/client/utils.py +115 -0
- agentrun_mem0/configs/__init__.py +0 -0
- agentrun_mem0/configs/base.py +90 -0
- agentrun_mem0/configs/embeddings/__init__.py +0 -0
- agentrun_mem0/configs/embeddings/base.py +110 -0
- agentrun_mem0/configs/enums.py +7 -0
- agentrun_mem0/configs/llms/__init__.py +0 -0
- agentrun_mem0/configs/llms/anthropic.py +56 -0
- agentrun_mem0/configs/llms/aws_bedrock.py +192 -0
- agentrun_mem0/configs/llms/azure.py +57 -0
- agentrun_mem0/configs/llms/base.py +62 -0
- agentrun_mem0/configs/llms/deepseek.py +56 -0
- agentrun_mem0/configs/llms/lmstudio.py +59 -0
- agentrun_mem0/configs/llms/ollama.py +56 -0
- agentrun_mem0/configs/llms/openai.py +79 -0
- agentrun_mem0/configs/llms/vllm.py +56 -0
- agentrun_mem0/configs/prompts.py +459 -0
- agentrun_mem0/configs/rerankers/__init__.py +0 -0
- agentrun_mem0/configs/rerankers/base.py +17 -0
- agentrun_mem0/configs/rerankers/cohere.py +15 -0
- agentrun_mem0/configs/rerankers/config.py +12 -0
- agentrun_mem0/configs/rerankers/huggingface.py +17 -0
- agentrun_mem0/configs/rerankers/llm.py +48 -0
- agentrun_mem0/configs/rerankers/sentence_transformer.py +16 -0
- agentrun_mem0/configs/rerankers/zero_entropy.py +28 -0
- agentrun_mem0/configs/vector_stores/__init__.py +0 -0
- agentrun_mem0/configs/vector_stores/alibabacloud_mysql.py +64 -0
- agentrun_mem0/configs/vector_stores/aliyun_tablestore.py +32 -0
- agentrun_mem0/configs/vector_stores/azure_ai_search.py +57 -0
- agentrun_mem0/configs/vector_stores/azure_mysql.py +84 -0
- agentrun_mem0/configs/vector_stores/baidu.py +27 -0
- agentrun_mem0/configs/vector_stores/chroma.py +58 -0
- agentrun_mem0/configs/vector_stores/databricks.py +61 -0
- agentrun_mem0/configs/vector_stores/elasticsearch.py +65 -0
- agentrun_mem0/configs/vector_stores/faiss.py +37 -0
- agentrun_mem0/configs/vector_stores/langchain.py +30 -0
- agentrun_mem0/configs/vector_stores/milvus.py +42 -0
- agentrun_mem0/configs/vector_stores/mongodb.py +25 -0
- agentrun_mem0/configs/vector_stores/neptune.py +27 -0
- agentrun_mem0/configs/vector_stores/opensearch.py +41 -0
- agentrun_mem0/configs/vector_stores/pgvector.py +52 -0
- agentrun_mem0/configs/vector_stores/pinecone.py +55 -0
- agentrun_mem0/configs/vector_stores/qdrant.py +47 -0
- agentrun_mem0/configs/vector_stores/redis.py +24 -0
- agentrun_mem0/configs/vector_stores/s3_vectors.py +28 -0
- agentrun_mem0/configs/vector_stores/supabase.py +44 -0
- agentrun_mem0/configs/vector_stores/upstash_vector.py +34 -0
- agentrun_mem0/configs/vector_stores/valkey.py +15 -0
- agentrun_mem0/configs/vector_stores/vertex_ai_vector_search.py +28 -0
- agentrun_mem0/configs/vector_stores/weaviate.py +41 -0
- agentrun_mem0/embeddings/__init__.py +0 -0
- agentrun_mem0/embeddings/aws_bedrock.py +100 -0
- agentrun_mem0/embeddings/azure_openai.py +55 -0
- agentrun_mem0/embeddings/base.py +31 -0
- agentrun_mem0/embeddings/configs.py +30 -0
- agentrun_mem0/embeddings/gemini.py +39 -0
- agentrun_mem0/embeddings/huggingface.py +44 -0
- agentrun_mem0/embeddings/langchain.py +35 -0
- agentrun_mem0/embeddings/lmstudio.py +29 -0
- agentrun_mem0/embeddings/mock.py +11 -0
- agentrun_mem0/embeddings/ollama.py +53 -0
- agentrun_mem0/embeddings/openai.py +49 -0
- agentrun_mem0/embeddings/together.py +31 -0
- agentrun_mem0/embeddings/vertexai.py +64 -0
- agentrun_mem0/exceptions.py +503 -0
- agentrun_mem0/graphs/__init__.py +0 -0
- agentrun_mem0/graphs/configs.py +105 -0
- agentrun_mem0/graphs/neptune/__init__.py +0 -0
- agentrun_mem0/graphs/neptune/base.py +497 -0
- agentrun_mem0/graphs/neptune/neptunedb.py +511 -0
- agentrun_mem0/graphs/neptune/neptunegraph.py +474 -0
- agentrun_mem0/graphs/tools.py +371 -0
- agentrun_mem0/graphs/utils.py +97 -0
- agentrun_mem0/llms/__init__.py +0 -0
- agentrun_mem0/llms/anthropic.py +87 -0
- agentrun_mem0/llms/aws_bedrock.py +665 -0
- agentrun_mem0/llms/azure_openai.py +141 -0
- agentrun_mem0/llms/azure_openai_structured.py +91 -0
- agentrun_mem0/llms/base.py +131 -0
- agentrun_mem0/llms/configs.py +34 -0
- agentrun_mem0/llms/deepseek.py +107 -0
- agentrun_mem0/llms/gemini.py +201 -0
- agentrun_mem0/llms/groq.py +88 -0
- agentrun_mem0/llms/langchain.py +94 -0
- agentrun_mem0/llms/litellm.py +87 -0
- agentrun_mem0/llms/lmstudio.py +114 -0
- agentrun_mem0/llms/ollama.py +117 -0
- agentrun_mem0/llms/openai.py +147 -0
- agentrun_mem0/llms/openai_structured.py +52 -0
- agentrun_mem0/llms/sarvam.py +89 -0
- agentrun_mem0/llms/together.py +88 -0
- agentrun_mem0/llms/vllm.py +107 -0
- agentrun_mem0/llms/xai.py +52 -0
- agentrun_mem0/memory/__init__.py +0 -0
- agentrun_mem0/memory/base.py +63 -0
- agentrun_mem0/memory/graph_memory.py +698 -0
- agentrun_mem0/memory/kuzu_memory.py +713 -0
- agentrun_mem0/memory/main.py +2229 -0
- agentrun_mem0/memory/memgraph_memory.py +689 -0
- agentrun_mem0/memory/setup.py +56 -0
- agentrun_mem0/memory/storage.py +218 -0
- agentrun_mem0/memory/telemetry.py +90 -0
- agentrun_mem0/memory/utils.py +208 -0
- agentrun_mem0/proxy/__init__.py +0 -0
- agentrun_mem0/proxy/main.py +189 -0
- agentrun_mem0/reranker/__init__.py +9 -0
- agentrun_mem0/reranker/base.py +20 -0
- agentrun_mem0/reranker/cohere_reranker.py +85 -0
- agentrun_mem0/reranker/huggingface_reranker.py +147 -0
- agentrun_mem0/reranker/llm_reranker.py +142 -0
- agentrun_mem0/reranker/sentence_transformer_reranker.py +107 -0
- agentrun_mem0/reranker/zero_entropy_reranker.py +96 -0
- agentrun_mem0/utils/factory.py +283 -0
- agentrun_mem0/utils/gcp_auth.py +167 -0
- agentrun_mem0/vector_stores/__init__.py +0 -0
- agentrun_mem0/vector_stores/alibabacloud_mysql.py +547 -0
- agentrun_mem0/vector_stores/aliyun_tablestore.py +252 -0
- agentrun_mem0/vector_stores/azure_ai_search.py +396 -0
- agentrun_mem0/vector_stores/azure_mysql.py +463 -0
- agentrun_mem0/vector_stores/baidu.py +368 -0
- agentrun_mem0/vector_stores/base.py +58 -0
- agentrun_mem0/vector_stores/chroma.py +332 -0
- agentrun_mem0/vector_stores/configs.py +67 -0
- agentrun_mem0/vector_stores/databricks.py +761 -0
- agentrun_mem0/vector_stores/elasticsearch.py +237 -0
- agentrun_mem0/vector_stores/faiss.py +479 -0
- agentrun_mem0/vector_stores/langchain.py +180 -0
- agentrun_mem0/vector_stores/milvus.py +250 -0
- agentrun_mem0/vector_stores/mongodb.py +310 -0
- agentrun_mem0/vector_stores/neptune_analytics.py +467 -0
- agentrun_mem0/vector_stores/opensearch.py +292 -0
- agentrun_mem0/vector_stores/pgvector.py +404 -0
- agentrun_mem0/vector_stores/pinecone.py +382 -0
- agentrun_mem0/vector_stores/qdrant.py +270 -0
- agentrun_mem0/vector_stores/redis.py +295 -0
- agentrun_mem0/vector_stores/s3_vectors.py +176 -0
- agentrun_mem0/vector_stores/supabase.py +237 -0
- agentrun_mem0/vector_stores/upstash_vector.py +293 -0
- agentrun_mem0/vector_stores/valkey.py +824 -0
- agentrun_mem0/vector_stores/vertex_ai_vector_search.py +635 -0
- agentrun_mem0/vector_stores/weaviate.py +343 -0
- agentrun_mem0ai-0.0.11.data/data/README.md +205 -0
- agentrun_mem0ai-0.0.11.dist-info/METADATA +277 -0
- agentrun_mem0ai-0.0.11.dist-info/RECORD +150 -0
- agentrun_mem0ai-0.0.11.dist-info/WHEEL +4 -0
- agentrun_mem0ai-0.0.11.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,292 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import time
|
|
3
|
+
from typing import Any, Dict, List, Optional
|
|
4
|
+
|
|
5
|
+
try:
|
|
6
|
+
from opensearchpy import OpenSearch, RequestsHttpConnection
|
|
7
|
+
except ImportError:
|
|
8
|
+
raise ImportError("OpenSearch requires extra dependencies. Install with `pip install opensearch-py`") from None
|
|
9
|
+
|
|
10
|
+
from pydantic import BaseModel
|
|
11
|
+
|
|
12
|
+
from agentrun_mem0.configs.vector_stores.opensearch import OpenSearchConfig
|
|
13
|
+
from agentrun_mem0.vector_stores.base import VectorStoreBase
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class OutputData(BaseModel):
|
|
19
|
+
id: str
|
|
20
|
+
score: float
|
|
21
|
+
payload: Dict
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class OpenSearchDB(VectorStoreBase):
|
|
25
|
+
def __init__(self, **kwargs):
|
|
26
|
+
config = OpenSearchConfig(**kwargs)
|
|
27
|
+
|
|
28
|
+
# Initialize OpenSearch client
|
|
29
|
+
self.client = OpenSearch(
|
|
30
|
+
hosts=[{"host": config.host, "port": config.port or 9200}],
|
|
31
|
+
http_auth=config.http_auth
|
|
32
|
+
if config.http_auth
|
|
33
|
+
else ((config.user, config.password) if (config.user and config.password) else None),
|
|
34
|
+
use_ssl=config.use_ssl,
|
|
35
|
+
verify_certs=config.verify_certs,
|
|
36
|
+
connection_class=RequestsHttpConnection,
|
|
37
|
+
pool_maxsize=20,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
self.collection_name = config.collection_name
|
|
41
|
+
self.embedding_model_dims = config.embedding_model_dims
|
|
42
|
+
self.create_col(self.collection_name, self.embedding_model_dims)
|
|
43
|
+
|
|
44
|
+
def create_index(self) -> None:
|
|
45
|
+
"""Create OpenSearch index with proper mappings if it doesn't exist."""
|
|
46
|
+
index_settings = {
|
|
47
|
+
"settings": {
|
|
48
|
+
"index": {"number_of_replicas": 1, "number_of_shards": 5, "refresh_interval": "10s", "knn": True}
|
|
49
|
+
},
|
|
50
|
+
"mappings": {
|
|
51
|
+
"properties": {
|
|
52
|
+
"text": {"type": "text"},
|
|
53
|
+
"vector_field": {
|
|
54
|
+
"type": "knn_vector",
|
|
55
|
+
"dimension": self.embedding_model_dims,
|
|
56
|
+
"method": {"engine": "nmslib", "name": "hnsw", "space_type": "cosinesimil"},
|
|
57
|
+
},
|
|
58
|
+
"metadata": {"type": "object", "properties": {"user_id": {"type": "keyword"}}},
|
|
59
|
+
}
|
|
60
|
+
},
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
if not self.client.indices.exists(index=self.collection_name):
|
|
64
|
+
self.client.indices.create(index=self.collection_name, body=index_settings)
|
|
65
|
+
logger.info(f"Created index {self.collection_name}")
|
|
66
|
+
else:
|
|
67
|
+
logger.info(f"Index {self.collection_name} already exists")
|
|
68
|
+
|
|
69
|
+
def create_col(self, name: str, vector_size: int) -> None:
|
|
70
|
+
"""Create a new collection (index in OpenSearch)."""
|
|
71
|
+
index_settings = {
|
|
72
|
+
"settings": {"index.knn": True},
|
|
73
|
+
"mappings": {
|
|
74
|
+
"properties": {
|
|
75
|
+
"vector_field": {
|
|
76
|
+
"type": "knn_vector",
|
|
77
|
+
"dimension": vector_size,
|
|
78
|
+
"method": {"engine": "nmslib", "name": "hnsw", "space_type": "cosinesimil"},
|
|
79
|
+
},
|
|
80
|
+
"payload": {"type": "object"},
|
|
81
|
+
"id": {"type": "keyword"},
|
|
82
|
+
}
|
|
83
|
+
},
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
if not self.client.indices.exists(index=name):
|
|
87
|
+
logger.warning(f"Creating index {name}, it might take 1-2 minutes...")
|
|
88
|
+
self.client.indices.create(index=name, body=index_settings)
|
|
89
|
+
|
|
90
|
+
# Wait for index to be ready
|
|
91
|
+
max_retries = 180 # 3 minutes timeout
|
|
92
|
+
retry_count = 0
|
|
93
|
+
while retry_count < max_retries:
|
|
94
|
+
try:
|
|
95
|
+
# Check if index is ready by attempting a simple search
|
|
96
|
+
self.client.search(index=name, body={"query": {"match_all": {}}})
|
|
97
|
+
time.sleep(1)
|
|
98
|
+
logger.info(f"Index {name} is ready")
|
|
99
|
+
return
|
|
100
|
+
except Exception:
|
|
101
|
+
retry_count += 1
|
|
102
|
+
if retry_count == max_retries:
|
|
103
|
+
raise TimeoutError(f"Index {name} creation timed out after {max_retries} seconds")
|
|
104
|
+
time.sleep(0.5)
|
|
105
|
+
|
|
106
|
+
def insert(
|
|
107
|
+
self, vectors: List[List[float]], payloads: Optional[List[Dict]] = None, ids: Optional[List[str]] = None
|
|
108
|
+
) -> List[OutputData]:
|
|
109
|
+
"""Insert vectors into the index."""
|
|
110
|
+
if not ids:
|
|
111
|
+
ids = [str(i) for i in range(len(vectors))]
|
|
112
|
+
|
|
113
|
+
if payloads is None:
|
|
114
|
+
payloads = [{} for _ in range(len(vectors))]
|
|
115
|
+
|
|
116
|
+
results = []
|
|
117
|
+
for i, (vec, id_) in enumerate(zip(vectors, ids)):
|
|
118
|
+
body = {
|
|
119
|
+
"vector_field": vec,
|
|
120
|
+
"payload": payloads[i],
|
|
121
|
+
"id": id_,
|
|
122
|
+
}
|
|
123
|
+
try:
|
|
124
|
+
self.client.index(index=self.collection_name, body=body)
|
|
125
|
+
# Force refresh to make documents immediately searchable for tests
|
|
126
|
+
self.client.indices.refresh(index=self.collection_name)
|
|
127
|
+
|
|
128
|
+
results.append(OutputData(
|
|
129
|
+
id=id_,
|
|
130
|
+
score=1.0, # No score for inserts
|
|
131
|
+
payload=payloads[i]
|
|
132
|
+
))
|
|
133
|
+
except Exception as e:
|
|
134
|
+
logger.error(f"Error inserting vector {id_}: {e}")
|
|
135
|
+
raise
|
|
136
|
+
|
|
137
|
+
return results
|
|
138
|
+
|
|
139
|
+
def search(
|
|
140
|
+
self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None
|
|
141
|
+
) -> List[OutputData]:
|
|
142
|
+
"""Search for similar vectors using OpenSearch k-NN search with optional filters."""
|
|
143
|
+
|
|
144
|
+
# Base KNN query
|
|
145
|
+
knn_query = {
|
|
146
|
+
"knn": {
|
|
147
|
+
"vector_field": {
|
|
148
|
+
"vector": vectors,
|
|
149
|
+
"k": limit * 2,
|
|
150
|
+
}
|
|
151
|
+
}
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
# Start building the full query
|
|
155
|
+
query_body = {"size": limit * 2, "query": None}
|
|
156
|
+
|
|
157
|
+
# Prepare filter conditions if applicable
|
|
158
|
+
filter_clauses = []
|
|
159
|
+
if filters:
|
|
160
|
+
for key in ["user_id", "run_id", "agent_id"]:
|
|
161
|
+
value = filters.get(key)
|
|
162
|
+
if value:
|
|
163
|
+
filter_clauses.append({"term": {f"payload.{key}.keyword": value}})
|
|
164
|
+
|
|
165
|
+
# Combine knn with filters if needed
|
|
166
|
+
if filter_clauses:
|
|
167
|
+
query_body["query"] = {"bool": {"must": knn_query, "filter": filter_clauses}}
|
|
168
|
+
else:
|
|
169
|
+
query_body["query"] = knn_query
|
|
170
|
+
|
|
171
|
+
try:
|
|
172
|
+
# Execute search
|
|
173
|
+
response = self.client.search(index=self.collection_name, body=query_body)
|
|
174
|
+
|
|
175
|
+
hits = response["hits"]["hits"]
|
|
176
|
+
results = [
|
|
177
|
+
OutputData(id=hit["_source"].get("id"), score=hit["_score"], payload=hit["_source"].get("payload", {}))
|
|
178
|
+
for hit in hits[:limit] # Ensure we don't exceed limit
|
|
179
|
+
]
|
|
180
|
+
return results
|
|
181
|
+
except Exception as e:
|
|
182
|
+
logger.error(f"Error during search: {e}")
|
|
183
|
+
return []
|
|
184
|
+
|
|
185
|
+
def delete(self, vector_id: str) -> None:
|
|
186
|
+
"""Delete a vector by custom ID."""
|
|
187
|
+
# First, find the document by custom ID
|
|
188
|
+
search_query = {"query": {"term": {"id": vector_id}}}
|
|
189
|
+
|
|
190
|
+
response = self.client.search(index=self.collection_name, body=search_query)
|
|
191
|
+
hits = response.get("hits", {}).get("hits", [])
|
|
192
|
+
|
|
193
|
+
if not hits:
|
|
194
|
+
return
|
|
195
|
+
|
|
196
|
+
opensearch_id = hits[0]["_id"]
|
|
197
|
+
|
|
198
|
+
# Delete using the actual document ID
|
|
199
|
+
self.client.delete(index=self.collection_name, id=opensearch_id)
|
|
200
|
+
|
|
201
|
+
def update(self, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[Dict] = None) -> None:
|
|
202
|
+
"""Update a vector and its payload using the custom 'id' field."""
|
|
203
|
+
|
|
204
|
+
# First, find the document by custom ID
|
|
205
|
+
search_query = {"query": {"term": {"id": vector_id}}}
|
|
206
|
+
|
|
207
|
+
response = self.client.search(index=self.collection_name, body=search_query)
|
|
208
|
+
hits = response.get("hits", {}).get("hits", [])
|
|
209
|
+
|
|
210
|
+
if not hits:
|
|
211
|
+
return
|
|
212
|
+
|
|
213
|
+
opensearch_id = hits[0]["_id"] # The actual document ID in OpenSearch
|
|
214
|
+
|
|
215
|
+
# Prepare updated fields
|
|
216
|
+
doc = {}
|
|
217
|
+
if vector is not None:
|
|
218
|
+
doc["vector_field"] = vector
|
|
219
|
+
if payload is not None:
|
|
220
|
+
doc["payload"] = payload
|
|
221
|
+
|
|
222
|
+
if doc:
|
|
223
|
+
try:
|
|
224
|
+
response = self.client.update(index=self.collection_name, id=opensearch_id, body={"doc": doc})
|
|
225
|
+
except Exception:
|
|
226
|
+
pass
|
|
227
|
+
|
|
228
|
+
def get(self, vector_id: str) -> Optional[OutputData]:
|
|
229
|
+
"""Retrieve a vector by ID."""
|
|
230
|
+
try:
|
|
231
|
+
search_query = {"query": {"term": {"id": vector_id}}}
|
|
232
|
+
response = self.client.search(index=self.collection_name, body=search_query)
|
|
233
|
+
|
|
234
|
+
hits = response["hits"]["hits"]
|
|
235
|
+
|
|
236
|
+
if not hits:
|
|
237
|
+
return None
|
|
238
|
+
|
|
239
|
+
return OutputData(id=hits[0]["_source"].get("id"), score=1.0, payload=hits[0]["_source"].get("payload", {}))
|
|
240
|
+
except Exception as e:
|
|
241
|
+
logger.error(f"Error retrieving vector {vector_id}: {str(e)}")
|
|
242
|
+
return None
|
|
243
|
+
|
|
244
|
+
def list_cols(self) -> List[str]:
|
|
245
|
+
"""List all collections (indices)."""
|
|
246
|
+
return list(self.client.indices.get_alias().keys())
|
|
247
|
+
|
|
248
|
+
def delete_col(self) -> None:
|
|
249
|
+
"""Delete a collection (index)."""
|
|
250
|
+
self.client.indices.delete(index=self.collection_name)
|
|
251
|
+
|
|
252
|
+
def col_info(self, name: str) -> Any:
|
|
253
|
+
"""Get information about a collection (index)."""
|
|
254
|
+
return self.client.indices.get(index=name)
|
|
255
|
+
|
|
256
|
+
def list(self, filters: Optional[Dict] = None, limit: Optional[int] = None) -> List[OutputData]:
|
|
257
|
+
try:
|
|
258
|
+
"""List all memories with optional filters."""
|
|
259
|
+
query: Dict = {"query": {"match_all": {}}}
|
|
260
|
+
|
|
261
|
+
filter_clauses = []
|
|
262
|
+
if filters:
|
|
263
|
+
for key in ["user_id", "run_id", "agent_id"]:
|
|
264
|
+
value = filters.get(key)
|
|
265
|
+
if value:
|
|
266
|
+
filter_clauses.append({"term": {f"payload.{key}.keyword": value}})
|
|
267
|
+
|
|
268
|
+
if filter_clauses:
|
|
269
|
+
query["query"] = {"bool": {"filter": filter_clauses}}
|
|
270
|
+
|
|
271
|
+
if limit:
|
|
272
|
+
query["size"] = limit
|
|
273
|
+
|
|
274
|
+
response = self.client.search(index=self.collection_name, body=query)
|
|
275
|
+
hits = response["hits"]["hits"]
|
|
276
|
+
|
|
277
|
+
# Return a flat list, not a nested array
|
|
278
|
+
results = [
|
|
279
|
+
OutputData(id=hit["_source"].get("id"), score=1.0, payload=hit["_source"].get("payload", {}))
|
|
280
|
+
for hit in hits
|
|
281
|
+
]
|
|
282
|
+
return [results] # VectorStore expects tuple/list format
|
|
283
|
+
except Exception as e:
|
|
284
|
+
logger.error(f"Error listing vectors: {e}")
|
|
285
|
+
return []
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def reset(self):
|
|
289
|
+
"""Reset the index by deleting and recreating it."""
|
|
290
|
+
logger.warning(f"Resetting index {self.collection_name}...")
|
|
291
|
+
self.delete_col()
|
|
292
|
+
self.create_col(self.collection_name, self.embedding_model_dims)
|
|
@@ -0,0 +1,404 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
from contextlib import contextmanager
|
|
4
|
+
from typing import Any, List, Optional
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel
|
|
7
|
+
|
|
8
|
+
# Try to import psycopg (psycopg3) first, then fall back to psycopg2
|
|
9
|
+
try:
|
|
10
|
+
from psycopg.types.json import Json
|
|
11
|
+
from psycopg_pool import ConnectionPool
|
|
12
|
+
PSYCOPG_VERSION = 3
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
logger.info("Using psycopg (psycopg3) with ConnectionPool for PostgreSQL connections")
|
|
15
|
+
except ImportError:
|
|
16
|
+
try:
|
|
17
|
+
from psycopg2.extras import Json, execute_values
|
|
18
|
+
from psycopg2.pool import ThreadedConnectionPool as ConnectionPool
|
|
19
|
+
PSYCOPG_VERSION = 2
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
logger.info("Using psycopg2 with ThreadedConnectionPool for PostgreSQL connections")
|
|
22
|
+
except ImportError:
|
|
23
|
+
raise ImportError(
|
|
24
|
+
"Neither 'psycopg' nor 'psycopg2' library is available. "
|
|
25
|
+
"Please install one of them using 'pip install psycopg[pool]' or 'pip install psycopg2'"
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
from agentrun_mem0.vector_stores.base import VectorStoreBase
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class OutputData(BaseModel):
|
|
34
|
+
id: Optional[str]
|
|
35
|
+
score: Optional[float]
|
|
36
|
+
payload: Optional[dict]
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class PGVector(VectorStoreBase):
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
dbname,
|
|
43
|
+
collection_name,
|
|
44
|
+
embedding_model_dims,
|
|
45
|
+
user,
|
|
46
|
+
password,
|
|
47
|
+
host,
|
|
48
|
+
port,
|
|
49
|
+
diskann,
|
|
50
|
+
hnsw,
|
|
51
|
+
minconn=1,
|
|
52
|
+
maxconn=5,
|
|
53
|
+
sslmode=None,
|
|
54
|
+
connection_string=None,
|
|
55
|
+
connection_pool=None,
|
|
56
|
+
):
|
|
57
|
+
"""
|
|
58
|
+
Initialize the PGVector database.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
dbname (str): Database name
|
|
62
|
+
collection_name (str): Collection name
|
|
63
|
+
embedding_model_dims (int): Dimension of the embedding vector
|
|
64
|
+
user (str): Database user
|
|
65
|
+
password (str): Database password
|
|
66
|
+
host (str, optional): Database host
|
|
67
|
+
port (int, optional): Database port
|
|
68
|
+
diskann (bool, optional): Use DiskANN for faster search
|
|
69
|
+
hnsw (bool, optional): Use HNSW for faster search
|
|
70
|
+
minconn (int): Minimum number of connections to keep in the connection pool
|
|
71
|
+
maxconn (int): Maximum number of connections allowed in the connection pool
|
|
72
|
+
sslmode (str, optional): SSL mode for PostgreSQL connection (e.g., 'require', 'prefer', 'disable')
|
|
73
|
+
connection_string (str, optional): PostgreSQL connection string (overrides individual connection parameters)
|
|
74
|
+
connection_pool (Any, optional): psycopg2 connection pool object (overrides connection string and individual parameters)
|
|
75
|
+
"""
|
|
76
|
+
self.collection_name = collection_name
|
|
77
|
+
self.use_diskann = diskann
|
|
78
|
+
self.use_hnsw = hnsw
|
|
79
|
+
self.embedding_model_dims = embedding_model_dims
|
|
80
|
+
self.connection_pool = None
|
|
81
|
+
|
|
82
|
+
# Connection setup with priority: connection_pool > connection_string > individual parameters
|
|
83
|
+
if connection_pool is not None:
|
|
84
|
+
# Use provided connection pool
|
|
85
|
+
self.connection_pool = connection_pool
|
|
86
|
+
elif connection_string:
|
|
87
|
+
if sslmode:
|
|
88
|
+
# Append sslmode to connection string if provided
|
|
89
|
+
if 'sslmode=' in connection_string:
|
|
90
|
+
# Replace existing sslmode
|
|
91
|
+
import re
|
|
92
|
+
connection_string = re.sub(r'sslmode=[^ ]*', f'sslmode={sslmode}', connection_string)
|
|
93
|
+
else:
|
|
94
|
+
# Add sslmode to connection string
|
|
95
|
+
connection_string = f"{connection_string} sslmode={sslmode}"
|
|
96
|
+
else:
|
|
97
|
+
connection_string = f"postgresql://{user}:{password}@{host}:{port}/{dbname}"
|
|
98
|
+
if sslmode:
|
|
99
|
+
connection_string = f"{connection_string} sslmode={sslmode}"
|
|
100
|
+
|
|
101
|
+
if self.connection_pool is None:
|
|
102
|
+
if PSYCOPG_VERSION == 3:
|
|
103
|
+
# psycopg3 ConnectionPool
|
|
104
|
+
self.connection_pool = ConnectionPool(conninfo=connection_string, min_size=minconn, max_size=maxconn, open=True)
|
|
105
|
+
else:
|
|
106
|
+
# psycopg2 ThreadedConnectionPool
|
|
107
|
+
self.connection_pool = ConnectionPool(minconn=minconn, maxconn=maxconn, dsn=connection_string)
|
|
108
|
+
|
|
109
|
+
collections = self.list_cols()
|
|
110
|
+
if collection_name not in collections:
|
|
111
|
+
self.create_col()
|
|
112
|
+
|
|
113
|
+
@contextmanager
|
|
114
|
+
def _get_cursor(self, commit: bool = False):
|
|
115
|
+
"""
|
|
116
|
+
Unified context manager to get a cursor from the appropriate pool.
|
|
117
|
+
Auto-commits or rolls back based on exception, and returns the connection to the pool.
|
|
118
|
+
"""
|
|
119
|
+
if PSYCOPG_VERSION == 3:
|
|
120
|
+
# psycopg3 auto-manages commit/rollback and pool return
|
|
121
|
+
with self.connection_pool.connection() as conn:
|
|
122
|
+
with conn.cursor() as cur:
|
|
123
|
+
try:
|
|
124
|
+
yield cur
|
|
125
|
+
if commit:
|
|
126
|
+
conn.commit()
|
|
127
|
+
except Exception:
|
|
128
|
+
conn.rollback()
|
|
129
|
+
logger.error("Error in cursor context (psycopg3)", exc_info=True)
|
|
130
|
+
raise
|
|
131
|
+
else:
|
|
132
|
+
# psycopg2 manual getconn/putconn
|
|
133
|
+
conn = self.connection_pool.getconn()
|
|
134
|
+
cur = conn.cursor()
|
|
135
|
+
try:
|
|
136
|
+
yield cur
|
|
137
|
+
if commit:
|
|
138
|
+
conn.commit()
|
|
139
|
+
except Exception as exc:
|
|
140
|
+
conn.rollback()
|
|
141
|
+
logger.error(f"Error occurred: {exc}")
|
|
142
|
+
raise exc
|
|
143
|
+
finally:
|
|
144
|
+
cur.close()
|
|
145
|
+
self.connection_pool.putconn(conn)
|
|
146
|
+
|
|
147
|
+
def create_col(self) -> None:
|
|
148
|
+
"""
|
|
149
|
+
Create a new collection (table in PostgreSQL).
|
|
150
|
+
Will also initialize vector search index if specified.
|
|
151
|
+
"""
|
|
152
|
+
with self._get_cursor(commit=True) as cur:
|
|
153
|
+
cur.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
|
154
|
+
cur.execute(
|
|
155
|
+
f"""
|
|
156
|
+
CREATE TABLE IF NOT EXISTS {self.collection_name} (
|
|
157
|
+
id UUID PRIMARY KEY,
|
|
158
|
+
vector vector({self.embedding_model_dims}),
|
|
159
|
+
payload JSONB
|
|
160
|
+
);
|
|
161
|
+
"""
|
|
162
|
+
)
|
|
163
|
+
if self.use_diskann and self.embedding_model_dims < 2000:
|
|
164
|
+
cur.execute("SELECT * FROM pg_extension WHERE extname = 'vectorscale'")
|
|
165
|
+
if cur.fetchone():
|
|
166
|
+
# Create DiskANN index if extension is installed for faster search
|
|
167
|
+
cur.execute(
|
|
168
|
+
f"""
|
|
169
|
+
CREATE INDEX IF NOT EXISTS {self.collection_name}_diskann_idx
|
|
170
|
+
ON {self.collection_name}
|
|
171
|
+
USING diskann (vector);
|
|
172
|
+
"""
|
|
173
|
+
)
|
|
174
|
+
elif self.use_hnsw:
|
|
175
|
+
cur.execute(
|
|
176
|
+
f"""
|
|
177
|
+
CREATE INDEX IF NOT EXISTS {self.collection_name}_hnsw_idx
|
|
178
|
+
ON {self.collection_name}
|
|
179
|
+
USING hnsw (vector vector_cosine_ops)
|
|
180
|
+
"""
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
def insert(self, vectors: list[list[float]], payloads=None, ids=None) -> None:
|
|
184
|
+
logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}")
|
|
185
|
+
json_payloads = [json.dumps(payload) for payload in payloads]
|
|
186
|
+
|
|
187
|
+
data = [(id, vector, payload) for id, vector, payload in zip(ids, vectors, json_payloads)]
|
|
188
|
+
if PSYCOPG_VERSION == 3:
|
|
189
|
+
with self._get_cursor(commit=True) as cur:
|
|
190
|
+
cur.executemany(
|
|
191
|
+
f"INSERT INTO {self.collection_name} (id, vector, payload) VALUES (%s, %s, %s)",
|
|
192
|
+
data,
|
|
193
|
+
)
|
|
194
|
+
else:
|
|
195
|
+
with self._get_cursor(commit=True) as cur:
|
|
196
|
+
execute_values(
|
|
197
|
+
cur,
|
|
198
|
+
f"INSERT INTO {self.collection_name} (id, vector, payload) VALUES %s",
|
|
199
|
+
data,
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
def search(
|
|
203
|
+
self,
|
|
204
|
+
query: str,
|
|
205
|
+
vectors: list[float],
|
|
206
|
+
limit: Optional[int] = 5,
|
|
207
|
+
filters: Optional[dict] = None,
|
|
208
|
+
) -> List[OutputData]:
|
|
209
|
+
"""
|
|
210
|
+
Search for similar vectors.
|
|
211
|
+
|
|
212
|
+
Args:
|
|
213
|
+
query (str): Query.
|
|
214
|
+
vectors (List[float]): Query vector.
|
|
215
|
+
limit (int, optional): Number of results to return. Defaults to 5.
|
|
216
|
+
filters (Dict, optional): Filters to apply to the search. Defaults to None.
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
219
|
+
list: Search results.
|
|
220
|
+
"""
|
|
221
|
+
filter_conditions = []
|
|
222
|
+
filter_params = []
|
|
223
|
+
|
|
224
|
+
if filters:
|
|
225
|
+
for k, v in filters.items():
|
|
226
|
+
filter_conditions.append("payload->>%s = %s")
|
|
227
|
+
filter_params.extend([k, str(v)])
|
|
228
|
+
|
|
229
|
+
filter_clause = "WHERE " + " AND ".join(filter_conditions) if filter_conditions else ""
|
|
230
|
+
|
|
231
|
+
with self._get_cursor() as cur:
|
|
232
|
+
cur.execute(
|
|
233
|
+
f"""
|
|
234
|
+
SELECT id, vector <=> %s::vector AS distance, payload
|
|
235
|
+
FROM {self.collection_name}
|
|
236
|
+
{filter_clause}
|
|
237
|
+
ORDER BY distance
|
|
238
|
+
LIMIT %s
|
|
239
|
+
""",
|
|
240
|
+
(vectors, *filter_params, limit),
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
results = cur.fetchall()
|
|
244
|
+
return [OutputData(id=str(r[0]), score=float(r[1]), payload=r[2]) for r in results]
|
|
245
|
+
|
|
246
|
+
def delete(self, vector_id: str) -> None:
|
|
247
|
+
"""
|
|
248
|
+
Delete a vector by ID.
|
|
249
|
+
|
|
250
|
+
Args:
|
|
251
|
+
vector_id (str): ID of the vector to delete.
|
|
252
|
+
"""
|
|
253
|
+
with self._get_cursor(commit=True) as cur:
|
|
254
|
+
cur.execute(f"DELETE FROM {self.collection_name} WHERE id = %s", (vector_id,))
|
|
255
|
+
|
|
256
|
+
def update(
|
|
257
|
+
self,
|
|
258
|
+
vector_id: str,
|
|
259
|
+
vector: Optional[list[float]] = None,
|
|
260
|
+
payload: Optional[dict] = None,
|
|
261
|
+
) -> None:
|
|
262
|
+
"""
|
|
263
|
+
Update a vector and its payload.
|
|
264
|
+
|
|
265
|
+
Args:
|
|
266
|
+
vector_id (str): ID of the vector to update.
|
|
267
|
+
vector (List[float], optional): Updated vector.
|
|
268
|
+
payload (Dict, optional): Updated payload.
|
|
269
|
+
"""
|
|
270
|
+
with self._get_cursor(commit=True) as cur:
|
|
271
|
+
if vector:
|
|
272
|
+
cur.execute(
|
|
273
|
+
f"UPDATE {self.collection_name} SET vector = %s WHERE id = %s",
|
|
274
|
+
(vector, vector_id),
|
|
275
|
+
)
|
|
276
|
+
if payload:
|
|
277
|
+
# Handle JSON serialization based on psycopg version
|
|
278
|
+
if PSYCOPG_VERSION == 3:
|
|
279
|
+
# psycopg3 uses psycopg.types.json.Json
|
|
280
|
+
cur.execute(
|
|
281
|
+
f"UPDATE {self.collection_name} SET payload = %s WHERE id = %s",
|
|
282
|
+
(Json(payload), vector_id),
|
|
283
|
+
)
|
|
284
|
+
else:
|
|
285
|
+
# psycopg2 uses psycopg2.extras.Json
|
|
286
|
+
cur.execute(
|
|
287
|
+
f"UPDATE {self.collection_name} SET payload = %s WHERE id = %s",
|
|
288
|
+
(Json(payload), vector_id),
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def get(self, vector_id: str) -> OutputData:
|
|
293
|
+
"""
|
|
294
|
+
Retrieve a vector by ID.
|
|
295
|
+
|
|
296
|
+
Args:
|
|
297
|
+
vector_id (str): ID of the vector to retrieve.
|
|
298
|
+
|
|
299
|
+
Returns:
|
|
300
|
+
OutputData: Retrieved vector.
|
|
301
|
+
"""
|
|
302
|
+
with self._get_cursor() as cur:
|
|
303
|
+
cur.execute(
|
|
304
|
+
f"SELECT id, vector, payload FROM {self.collection_name} WHERE id = %s",
|
|
305
|
+
(vector_id,),
|
|
306
|
+
)
|
|
307
|
+
result = cur.fetchone()
|
|
308
|
+
if not result:
|
|
309
|
+
return None
|
|
310
|
+
return OutputData(id=str(result[0]), score=None, payload=result[2])
|
|
311
|
+
|
|
312
|
+
def list_cols(self) -> List[str]:
|
|
313
|
+
"""
|
|
314
|
+
List all collections.
|
|
315
|
+
|
|
316
|
+
Returns:
|
|
317
|
+
List[str]: List of collection names.
|
|
318
|
+
"""
|
|
319
|
+
with self._get_cursor() as cur:
|
|
320
|
+
cur.execute("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'")
|
|
321
|
+
return [row[0] for row in cur.fetchall()]
|
|
322
|
+
|
|
323
|
+
def delete_col(self) -> None:
|
|
324
|
+
"""Delete a collection."""
|
|
325
|
+
with self._get_cursor(commit=True) as cur:
|
|
326
|
+
cur.execute(f"DROP TABLE IF EXISTS {self.collection_name}")
|
|
327
|
+
|
|
328
|
+
def col_info(self) -> dict[str, Any]:
|
|
329
|
+
"""
|
|
330
|
+
Get information about a collection.
|
|
331
|
+
|
|
332
|
+
Returns:
|
|
333
|
+
Dict[str, Any]: Collection information.
|
|
334
|
+
"""
|
|
335
|
+
with self._get_cursor() as cur:
|
|
336
|
+
cur.execute(
|
|
337
|
+
f"""
|
|
338
|
+
SELECT
|
|
339
|
+
table_name,
|
|
340
|
+
(SELECT COUNT(*) FROM {self.collection_name}) as row_count,
|
|
341
|
+
(SELECT pg_size_pretty(pg_total_relation_size('{self.collection_name}'))) as total_size
|
|
342
|
+
FROM information_schema.tables
|
|
343
|
+
WHERE table_schema = 'public' AND table_name = %s
|
|
344
|
+
""",
|
|
345
|
+
(self.collection_name,),
|
|
346
|
+
)
|
|
347
|
+
result = cur.fetchone()
|
|
348
|
+
return {"name": result[0], "count": result[1], "size": result[2]}
|
|
349
|
+
|
|
350
|
+
def list(
|
|
351
|
+
self,
|
|
352
|
+
filters: Optional[dict] = None,
|
|
353
|
+
limit: Optional[int] = 100
|
|
354
|
+
) -> List[OutputData]:
|
|
355
|
+
"""
|
|
356
|
+
List all vectors in a collection.
|
|
357
|
+
|
|
358
|
+
Args:
|
|
359
|
+
filters (Dict, optional): Filters to apply to the list.
|
|
360
|
+
limit (int, optional): Number of vectors to return. Defaults to 100.
|
|
361
|
+
|
|
362
|
+
Returns:
|
|
363
|
+
List[OutputData]: List of vectors.
|
|
364
|
+
"""
|
|
365
|
+
filter_conditions = []
|
|
366
|
+
filter_params = []
|
|
367
|
+
|
|
368
|
+
if filters:
|
|
369
|
+
for k, v in filters.items():
|
|
370
|
+
filter_conditions.append("payload->>%s = %s")
|
|
371
|
+
filter_params.extend([k, str(v)])
|
|
372
|
+
|
|
373
|
+
filter_clause = "WHERE " + " AND ".join(filter_conditions) if filter_conditions else ""
|
|
374
|
+
|
|
375
|
+
query = f"""
|
|
376
|
+
SELECT id, vector, payload
|
|
377
|
+
FROM {self.collection_name}
|
|
378
|
+
{filter_clause}
|
|
379
|
+
LIMIT %s
|
|
380
|
+
"""
|
|
381
|
+
|
|
382
|
+
with self._get_cursor() as cur:
|
|
383
|
+
cur.execute(query, (*filter_params, limit))
|
|
384
|
+
results = cur.fetchall()
|
|
385
|
+
return [[OutputData(id=str(r[0]), score=None, payload=r[2]) for r in results]]
|
|
386
|
+
|
|
387
|
+
def __del__(self) -> None:
|
|
388
|
+
"""
|
|
389
|
+
Close the database connection pool when the object is deleted.
|
|
390
|
+
"""
|
|
391
|
+
try:
|
|
392
|
+
# Close pool appropriately
|
|
393
|
+
if PSYCOPG_VERSION == 3:
|
|
394
|
+
self.connection_pool.close()
|
|
395
|
+
else:
|
|
396
|
+
self.connection_pool.closeall()
|
|
397
|
+
except Exception:
|
|
398
|
+
pass
|
|
399
|
+
|
|
400
|
+
def reset(self) -> None:
|
|
401
|
+
"""Reset the index by deleting and recreating it."""
|
|
402
|
+
logger.warning(f"Resetting index {self.collection_name}...")
|
|
403
|
+
self.delete_col()
|
|
404
|
+
self.create_col()
|