agentrun-mem0ai 0.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. agentrun_mem0/__init__.py +6 -0
  2. agentrun_mem0/client/__init__.py +0 -0
  3. agentrun_mem0/client/main.py +1747 -0
  4. agentrun_mem0/client/project.py +931 -0
  5. agentrun_mem0/client/utils.py +115 -0
  6. agentrun_mem0/configs/__init__.py +0 -0
  7. agentrun_mem0/configs/base.py +90 -0
  8. agentrun_mem0/configs/embeddings/__init__.py +0 -0
  9. agentrun_mem0/configs/embeddings/base.py +110 -0
  10. agentrun_mem0/configs/enums.py +7 -0
  11. agentrun_mem0/configs/llms/__init__.py +0 -0
  12. agentrun_mem0/configs/llms/anthropic.py +56 -0
  13. agentrun_mem0/configs/llms/aws_bedrock.py +192 -0
  14. agentrun_mem0/configs/llms/azure.py +57 -0
  15. agentrun_mem0/configs/llms/base.py +62 -0
  16. agentrun_mem0/configs/llms/deepseek.py +56 -0
  17. agentrun_mem0/configs/llms/lmstudio.py +59 -0
  18. agentrun_mem0/configs/llms/ollama.py +56 -0
  19. agentrun_mem0/configs/llms/openai.py +79 -0
  20. agentrun_mem0/configs/llms/vllm.py +56 -0
  21. agentrun_mem0/configs/prompts.py +459 -0
  22. agentrun_mem0/configs/rerankers/__init__.py +0 -0
  23. agentrun_mem0/configs/rerankers/base.py +17 -0
  24. agentrun_mem0/configs/rerankers/cohere.py +15 -0
  25. agentrun_mem0/configs/rerankers/config.py +12 -0
  26. agentrun_mem0/configs/rerankers/huggingface.py +17 -0
  27. agentrun_mem0/configs/rerankers/llm.py +48 -0
  28. agentrun_mem0/configs/rerankers/sentence_transformer.py +16 -0
  29. agentrun_mem0/configs/rerankers/zero_entropy.py +28 -0
  30. agentrun_mem0/configs/vector_stores/__init__.py +0 -0
  31. agentrun_mem0/configs/vector_stores/alibabacloud_mysql.py +64 -0
  32. agentrun_mem0/configs/vector_stores/aliyun_tablestore.py +32 -0
  33. agentrun_mem0/configs/vector_stores/azure_ai_search.py +57 -0
  34. agentrun_mem0/configs/vector_stores/azure_mysql.py +84 -0
  35. agentrun_mem0/configs/vector_stores/baidu.py +27 -0
  36. agentrun_mem0/configs/vector_stores/chroma.py +58 -0
  37. agentrun_mem0/configs/vector_stores/databricks.py +61 -0
  38. agentrun_mem0/configs/vector_stores/elasticsearch.py +65 -0
  39. agentrun_mem0/configs/vector_stores/faiss.py +37 -0
  40. agentrun_mem0/configs/vector_stores/langchain.py +30 -0
  41. agentrun_mem0/configs/vector_stores/milvus.py +42 -0
  42. agentrun_mem0/configs/vector_stores/mongodb.py +25 -0
  43. agentrun_mem0/configs/vector_stores/neptune.py +27 -0
  44. agentrun_mem0/configs/vector_stores/opensearch.py +41 -0
  45. agentrun_mem0/configs/vector_stores/pgvector.py +52 -0
  46. agentrun_mem0/configs/vector_stores/pinecone.py +55 -0
  47. agentrun_mem0/configs/vector_stores/qdrant.py +47 -0
  48. agentrun_mem0/configs/vector_stores/redis.py +24 -0
  49. agentrun_mem0/configs/vector_stores/s3_vectors.py +28 -0
  50. agentrun_mem0/configs/vector_stores/supabase.py +44 -0
  51. agentrun_mem0/configs/vector_stores/upstash_vector.py +34 -0
  52. agentrun_mem0/configs/vector_stores/valkey.py +15 -0
  53. agentrun_mem0/configs/vector_stores/vertex_ai_vector_search.py +28 -0
  54. agentrun_mem0/configs/vector_stores/weaviate.py +41 -0
  55. agentrun_mem0/embeddings/__init__.py +0 -0
  56. agentrun_mem0/embeddings/aws_bedrock.py +100 -0
  57. agentrun_mem0/embeddings/azure_openai.py +55 -0
  58. agentrun_mem0/embeddings/base.py +31 -0
  59. agentrun_mem0/embeddings/configs.py +30 -0
  60. agentrun_mem0/embeddings/gemini.py +39 -0
  61. agentrun_mem0/embeddings/huggingface.py +44 -0
  62. agentrun_mem0/embeddings/langchain.py +35 -0
  63. agentrun_mem0/embeddings/lmstudio.py +29 -0
  64. agentrun_mem0/embeddings/mock.py +11 -0
  65. agentrun_mem0/embeddings/ollama.py +53 -0
  66. agentrun_mem0/embeddings/openai.py +49 -0
  67. agentrun_mem0/embeddings/together.py +31 -0
  68. agentrun_mem0/embeddings/vertexai.py +64 -0
  69. agentrun_mem0/exceptions.py +503 -0
  70. agentrun_mem0/graphs/__init__.py +0 -0
  71. agentrun_mem0/graphs/configs.py +105 -0
  72. agentrun_mem0/graphs/neptune/__init__.py +0 -0
  73. agentrun_mem0/graphs/neptune/base.py +497 -0
  74. agentrun_mem0/graphs/neptune/neptunedb.py +511 -0
  75. agentrun_mem0/graphs/neptune/neptunegraph.py +474 -0
  76. agentrun_mem0/graphs/tools.py +371 -0
  77. agentrun_mem0/graphs/utils.py +97 -0
  78. agentrun_mem0/llms/__init__.py +0 -0
  79. agentrun_mem0/llms/anthropic.py +87 -0
  80. agentrun_mem0/llms/aws_bedrock.py +665 -0
  81. agentrun_mem0/llms/azure_openai.py +141 -0
  82. agentrun_mem0/llms/azure_openai_structured.py +91 -0
  83. agentrun_mem0/llms/base.py +131 -0
  84. agentrun_mem0/llms/configs.py +34 -0
  85. agentrun_mem0/llms/deepseek.py +107 -0
  86. agentrun_mem0/llms/gemini.py +201 -0
  87. agentrun_mem0/llms/groq.py +88 -0
  88. agentrun_mem0/llms/langchain.py +94 -0
  89. agentrun_mem0/llms/litellm.py +87 -0
  90. agentrun_mem0/llms/lmstudio.py +114 -0
  91. agentrun_mem0/llms/ollama.py +117 -0
  92. agentrun_mem0/llms/openai.py +147 -0
  93. agentrun_mem0/llms/openai_structured.py +52 -0
  94. agentrun_mem0/llms/sarvam.py +89 -0
  95. agentrun_mem0/llms/together.py +88 -0
  96. agentrun_mem0/llms/vllm.py +107 -0
  97. agentrun_mem0/llms/xai.py +52 -0
  98. agentrun_mem0/memory/__init__.py +0 -0
  99. agentrun_mem0/memory/base.py +63 -0
  100. agentrun_mem0/memory/graph_memory.py +698 -0
  101. agentrun_mem0/memory/kuzu_memory.py +713 -0
  102. agentrun_mem0/memory/main.py +2229 -0
  103. agentrun_mem0/memory/memgraph_memory.py +689 -0
  104. agentrun_mem0/memory/setup.py +56 -0
  105. agentrun_mem0/memory/storage.py +218 -0
  106. agentrun_mem0/memory/telemetry.py +90 -0
  107. agentrun_mem0/memory/utils.py +208 -0
  108. agentrun_mem0/proxy/__init__.py +0 -0
  109. agentrun_mem0/proxy/main.py +189 -0
  110. agentrun_mem0/reranker/__init__.py +9 -0
  111. agentrun_mem0/reranker/base.py +20 -0
  112. agentrun_mem0/reranker/cohere_reranker.py +85 -0
  113. agentrun_mem0/reranker/huggingface_reranker.py +147 -0
  114. agentrun_mem0/reranker/llm_reranker.py +142 -0
  115. agentrun_mem0/reranker/sentence_transformer_reranker.py +107 -0
  116. agentrun_mem0/reranker/zero_entropy_reranker.py +96 -0
  117. agentrun_mem0/utils/factory.py +283 -0
  118. agentrun_mem0/utils/gcp_auth.py +167 -0
  119. agentrun_mem0/vector_stores/__init__.py +0 -0
  120. agentrun_mem0/vector_stores/alibabacloud_mysql.py +547 -0
  121. agentrun_mem0/vector_stores/aliyun_tablestore.py +252 -0
  122. agentrun_mem0/vector_stores/azure_ai_search.py +396 -0
  123. agentrun_mem0/vector_stores/azure_mysql.py +463 -0
  124. agentrun_mem0/vector_stores/baidu.py +368 -0
  125. agentrun_mem0/vector_stores/base.py +58 -0
  126. agentrun_mem0/vector_stores/chroma.py +332 -0
  127. agentrun_mem0/vector_stores/configs.py +67 -0
  128. agentrun_mem0/vector_stores/databricks.py +761 -0
  129. agentrun_mem0/vector_stores/elasticsearch.py +237 -0
  130. agentrun_mem0/vector_stores/faiss.py +479 -0
  131. agentrun_mem0/vector_stores/langchain.py +180 -0
  132. agentrun_mem0/vector_stores/milvus.py +250 -0
  133. agentrun_mem0/vector_stores/mongodb.py +310 -0
  134. agentrun_mem0/vector_stores/neptune_analytics.py +467 -0
  135. agentrun_mem0/vector_stores/opensearch.py +292 -0
  136. agentrun_mem0/vector_stores/pgvector.py +404 -0
  137. agentrun_mem0/vector_stores/pinecone.py +382 -0
  138. agentrun_mem0/vector_stores/qdrant.py +270 -0
  139. agentrun_mem0/vector_stores/redis.py +295 -0
  140. agentrun_mem0/vector_stores/s3_vectors.py +176 -0
  141. agentrun_mem0/vector_stores/supabase.py +237 -0
  142. agentrun_mem0/vector_stores/upstash_vector.py +293 -0
  143. agentrun_mem0/vector_stores/valkey.py +824 -0
  144. agentrun_mem0/vector_stores/vertex_ai_vector_search.py +635 -0
  145. agentrun_mem0/vector_stores/weaviate.py +343 -0
  146. agentrun_mem0ai-0.0.11.data/data/README.md +205 -0
  147. agentrun_mem0ai-0.0.11.dist-info/METADATA +277 -0
  148. agentrun_mem0ai-0.0.11.dist-info/RECORD +150 -0
  149. agentrun_mem0ai-0.0.11.dist-info/WHEEL +4 -0
  150. agentrun_mem0ai-0.0.11.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,479 @@
