agentrun-mem0ai 0.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. agentrun_mem0/__init__.py +6 -0
  2. agentrun_mem0/client/__init__.py +0 -0
  3. agentrun_mem0/client/main.py +1747 -0
  4. agentrun_mem0/client/project.py +931 -0
  5. agentrun_mem0/client/utils.py +115 -0
  6. agentrun_mem0/configs/__init__.py +0 -0
  7. agentrun_mem0/configs/base.py +90 -0
  8. agentrun_mem0/configs/embeddings/__init__.py +0 -0
  9. agentrun_mem0/configs/embeddings/base.py +110 -0
  10. agentrun_mem0/configs/enums.py +7 -0
  11. agentrun_mem0/configs/llms/__init__.py +0 -0
  12. agentrun_mem0/configs/llms/anthropic.py +56 -0
  13. agentrun_mem0/configs/llms/aws_bedrock.py +192 -0
  14. agentrun_mem0/configs/llms/azure.py +57 -0
  15. agentrun_mem0/configs/llms/base.py +62 -0
  16. agentrun_mem0/configs/llms/deepseek.py +56 -0
  17. agentrun_mem0/configs/llms/lmstudio.py +59 -0
  18. agentrun_mem0/configs/llms/ollama.py +56 -0
  19. agentrun_mem0/configs/llms/openai.py +79 -0
  20. agentrun_mem0/configs/llms/vllm.py +56 -0
  21. agentrun_mem0/configs/prompts.py +459 -0
  22. agentrun_mem0/configs/rerankers/__init__.py +0 -0
  23. agentrun_mem0/configs/rerankers/base.py +17 -0
  24. agentrun_mem0/configs/rerankers/cohere.py +15 -0
  25. agentrun_mem0/configs/rerankers/config.py +12 -0
  26. agentrun_mem0/configs/rerankers/huggingface.py +17 -0
  27. agentrun_mem0/configs/rerankers/llm.py +48 -0
  28. agentrun_mem0/configs/rerankers/sentence_transformer.py +16 -0
  29. agentrun_mem0/configs/rerankers/zero_entropy.py +28 -0
  30. agentrun_mem0/configs/vector_stores/__init__.py +0 -0
  31. agentrun_mem0/configs/vector_stores/alibabacloud_mysql.py +64 -0
  32. agentrun_mem0/configs/vector_stores/aliyun_tablestore.py +32 -0
  33. agentrun_mem0/configs/vector_stores/azure_ai_search.py +57 -0
  34. agentrun_mem0/configs/vector_stores/azure_mysql.py +84 -0
  35. agentrun_mem0/configs/vector_stores/baidu.py +27 -0
  36. agentrun_mem0/configs/vector_stores/chroma.py +58 -0
  37. agentrun_mem0/configs/vector_stores/databricks.py +61 -0
  38. agentrun_mem0/configs/vector_stores/elasticsearch.py +65 -0
  39. agentrun_mem0/configs/vector_stores/faiss.py +37 -0
  40. agentrun_mem0/configs/vector_stores/langchain.py +30 -0
  41. agentrun_mem0/configs/vector_stores/milvus.py +42 -0
  42. agentrun_mem0/configs/vector_stores/mongodb.py +25 -0
  43. agentrun_mem0/configs/vector_stores/neptune.py +27 -0
  44. agentrun_mem0/configs/vector_stores/opensearch.py +41 -0
  45. agentrun_mem0/configs/vector_stores/pgvector.py +52 -0
  46. agentrun_mem0/configs/vector_stores/pinecone.py +55 -0
  47. agentrun_mem0/configs/vector_stores/qdrant.py +47 -0
  48. agentrun_mem0/configs/vector_stores/redis.py +24 -0
  49. agentrun_mem0/configs/vector_stores/s3_vectors.py +28 -0
  50. agentrun_mem0/configs/vector_stores/supabase.py +44 -0
  51. agentrun_mem0/configs/vector_stores/upstash_vector.py +34 -0
  52. agentrun_mem0/configs/vector_stores/valkey.py +15 -0
  53. agentrun_mem0/configs/vector_stores/vertex_ai_vector_search.py +28 -0
  54. agentrun_mem0/configs/vector_stores/weaviate.py +41 -0
  55. agentrun_mem0/embeddings/__init__.py +0 -0
  56. agentrun_mem0/embeddings/aws_bedrock.py +100 -0
  57. agentrun_mem0/embeddings/azure_openai.py +55 -0
  58. agentrun_mem0/embeddings/base.py +31 -0
  59. agentrun_mem0/embeddings/configs.py +30 -0
  60. agentrun_mem0/embeddings/gemini.py +39 -0
  61. agentrun_mem0/embeddings/huggingface.py +44 -0
  62. agentrun_mem0/embeddings/langchain.py +35 -0
  63. agentrun_mem0/embeddings/lmstudio.py +29 -0
  64. agentrun_mem0/embeddings/mock.py +11 -0
  65. agentrun_mem0/embeddings/ollama.py +53 -0
  66. agentrun_mem0/embeddings/openai.py +49 -0
  67. agentrun_mem0/embeddings/together.py +31 -0
  68. agentrun_mem0/embeddings/vertexai.py +64 -0
  69. agentrun_mem0/exceptions.py +503 -0
  70. agentrun_mem0/graphs/__init__.py +0 -0
  71. agentrun_mem0/graphs/configs.py +105 -0
  72. agentrun_mem0/graphs/neptune/__init__.py +0 -0
  73. agentrun_mem0/graphs/neptune/base.py +497 -0
  74. agentrun_mem0/graphs/neptune/neptunedb.py +511 -0
  75. agentrun_mem0/graphs/neptune/neptunegraph.py +474 -0
  76. agentrun_mem0/graphs/tools.py +371 -0
  77. agentrun_mem0/graphs/utils.py +97 -0
  78. agentrun_mem0/llms/__init__.py +0 -0
  79. agentrun_mem0/llms/anthropic.py +87 -0
  80. agentrun_mem0/llms/aws_bedrock.py +665 -0
  81. agentrun_mem0/llms/azure_openai.py +141 -0
  82. agentrun_mem0/llms/azure_openai_structured.py +91 -0
  83. agentrun_mem0/llms/base.py +131 -0
  84. agentrun_mem0/llms/configs.py +34 -0
  85. agentrun_mem0/llms/deepseek.py +107 -0
  86. agentrun_mem0/llms/gemini.py +201 -0
  87. agentrun_mem0/llms/groq.py +88 -0
  88. agentrun_mem0/llms/langchain.py +94 -0
  89. agentrun_mem0/llms/litellm.py +87 -0
  90. agentrun_mem0/llms/lmstudio.py +114 -0
  91. agentrun_mem0/llms/ollama.py +117 -0
  92. agentrun_mem0/llms/openai.py +147 -0
  93. agentrun_mem0/llms/openai_structured.py +52 -0
  94. agentrun_mem0/llms/sarvam.py +89 -0
  95. agentrun_mem0/llms/together.py +88 -0
  96. agentrun_mem0/llms/vllm.py +107 -0
  97. agentrun_mem0/llms/xai.py +52 -0
  98. agentrun_mem0/memory/__init__.py +0 -0
  99. agentrun_mem0/memory/base.py +63 -0
  100. agentrun_mem0/memory/graph_memory.py +698 -0
  101. agentrun_mem0/memory/kuzu_memory.py +713 -0
  102. agentrun_mem0/memory/main.py +2229 -0
  103. agentrun_mem0/memory/memgraph_memory.py +689 -0
  104. agentrun_mem0/memory/setup.py +56 -0
  105. agentrun_mem0/memory/storage.py +218 -0
  106. agentrun_mem0/memory/telemetry.py +90 -0
  107. agentrun_mem0/memory/utils.py +208 -0
  108. agentrun_mem0/proxy/__init__.py +0 -0
  109. agentrun_mem0/proxy/main.py +189 -0
  110. agentrun_mem0/reranker/__init__.py +9 -0
  111. agentrun_mem0/reranker/base.py +20 -0
  112. agentrun_mem0/reranker/cohere_reranker.py +85 -0
  113. agentrun_mem0/reranker/huggingface_reranker.py +147 -0
  114. agentrun_mem0/reranker/llm_reranker.py +142 -0
  115. agentrun_mem0/reranker/sentence_transformer_reranker.py +107 -0
  116. agentrun_mem0/reranker/zero_entropy_reranker.py +96 -0
  117. agentrun_mem0/utils/factory.py +283 -0
  118. agentrun_mem0/utils/gcp_auth.py +167 -0
  119. agentrun_mem0/vector_stores/__init__.py +0 -0
  120. agentrun_mem0/vector_stores/alibabacloud_mysql.py +547 -0
  121. agentrun_mem0/vector_stores/aliyun_tablestore.py +252 -0
  122. agentrun_mem0/vector_stores/azure_ai_search.py +396 -0
  123. agentrun_mem0/vector_stores/azure_mysql.py +463 -0
  124. agentrun_mem0/vector_stores/baidu.py +368 -0
  125. agentrun_mem0/vector_stores/base.py +58 -0
  126. agentrun_mem0/vector_stores/chroma.py +332 -0
  127. agentrun_mem0/vector_stores/configs.py +67 -0
  128. agentrun_mem0/vector_stores/databricks.py +761 -0
  129. agentrun_mem0/vector_stores/elasticsearch.py +237 -0
  130. agentrun_mem0/vector_stores/faiss.py +479 -0
  131. agentrun_mem0/vector_stores/langchain.py +180 -0
  132. agentrun_mem0/vector_stores/milvus.py +250 -0
  133. agentrun_mem0/vector_stores/mongodb.py +310 -0
  134. agentrun_mem0/vector_stores/neptune_analytics.py +467 -0
  135. agentrun_mem0/vector_stores/opensearch.py +292 -0
  136. agentrun_mem0/vector_stores/pgvector.py +404 -0
  137. agentrun_mem0/vector_stores/pinecone.py +382 -0
  138. agentrun_mem0/vector_stores/qdrant.py +270 -0
  139. agentrun_mem0/vector_stores/redis.py +295 -0
  140. agentrun_mem0/vector_stores/s3_vectors.py +176 -0
  141. agentrun_mem0/vector_stores/supabase.py +237 -0
  142. agentrun_mem0/vector_stores/upstash_vector.py +293 -0
  143. agentrun_mem0/vector_stores/valkey.py +824 -0
  144. agentrun_mem0/vector_stores/vertex_ai_vector_search.py +635 -0
  145. agentrun_mem0/vector_stores/weaviate.py +343 -0
  146. agentrun_mem0ai-0.0.11.data/data/README.md +205 -0
  147. agentrun_mem0ai-0.0.11.dist-info/METADATA +277 -0
  148. agentrun_mem0ai-0.0.11.dist-info/RECORD +150 -0
  149. agentrun_mem0ai-0.0.11.dist-info/WHEEL +4 -0
  150. agentrun_mem0ai-0.0.11.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,382 @@
