agentrun-mem0ai 0.0.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentrun_mem0/__init__.py +6 -0
- agentrun_mem0/client/__init__.py +0 -0
- agentrun_mem0/client/main.py +1747 -0
- agentrun_mem0/client/project.py +931 -0
- agentrun_mem0/client/utils.py +115 -0
- agentrun_mem0/configs/__init__.py +0 -0
- agentrun_mem0/configs/base.py +90 -0
- agentrun_mem0/configs/embeddings/__init__.py +0 -0
- agentrun_mem0/configs/embeddings/base.py +110 -0
- agentrun_mem0/configs/enums.py +7 -0
- agentrun_mem0/configs/llms/__init__.py +0 -0
- agentrun_mem0/configs/llms/anthropic.py +56 -0
- agentrun_mem0/configs/llms/aws_bedrock.py +192 -0
- agentrun_mem0/configs/llms/azure.py +57 -0
- agentrun_mem0/configs/llms/base.py +62 -0
- agentrun_mem0/configs/llms/deepseek.py +56 -0
- agentrun_mem0/configs/llms/lmstudio.py +59 -0
- agentrun_mem0/configs/llms/ollama.py +56 -0
- agentrun_mem0/configs/llms/openai.py +79 -0
- agentrun_mem0/configs/llms/vllm.py +56 -0
- agentrun_mem0/configs/prompts.py +459 -0
- agentrun_mem0/configs/rerankers/__init__.py +0 -0
- agentrun_mem0/configs/rerankers/base.py +17 -0
- agentrun_mem0/configs/rerankers/cohere.py +15 -0
- agentrun_mem0/configs/rerankers/config.py +12 -0
- agentrun_mem0/configs/rerankers/huggingface.py +17 -0
- agentrun_mem0/configs/rerankers/llm.py +48 -0
- agentrun_mem0/configs/rerankers/sentence_transformer.py +16 -0
- agentrun_mem0/configs/rerankers/zero_entropy.py +28 -0
- agentrun_mem0/configs/vector_stores/__init__.py +0 -0
- agentrun_mem0/configs/vector_stores/alibabacloud_mysql.py +64 -0
- agentrun_mem0/configs/vector_stores/aliyun_tablestore.py +32 -0
- agentrun_mem0/configs/vector_stores/azure_ai_search.py +57 -0
- agentrun_mem0/configs/vector_stores/azure_mysql.py +84 -0
- agentrun_mem0/configs/vector_stores/baidu.py +27 -0
- agentrun_mem0/configs/vector_stores/chroma.py +58 -0
- agentrun_mem0/configs/vector_stores/databricks.py +61 -0
- agentrun_mem0/configs/vector_stores/elasticsearch.py +65 -0
- agentrun_mem0/configs/vector_stores/faiss.py +37 -0
- agentrun_mem0/configs/vector_stores/langchain.py +30 -0
- agentrun_mem0/configs/vector_stores/milvus.py +42 -0
- agentrun_mem0/configs/vector_stores/mongodb.py +25 -0
- agentrun_mem0/configs/vector_stores/neptune.py +27 -0
- agentrun_mem0/configs/vector_stores/opensearch.py +41 -0
- agentrun_mem0/configs/vector_stores/pgvector.py +52 -0
- agentrun_mem0/configs/vector_stores/pinecone.py +55 -0
- agentrun_mem0/configs/vector_stores/qdrant.py +47 -0
- agentrun_mem0/configs/vector_stores/redis.py +24 -0
- agentrun_mem0/configs/vector_stores/s3_vectors.py +28 -0
- agentrun_mem0/configs/vector_stores/supabase.py +44 -0
- agentrun_mem0/configs/vector_stores/upstash_vector.py +34 -0
- agentrun_mem0/configs/vector_stores/valkey.py +15 -0
- agentrun_mem0/configs/vector_stores/vertex_ai_vector_search.py +28 -0
- agentrun_mem0/configs/vector_stores/weaviate.py +41 -0
- agentrun_mem0/embeddings/__init__.py +0 -0
- agentrun_mem0/embeddings/aws_bedrock.py +100 -0
- agentrun_mem0/embeddings/azure_openai.py +55 -0
- agentrun_mem0/embeddings/base.py +31 -0
- agentrun_mem0/embeddings/configs.py +30 -0
- agentrun_mem0/embeddings/gemini.py +39 -0
- agentrun_mem0/embeddings/huggingface.py +44 -0
- agentrun_mem0/embeddings/langchain.py +35 -0
- agentrun_mem0/embeddings/lmstudio.py +29 -0
- agentrun_mem0/embeddings/mock.py +11 -0
- agentrun_mem0/embeddings/ollama.py +53 -0
- agentrun_mem0/embeddings/openai.py +49 -0
- agentrun_mem0/embeddings/together.py +31 -0
- agentrun_mem0/embeddings/vertexai.py +64 -0
- agentrun_mem0/exceptions.py +503 -0
- agentrun_mem0/graphs/__init__.py +0 -0
- agentrun_mem0/graphs/configs.py +105 -0
- agentrun_mem0/graphs/neptune/__init__.py +0 -0
- agentrun_mem0/graphs/neptune/base.py +497 -0
- agentrun_mem0/graphs/neptune/neptunedb.py +511 -0
- agentrun_mem0/graphs/neptune/neptunegraph.py +474 -0
- agentrun_mem0/graphs/tools.py +371 -0
- agentrun_mem0/graphs/utils.py +97 -0
- agentrun_mem0/llms/__init__.py +0 -0
- agentrun_mem0/llms/anthropic.py +87 -0
- agentrun_mem0/llms/aws_bedrock.py +665 -0
- agentrun_mem0/llms/azure_openai.py +141 -0
- agentrun_mem0/llms/azure_openai_structured.py +91 -0
- agentrun_mem0/llms/base.py +131 -0
- agentrun_mem0/llms/configs.py +34 -0
- agentrun_mem0/llms/deepseek.py +107 -0
- agentrun_mem0/llms/gemini.py +201 -0
- agentrun_mem0/llms/groq.py +88 -0
- agentrun_mem0/llms/langchain.py +94 -0
- agentrun_mem0/llms/litellm.py +87 -0
- agentrun_mem0/llms/lmstudio.py +114 -0
- agentrun_mem0/llms/ollama.py +117 -0
- agentrun_mem0/llms/openai.py +147 -0
- agentrun_mem0/llms/openai_structured.py +52 -0
- agentrun_mem0/llms/sarvam.py +89 -0
- agentrun_mem0/llms/together.py +88 -0
- agentrun_mem0/llms/vllm.py +107 -0
- agentrun_mem0/llms/xai.py +52 -0
- agentrun_mem0/memory/__init__.py +0 -0
- agentrun_mem0/memory/base.py +63 -0
- agentrun_mem0/memory/graph_memory.py +698 -0
- agentrun_mem0/memory/kuzu_memory.py +713 -0
- agentrun_mem0/memory/main.py +2229 -0
- agentrun_mem0/memory/memgraph_memory.py +689 -0
- agentrun_mem0/memory/setup.py +56 -0
- agentrun_mem0/memory/storage.py +218 -0
- agentrun_mem0/memory/telemetry.py +90 -0
- agentrun_mem0/memory/utils.py +208 -0
- agentrun_mem0/proxy/__init__.py +0 -0
- agentrun_mem0/proxy/main.py +189 -0
- agentrun_mem0/reranker/__init__.py +9 -0
- agentrun_mem0/reranker/base.py +20 -0
- agentrun_mem0/reranker/cohere_reranker.py +85 -0
- agentrun_mem0/reranker/huggingface_reranker.py +147 -0
- agentrun_mem0/reranker/llm_reranker.py +142 -0
- agentrun_mem0/reranker/sentence_transformer_reranker.py +107 -0
- agentrun_mem0/reranker/zero_entropy_reranker.py +96 -0
- agentrun_mem0/utils/factory.py +283 -0
- agentrun_mem0/utils/gcp_auth.py +167 -0
- agentrun_mem0/vector_stores/__init__.py +0 -0
- agentrun_mem0/vector_stores/alibabacloud_mysql.py +547 -0
- agentrun_mem0/vector_stores/aliyun_tablestore.py +252 -0
- agentrun_mem0/vector_stores/azure_ai_search.py +396 -0
- agentrun_mem0/vector_stores/azure_mysql.py +463 -0
- agentrun_mem0/vector_stores/baidu.py +368 -0
- agentrun_mem0/vector_stores/base.py +58 -0
- agentrun_mem0/vector_stores/chroma.py +332 -0
- agentrun_mem0/vector_stores/configs.py +67 -0
- agentrun_mem0/vector_stores/databricks.py +761 -0
- agentrun_mem0/vector_stores/elasticsearch.py +237 -0
- agentrun_mem0/vector_stores/faiss.py +479 -0
- agentrun_mem0/vector_stores/langchain.py +180 -0
- agentrun_mem0/vector_stores/milvus.py +250 -0
- agentrun_mem0/vector_stores/mongodb.py +310 -0
- agentrun_mem0/vector_stores/neptune_analytics.py +467 -0
- agentrun_mem0/vector_stores/opensearch.py +292 -0
- agentrun_mem0/vector_stores/pgvector.py +404 -0
- agentrun_mem0/vector_stores/pinecone.py +382 -0
- agentrun_mem0/vector_stores/qdrant.py +270 -0
- agentrun_mem0/vector_stores/redis.py +295 -0
- agentrun_mem0/vector_stores/s3_vectors.py +176 -0
- agentrun_mem0/vector_stores/supabase.py +237 -0
- agentrun_mem0/vector_stores/upstash_vector.py +293 -0
- agentrun_mem0/vector_stores/valkey.py +824 -0
- agentrun_mem0/vector_stores/vertex_ai_vector_search.py +635 -0
- agentrun_mem0/vector_stores/weaviate.py +343 -0
- agentrun_mem0ai-0.0.11.data/data/README.md +205 -0
- agentrun_mem0ai-0.0.11.dist-info/METADATA +277 -0
- agentrun_mem0ai-0.0.11.dist-info/RECORD +150 -0
- agentrun_mem0ai-0.0.11.dist-info/WHEEL +4 -0
- agentrun_mem0ai-0.0.11.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,382 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
from typing import Any, Dict, List, Optional, Union
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel
|
|
6
|
+
|
|
7
|
+
try:
|
|
8
|
+
from pinecone import Pinecone, PodSpec, ServerlessSpec, Vector
|
|
9
|
+
except ImportError:
|
|
10
|
+
raise ImportError(
|
|
11
|
+
"Pinecone requires extra dependencies. Install with `pip install pinecone pinecone-text`"
|
|
12
|
+
) from None
|
|
13
|
+
|
|
14
|
+
from agentrun_mem0.vector_stores.base import VectorStoreBase
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class OutputData(BaseModel):
|
|
20
|
+
id: Optional[str] # memory id
|
|
21
|
+
score: Optional[float] # distance
|
|
22
|
+
payload: Optional[Dict] # metadata
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class PineconeDB(VectorStoreBase):
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
collection_name: str,
|
|
29
|
+
embedding_model_dims: int,
|
|
30
|
+
client: Optional["Pinecone"],
|
|
31
|
+
api_key: Optional[str],
|
|
32
|
+
environment: Optional[str],
|
|
33
|
+
serverless_config: Optional[Dict[str, Any]],
|
|
34
|
+
pod_config: Optional[Dict[str, Any]],
|
|
35
|
+
hybrid_search: bool,
|
|
36
|
+
metric: str,
|
|
37
|
+
batch_size: int,
|
|
38
|
+
extra_params: Optional[Dict[str, Any]],
|
|
39
|
+
namespace: Optional[str] = None,
|
|
40
|
+
):
|
|
41
|
+
"""
|
|
42
|
+
Initialize the Pinecone vector store.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
collection_name (str): Name of the index/collection.
|
|
46
|
+
embedding_model_dims (int): Dimensions of the embedding model.
|
|
47
|
+
client (Pinecone, optional): Existing Pinecone client instance. Defaults to None.
|
|
48
|
+
api_key (str, optional): API key for Pinecone. Defaults to None.
|
|
49
|
+
environment (str, optional): Pinecone environment. Defaults to None.
|
|
50
|
+
serverless_config (Dict, optional): Configuration for serverless deployment. Defaults to None.
|
|
51
|
+
pod_config (Dict, optional): Configuration for pod-based deployment. Defaults to None.
|
|
52
|
+
hybrid_search (bool, optional): Whether to enable hybrid search. Defaults to False.
|
|
53
|
+
metric (str, optional): Distance metric for vector similarity. Defaults to "cosine".
|
|
54
|
+
batch_size (int, optional): Batch size for operations. Defaults to 100.
|
|
55
|
+
extra_params (Dict, optional): Additional parameters for Pinecone client. Defaults to None.
|
|
56
|
+
namespace (str, optional): Namespace for the collection. Defaults to None.
|
|
57
|
+
"""
|
|
58
|
+
if client:
|
|
59
|
+
self.client = client
|
|
60
|
+
else:
|
|
61
|
+
api_key = api_key or os.environ.get("PINECONE_API_KEY")
|
|
62
|
+
if not api_key:
|
|
63
|
+
raise ValueError(
|
|
64
|
+
"Pinecone API key must be provided either as a parameter or as an environment variable"
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
params = extra_params or {}
|
|
68
|
+
self.client = Pinecone(api_key=api_key, **params)
|
|
69
|
+
|
|
70
|
+
self.collection_name = collection_name
|
|
71
|
+
self.embedding_model_dims = embedding_model_dims
|
|
72
|
+
self.environment = environment
|
|
73
|
+
self.serverless_config = serverless_config
|
|
74
|
+
self.pod_config = pod_config
|
|
75
|
+
self.hybrid_search = hybrid_search
|
|
76
|
+
self.metric = metric
|
|
77
|
+
self.batch_size = batch_size
|
|
78
|
+
self.namespace = namespace
|
|
79
|
+
|
|
80
|
+
self.sparse_encoder = None
|
|
81
|
+
if self.hybrid_search:
|
|
82
|
+
try:
|
|
83
|
+
from pinecone_text.sparse import BM25Encoder
|
|
84
|
+
|
|
85
|
+
logger.info("Initializing BM25Encoder for sparse vectors...")
|
|
86
|
+
self.sparse_encoder = BM25Encoder.default()
|
|
87
|
+
except ImportError:
|
|
88
|
+
logger.warning("pinecone-text not installed. Hybrid search will be disabled.")
|
|
89
|
+
self.hybrid_search = False
|
|
90
|
+
|
|
91
|
+
self.create_col(embedding_model_dims, metric)
|
|
92
|
+
|
|
93
|
+
def create_col(self, vector_size: int, metric: str = "cosine"):
|
|
94
|
+
"""
|
|
95
|
+
Create a new index/collection.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
vector_size (int): Size of the vectors to be stored.
|
|
99
|
+
metric (str, optional): Distance metric for vector similarity. Defaults to "cosine".
|
|
100
|
+
"""
|
|
101
|
+
existing_indexes = self.list_cols().names()
|
|
102
|
+
|
|
103
|
+
if self.collection_name in existing_indexes:
|
|
104
|
+
logger.debug(f"Index {self.collection_name} already exists. Skipping creation.")
|
|
105
|
+
self.index = self.client.Index(self.collection_name)
|
|
106
|
+
return
|
|
107
|
+
|
|
108
|
+
if self.serverless_config:
|
|
109
|
+
spec = ServerlessSpec(**self.serverless_config)
|
|
110
|
+
elif self.pod_config:
|
|
111
|
+
spec = PodSpec(**self.pod_config)
|
|
112
|
+
else:
|
|
113
|
+
spec = ServerlessSpec(cloud="aws", region="us-west-2")
|
|
114
|
+
|
|
115
|
+
self.client.create_index(
|
|
116
|
+
name=self.collection_name,
|
|
117
|
+
dimension=vector_size,
|
|
118
|
+
metric=metric,
|
|
119
|
+
spec=spec,
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
self.index = self.client.Index(self.collection_name)
|
|
123
|
+
|
|
124
|
+
def insert(
|
|
125
|
+
self,
|
|
126
|
+
vectors: List[List[float]],
|
|
127
|
+
payloads: Optional[List[Dict]] = None,
|
|
128
|
+
ids: Optional[List[Union[str, int]]] = None,
|
|
129
|
+
):
|
|
130
|
+
"""
|
|
131
|
+
Insert vectors into an index.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
vectors (list): List of vectors to insert.
|
|
135
|
+
payloads (list, optional): List of payloads corresponding to vectors. Defaults to None.
|
|
136
|
+
ids (list, optional): List of IDs corresponding to vectors. Defaults to None.
|
|
137
|
+
"""
|
|
138
|
+
logger.info(f"Inserting {len(vectors)} vectors into index {self.collection_name}")
|
|
139
|
+
items = []
|
|
140
|
+
|
|
141
|
+
for idx, vector in enumerate(vectors):
|
|
142
|
+
item_id = str(ids[idx]) if ids is not None else str(idx)
|
|
143
|
+
payload = payloads[idx] if payloads else {}
|
|
144
|
+
|
|
145
|
+
vector_record = {"id": item_id, "values": vector, "metadata": payload}
|
|
146
|
+
|
|
147
|
+
if self.hybrid_search and self.sparse_encoder and "text" in payload:
|
|
148
|
+
sparse_vector = self.sparse_encoder.encode_documents(payload["text"])
|
|
149
|
+
vector_record["sparse_values"] = sparse_vector
|
|
150
|
+
|
|
151
|
+
items.append(vector_record)
|
|
152
|
+
|
|
153
|
+
if len(items) >= self.batch_size:
|
|
154
|
+
self.index.upsert(vectors=items, namespace=self.namespace)
|
|
155
|
+
items = []
|
|
156
|
+
|
|
157
|
+
if items:
|
|
158
|
+
self.index.upsert(vectors=items, namespace=self.namespace)
|
|
159
|
+
|
|
160
|
+
def _parse_output(self, data: Dict) -> List[OutputData]:
|
|
161
|
+
"""
|
|
162
|
+
Parse the output data from Pinecone search results.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
data (Dict): Output data from Pinecone query.
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
List[OutputData]: Parsed output data.
|
|
169
|
+
"""
|
|
170
|
+
if isinstance(data, Vector):
|
|
171
|
+
result = OutputData(
|
|
172
|
+
id=data.id,
|
|
173
|
+
score=0.0,
|
|
174
|
+
payload=data.metadata,
|
|
175
|
+
)
|
|
176
|
+
return result
|
|
177
|
+
else:
|
|
178
|
+
result = []
|
|
179
|
+
for match in data:
|
|
180
|
+
entry = OutputData(
|
|
181
|
+
id=match.get("id"),
|
|
182
|
+
score=match.get("score"),
|
|
183
|
+
payload=match.get("metadata"),
|
|
184
|
+
)
|
|
185
|
+
result.append(entry)
|
|
186
|
+
|
|
187
|
+
return result
|
|
188
|
+
|
|
189
|
+
def _create_filter(self, filters: Optional[Dict]) -> Dict:
|
|
190
|
+
"""
|
|
191
|
+
Create a filter dictionary from the provided filters.
|
|
192
|
+
"""
|
|
193
|
+
if not filters:
|
|
194
|
+
return {}
|
|
195
|
+
|
|
196
|
+
pinecone_filter = {}
|
|
197
|
+
|
|
198
|
+
for key, value in filters.items():
|
|
199
|
+
if isinstance(value, dict) and "gte" in value and "lte" in value:
|
|
200
|
+
pinecone_filter[key] = {"$gte": value["gte"], "$lte": value["lte"]}
|
|
201
|
+
else:
|
|
202
|
+
pinecone_filter[key] = {"$eq": value}
|
|
203
|
+
|
|
204
|
+
return pinecone_filter
|
|
205
|
+
|
|
206
|
+
def search(
|
|
207
|
+
self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None
|
|
208
|
+
) -> List[OutputData]:
|
|
209
|
+
"""
|
|
210
|
+
Search for similar vectors.
|
|
211
|
+
|
|
212
|
+
Args:
|
|
213
|
+
query (str): Query.
|
|
214
|
+
vectors (list): List of vectors to search.
|
|
215
|
+
limit (int, optional): Number of results to return. Defaults to 5.
|
|
216
|
+
filters (dict, optional): Filters to apply to the search. Defaults to None.
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
219
|
+
list: Search results.
|
|
220
|
+
"""
|
|
221
|
+
filter_dict = self._create_filter(filters) if filters else None
|
|
222
|
+
|
|
223
|
+
query_params = {
|
|
224
|
+
"vector": vectors,
|
|
225
|
+
"top_k": limit,
|
|
226
|
+
"include_metadata": True,
|
|
227
|
+
"include_values": False,
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
if filter_dict:
|
|
231
|
+
query_params["filter"] = filter_dict
|
|
232
|
+
|
|
233
|
+
if self.hybrid_search and self.sparse_encoder and "text" in filters:
|
|
234
|
+
query_text = filters.get("text")
|
|
235
|
+
if query_text:
|
|
236
|
+
sparse_vector = self.sparse_encoder.encode_queries(query_text)
|
|
237
|
+
query_params["sparse_vector"] = sparse_vector
|
|
238
|
+
|
|
239
|
+
response = self.index.query(**query_params, namespace=self.namespace)
|
|
240
|
+
|
|
241
|
+
results = self._parse_output(response.matches)
|
|
242
|
+
return results
|
|
243
|
+
|
|
244
|
+
def delete(self, vector_id: Union[str, int]):
|
|
245
|
+
"""
|
|
246
|
+
Delete a vector by ID.
|
|
247
|
+
|
|
248
|
+
Args:
|
|
249
|
+
vector_id (Union[str, int]): ID of the vector to delete.
|
|
250
|
+
"""
|
|
251
|
+
self.index.delete(ids=[str(vector_id)], namespace=self.namespace)
|
|
252
|
+
|
|
253
|
+
def update(self, vector_id: Union[str, int], vector: Optional[List[float]] = None, payload: Optional[Dict] = None):
|
|
254
|
+
"""
|
|
255
|
+
Update a vector and its payload.
|
|
256
|
+
|
|
257
|
+
Args:
|
|
258
|
+
vector_id (Union[str, int]): ID of the vector to update.
|
|
259
|
+
vector (list, optional): Updated vector. Defaults to None.
|
|
260
|
+
payload (dict, optional): Updated payload. Defaults to None.
|
|
261
|
+
"""
|
|
262
|
+
item = {
|
|
263
|
+
"id": str(vector_id),
|
|
264
|
+
}
|
|
265
|
+
|
|
266
|
+
if vector is not None:
|
|
267
|
+
item["values"] = vector
|
|
268
|
+
|
|
269
|
+
if payload is not None:
|
|
270
|
+
item["metadata"] = payload
|
|
271
|
+
|
|
272
|
+
if self.hybrid_search and self.sparse_encoder and "text" in payload:
|
|
273
|
+
sparse_vector = self.sparse_encoder.encode_documents(payload["text"])
|
|
274
|
+
item["sparse_values"] = sparse_vector
|
|
275
|
+
|
|
276
|
+
self.index.upsert(vectors=[item], namespace=self.namespace)
|
|
277
|
+
|
|
278
|
+
def get(self, vector_id: Union[str, int]) -> OutputData:
|
|
279
|
+
"""
|
|
280
|
+
Retrieve a vector by ID.
|
|
281
|
+
|
|
282
|
+
Args:
|
|
283
|
+
vector_id (Union[str, int]): ID of the vector to retrieve.
|
|
284
|
+
|
|
285
|
+
Returns:
|
|
286
|
+
dict: Retrieved vector or None if not found.
|
|
287
|
+
"""
|
|
288
|
+
try:
|
|
289
|
+
response = self.index.fetch(ids=[str(vector_id)], namespace=self.namespace)
|
|
290
|
+
if str(vector_id) in response.vectors:
|
|
291
|
+
return self._parse_output(response.vectors[str(vector_id)])
|
|
292
|
+
return None
|
|
293
|
+
except Exception as e:
|
|
294
|
+
logger.error(f"Error retrieving vector {vector_id}: {e}")
|
|
295
|
+
return None
|
|
296
|
+
|
|
297
|
+
def list_cols(self):
|
|
298
|
+
"""
|
|
299
|
+
List all indexes/collections.
|
|
300
|
+
|
|
301
|
+
Returns:
|
|
302
|
+
list: List of index information.
|
|
303
|
+
"""
|
|
304
|
+
return self.client.list_indexes()
|
|
305
|
+
|
|
306
|
+
def delete_col(self):
|
|
307
|
+
"""Delete an index/collection."""
|
|
308
|
+
try:
|
|
309
|
+
self.client.delete_index(self.collection_name)
|
|
310
|
+
logger.info(f"Index {self.collection_name} deleted successfully")
|
|
311
|
+
except Exception as e:
|
|
312
|
+
logger.error(f"Error deleting index {self.collection_name}: {e}")
|
|
313
|
+
|
|
314
|
+
def col_info(self) -> Dict:
|
|
315
|
+
"""
|
|
316
|
+
Get information about an index/collection.
|
|
317
|
+
|
|
318
|
+
Returns:
|
|
319
|
+
dict: Index information.
|
|
320
|
+
"""
|
|
321
|
+
return self.client.describe_index(self.collection_name)
|
|
322
|
+
|
|
323
|
+
def list(self, filters: Optional[Dict] = None, limit: int = 100) -> List[OutputData]:
|
|
324
|
+
"""
|
|
325
|
+
List vectors in an index with optional filtering.
|
|
326
|
+
|
|
327
|
+
Args:
|
|
328
|
+
filters (dict, optional): Filters to apply to the list. Defaults to None.
|
|
329
|
+
limit (int, optional): Number of vectors to return. Defaults to 100.
|
|
330
|
+
|
|
331
|
+
Returns:
|
|
332
|
+
dict: List of vectors with their metadata.
|
|
333
|
+
"""
|
|
334
|
+
filter_dict = self._create_filter(filters) if filters else None
|
|
335
|
+
|
|
336
|
+
stats = self.index.describe_index_stats()
|
|
337
|
+
dimension = stats.dimension
|
|
338
|
+
|
|
339
|
+
zero_vector = [0.0] * dimension
|
|
340
|
+
|
|
341
|
+
query_params = {
|
|
342
|
+
"vector": zero_vector,
|
|
343
|
+
"top_k": limit,
|
|
344
|
+
"include_metadata": True,
|
|
345
|
+
"include_values": True,
|
|
346
|
+
}
|
|
347
|
+
|
|
348
|
+
if filter_dict:
|
|
349
|
+
query_params["filter"] = filter_dict
|
|
350
|
+
|
|
351
|
+
try:
|
|
352
|
+
response = self.index.query(**query_params, namespace=self.namespace)
|
|
353
|
+
response = response.to_dict()
|
|
354
|
+
results = self._parse_output(response["matches"])
|
|
355
|
+
return [results]
|
|
356
|
+
except Exception as e:
|
|
357
|
+
logger.error(f"Error listing vectors: {e}")
|
|
358
|
+
return {"points": [], "next_page_token": None}
|
|
359
|
+
|
|
360
|
+
def count(self) -> int:
|
|
361
|
+
"""
|
|
362
|
+
Count number of vectors in the index.
|
|
363
|
+
|
|
364
|
+
Returns:
|
|
365
|
+
int: Total number of vectors.
|
|
366
|
+
"""
|
|
367
|
+
stats = self.index.describe_index_stats()
|
|
368
|
+
if self.namespace:
|
|
369
|
+
# Safely get the namespace stats and return vector_count, defaulting to 0 if not found
|
|
370
|
+
namespace_summary = (stats.namespaces or {}).get(self.namespace)
|
|
371
|
+
if namespace_summary:
|
|
372
|
+
return namespace_summary.vector_count or 0
|
|
373
|
+
return 0
|
|
374
|
+
return stats.total_vector_count or 0
|
|
375
|
+
|
|
376
|
+
def reset(self):
|
|
377
|
+
"""
|
|
378
|
+
Reset the index by deleting and recreating it.
|
|
379
|
+
"""
|
|
380
|
+
logger.warning(f"Resetting index {self.collection_name}...")
|
|
381
|
+
self.delete_col()
|
|
382
|
+
self.create_col(self.embedding_model_dims, self.metric)
|
|
@@ -0,0 +1,270 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
import shutil
|
|
4
|
+
|
|
5
|
+
from qdrant_client import QdrantClient
|
|
6
|
+
from qdrant_client.models import (
|
|
7
|
+
Distance,
|
|
8
|
+
FieldCondition,
|
|
9
|
+
Filter,
|
|
10
|
+
MatchValue,
|
|
11
|
+
PointIdsList,
|
|
12
|
+
PointStruct,
|
|
13
|
+
Range,
|
|
14
|
+
VectorParams,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
from agentrun_mem0.vector_stores.base import VectorStoreBase
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class Qdrant(VectorStoreBase):
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
collection_name: str,
|
|
26
|
+
embedding_model_dims: int,
|
|
27
|
+
client: QdrantClient = None,
|
|
28
|
+
host: str = None,
|
|
29
|
+
port: int = None,
|
|
30
|
+
path: str = None,
|
|
31
|
+
url: str = None,
|
|
32
|
+
api_key: str = None,
|
|
33
|
+
on_disk: bool = False,
|
|
34
|
+
):
|
|
35
|
+
"""
|
|
36
|
+
Initialize the Qdrant vector store.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
collection_name (str): Name of the collection.
|
|
40
|
+
embedding_model_dims (int): Dimensions of the embedding model.
|
|
41
|
+
client (QdrantClient, optional): Existing Qdrant client instance. Defaults to None.
|
|
42
|
+
host (str, optional): Host address for Qdrant server. Defaults to None.
|
|
43
|
+
port (int, optional): Port for Qdrant server. Defaults to None.
|
|
44
|
+
path (str, optional): Path for local Qdrant database. Defaults to None.
|
|
45
|
+
url (str, optional): Full URL for Qdrant server. Defaults to None.
|
|
46
|
+
api_key (str, optional): API key for Qdrant server. Defaults to None.
|
|
47
|
+
on_disk (bool, optional): Enables persistent storage. Defaults to False.
|
|
48
|
+
"""
|
|
49
|
+
if client:
|
|
50
|
+
self.client = client
|
|
51
|
+
self.is_local = False
|
|
52
|
+
else:
|
|
53
|
+
params = {}
|
|
54
|
+
if api_key:
|
|
55
|
+
params["api_key"] = api_key
|
|
56
|
+
if url:
|
|
57
|
+
params["url"] = url
|
|
58
|
+
if host and port:
|
|
59
|
+
params["host"] = host
|
|
60
|
+
params["port"] = port
|
|
61
|
+
|
|
62
|
+
if not params:
|
|
63
|
+
params["path"] = path
|
|
64
|
+
self.is_local = True
|
|
65
|
+
if not on_disk:
|
|
66
|
+
if os.path.exists(path) and os.path.isdir(path):
|
|
67
|
+
shutil.rmtree(path)
|
|
68
|
+
else:
|
|
69
|
+
self.is_local = False
|
|
70
|
+
|
|
71
|
+
self.client = QdrantClient(**params)
|
|
72
|
+
|
|
73
|
+
self.collection_name = collection_name
|
|
74
|
+
self.embedding_model_dims = embedding_model_dims
|
|
75
|
+
self.on_disk = on_disk
|
|
76
|
+
self.create_col(embedding_model_dims, on_disk)
|
|
77
|
+
|
|
78
|
+
def create_col(self, vector_size: int, on_disk: bool, distance: Distance = Distance.COSINE):
|
|
79
|
+
"""
|
|
80
|
+
Create a new collection.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
vector_size (int): Size of the vectors to be stored.
|
|
84
|
+
on_disk (bool): Enables persistent storage.
|
|
85
|
+
distance (Distance, optional): Distance metric for vector similarity. Defaults to Distance.COSINE.
|
|
86
|
+
"""
|
|
87
|
+
# Skip creating collection if already exists
|
|
88
|
+
response = self.list_cols()
|
|
89
|
+
for collection in response.collections:
|
|
90
|
+
if collection.name == self.collection_name:
|
|
91
|
+
logger.debug(f"Collection {self.collection_name} already exists. Skipping creation.")
|
|
92
|
+
self._create_filter_indexes()
|
|
93
|
+
return
|
|
94
|
+
|
|
95
|
+
self.client.create_collection(
|
|
96
|
+
collection_name=self.collection_name,
|
|
97
|
+
vectors_config=VectorParams(size=vector_size, distance=distance, on_disk=on_disk),
|
|
98
|
+
)
|
|
99
|
+
self._create_filter_indexes()
|
|
100
|
+
|
|
101
|
+
def _create_filter_indexes(self):
|
|
102
|
+
"""Create indexes for commonly used filter fields to enable filtering."""
|
|
103
|
+
# Only create payload indexes for remote Qdrant servers
|
|
104
|
+
if self.is_local:
|
|
105
|
+
logger.debug("Skipping payload index creation for local Qdrant (not supported)")
|
|
106
|
+
return
|
|
107
|
+
|
|
108
|
+
common_fields = ["user_id", "agent_id", "run_id", "actor_id"]
|
|
109
|
+
|
|
110
|
+
for field in common_fields:
|
|
111
|
+
try:
|
|
112
|
+
self.client.create_payload_index(
|
|
113
|
+
collection_name=self.collection_name,
|
|
114
|
+
field_name=field,
|
|
115
|
+
field_schema="keyword"
|
|
116
|
+
)
|
|
117
|
+
logger.info(f"Created index for {field} in collection {self.collection_name}")
|
|
118
|
+
except Exception as e:
|
|
119
|
+
logger.debug(f"Index for {field} might already exist: {e}")
|
|
120
|
+
|
|
121
|
+
def insert(self, vectors: list, payloads: list = None, ids: list = None):
|
|
122
|
+
"""
|
|
123
|
+
Insert vectors into a collection.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
vectors (list): List of vectors to insert.
|
|
127
|
+
payloads (list, optional): List of payloads corresponding to vectors. Defaults to None.
|
|
128
|
+
ids (list, optional): List of IDs corresponding to vectors. Defaults to None.
|
|
129
|
+
"""
|
|
130
|
+
logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}")
|
|
131
|
+
points = [
|
|
132
|
+
PointStruct(
|
|
133
|
+
id=idx if ids is None else ids[idx],
|
|
134
|
+
vector=vector,
|
|
135
|
+
payload=payloads[idx] if payloads else {},
|
|
136
|
+
)
|
|
137
|
+
for idx, vector in enumerate(vectors)
|
|
138
|
+
]
|
|
139
|
+
self.client.upsert(collection_name=self.collection_name, points=points)
|
|
140
|
+
|
|
141
|
+
def _create_filter(self, filters: dict) -> Filter:
|
|
142
|
+
"""
|
|
143
|
+
Create a Filter object from the provided filters.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
filters (dict): Filters to apply.
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
Filter: The created Filter object.
|
|
150
|
+
"""
|
|
151
|
+
if not filters:
|
|
152
|
+
return None
|
|
153
|
+
|
|
154
|
+
conditions = []
|
|
155
|
+
for key, value in filters.items():
|
|
156
|
+
if isinstance(value, dict) and "gte" in value and "lte" in value:
|
|
157
|
+
conditions.append(FieldCondition(key=key, range=Range(gte=value["gte"], lte=value["lte"])))
|
|
158
|
+
else:
|
|
159
|
+
conditions.append(FieldCondition(key=key, match=MatchValue(value=value)))
|
|
160
|
+
return Filter(must=conditions) if conditions else None
|
|
161
|
+
|
|
162
|
+
def search(self, query: str, vectors: list, limit: int = 5, filters: dict = None) -> list:
|
|
163
|
+
"""
|
|
164
|
+
Search for similar vectors.
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
query (str): Query.
|
|
168
|
+
vectors (list): Query vector.
|
|
169
|
+
limit (int, optional): Number of results to return. Defaults to 5.
|
|
170
|
+
filters (dict, optional): Filters to apply to the search. Defaults to None.
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
list: Search results.
|
|
174
|
+
"""
|
|
175
|
+
query_filter = self._create_filter(filters) if filters else None
|
|
176
|
+
hits = self.client.query_points(
|
|
177
|
+
collection_name=self.collection_name,
|
|
178
|
+
query=vectors,
|
|
179
|
+
query_filter=query_filter,
|
|
180
|
+
limit=limit,
|
|
181
|
+
)
|
|
182
|
+
return hits.points
|
|
183
|
+
|
|
184
|
+
def delete(self, vector_id: int):
|
|
185
|
+
"""
|
|
186
|
+
Delete a vector by ID.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
vector_id (int): ID of the vector to delete.
|
|
190
|
+
"""
|
|
191
|
+
self.client.delete(
|
|
192
|
+
collection_name=self.collection_name,
|
|
193
|
+
points_selector=PointIdsList(
|
|
194
|
+
points=[vector_id],
|
|
195
|
+
),
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
def update(self, vector_id: int, vector: list = None, payload: dict = None):
|
|
199
|
+
"""
|
|
200
|
+
Update a vector and its payload.
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
vector_id (int): ID of the vector to update.
|
|
204
|
+
vector (list, optional): Updated vector. Defaults to None.
|
|
205
|
+
payload (dict, optional): Updated payload. Defaults to None.
|
|
206
|
+
"""
|
|
207
|
+
point = PointStruct(id=vector_id, vector=vector, payload=payload)
|
|
208
|
+
self.client.upsert(collection_name=self.collection_name, points=[point])
|
|
209
|
+
|
|
210
|
+
def get(self, vector_id: int) -> dict:
|
|
211
|
+
"""
|
|
212
|
+
Retrieve a vector by ID.
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
vector_id (int): ID of the vector to retrieve.
|
|
216
|
+
|
|
217
|
+
Returns:
|
|
218
|
+
dict: Retrieved vector.
|
|
219
|
+
"""
|
|
220
|
+
result = self.client.retrieve(collection_name=self.collection_name, ids=[vector_id], with_payload=True)
|
|
221
|
+
return result[0] if result else None
|
|
222
|
+
|
|
223
|
+
def list_cols(self) -> list:
|
|
224
|
+
"""
|
|
225
|
+
List all collections.
|
|
226
|
+
|
|
227
|
+
Returns:
|
|
228
|
+
list: List of collection names.
|
|
229
|
+
"""
|
|
230
|
+
return self.client.get_collections()
|
|
231
|
+
|
|
232
|
+
def delete_col(self):
|
|
233
|
+
"""Delete a collection."""
|
|
234
|
+
self.client.delete_collection(collection_name=self.collection_name)
|
|
235
|
+
|
|
236
|
+
def col_info(self) -> dict:
|
|
237
|
+
"""
|
|
238
|
+
Get information about a collection.
|
|
239
|
+
|
|
240
|
+
Returns:
|
|
241
|
+
dict: Collection information.
|
|
242
|
+
"""
|
|
243
|
+
return self.client.get_collection(collection_name=self.collection_name)
|
|
244
|
+
|
|
245
|
+
def list(self, filters: dict = None, limit: int = 100) -> list:
|
|
246
|
+
"""
|
|
247
|
+
List all vectors in a collection.
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
filters (dict, optional): Filters to apply to the list. Defaults to None.
|
|
251
|
+
limit (int, optional): Number of vectors to return. Defaults to 100.
|
|
252
|
+
|
|
253
|
+
Returns:
|
|
254
|
+
list: List of vectors.
|
|
255
|
+
"""
|
|
256
|
+
query_filter = self._create_filter(filters) if filters else None
|
|
257
|
+
result = self.client.scroll(
|
|
258
|
+
collection_name=self.collection_name,
|
|
259
|
+
scroll_filter=query_filter,
|
|
260
|
+
limit=limit,
|
|
261
|
+
with_payload=True,
|
|
262
|
+
with_vectors=False,
|
|
263
|
+
)
|
|
264
|
+
return result
|
|
265
|
+
|
|
266
|
+
def reset(self):
|
|
267
|
+
"""Reset the index by deleting and recreating it."""
|
|
268
|
+
logger.warning(f"Resetting index {self.collection_name}...")
|
|
269
|
+
self.delete_col()
|
|
270
|
+
self.create_col(self.embedding_model_dims, self.on_disk)
|