1
+ import logging
2
+ import os
3
+ import pickle
4
+ import uuid
5
+ from pathlib import Path
6
+ from typing import Dict, List, Optional
7
+
8
+ import numpy as np
9
+ from pydantic import BaseModel
10
+
11
+ import warnings
12
+
13
+ try:
14
+ # Suppress SWIG deprecation warnings from FAISS
15
+ warnings.filterwarnings("ignore", category=DeprecationWarning, message=".*SwigPy.*")
16
+ warnings.filterwarnings("ignore", category=DeprecationWarning, message=".*swigvarlink.*")
17
+
18
+ logging.getLogger("faiss").setLevel(logging.WARNING)
19
+ logging.getLogger("faiss.loader").setLevel(logging.WARNING)
20
+
21
+ import faiss
22
+ except ImportError:
23
+ raise ImportError(
24
+ "Could not import faiss python package. "
25
+ "Please install it with `pip install faiss-gpu` (for CUDA supported GPU) "
26
+ "or `pip install faiss-cpu` (depending on Python version)."
27
+ )
28
+
29
+ from agentrun_mem0.vector_stores.base import VectorStoreBase
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ class OutputData(BaseModel):
35
+ id: Optional[str] # memory id
36
+ score: Optional[float] # distance
37
+ payload: Optional[Dict] # metadata
38
+
39
+
40
+ class FAISS(VectorStoreBase):
41
+ def __init__(
42
+ self,
43
+ collection_name: str,
44
+ path: Optional[str] = None,
45
+ distance_strategy: str = "euclidean",
46
+ normalize_L2: bool = False,
47
+ embedding_model_dims: int = 1536,
48
+ ):
49
+ """
50
+ Initialize the FAISS vector store.
51
+
52
+ Args:
53
+ collection_name (str): Name of the collection.
54
+ path (str, optional): Path for local FAISS database. Defaults to None.
55
+ distance_strategy (str, optional): Distance strategy to use. Options: 'euclidean', 'inner_product', 'cosine'.
56
+ Defaults to "euclidean".
57
+ normalize_L2 (bool, optional): Whether to normalize L2 vectors. Only applicable for euclidean distance.
58
+ Defaults to False.
59
+ """
60
+ self.collection_name = collection_name
61
+ self.path = path or f"/tmp/faiss/{collection_name}"
62
+ self.distance_strategy = distance_strategy
63
+ self.normalize_L2 = normalize_L2
64
+ self.embedding_model_dims = embedding_model_dims
65
+
66
+ # Initialize storage structures
67
+ self.index = None
68
+ self.docstore = {}
69
+ self.index_to_id = {}
70
+
71
+ # Create directory if it doesn't exist
72
+ if self.path:
73
+ os.makedirs(os.path.dirname(self.path), exist_ok=True)
74
+
75
+ # Try to load existing index if available
76
+ index_path = f"{self.path}/{collection_name}.faiss"
77
+ docstore_path = f"{self.path}/{collection_name}.pkl"
78
+ if os.path.exists(index_path) and os.path.exists(docstore_path):
79
+ self._load(index_path, docstore_path)
80
+ else:
81
+ self.create_col(collection_name)
82
+
83
+ def _load(self, index_path: str, docstore_path: str):
84
+ """
85
+ Load FAISS index and docstore from disk.
86
+
87
+ Args:
88
+ index_path (str): Path to FAISS index file.
89
+ docstore_path (str): Path to docstore pickle file.
90
+ """
91
+ try:
92
+ self.index = faiss.read_index(index_path)
93
+ with open(docstore_path, "rb") as f:
94
+ self.docstore, self.index_to_id = pickle.load(f)
95
+ logger.info(f"Loaded FAISS index from {index_path} with {self.index.ntotal} vectors")
96
+ except Exception as e:
97
+ logger.warning(f"Failed to load FAISS index: {e}")
98
+
99
+ self.docstore = {}
100
+ self.index_to_id = {}
101
+
102
+ def _save(self):
103
+ """Save FAISS index and docstore to disk."""
104
+ if not self.path or not self.index:
105
+ return
106
+
107
+ try:
108
+ os.makedirs(self.path, exist_ok=True)
109
+ index_path = f"{self.path}/{self.collection_name}.faiss"
110
+ docstore_path = f"{self.path}/{self.collection_name}.pkl"
111
+
112
+ faiss.write_index(self.index, index_path)
113
+ with open(docstore_path, "wb") as f:
114
+ pickle.dump((self.docstore, self.index_to_id), f)
115
+ except Exception as e:
116
+ logger.warning(f"Failed to save FAISS index: {e}")
117
+
118
+ def _parse_output(self, scores, ids, limit=None) -> List[OutputData]:
119
+ """
120
+ Parse the output data.
121
+
122
+ Args:
123
+ scores: Similarity scores from FAISS.
124
+ ids: Indices from FAISS.
125
+ limit: Maximum number of results to return.
126
+
127
+ Returns:
128
+ List[OutputData]: Parsed output data.
129
+ """
130
+ if limit is None:
131
+ limit = len(ids)
132
+
133
+ results = []
134
+ for i in range(min(len(ids), limit)):
135
+ if ids[i] == -1: # FAISS returns -1 for empty results
136
+ continue
137
+
138
+ index_id = int(ids[i])
139
+ vector_id = self.index_to_id.get(index_id)
140
+ if vector_id is None:
141
+ continue
142
+
143
+ payload = self.docstore.get(vector_id)
144
+ if payload is None:
145
+ continue
146
+
147
+ payload_copy = payload.copy()
148
+
149
+ score = float(scores[i])
150
+ entry = OutputData(
151
+ id=vector_id,
152
+ score=score,
153
+ payload=payload_copy,
154
+ )
155
+ results.append(entry)
156
+
157
+ return results
158
+
159
+ def create_col(self, name: str, distance: str = None):
160
+ """
161
+ Create a new collection.
162
+
163
+ Args:
164
+ name (str): Name of the collection.
165
+ distance (str, optional): Distance metric to use. Overrides the distance_strategy
166
+ passed during initialization. Defaults to None.
167
+
168
+ Returns:
169
+ self: The FAISS instance.
170
+ """
171
+ distance_strategy = distance or self.distance_strategy
172
+
173
+ # Create index based on distance strategy
174
+ if distance_strategy.lower() == "inner_product" or distance_strategy.lower() == "cosine":
175
+ self.index = faiss.IndexFlatIP(self.embedding_model_dims)
176
+ else:
177
+ self.index = faiss.IndexFlatL2(self.embedding_model_dims)
178
+
179
+ self.collection_name = name
180
+
181
+ self._save()
182
+
183
+ return self
184
+
185
+ def insert(
186
+ self,
187
+ vectors: List[list],
188
+ payloads: Optional[List[Dict]] = None,
189
+ ids: Optional[List[str]] = None,
190
+ ):
191
+ """
192
+ Insert vectors into a collection.
193
+
194
+ Args:
195
+ vectors (List[list]): List of vectors to insert.
196
+ payloads (Optional[List[Dict]], optional): List of payloads corresponding to vectors. Defaults to None.
197
+ ids (Optional[List[str]], optional): List of IDs corresponding to vectors. Defaults to None.
198
+ """
199
+ if self.index is None:
200
+ raise ValueError("Collection not initialized. Call create_col first.")
201
+
202
+ if ids is None:
203
+ ids = [str(uuid.uuid4()) for _ in range(len(vectors))]
204
+
205
+ if payloads is None:
206
+ payloads = [{} for _ in range(len(vectors))]
207
+
208
+ if len(vectors) != len(ids) or len(vectors) != len(payloads):
209
+ raise ValueError("Vectors, payloads, and IDs must have the same length")
210
+
211
+ vectors_np = np.array(vectors, dtype=np.float32)
212
+
213
+ if self.normalize_L2 and self.distance_strategy.lower() == "euclidean":
214
+ faiss.normalize_L2(vectors_np)
215
+
216
+ self.index.add(vectors_np)
217
+
218
+ starting_idx = len(self.index_to_id)
219
+ for i, (vector_id, payload) in enumerate(zip(ids, payloads)):
220
+ self.docstore[vector_id] = payload.copy()
221
+ self.index_to_id[starting_idx + i] = vector_id
222
+
223
+ self._save()
224
+
225
+ logger.info(f"Inserted {len(vectors)} vectors into collection {self.collection_name}")
226
+
227
+ def search(
228
+ self, query: str, vectors: List[list], limit: int = 5, filters: Optional[Dict] = None
229
+ ) -> List[OutputData]:
230
+ """
231
+ Search for similar vectors.
232
+
233
+ Args:
234
+ query (str): Query (not used, kept for API compatibility).
235
+ vectors (List[list]): List of vectors to search.
236
+ limit (int, optional): Number of results to return. Defaults to 5.
237
+ filters (Optional[Dict], optional): Filters to apply to the search. Defaults to None.
238
+
239
+ Returns:
240
+ List[OutputData]: Search results.
241
+ """
242
+ if self.index is None:
243
+ raise ValueError("Collection not initialized. Call create_col first.")
244
+
245
+ query_vectors = np.array(vectors, dtype=np.float32)
246
+
247
+ if len(query_vectors.shape) == 1:
248
+ query_vectors = query_vectors.reshape(1, -1)
249
+
250
+ if self.normalize_L2 and self.distance_strategy.lower() == "euclidean":
251
+ faiss.normalize_L2(query_vectors)
252
+
253
+ fetch_k = limit * 2 if filters else limit
254
+ scores, indices = self.index.search(query_vectors, fetch_k)
255
+
256
+ results = self._parse_output(scores[0], indices[0], limit)
257
+
258
+ if filters:
259
+ filtered_results = []
260
+ for result in results:
261
+ if self._apply_filters(result.payload, filters):
262
+ filtered_results.append(result)
263
+ if len(filtered_results) >= limit:
264
+ break
265
+ results = filtered_results[:limit]
266
+
267
+ return results
268
+
269
+ def _apply_filters(self, payload: Dict, filters: Dict) -> bool:
270
+ """
271
+ Apply filters to a payload.
272
+
273
+ Args:
274
+ payload (Dict): Payload to filter.
275
+ filters (Dict): Filters to apply.
276
+
277
+ Returns:
278
+ bool: True if payload passes filters, False otherwise.
279
+ """
280
+ if not filters or not payload:
281
+ return True
282
+
283
+ for key, value in filters.items():
284
+ if key not in payload:
285
+ return False
286
+
287
+ if isinstance(value, list):
288
+ if payload[key] not in value:
289
+ return False
290
+ elif payload[key] != value:
291
+ return False
292
+
293
+ return True
294
+
295
+ def delete(self, vector_id: str):
296
+ """
297
+ Delete a vector by ID.
298
+
299
+ Args:
300
+ vector_id (str): ID of the vector to delete.
301
+ """
302
+ if self.index is None:
303
+ raise ValueError("Collection not initialized. Call create_col first.")
304
+
305
+ index_to_delete = None
306
+ for idx, vid in self.index_to_id.items():
307
+ if vid == vector_id:
308
+ index_to_delete = idx
309
+ break
310
+
311
+ if index_to_delete is not None:
312
+ self.docstore.pop(vector_id, None)
313
+ self.index_to_id.pop(index_to_delete, None)
314
+
315
+ self._save()
316
+
317
+ logger.info(f"Deleted vector {vector_id} from collection {self.collection_name}")
318
+ else:
319
+ logger.warning(f"Vector {vector_id} not found in collection {self.collection_name}")
320
+
321
+ def update(
322
+ self,
323
+ vector_id: str,
324
+ vector: Optional[List[float]] = None,
325
+ payload: Optional[Dict] = None,
326
+ ):
327
+ """
328
+ Update a vector and its payload.
329
+
330
+ Args:
331
+ vector_id (str): ID of the vector to update.
332
+ vector (Optional[List[float]], optional): Updated vector. Defaults to None.
333
+ payload (Optional[Dict], optional): Updated payload. Defaults to None.
334
+ """
335
+ if self.index is None:
336
+ raise ValueError("Collection not initialized. Call create_col first.")
337
+
338
+ if vector_id not in self.docstore:
339
+ raise ValueError(f"Vector {vector_id} not found")
340
+
341
+ current_payload = self.docstore[vector_id].copy()
342
+
343
+ if payload is not None:
344
+ self.docstore[vector_id] = payload.copy()
345
+ current_payload = self.docstore[vector_id].copy()
346
+
347
+ if vector is not None:
348
+ self.delete(vector_id)
349
+ self.insert([vector], [current_payload], [vector_id])
350
+ else:
351
+ self._save()
352
+
353
+ logger.info(f"Updated vector {vector_id} in collection {self.collection_name}")
354
+
355
+ def get(self, vector_id: str) -> OutputData:
356
+ """
357
+ Retrieve a vector by ID.
358
+
359
+ Args:
360
+ vector_id (str): ID of the vector to retrieve.
361
+
362
+ Returns:
363
+ OutputData: Retrieved vector.
364
+ """
365
+ if self.index is None:
366
+ raise ValueError("Collection not initialized. Call create_col first.")
367
+
368
+ if vector_id not in self.docstore:
369
+ return None
370
+
371
+ payload = self.docstore[vector_id].copy()
372
+
373
+ return OutputData(
374
+ id=vector_id,
375
+ score=None,
376
+ payload=payload,
377
+ )
378
+
379
+ def list_cols(self) -> List[str]:
380
+ """
381
+ List all collections.
382
+
383
+ Returns:
384
+ List[str]: List of collection names.
385
+ """
386
+ if not self.path:
387
+ return [self.collection_name] if self.index else []
388
+
389
+ try:
390
+ collections = []
391
+ path = Path(self.path).parent
392
+ for file in path.glob("*.faiss"):
393
+ collections.append(file.stem)
394
+ return collections
395
+ except Exception as e:
396
+ logger.warning(f"Failed to list collections: {e}")
397
+ return [self.collection_name] if self.index else []
398
+
399
+ def delete_col(self):
400
+ """
401
+ Delete a collection.
402
+ """
403
+ if self.path:
404
+ try:
405
+ index_path = f"{self.path}/{self.collection_name}.faiss"
406
+ docstore_path = f"{self.path}/{self.collection_name}.pkl"
407
+
408
+ if os.path.exists(index_path):
409
+ os.remove(index_path)
410
+ if os.path.exists(docstore_path):
411
+ os.remove(docstore_path)
412
+
413
+ logger.info(f"Deleted collection {self.collection_name}")
414
+ except Exception as e:
415
+ logger.warning(f"Failed to delete collection: {e}")
416
+
417
+ self.index = None
418
+ self.docstore = {}
419
+ self.index_to_id = {}
420
+
421
+ def col_info(self) -> Dict:
422
+ """
423
+ Get information about a collection.
424
+
425
+ Returns:
426
+ Dict: Collection information.
427
+ """
428
+ if self.index is None:
429
+ return {"name": self.collection_name, "count": 0}
430
+
431
+ return {
432
+ "name": self.collection_name,
433
+ "count": self.index.ntotal,
434
+ "dimension": self.index.d,
435
+ "distance": self.distance_strategy,
436
+ }
437
+
438
+ def list(self, filters: Optional[Dict] = None, limit: int = 100) -> List[OutputData]:
439
+ """
440
+ List all vectors in a collection.
441
+
442
+ Args:
443
+ filters (Optional[Dict], optional): Filters to apply to the list. Defaults to None.
444
+ limit (int, optional): Number of vectors to return. Defaults to 100.
445
+
446
+ Returns:
447
+ List[OutputData]: List of vectors.
448
+ """
449
+ if self.index is None:
450
+ return []
451
+
452
+ results = []
453
+ count = 0
454
+
455
+ for vector_id, payload in self.docstore.items():
456
+ if filters and not self._apply_filters(payload, filters):
457
+ continue
458
+
459
+ payload_copy = payload.copy()
460
+
461
+ results.append(
462
+ OutputData(
463
+ id=vector_id,
464
+ score=None,
465
+ payload=payload_copy,
466
+ )
467
+ )
468
+
469
+ count += 1
470
+ if count >= limit:
471
+ break
472
+
473
+ return [results]
474
+
475
+ def reset(self):
476
+ """Reset the index by deleting and recreating it."""
477
+ logger.warning(f"Resetting index {self.collection_name}...")
478
+ self.delete_col()
479
+ self.create_col(self.collection_name)
@@ -0,0 +1,180 @@
1
+ import logging
2
+ from typing import Dict, List, Optional
3
+
4
+ from pydantic import BaseModel
5
+
6
+ try:
7
+ from langchain_community.vectorstores import VectorStore
8
+ except ImportError:
9
+ raise ImportError(
10
+ "The 'langchain_community' library is required. Please install it using 'pip install langchain_community'."
11
+ )
12
+
13
+ from agentrun_mem0.vector_stores.base import VectorStoreBase
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class OutputData(BaseModel):
19
+ id: Optional[str] # memory id
20
+ score: Optional[float] # distance
21
+ payload: Optional[Dict] # metadata
22
+
23
+
24
+ class Langchain(VectorStoreBase):
25
+ def __init__(self, client: VectorStore, collection_name: str = "mem0"):
26
+ self.client = client
27
+ self.collection_name = collection_name
28
+
29
+ def _parse_output(self, data: Dict) -> List[OutputData]:
30
+ """
31
+ Parse the output data.
32
+
33
+ Args:
34
+ data (Dict): Output data or list of Document objects.
35
+
36
+ Returns:
37
+ List[OutputData]: Parsed output data.
38
+ """
39
+ # Check if input is a list of Document objects
40
+ if isinstance(data, list) and all(hasattr(doc, "metadata") for doc in data if hasattr(doc, "__dict__")):
41
+ result = []
42
+ for doc in data:
43
+ entry = OutputData(
44
+ id=getattr(doc, "id", None),
45
+ score=None, # Document objects typically don't include scores
46
+ payload=getattr(doc, "metadata", {}),
47
+ )
48
+ result.append(entry)
49
+ return result
50
+
51
+ # Original format handling
52
+ keys = ["ids", "distances", "metadatas"]
53
+ values = []
54
+
55
+ for key in keys:
56
+ value = data.get(key, [])
57
+ if isinstance(value, list) and value and isinstance(value[0], list):
58
+ value = value[0]
59
+ values.append(value)
60
+
61
+ ids, distances, metadatas = values
62
+ max_length = max(len(v) for v in values if isinstance(v, list) and v is not None)
63
+
64
+ result = []
65
+ for i in range(max_length):
66
+ entry = OutputData(
67
+ id=ids[i] if isinstance(ids, list) and ids and i < len(ids) else None,
68
+ score=(distances[i] if isinstance(distances, list) and distances and i < len(distances) else None),
69
+ payload=(metadatas[i] if isinstance(metadatas, list) and metadatas and i < len(metadatas) else None),
70
+ )
71
+ result.append(entry)
72
+
73
+ return result
74
+
75
+ def create_col(self, name, vector_size=None, distance=None):
76
+ self.collection_name = name
77
+ return self.client
78
+
79
+ def insert(
80
+ self, vectors: List[List[float]], payloads: Optional[List[Dict]] = None, ids: Optional[List[str]] = None
81
+ ):
82
+ """
83
+ Insert vectors into the LangChain vectorstore.
84
+ """
85
+ # Check if client has add_embeddings method
86
+ if hasattr(self.client, "add_embeddings"):
87
+ # Some LangChain vectorstores have a direct add_embeddings method
88
+ self.client.add_embeddings(embeddings=vectors, metadatas=payloads, ids=ids)
89
+ else:
90
+ # Fallback to add_texts method
91
+ texts = [payload.get("data", "") for payload in payloads] if payloads else [""] * len(vectors)
92
+ self.client.add_texts(texts=texts, metadatas=payloads, ids=ids)
93
+
94
+ def search(self, query: str, vectors: List[List[float]], limit: int = 5, filters: Optional[Dict] = None):
95
+ """
96
+ Search for similar vectors in LangChain.
97
+ """
98
+ # For each vector, perform a similarity search
99
+ if filters:
100
+ results = self.client.similarity_search_by_vector(embedding=vectors, k=limit, filter=filters)
101
+ else:
102
+ results = self.client.similarity_search_by_vector(embedding=vectors, k=limit)
103
+
104
+ final_results = self._parse_output(results)
105
+ return final_results
106
+
107
+ def delete(self, vector_id):
108
+ """
109
+ Delete a vector by ID.
110
+ """
111
+ self.client.delete(ids=[vector_id])
112
+
113
+ def update(self, vector_id, vector=None, payload=None):
114
+ """
115
+ Update a vector and its payload.
116
+ """
117
+ self.delete(vector_id)
118
+ self.insert(vector, payload, [vector_id])
119
+
120
+ def get(self, vector_id):
121
+ """
122
+ Retrieve a vector by ID.
123
+ """
124
+ docs = self.client.get_by_ids([vector_id])
125
+ if docs and len(docs) > 0:
126
+ doc = docs[0]
127
+ return self._parse_output([doc])[0]
128
+ return None
129
+
130
+ def list_cols(self):
131
+ """
132
+ List all collections.
133
+ """
134
+ # LangChain doesn't have collections
135
+ return [self.collection_name]
136
+
137
+ def delete_col(self):
138
+ """
139
+ Delete a collection.
140
+ """
141
+ logger.warning("Deleting collection")
142
+ if hasattr(self.client, "delete_collection"):
143
+ self.client.delete_collection()
144
+ elif hasattr(self.client, "reset_collection"):
145
+ self.client.reset_collection()
146
+ else:
147
+ self.client.delete(ids=None)
148
+
149
+ def col_info(self):
150
+ """
151
+ Get information about a collection.
152
+ """
153
+ return {"name": self.collection_name}
154
+
155
+ def list(self, filters=None, limit=None):
156
+ """
157
+ List all vectors in a collection.
158
+ """
159
+ try:
160
+ if hasattr(self.client, "_collection") and hasattr(self.client._collection, "get"):
161
+ # Convert mem0 filters to Chroma where clause if needed
162
+ where_clause = None
163
+ if filters:
164
+ # Handle all filters, not just user_id
165
+ where_clause = filters
166
+
167
+ result = self.client._collection.get(where=where_clause, limit=limit)
168
+
169
+ # Convert the result to the expected format
170
+ if result and isinstance(result, dict):
171
+ return [self._parse_output(result)]
172
+ return []
173
+ except Exception as e:
174
+ logger.error(f"Error listing vectors from Chroma: {e}")
175
+ return []
176
+
177
+ def reset(self):
178
+ """Reset the index by deleting and recreating it."""
179
+ logger.warning(f"Resetting collection: {self.collection_name}")
180
+ self.delete_col()