1
+ import logging
2
+ import os
3
+ from typing import Any, Dict, List, Optional, Union
4
+
5
+ from pydantic import BaseModel
6
+
7
+ try:
8
+ from pinecone import Pinecone, PodSpec, ServerlessSpec, Vector
9
+ except ImportError:
10
+ raise ImportError(
11
+ "Pinecone requires extra dependencies. Install with `pip install pinecone pinecone-text`"
12
+ ) from None
13
+
14
+ from agentrun_mem0.vector_stores.base import VectorStoreBase
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class OutputData(BaseModel):
20
+ id: Optional[str] # memory id
21
+ score: Optional[float] # distance
22
+ payload: Optional[Dict] # metadata
23
+
24
+
25
+ class PineconeDB(VectorStoreBase):
26
+ def __init__(
27
+ self,
28
+ collection_name: str,
29
+ embedding_model_dims: int,
30
+ client: Optional["Pinecone"],
31
+ api_key: Optional[str],
32
+ environment: Optional[str],
33
+ serverless_config: Optional[Dict[str, Any]],
34
+ pod_config: Optional[Dict[str, Any]],
35
+ hybrid_search: bool,
36
+ metric: str,
37
+ batch_size: int,
38
+ extra_params: Optional[Dict[str, Any]],
39
+ namespace: Optional[str] = None,
40
+ ):
41
+ """
42
+ Initialize the Pinecone vector store.
43
+
44
+ Args:
45
+ collection_name (str): Name of the index/collection.
46
+ embedding_model_dims (int): Dimensions of the embedding model.
47
+ client (Pinecone, optional): Existing Pinecone client instance. Defaults to None.
48
+ api_key (str, optional): API key for Pinecone. Defaults to None.
49
+ environment (str, optional): Pinecone environment. Defaults to None.
50
+ serverless_config (Dict, optional): Configuration for serverless deployment. Defaults to None.
51
+ pod_config (Dict, optional): Configuration for pod-based deployment. Defaults to None.
52
+ hybrid_search (bool, optional): Whether to enable hybrid search. Defaults to False.
53
+ metric (str, optional): Distance metric for vector similarity. Defaults to "cosine".
54
+ batch_size (int, optional): Batch size for operations. Defaults to 100.
55
+ extra_params (Dict, optional): Additional parameters for Pinecone client. Defaults to None.
56
+ namespace (str, optional): Namespace for the collection. Defaults to None.
57
+ """
58
+ if client:
59
+ self.client = client
60
+ else:
61
+ api_key = api_key or os.environ.get("PINECONE_API_KEY")
62
+ if not api_key:
63
+ raise ValueError(
64
+ "Pinecone API key must be provided either as a parameter or as an environment variable"
65
+ )
66
+
67
+ params = extra_params or {}
68
+ self.client = Pinecone(api_key=api_key, **params)
69
+
70
+ self.collection_name = collection_name
71
+ self.embedding_model_dims = embedding_model_dims
72
+ self.environment = environment
73
+ self.serverless_config = serverless_config
74
+ self.pod_config = pod_config
75
+ self.hybrid_search = hybrid_search
76
+ self.metric = metric
77
+ self.batch_size = batch_size
78
+ self.namespace = namespace
79
+
80
+ self.sparse_encoder = None
81
+ if self.hybrid_search:
82
+ try:
83
+ from pinecone_text.sparse import BM25Encoder
84
+
85
+ logger.info("Initializing BM25Encoder for sparse vectors...")
86
+ self.sparse_encoder = BM25Encoder.default()
87
+ except ImportError:
88
+ logger.warning("pinecone-text not installed. Hybrid search will be disabled.")
89
+ self.hybrid_search = False
90
+
91
+ self.create_col(embedding_model_dims, metric)
92
+
93
+ def create_col(self, vector_size: int, metric: str = "cosine"):
94
+ """
95
+ Create a new index/collection.
96
+
97
+ Args:
98
+ vector_size (int): Size of the vectors to be stored.
99
+ metric (str, optional): Distance metric for vector similarity. Defaults to "cosine".
100
+ """
101
+ existing_indexes = self.list_cols().names()
102
+
103
+ if self.collection_name in existing_indexes:
104
+ logger.debug(f"Index {self.collection_name} already exists. Skipping creation.")
105
+ self.index = self.client.Index(self.collection_name)
106
+ return
107
+
108
+ if self.serverless_config:
109
+ spec = ServerlessSpec(**self.serverless_config)
110
+ elif self.pod_config:
111
+ spec = PodSpec(**self.pod_config)
112
+ else:
113
+ spec = ServerlessSpec(cloud="aws", region="us-west-2")
114
+
115
+ self.client.create_index(
116
+ name=self.collection_name,
117
+ dimension=vector_size,
118
+ metric=metric,
119
+ spec=spec,
120
+ )
121
+
122
+ self.index = self.client.Index(self.collection_name)
123
+
124
+ def insert(
125
+ self,
126
+ vectors: List[List[float]],
127
+ payloads: Optional[List[Dict]] = None,
128
+ ids: Optional[List[Union[str, int]]] = None,
129
+ ):
130
+ """
131
+ Insert vectors into an index.
132
+
133
+ Args:
134
+ vectors (list): List of vectors to insert.
135
+ payloads (list, optional): List of payloads corresponding to vectors. Defaults to None.
136
+ ids (list, optional): List of IDs corresponding to vectors. Defaults to None.
137
+ """
138
+ logger.info(f"Inserting {len(vectors)} vectors into index {self.collection_name}")
139
+ items = []
140
+
141
+ for idx, vector in enumerate(vectors):
142
+ item_id = str(ids[idx]) if ids is not None else str(idx)
143
+ payload = payloads[idx] if payloads else {}
144
+
145
+ vector_record = {"id": item_id, "values": vector, "metadata": payload}
146
+
147
+ if self.hybrid_search and self.sparse_encoder and "text" in payload:
148
+ sparse_vector = self.sparse_encoder.encode_documents(payload["text"])
149
+ vector_record["sparse_values"] = sparse_vector
150
+
151
+ items.append(vector_record)
152
+
153
+ if len(items) >= self.batch_size:
154
+ self.index.upsert(vectors=items, namespace=self.namespace)
155
+ items = []
156
+
157
+ if items:
158
+ self.index.upsert(vectors=items, namespace=self.namespace)
159
+
160
+ def _parse_output(self, data: Dict) -> List[OutputData]:
161
+ """
162
+ Parse the output data from Pinecone search results.
163
+
164
+ Args:
165
+ data (Dict): Output data from Pinecone query.
166
+
167
+ Returns:
168
+ List[OutputData]: Parsed output data.
169
+ """
170
+ if isinstance(data, Vector):
171
+ result = OutputData(
172
+ id=data.id,
173
+ score=0.0,
174
+ payload=data.metadata,
175
+ )
176
+ return result
177
+ else:
178
+ result = []
179
+ for match in data:
180
+ entry = OutputData(
181
+ id=match.get("id"),
182
+ score=match.get("score"),
183
+ payload=match.get("metadata"),
184
+ )
185
+ result.append(entry)
186
+
187
+ return result
188
+
189
+ def _create_filter(self, filters: Optional[Dict]) -> Dict:
190
+ """
191
+ Create a filter dictionary from the provided filters.
192
+ """
193
+ if not filters:
194
+ return {}
195
+
196
+ pinecone_filter = {}
197
+
198
+ for key, value in filters.items():
199
+ if isinstance(value, dict) and "gte" in value and "lte" in value:
200
+ pinecone_filter[key] = {"$gte": value["gte"], "$lte": value["lte"]}
201
+ else:
202
+ pinecone_filter[key] = {"$eq": value}
203
+
204
+ return pinecone_filter
205
+
206
+ def search(
207
+ self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None
208
+ ) -> List[OutputData]:
209
+ """
210
+ Search for similar vectors.
211
+
212
+ Args:
213
+ query (str): Query.
214
+ vectors (list): List of vectors to search.
215
+ limit (int, optional): Number of results to return. Defaults to 5.
216
+ filters (dict, optional): Filters to apply to the search. Defaults to None.
217
+
218
+ Returns:
219
+ list: Search results.
220
+ """
221
+ filter_dict = self._create_filter(filters) if filters else None
222
+
223
+ query_params = {
224
+ "vector": vectors,
225
+ "top_k": limit,
226
+ "include_metadata": True,
227
+ "include_values": False,
228
+ }
229
+
230
+ if filter_dict:
231
+ query_params["filter"] = filter_dict
232
+
233
+ if self.hybrid_search and self.sparse_encoder and "text" in filters:
234
+ query_text = filters.get("text")
235
+ if query_text:
236
+ sparse_vector = self.sparse_encoder.encode_queries(query_text)
237
+ query_params["sparse_vector"] = sparse_vector
238
+
239
+ response = self.index.query(**query_params, namespace=self.namespace)
240
+
241
+ results = self._parse_output(response.matches)
242
+ return results
243
+
244
+ def delete(self, vector_id: Union[str, int]):
245
+ """
246
+ Delete a vector by ID.
247
+
248
+ Args:
249
+ vector_id (Union[str, int]): ID of the vector to delete.
250
+ """
251
+ self.index.delete(ids=[str(vector_id)], namespace=self.namespace)
252
+
253
+ def update(self, vector_id: Union[str, int], vector: Optional[List[float]] = None, payload: Optional[Dict] = None):
254
+ """
255
+ Update a vector and its payload.
256
+
257
+ Args:
258
+ vector_id (Union[str, int]): ID of the vector to update.
259
+ vector (list, optional): Updated vector. Defaults to None.
260
+ payload (dict, optional): Updated payload. Defaults to None.
261
+ """
262
+ item = {
263
+ "id": str(vector_id),
264
+ }
265
+
266
+ if vector is not None:
267
+ item["values"] = vector
268
+
269
+ if payload is not None:
270
+ item["metadata"] = payload
271
+
272
+ if self.hybrid_search and self.sparse_encoder and "text" in payload:
273
+ sparse_vector = self.sparse_encoder.encode_documents(payload["text"])
274
+ item["sparse_values"] = sparse_vector
275
+
276
+ self.index.upsert(vectors=[item], namespace=self.namespace)
277
+
278
+ def get(self, vector_id: Union[str, int]) -> OutputData:
279
+ """
280
+ Retrieve a vector by ID.
281
+
282
+ Args:
283
+ vector_id (Union[str, int]): ID of the vector to retrieve.
284
+
285
+ Returns:
286
+ dict: Retrieved vector or None if not found.
287
+ """
288
+ try:
289
+ response = self.index.fetch(ids=[str(vector_id)], namespace=self.namespace)
290
+ if str(vector_id) in response.vectors:
291
+ return self._parse_output(response.vectors[str(vector_id)])
292
+ return None
293
+ except Exception as e:
294
+ logger.error(f"Error retrieving vector {vector_id}: {e}")
295
+ return None
296
+
297
+ def list_cols(self):
298
+ """
299
+ List all indexes/collections.
300
+
301
+ Returns:
302
+ list: List of index information.
303
+ """
304
+ return self.client.list_indexes()
305
+
306
+ def delete_col(self):
307
+ """Delete an index/collection."""
308
+ try:
309
+ self.client.delete_index(self.collection_name)
310
+ logger.info(f"Index {self.collection_name} deleted successfully")
311
+ except Exception as e:
312
+ logger.error(f"Error deleting index {self.collection_name}: {e}")
313
+
314
+ def col_info(self) -> Dict:
315
+ """
316
+ Get information about an index/collection.
317
+
318
+ Returns:
319
+ dict: Index information.
320
+ """
321
+ return self.client.describe_index(self.collection_name)
322
+
323
+ def list(self, filters: Optional[Dict] = None, limit: int = 100) -> List[OutputData]:
324
+ """
325
+ List vectors in an index with optional filtering.
326
+
327
+ Args:
328
+ filters (dict, optional): Filters to apply to the list. Defaults to None.
329
+ limit (int, optional): Number of vectors to return. Defaults to 100.
330
+
331
+ Returns:
332
+ dict: List of vectors with their metadata.
333
+ """
334
+ filter_dict = self._create_filter(filters) if filters else None
335
+
336
+ stats = self.index.describe_index_stats()
337
+ dimension = stats.dimension
338
+
339
+ zero_vector = [0.0] * dimension
340
+
341
+ query_params = {
342
+ "vector": zero_vector,
343
+ "top_k": limit,
344
+ "include_metadata": True,
345
+ "include_values": True,
346
+ }
347
+
348
+ if filter_dict:
349
+ query_params["filter"] = filter_dict
350
+
351
+ try:
352
+ response = self.index.query(**query_params, namespace=self.namespace)
353
+ response = response.to_dict()
354
+ results = self._parse_output(response["matches"])
355
+ return [results]
356
+ except Exception as e:
357
+ logger.error(f"Error listing vectors: {e}")
358
+ return {"points": [], "next_page_token": None}
359
+
360
+ def count(self) -> int:
361
+ """
362
+ Count number of vectors in the index.
363
+
364
+ Returns:
365
+ int: Total number of vectors.
366
+ """
367
+ stats = self.index.describe_index_stats()
368
+ if self.namespace:
369
+ # Safely get the namespace stats and return vector_count, defaulting to 0 if not found
370
+ namespace_summary = (stats.namespaces or {}).get(self.namespace)
371
+ if namespace_summary:
372
+ return namespace_summary.vector_count or 0
373
+ return 0
374
+ return stats.total_vector_count or 0
375
+
376
+ def reset(self):
377
+ """
378
+ Reset the index by deleting and recreating it.
379
+ """
380
+ logger.warning(f"Resetting index {self.collection_name}...")
381
+ self.delete_col()
382
+ self.create_col(self.embedding_model_dims, self.metric)
@@ -0,0 +1,270 @@
1
+ import logging
2
+ import os
3
+ import shutil
4
+
5
+ from qdrant_client import QdrantClient
6
+ from qdrant_client.models import (
7
+ Distance,
8
+ FieldCondition,
9
+ Filter,
10
+ MatchValue,
11
+ PointIdsList,
12
+ PointStruct,
13
+ Range,
14
+ VectorParams,
15
+ )
16
+
17
+ from agentrun_mem0.vector_stores.base import VectorStoreBase
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class Qdrant(VectorStoreBase):
23
+ def __init__(
24
+ self,
25
+ collection_name: str,
26
+ embedding_model_dims: int,
27
+ client: QdrantClient = None,
28
+ host: str = None,
29
+ port: int = None,
30
+ path: str = None,
31
+ url: str = None,
32
+ api_key: str = None,
33
+ on_disk: bool = False,
34
+ ):
35
+ """
36
+ Initialize the Qdrant vector store.
37
+
38
+ Args:
39
+ collection_name (str): Name of the collection.
40
+ embedding_model_dims (int): Dimensions of the embedding model.
41
+ client (QdrantClient, optional): Existing Qdrant client instance. Defaults to None.
42
+ host (str, optional): Host address for Qdrant server. Defaults to None.
43
+ port (int, optional): Port for Qdrant server. Defaults to None.
44
+ path (str, optional): Path for local Qdrant database. Defaults to None.
45
+ url (str, optional): Full URL for Qdrant server. Defaults to None.
46
+ api_key (str, optional): API key for Qdrant server. Defaults to None.
47
+ on_disk (bool, optional): Enables persistent storage. Defaults to False.
48
+ """
49
+ if client:
50
+ self.client = client
51
+ self.is_local = False
52
+ else:
53
+ params = {}
54
+ if api_key:
55
+ params["api_key"] = api_key
56
+ if url:
57
+ params["url"] = url
58
+ if host and port:
59
+ params["host"] = host
60
+ params["port"] = port
61
+
62
+ if not params:
63
+ params["path"] = path
64
+ self.is_local = True
65
+ if not on_disk:
66
+ if os.path.exists(path) and os.path.isdir(path):
67
+ shutil.rmtree(path)
68
+ else:
69
+ self.is_local = False
70
+
71
+ self.client = QdrantClient(**params)
72
+
73
+ self.collection_name = collection_name
74
+ self.embedding_model_dims = embedding_model_dims
75
+ self.on_disk = on_disk
76
+ self.create_col(embedding_model_dims, on_disk)
77
+
78
+ def create_col(self, vector_size: int, on_disk: bool, distance: Distance = Distance.COSINE):
79
+ """
80
+ Create a new collection.
81
+
82
+ Args:
83
+ vector_size (int): Size of the vectors to be stored.
84
+ on_disk (bool): Enables persistent storage.
85
+ distance (Distance, optional): Distance metric for vector similarity. Defaults to Distance.COSINE.
86
+ """
87
+ # Skip creating collection if already exists
88
+ response = self.list_cols()
89
+ for collection in response.collections:
90
+ if collection.name == self.collection_name:
91
+ logger.debug(f"Collection {self.collection_name} already exists. Skipping creation.")
92
+ self._create_filter_indexes()
93
+ return
94
+
95
+ self.client.create_collection(
96
+ collection_name=self.collection_name,
97
+ vectors_config=VectorParams(size=vector_size, distance=distance, on_disk=on_disk),
98
+ )
99
+ self._create_filter_indexes()
100
+
101
+ def _create_filter_indexes(self):
102
+ """Create indexes for commonly used filter fields to enable filtering."""
103
+ # Only create payload indexes for remote Qdrant servers
104
+ if self.is_local:
105
+ logger.debug("Skipping payload index creation for local Qdrant (not supported)")
106
+ return
107
+
108
+ common_fields = ["user_id", "agent_id", "run_id", "actor_id"]
109
+
110
+ for field in common_fields:
111
+ try:
112
+ self.client.create_payload_index(
113
+ collection_name=self.collection_name,
114
+ field_name=field,
115
+ field_schema="keyword"
116
+ )
117
+ logger.info(f"Created index for {field} in collection {self.collection_name}")
118
+ except Exception as e:
119
+ logger.debug(f"Index for {field} might already exist: {e}")
120
+
121
+ def insert(self, vectors: list, payloads: list = None, ids: list = None):
122
+ """
123
+ Insert vectors into a collection.
124
+
125
+ Args:
126
+ vectors (list): List of vectors to insert.
127
+ payloads (list, optional): List of payloads corresponding to vectors. Defaults to None.
128
+ ids (list, optional): List of IDs corresponding to vectors. Defaults to None.
129
+ """
130
+ logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}")
131
+ points = [
132
+ PointStruct(
133
+ id=idx if ids is None else ids[idx],
134
+ vector=vector,
135
+ payload=payloads[idx] if payloads else {},
136
+ )
137
+ for idx, vector in enumerate(vectors)
138
+ ]
139
+ self.client.upsert(collection_name=self.collection_name, points=points)
140
+
141
+ def _create_filter(self, filters: dict) -> Filter:
142
+ """
143
+ Create a Filter object from the provided filters.
144
+
145
+ Args:
146
+ filters (dict): Filters to apply.
147
+
148
+ Returns:
149
+ Filter: The created Filter object.
150
+ """
151
+ if not filters:
152
+ return None
153
+
154
+ conditions = []
155
+ for key, value in filters.items():
156
+ if isinstance(value, dict) and "gte" in value and "lte" in value:
157
+ conditions.append(FieldCondition(key=key, range=Range(gte=value["gte"], lte=value["lte"])))
158
+ else:
159
+ conditions.append(FieldCondition(key=key, match=MatchValue(value=value)))
160
+ return Filter(must=conditions) if conditions else None
161
+
162
+ def search(self, query: str, vectors: list, limit: int = 5, filters: dict = None) -> list:
163
+ """
164
+ Search for similar vectors.
165
+
166
+ Args:
167
+ query (str): Query.
168
+ vectors (list): Query vector.
169
+ limit (int, optional): Number of results to return. Defaults to 5.
170
+ filters (dict, optional): Filters to apply to the search. Defaults to None.
171
+
172
+ Returns:
173
+ list: Search results.
174
+ """
175
+ query_filter = self._create_filter(filters) if filters else None
176
+ hits = self.client.query_points(
177
+ collection_name=self.collection_name,
178
+ query=vectors,
179
+ query_filter=query_filter,
180
+ limit=limit,
181
+ )
182
+ return hits.points
183
+
184
+ def delete(self, vector_id: int):
185
+ """
186
+ Delete a vector by ID.
187
+
188
+ Args:
189
+ vector_id (int): ID of the vector to delete.
190
+ """
191
+ self.client.delete(
192
+ collection_name=self.collection_name,
193
+ points_selector=PointIdsList(
194
+ points=[vector_id],
195
+ ),
196
+ )
197
+
198
+ def update(self, vector_id: int, vector: list = None, payload: dict = None):
199
+ """
200
+ Update a vector and its payload.
201
+
202
+ Args:
203
+ vector_id (int): ID of the vector to update.
204
+ vector (list, optional): Updated vector. Defaults to None.
205
+ payload (dict, optional): Updated payload. Defaults to None.
206
+ """
207
+ point = PointStruct(id=vector_id, vector=vector, payload=payload)
208
+ self.client.upsert(collection_name=self.collection_name, points=[point])
209
+
210
+ def get(self, vector_id: int) -> dict:
211
+ """
212
+ Retrieve a vector by ID.
213
+
214
+ Args:
215
+ vector_id (int): ID of the vector to retrieve.
216
+
217
+ Returns:
218
+ dict: Retrieved vector.
219
+ """
220
+ result = self.client.retrieve(collection_name=self.collection_name, ids=[vector_id], with_payload=True)
221
+ return result[0] if result else None
222
+
223
+ def list_cols(self) -> list:
224
+ """
225
+ List all collections.
226
+
227
+ Returns:
228
+ list: List of collection names.
229
+ """
230
+ return self.client.get_collections()
231
+
232
+ def delete_col(self):
233
+ """Delete a collection."""
234
+ self.client.delete_collection(collection_name=self.collection_name)
235
+
236
+ def col_info(self) -> dict:
237
+ """
238
+ Get information about a collection.
239
+
240
+ Returns:
241
+ dict: Collection information.
242
+ """
243
+ return self.client.get_collection(collection_name=self.collection_name)
244
+
245
+ def list(self, filters: dict = None, limit: int = 100) -> list:
246
+ """
247
+ List all vectors in a collection.
248
+
249
+ Args:
250
+ filters (dict, optional): Filters to apply to the list. Defaults to None.
251
+ limit (int, optional): Number of vectors to return. Defaults to 100.
252
+
253
+ Returns:
254
+ list: List of vectors.
255
+ """
256
+ query_filter = self._create_filter(filters) if filters else None
257
+ result = self.client.scroll(
258
+ collection_name=self.collection_name,
259
+ scroll_filter=query_filter,
260
+ limit=limit,
261
+ with_payload=True,
262
+ with_vectors=False,
263
+ )
264
+ return result
265
+
266
+ def reset(self):
267
+ """Reset the index by deleting and recreating it."""
268
+ logger.warning(f"Resetting index {self.collection_name}...")
269
+ self.delete_col()
270
+ self.create_col(self.embedding_model_dims, self.on_disk)