agentrun-mem0ai 0.0.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentrun_mem0/__init__.py +6 -0
- agentrun_mem0/client/__init__.py +0 -0
- agentrun_mem0/client/main.py +1747 -0
- agentrun_mem0/client/project.py +931 -0
- agentrun_mem0/client/utils.py +115 -0
- agentrun_mem0/configs/__init__.py +0 -0
- agentrun_mem0/configs/base.py +90 -0
- agentrun_mem0/configs/embeddings/__init__.py +0 -0
- agentrun_mem0/configs/embeddings/base.py +110 -0
- agentrun_mem0/configs/enums.py +7 -0
- agentrun_mem0/configs/llms/__init__.py +0 -0
- agentrun_mem0/configs/llms/anthropic.py +56 -0
- agentrun_mem0/configs/llms/aws_bedrock.py +192 -0
- agentrun_mem0/configs/llms/azure.py +57 -0
- agentrun_mem0/configs/llms/base.py +62 -0
- agentrun_mem0/configs/llms/deepseek.py +56 -0
- agentrun_mem0/configs/llms/lmstudio.py +59 -0
- agentrun_mem0/configs/llms/ollama.py +56 -0
- agentrun_mem0/configs/llms/openai.py +79 -0
- agentrun_mem0/configs/llms/vllm.py +56 -0
- agentrun_mem0/configs/prompts.py +459 -0
- agentrun_mem0/configs/rerankers/__init__.py +0 -0
- agentrun_mem0/configs/rerankers/base.py +17 -0
- agentrun_mem0/configs/rerankers/cohere.py +15 -0
- agentrun_mem0/configs/rerankers/config.py +12 -0
- agentrun_mem0/configs/rerankers/huggingface.py +17 -0
- agentrun_mem0/configs/rerankers/llm.py +48 -0
- agentrun_mem0/configs/rerankers/sentence_transformer.py +16 -0
- agentrun_mem0/configs/rerankers/zero_entropy.py +28 -0
- agentrun_mem0/configs/vector_stores/__init__.py +0 -0
- agentrun_mem0/configs/vector_stores/alibabacloud_mysql.py +64 -0
- agentrun_mem0/configs/vector_stores/aliyun_tablestore.py +32 -0
- agentrun_mem0/configs/vector_stores/azure_ai_search.py +57 -0
- agentrun_mem0/configs/vector_stores/azure_mysql.py +84 -0
- agentrun_mem0/configs/vector_stores/baidu.py +27 -0
- agentrun_mem0/configs/vector_stores/chroma.py +58 -0
- agentrun_mem0/configs/vector_stores/databricks.py +61 -0
- agentrun_mem0/configs/vector_stores/elasticsearch.py +65 -0
- agentrun_mem0/configs/vector_stores/faiss.py +37 -0
- agentrun_mem0/configs/vector_stores/langchain.py +30 -0
- agentrun_mem0/configs/vector_stores/milvus.py +42 -0
- agentrun_mem0/configs/vector_stores/mongodb.py +25 -0
- agentrun_mem0/configs/vector_stores/neptune.py +27 -0
- agentrun_mem0/configs/vector_stores/opensearch.py +41 -0
- agentrun_mem0/configs/vector_stores/pgvector.py +52 -0
- agentrun_mem0/configs/vector_stores/pinecone.py +55 -0
- agentrun_mem0/configs/vector_stores/qdrant.py +47 -0
- agentrun_mem0/configs/vector_stores/redis.py +24 -0
- agentrun_mem0/configs/vector_stores/s3_vectors.py +28 -0
- agentrun_mem0/configs/vector_stores/supabase.py +44 -0
- agentrun_mem0/configs/vector_stores/upstash_vector.py +34 -0
- agentrun_mem0/configs/vector_stores/valkey.py +15 -0
- agentrun_mem0/configs/vector_stores/vertex_ai_vector_search.py +28 -0
- agentrun_mem0/configs/vector_stores/weaviate.py +41 -0
- agentrun_mem0/embeddings/__init__.py +0 -0
- agentrun_mem0/embeddings/aws_bedrock.py +100 -0
- agentrun_mem0/embeddings/azure_openai.py +55 -0
- agentrun_mem0/embeddings/base.py +31 -0
- agentrun_mem0/embeddings/configs.py +30 -0
- agentrun_mem0/embeddings/gemini.py +39 -0
- agentrun_mem0/embeddings/huggingface.py +44 -0
- agentrun_mem0/embeddings/langchain.py +35 -0
- agentrun_mem0/embeddings/lmstudio.py +29 -0
- agentrun_mem0/embeddings/mock.py +11 -0
- agentrun_mem0/embeddings/ollama.py +53 -0
- agentrun_mem0/embeddings/openai.py +49 -0
- agentrun_mem0/embeddings/together.py +31 -0
- agentrun_mem0/embeddings/vertexai.py +64 -0
- agentrun_mem0/exceptions.py +503 -0
- agentrun_mem0/graphs/__init__.py +0 -0
- agentrun_mem0/graphs/configs.py +105 -0
- agentrun_mem0/graphs/neptune/__init__.py +0 -0
- agentrun_mem0/graphs/neptune/base.py +497 -0
- agentrun_mem0/graphs/neptune/neptunedb.py +511 -0
- agentrun_mem0/graphs/neptune/neptunegraph.py +474 -0
- agentrun_mem0/graphs/tools.py +371 -0
- agentrun_mem0/graphs/utils.py +97 -0
- agentrun_mem0/llms/__init__.py +0 -0
- agentrun_mem0/llms/anthropic.py +87 -0
- agentrun_mem0/llms/aws_bedrock.py +665 -0
- agentrun_mem0/llms/azure_openai.py +141 -0
- agentrun_mem0/llms/azure_openai_structured.py +91 -0
- agentrun_mem0/llms/base.py +131 -0
- agentrun_mem0/llms/configs.py +34 -0
- agentrun_mem0/llms/deepseek.py +107 -0
- agentrun_mem0/llms/gemini.py +201 -0
- agentrun_mem0/llms/groq.py +88 -0
- agentrun_mem0/llms/langchain.py +94 -0
- agentrun_mem0/llms/litellm.py +87 -0
- agentrun_mem0/llms/lmstudio.py +114 -0
- agentrun_mem0/llms/ollama.py +117 -0
- agentrun_mem0/llms/openai.py +147 -0
- agentrun_mem0/llms/openai_structured.py +52 -0
- agentrun_mem0/llms/sarvam.py +89 -0
- agentrun_mem0/llms/together.py +88 -0
- agentrun_mem0/llms/vllm.py +107 -0
- agentrun_mem0/llms/xai.py +52 -0
- agentrun_mem0/memory/__init__.py +0 -0
- agentrun_mem0/memory/base.py +63 -0
- agentrun_mem0/memory/graph_memory.py +698 -0
- agentrun_mem0/memory/kuzu_memory.py +713 -0
- agentrun_mem0/memory/main.py +2229 -0
- agentrun_mem0/memory/memgraph_memory.py +689 -0
- agentrun_mem0/memory/setup.py +56 -0
- agentrun_mem0/memory/storage.py +218 -0
- agentrun_mem0/memory/telemetry.py +90 -0
- agentrun_mem0/memory/utils.py +208 -0
- agentrun_mem0/proxy/__init__.py +0 -0
- agentrun_mem0/proxy/main.py +189 -0
- agentrun_mem0/reranker/__init__.py +9 -0
- agentrun_mem0/reranker/base.py +20 -0
- agentrun_mem0/reranker/cohere_reranker.py +85 -0
- agentrun_mem0/reranker/huggingface_reranker.py +147 -0
- agentrun_mem0/reranker/llm_reranker.py +142 -0
- agentrun_mem0/reranker/sentence_transformer_reranker.py +107 -0
- agentrun_mem0/reranker/zero_entropy_reranker.py +96 -0
- agentrun_mem0/utils/factory.py +283 -0
- agentrun_mem0/utils/gcp_auth.py +167 -0
- agentrun_mem0/vector_stores/__init__.py +0 -0
- agentrun_mem0/vector_stores/alibabacloud_mysql.py +547 -0
- agentrun_mem0/vector_stores/aliyun_tablestore.py +252 -0
- agentrun_mem0/vector_stores/azure_ai_search.py +396 -0
- agentrun_mem0/vector_stores/azure_mysql.py +463 -0
- agentrun_mem0/vector_stores/baidu.py +368 -0
- agentrun_mem0/vector_stores/base.py +58 -0
- agentrun_mem0/vector_stores/chroma.py +332 -0
- agentrun_mem0/vector_stores/configs.py +67 -0
- agentrun_mem0/vector_stores/databricks.py +761 -0
- agentrun_mem0/vector_stores/elasticsearch.py +237 -0
- agentrun_mem0/vector_stores/faiss.py +479 -0
- agentrun_mem0/vector_stores/langchain.py +180 -0
- agentrun_mem0/vector_stores/milvus.py +250 -0
- agentrun_mem0/vector_stores/mongodb.py +310 -0
- agentrun_mem0/vector_stores/neptune_analytics.py +467 -0
- agentrun_mem0/vector_stores/opensearch.py +292 -0
- agentrun_mem0/vector_stores/pgvector.py +404 -0
- agentrun_mem0/vector_stores/pinecone.py +382 -0
- agentrun_mem0/vector_stores/qdrant.py +270 -0
- agentrun_mem0/vector_stores/redis.py +295 -0
- agentrun_mem0/vector_stores/s3_vectors.py +176 -0
- agentrun_mem0/vector_stores/supabase.py +237 -0
- agentrun_mem0/vector_stores/upstash_vector.py +293 -0
- agentrun_mem0/vector_stores/valkey.py +824 -0
- agentrun_mem0/vector_stores/vertex_ai_vector_search.py +635 -0
- agentrun_mem0/vector_stores/weaviate.py +343 -0
- agentrun_mem0ai-0.0.11.data/data/README.md +205 -0
- agentrun_mem0ai-0.0.11.dist-info/METADATA +277 -0
- agentrun_mem0ai-0.0.11.dist-info/RECORD +150 -0
- agentrun_mem0ai-0.0.11.dist-info/WHEEL +4 -0
- agentrun_mem0ai-0.0.11.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,479 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
import pickle
|
|
4
|
+
import uuid
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Dict, List, Optional
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
from pydantic import BaseModel
|
|
10
|
+
|
|
11
|
+
import warnings
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
# Suppress SWIG deprecation warnings from FAISS
|
|
15
|
+
warnings.filterwarnings("ignore", category=DeprecationWarning, message=".*SwigPy.*")
|
|
16
|
+
warnings.filterwarnings("ignore", category=DeprecationWarning, message=".*swigvarlink.*")
|
|
17
|
+
|
|
18
|
+
logging.getLogger("faiss").setLevel(logging.WARNING)
|
|
19
|
+
logging.getLogger("faiss.loader").setLevel(logging.WARNING)
|
|
20
|
+
|
|
21
|
+
import faiss
|
|
22
|
+
except ImportError:
|
|
23
|
+
raise ImportError(
|
|
24
|
+
"Could not import faiss python package. "
|
|
25
|
+
"Please install it with `pip install faiss-gpu` (for CUDA supported GPU) "
|
|
26
|
+
"or `pip install faiss-cpu` (depending on Python version)."
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
from agentrun_mem0.vector_stores.base import VectorStoreBase
|
|
30
|
+
|
|
31
|
+
logger = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class OutputData(BaseModel):
|
|
35
|
+
id: Optional[str] # memory id
|
|
36
|
+
score: Optional[float] # distance
|
|
37
|
+
payload: Optional[Dict] # metadata
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class FAISS(VectorStoreBase):
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
collection_name: str,
|
|
44
|
+
path: Optional[str] = None,
|
|
45
|
+
distance_strategy: str = "euclidean",
|
|
46
|
+
normalize_L2: bool = False,
|
|
47
|
+
embedding_model_dims: int = 1536,
|
|
48
|
+
):
|
|
49
|
+
"""
|
|
50
|
+
Initialize the FAISS vector store.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
collection_name (str): Name of the collection.
|
|
54
|
+
path (str, optional): Path for local FAISS database. Defaults to None.
|
|
55
|
+
distance_strategy (str, optional): Distance strategy to use. Options: 'euclidean', 'inner_product', 'cosine'.
|
|
56
|
+
Defaults to "euclidean".
|
|
57
|
+
normalize_L2 (bool, optional): Whether to normalize L2 vectors. Only applicable for euclidean distance.
|
|
58
|
+
Defaults to False.
|
|
59
|
+
"""
|
|
60
|
+
self.collection_name = collection_name
|
|
61
|
+
self.path = path or f"/tmp/faiss/{collection_name}"
|
|
62
|
+
self.distance_strategy = distance_strategy
|
|
63
|
+
self.normalize_L2 = normalize_L2
|
|
64
|
+
self.embedding_model_dims = embedding_model_dims
|
|
65
|
+
|
|
66
|
+
# Initialize storage structures
|
|
67
|
+
self.index = None
|
|
68
|
+
self.docstore = {}
|
|
69
|
+
self.index_to_id = {}
|
|
70
|
+
|
|
71
|
+
# Create directory if it doesn't exist
|
|
72
|
+
if self.path:
|
|
73
|
+
os.makedirs(os.path.dirname(self.path), exist_ok=True)
|
|
74
|
+
|
|
75
|
+
# Try to load existing index if available
|
|
76
|
+
index_path = f"{self.path}/{collection_name}.faiss"
|
|
77
|
+
docstore_path = f"{self.path}/{collection_name}.pkl"
|
|
78
|
+
if os.path.exists(index_path) and os.path.exists(docstore_path):
|
|
79
|
+
self._load(index_path, docstore_path)
|
|
80
|
+
else:
|
|
81
|
+
self.create_col(collection_name)
|
|
82
|
+
|
|
83
|
+
def _load(self, index_path: str, docstore_path: str):
|
|
84
|
+
"""
|
|
85
|
+
Load FAISS index and docstore from disk.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
index_path (str): Path to FAISS index file.
|
|
89
|
+
docstore_path (str): Path to docstore pickle file.
|
|
90
|
+
"""
|
|
91
|
+
try:
|
|
92
|
+
self.index = faiss.read_index(index_path)
|
|
93
|
+
with open(docstore_path, "rb") as f:
|
|
94
|
+
self.docstore, self.index_to_id = pickle.load(f)
|
|
95
|
+
logger.info(f"Loaded FAISS index from {index_path} with {self.index.ntotal} vectors")
|
|
96
|
+
except Exception as e:
|
|
97
|
+
logger.warning(f"Failed to load FAISS index: {e}")
|
|
98
|
+
|
|
99
|
+
self.docstore = {}
|
|
100
|
+
self.index_to_id = {}
|
|
101
|
+
|
|
102
|
+
def _save(self):
|
|
103
|
+
"""Save FAISS index and docstore to disk."""
|
|
104
|
+
if not self.path or not self.index:
|
|
105
|
+
return
|
|
106
|
+
|
|
107
|
+
try:
|
|
108
|
+
os.makedirs(self.path, exist_ok=True)
|
|
109
|
+
index_path = f"{self.path}/{self.collection_name}.faiss"
|
|
110
|
+
docstore_path = f"{self.path}/{self.collection_name}.pkl"
|
|
111
|
+
|
|
112
|
+
faiss.write_index(self.index, index_path)
|
|
113
|
+
with open(docstore_path, "wb") as f:
|
|
114
|
+
pickle.dump((self.docstore, self.index_to_id), f)
|
|
115
|
+
except Exception as e:
|
|
116
|
+
logger.warning(f"Failed to save FAISS index: {e}")
|
|
117
|
+
|
|
118
|
+
def _parse_output(self, scores, ids, limit=None) -> List[OutputData]:
|
|
119
|
+
"""
|
|
120
|
+
Parse the output data.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
scores: Similarity scores from FAISS.
|
|
124
|
+
ids: Indices from FAISS.
|
|
125
|
+
limit: Maximum number of results to return.
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
List[OutputData]: Parsed output data.
|
|
129
|
+
"""
|
|
130
|
+
if limit is None:
|
|
131
|
+
limit = len(ids)
|
|
132
|
+
|
|
133
|
+
results = []
|
|
134
|
+
for i in range(min(len(ids), limit)):
|
|
135
|
+
if ids[i] == -1: # FAISS returns -1 for empty results
|
|
136
|
+
continue
|
|
137
|
+
|
|
138
|
+
index_id = int(ids[i])
|
|
139
|
+
vector_id = self.index_to_id.get(index_id)
|
|
140
|
+
if vector_id is None:
|
|
141
|
+
continue
|
|
142
|
+
|
|
143
|
+
payload = self.docstore.get(vector_id)
|
|
144
|
+
if payload is None:
|
|
145
|
+
continue
|
|
146
|
+
|
|
147
|
+
payload_copy = payload.copy()
|
|
148
|
+
|
|
149
|
+
score = float(scores[i])
|
|
150
|
+
entry = OutputData(
|
|
151
|
+
id=vector_id,
|
|
152
|
+
score=score,
|
|
153
|
+
payload=payload_copy,
|
|
154
|
+
)
|
|
155
|
+
results.append(entry)
|
|
156
|
+
|
|
157
|
+
return results
|
|
158
|
+
|
|
159
|
+
def create_col(self, name: str, distance: str = None):
|
|
160
|
+
"""
|
|
161
|
+
Create a new collection.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
name (str): Name of the collection.
|
|
165
|
+
distance (str, optional): Distance metric to use. Overrides the distance_strategy
|
|
166
|
+
passed during initialization. Defaults to None.
|
|
167
|
+
|
|
168
|
+
Returns:
|
|
169
|
+
self: The FAISS instance.
|
|
170
|
+
"""
|
|
171
|
+
distance_strategy = distance or self.distance_strategy
|
|
172
|
+
|
|
173
|
+
# Create index based on distance strategy
|
|
174
|
+
if distance_strategy.lower() == "inner_product" or distance_strategy.lower() == "cosine":
|
|
175
|
+
self.index = faiss.IndexFlatIP(self.embedding_model_dims)
|
|
176
|
+
else:
|
|
177
|
+
self.index = faiss.IndexFlatL2(self.embedding_model_dims)
|
|
178
|
+
|
|
179
|
+
self.collection_name = name
|
|
180
|
+
|
|
181
|
+
self._save()
|
|
182
|
+
|
|
183
|
+
return self
|
|
184
|
+
|
|
185
|
+
def insert(
|
|
186
|
+
self,
|
|
187
|
+
vectors: List[list],
|
|
188
|
+
payloads: Optional[List[Dict]] = None,
|
|
189
|
+
ids: Optional[List[str]] = None,
|
|
190
|
+
):
|
|
191
|
+
"""
|
|
192
|
+
Insert vectors into a collection.
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
vectors (List[list]): List of vectors to insert.
|
|
196
|
+
payloads (Optional[List[Dict]], optional): List of payloads corresponding to vectors. Defaults to None.
|
|
197
|
+
ids (Optional[List[str]], optional): List of IDs corresponding to vectors. Defaults to None.
|
|
198
|
+
"""
|
|
199
|
+
if self.index is None:
|
|
200
|
+
raise ValueError("Collection not initialized. Call create_col first.")
|
|
201
|
+
|
|
202
|
+
if ids is None:
|
|
203
|
+
ids = [str(uuid.uuid4()) for _ in range(len(vectors))]
|
|
204
|
+
|
|
205
|
+
if payloads is None:
|
|
206
|
+
payloads = [{} for _ in range(len(vectors))]
|
|
207
|
+
|
|
208
|
+
if len(vectors) != len(ids) or len(vectors) != len(payloads):
|
|
209
|
+
raise ValueError("Vectors, payloads, and IDs must have the same length")
|
|
210
|
+
|
|
211
|
+
vectors_np = np.array(vectors, dtype=np.float32)
|
|
212
|
+
|
|
213
|
+
if self.normalize_L2 and self.distance_strategy.lower() == "euclidean":
|
|
214
|
+
faiss.normalize_L2(vectors_np)
|
|
215
|
+
|
|
216
|
+
self.index.add(vectors_np)
|
|
217
|
+
|
|
218
|
+
starting_idx = len(self.index_to_id)
|
|
219
|
+
for i, (vector_id, payload) in enumerate(zip(ids, payloads)):
|
|
220
|
+
self.docstore[vector_id] = payload.copy()
|
|
221
|
+
self.index_to_id[starting_idx + i] = vector_id
|
|
222
|
+
|
|
223
|
+
self._save()
|
|
224
|
+
|
|
225
|
+
logger.info(f"Inserted {len(vectors)} vectors into collection {self.collection_name}")
|
|
226
|
+
|
|
227
|
+
def search(
|
|
228
|
+
self, query: str, vectors: List[list], limit: int = 5, filters: Optional[Dict] = None
|
|
229
|
+
) -> List[OutputData]:
|
|
230
|
+
"""
|
|
231
|
+
Search for similar vectors.
|
|
232
|
+
|
|
233
|
+
Args:
|
|
234
|
+
query (str): Query (not used, kept for API compatibility).
|
|
235
|
+
vectors (List[list]): List of vectors to search.
|
|
236
|
+
limit (int, optional): Number of results to return. Defaults to 5.
|
|
237
|
+
filters (Optional[Dict], optional): Filters to apply to the search. Defaults to None.
|
|
238
|
+
|
|
239
|
+
Returns:
|
|
240
|
+
List[OutputData]: Search results.
|
|
241
|
+
"""
|
|
242
|
+
if self.index is None:
|
|
243
|
+
raise ValueError("Collection not initialized. Call create_col first.")
|
|
244
|
+
|
|
245
|
+
query_vectors = np.array(vectors, dtype=np.float32)
|
|
246
|
+
|
|
247
|
+
if len(query_vectors.shape) == 1:
|
|
248
|
+
query_vectors = query_vectors.reshape(1, -1)
|
|
249
|
+
|
|
250
|
+
if self.normalize_L2 and self.distance_strategy.lower() == "euclidean":
|
|
251
|
+
faiss.normalize_L2(query_vectors)
|
|
252
|
+
|
|
253
|
+
fetch_k = limit * 2 if filters else limit
|
|
254
|
+
scores, indices = self.index.search(query_vectors, fetch_k)
|
|
255
|
+
|
|
256
|
+
results = self._parse_output(scores[0], indices[0], limit)
|
|
257
|
+
|
|
258
|
+
if filters:
|
|
259
|
+
filtered_results = []
|
|
260
|
+
for result in results:
|
|
261
|
+
if self._apply_filters(result.payload, filters):
|
|
262
|
+
filtered_results.append(result)
|
|
263
|
+
if len(filtered_results) >= limit:
|
|
264
|
+
break
|
|
265
|
+
results = filtered_results[:limit]
|
|
266
|
+
|
|
267
|
+
return results
|
|
268
|
+
|
|
269
|
+
def _apply_filters(self, payload: Dict, filters: Dict) -> bool:
|
|
270
|
+
"""
|
|
271
|
+
Apply filters to a payload.
|
|
272
|
+
|
|
273
|
+
Args:
|
|
274
|
+
payload (Dict): Payload to filter.
|
|
275
|
+
filters (Dict): Filters to apply.
|
|
276
|
+
|
|
277
|
+
Returns:
|
|
278
|
+
bool: True if payload passes filters, False otherwise.
|
|
279
|
+
"""
|
|
280
|
+
if not filters or not payload:
|
|
281
|
+
return True
|
|
282
|
+
|
|
283
|
+
for key, value in filters.items():
|
|
284
|
+
if key not in payload:
|
|
285
|
+
return False
|
|
286
|
+
|
|
287
|
+
if isinstance(value, list):
|
|
288
|
+
if payload[key] not in value:
|
|
289
|
+
return False
|
|
290
|
+
elif payload[key] != value:
|
|
291
|
+
return False
|
|
292
|
+
|
|
293
|
+
return True
|
|
294
|
+
|
|
295
|
+
def delete(self, vector_id: str):
|
|
296
|
+
"""
|
|
297
|
+
Delete a vector by ID.
|
|
298
|
+
|
|
299
|
+
Args:
|
|
300
|
+
vector_id (str): ID of the vector to delete.
|
|
301
|
+
"""
|
|
302
|
+
if self.index is None:
|
|
303
|
+
raise ValueError("Collection not initialized. Call create_col first.")
|
|
304
|
+
|
|
305
|
+
index_to_delete = None
|
|
306
|
+
for idx, vid in self.index_to_id.items():
|
|
307
|
+
if vid == vector_id:
|
|
308
|
+
index_to_delete = idx
|
|
309
|
+
break
|
|
310
|
+
|
|
311
|
+
if index_to_delete is not None:
|
|
312
|
+
self.docstore.pop(vector_id, None)
|
|
313
|
+
self.index_to_id.pop(index_to_delete, None)
|
|
314
|
+
|
|
315
|
+
self._save()
|
|
316
|
+
|
|
317
|
+
logger.info(f"Deleted vector {vector_id} from collection {self.collection_name}")
|
|
318
|
+
else:
|
|
319
|
+
logger.warning(f"Vector {vector_id} not found in collection {self.collection_name}")
|
|
320
|
+
|
|
321
|
+
def update(
|
|
322
|
+
self,
|
|
323
|
+
vector_id: str,
|
|
324
|
+
vector: Optional[List[float]] = None,
|
|
325
|
+
payload: Optional[Dict] = None,
|
|
326
|
+
):
|
|
327
|
+
"""
|
|
328
|
+
Update a vector and its payload.
|
|
329
|
+
|
|
330
|
+
Args:
|
|
331
|
+
vector_id (str): ID of the vector to update.
|
|
332
|
+
vector (Optional[List[float]], optional): Updated vector. Defaults to None.
|
|
333
|
+
payload (Optional[Dict], optional): Updated payload. Defaults to None.
|
|
334
|
+
"""
|
|
335
|
+
if self.index is None:
|
|
336
|
+
raise ValueError("Collection not initialized. Call create_col first.")
|
|
337
|
+
|
|
338
|
+
if vector_id not in self.docstore:
|
|
339
|
+
raise ValueError(f"Vector {vector_id} not found")
|
|
340
|
+
|
|
341
|
+
current_payload = self.docstore[vector_id].copy()
|
|
342
|
+
|
|
343
|
+
if payload is not None:
|
|
344
|
+
self.docstore[vector_id] = payload.copy()
|
|
345
|
+
current_payload = self.docstore[vector_id].copy()
|
|
346
|
+
|
|
347
|
+
if vector is not None:
|
|
348
|
+
self.delete(vector_id)
|
|
349
|
+
self.insert([vector], [current_payload], [vector_id])
|
|
350
|
+
else:
|
|
351
|
+
self._save()
|
|
352
|
+
|
|
353
|
+
logger.info(f"Updated vector {vector_id} in collection {self.collection_name}")
|
|
354
|
+
|
|
355
|
+
def get(self, vector_id: str) -> OutputData:
|
|
356
|
+
"""
|
|
357
|
+
Retrieve a vector by ID.
|
|
358
|
+
|
|
359
|
+
Args:
|
|
360
|
+
vector_id (str): ID of the vector to retrieve.
|
|
361
|
+
|
|
362
|
+
Returns:
|
|
363
|
+
OutputData: Retrieved vector.
|
|
364
|
+
"""
|
|
365
|
+
if self.index is None:
|
|
366
|
+
raise ValueError("Collection not initialized. Call create_col first.")
|
|
367
|
+
|
|
368
|
+
if vector_id not in self.docstore:
|
|
369
|
+
return None
|
|
370
|
+
|
|
371
|
+
payload = self.docstore[vector_id].copy()
|
|
372
|
+
|
|
373
|
+
return OutputData(
|
|
374
|
+
id=vector_id,
|
|
375
|
+
score=None,
|
|
376
|
+
payload=payload,
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
def list_cols(self) -> List[str]:
|
|
380
|
+
"""
|
|
381
|
+
List all collections.
|
|
382
|
+
|
|
383
|
+
Returns:
|
|
384
|
+
List[str]: List of collection names.
|
|
385
|
+
"""
|
|
386
|
+
if not self.path:
|
|
387
|
+
return [self.collection_name] if self.index else []
|
|
388
|
+
|
|
389
|
+
try:
|
|
390
|
+
collections = []
|
|
391
|
+
path = Path(self.path).parent
|
|
392
|
+
for file in path.glob("*.faiss"):
|
|
393
|
+
collections.append(file.stem)
|
|
394
|
+
return collections
|
|
395
|
+
except Exception as e:
|
|
396
|
+
logger.warning(f"Failed to list collections: {e}")
|
|
397
|
+
return [self.collection_name] if self.index else []
|
|
398
|
+
|
|
399
|
+
def delete_col(self):
|
|
400
|
+
"""
|
|
401
|
+
Delete a collection.
|
|
402
|
+
"""
|
|
403
|
+
if self.path:
|
|
404
|
+
try:
|
|
405
|
+
index_path = f"{self.path}/{self.collection_name}.faiss"
|
|
406
|
+
docstore_path = f"{self.path}/{self.collection_name}.pkl"
|
|
407
|
+
|
|
408
|
+
if os.path.exists(index_path):
|
|
409
|
+
os.remove(index_path)
|
|
410
|
+
if os.path.exists(docstore_path):
|
|
411
|
+
os.remove(docstore_path)
|
|
412
|
+
|
|
413
|
+
logger.info(f"Deleted collection {self.collection_name}")
|
|
414
|
+
except Exception as e:
|
|
415
|
+
logger.warning(f"Failed to delete collection: {e}")
|
|
416
|
+
|
|
417
|
+
self.index = None
|
|
418
|
+
self.docstore = {}
|
|
419
|
+
self.index_to_id = {}
|
|
420
|
+
|
|
421
|
+
def col_info(self) -> Dict:
|
|
422
|
+
"""
|
|
423
|
+
Get information about a collection.
|
|
424
|
+
|
|
425
|
+
Returns:
|
|
426
|
+
Dict: Collection information.
|
|
427
|
+
"""
|
|
428
|
+
if self.index is None:
|
|
429
|
+
return {"name": self.collection_name, "count": 0}
|
|
430
|
+
|
|
431
|
+
return {
|
|
432
|
+
"name": self.collection_name,
|
|
433
|
+
"count": self.index.ntotal,
|
|
434
|
+
"dimension": self.index.d,
|
|
435
|
+
"distance": self.distance_strategy,
|
|
436
|
+
}
|
|
437
|
+
|
|
438
|
+
def list(self, filters: Optional[Dict] = None, limit: int = 100) -> List[OutputData]:
|
|
439
|
+
"""
|
|
440
|
+
List all vectors in a collection.
|
|
441
|
+
|
|
442
|
+
Args:
|
|
443
|
+
filters (Optional[Dict], optional): Filters to apply to the list. Defaults to None.
|
|
444
|
+
limit (int, optional): Number of vectors to return. Defaults to 100.
|
|
445
|
+
|
|
446
|
+
Returns:
|
|
447
|
+
List[OutputData]: List of vectors.
|
|
448
|
+
"""
|
|
449
|
+
if self.index is None:
|
|
450
|
+
return []
|
|
451
|
+
|
|
452
|
+
results = []
|
|
453
|
+
count = 0
|
|
454
|
+
|
|
455
|
+
for vector_id, payload in self.docstore.items():
|
|
456
|
+
if filters and not self._apply_filters(payload, filters):
|
|
457
|
+
continue
|
|
458
|
+
|
|
459
|
+
payload_copy = payload.copy()
|
|
460
|
+
|
|
461
|
+
results.append(
|
|
462
|
+
OutputData(
|
|
463
|
+
id=vector_id,
|
|
464
|
+
score=None,
|
|
465
|
+
payload=payload_copy,
|
|
466
|
+
)
|
|
467
|
+
)
|
|
468
|
+
|
|
469
|
+
count += 1
|
|
470
|
+
if count >= limit:
|
|
471
|
+
break
|
|
472
|
+
|
|
473
|
+
return [results]
|
|
474
|
+
|
|
475
|
+
def reset(self):
|
|
476
|
+
"""Reset the index by deleting and recreating it."""
|
|
477
|
+
logger.warning(f"Resetting index {self.collection_name}...")
|
|
478
|
+
self.delete_col()
|
|
479
|
+
self.create_col(self.collection_name)
|
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Dict, List, Optional
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel
|
|
5
|
+
|
|
6
|
+
try:
|
|
7
|
+
from langchain_community.vectorstores import VectorStore
|
|
8
|
+
except ImportError:
|
|
9
|
+
raise ImportError(
|
|
10
|
+
"The 'langchain_community' library is required. Please install it using 'pip install langchain_community'."
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
from agentrun_mem0.vector_stores.base import VectorStoreBase
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class OutputData(BaseModel):
|
|
19
|
+
id: Optional[str] # memory id
|
|
20
|
+
score: Optional[float] # distance
|
|
21
|
+
payload: Optional[Dict] # metadata
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class Langchain(VectorStoreBase):
|
|
25
|
+
def __init__(self, client: VectorStore, collection_name: str = "mem0"):
|
|
26
|
+
self.client = client
|
|
27
|
+
self.collection_name = collection_name
|
|
28
|
+
|
|
29
|
+
def _parse_output(self, data: Dict) -> List[OutputData]:
|
|
30
|
+
"""
|
|
31
|
+
Parse the output data.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
data (Dict): Output data or list of Document objects.
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
List[OutputData]: Parsed output data.
|
|
38
|
+
"""
|
|
39
|
+
# Check if input is a list of Document objects
|
|
40
|
+
if isinstance(data, list) and all(hasattr(doc, "metadata") for doc in data if hasattr(doc, "__dict__")):
|
|
41
|
+
result = []
|
|
42
|
+
for doc in data:
|
|
43
|
+
entry = OutputData(
|
|
44
|
+
id=getattr(doc, "id", None),
|
|
45
|
+
score=None, # Document objects typically don't include scores
|
|
46
|
+
payload=getattr(doc, "metadata", {}),
|
|
47
|
+
)
|
|
48
|
+
result.append(entry)
|
|
49
|
+
return result
|
|
50
|
+
|
|
51
|
+
# Original format handling
|
|
52
|
+
keys = ["ids", "distances", "metadatas"]
|
|
53
|
+
values = []
|
|
54
|
+
|
|
55
|
+
for key in keys:
|
|
56
|
+
value = data.get(key, [])
|
|
57
|
+
if isinstance(value, list) and value and isinstance(value[0], list):
|
|
58
|
+
value = value[0]
|
|
59
|
+
values.append(value)
|
|
60
|
+
|
|
61
|
+
ids, distances, metadatas = values
|
|
62
|
+
max_length = max(len(v) for v in values if isinstance(v, list) and v is not None)
|
|
63
|
+
|
|
64
|
+
result = []
|
|
65
|
+
for i in range(max_length):
|
|
66
|
+
entry = OutputData(
|
|
67
|
+
id=ids[i] if isinstance(ids, list) and ids and i < len(ids) else None,
|
|
68
|
+
score=(distances[i] if isinstance(distances, list) and distances and i < len(distances) else None),
|
|
69
|
+
payload=(metadatas[i] if isinstance(metadatas, list) and metadatas and i < len(metadatas) else None),
|
|
70
|
+
)
|
|
71
|
+
result.append(entry)
|
|
72
|
+
|
|
73
|
+
return result
|
|
74
|
+
|
|
75
|
+
def create_col(self, name, vector_size=None, distance=None):
|
|
76
|
+
self.collection_name = name
|
|
77
|
+
return self.client
|
|
78
|
+
|
|
79
|
+
def insert(
|
|
80
|
+
self, vectors: List[List[float]], payloads: Optional[List[Dict]] = None, ids: Optional[List[str]] = None
|
|
81
|
+
):
|
|
82
|
+
"""
|
|
83
|
+
Insert vectors into the LangChain vectorstore.
|
|
84
|
+
"""
|
|
85
|
+
# Check if client has add_embeddings method
|
|
86
|
+
if hasattr(self.client, "add_embeddings"):
|
|
87
|
+
# Some LangChain vectorstores have a direct add_embeddings method
|
|
88
|
+
self.client.add_embeddings(embeddings=vectors, metadatas=payloads, ids=ids)
|
|
89
|
+
else:
|
|
90
|
+
# Fallback to add_texts method
|
|
91
|
+
texts = [payload.get("data", "") for payload in payloads] if payloads else [""] * len(vectors)
|
|
92
|
+
self.client.add_texts(texts=texts, metadatas=payloads, ids=ids)
|
|
93
|
+
|
|
94
|
+
def search(self, query: str, vectors: List[List[float]], limit: int = 5, filters: Optional[Dict] = None):
|
|
95
|
+
"""
|
|
96
|
+
Search for similar vectors in LangChain.
|
|
97
|
+
"""
|
|
98
|
+
# For each vector, perform a similarity search
|
|
99
|
+
if filters:
|
|
100
|
+
results = self.client.similarity_search_by_vector(embedding=vectors, k=limit, filter=filters)
|
|
101
|
+
else:
|
|
102
|
+
results = self.client.similarity_search_by_vector(embedding=vectors, k=limit)
|
|
103
|
+
|
|
104
|
+
final_results = self._parse_output(results)
|
|
105
|
+
return final_results
|
|
106
|
+
|
|
107
|
+
def delete(self, vector_id):
|
|
108
|
+
"""
|
|
109
|
+
Delete a vector by ID.
|
|
110
|
+
"""
|
|
111
|
+
self.client.delete(ids=[vector_id])
|
|
112
|
+
|
|
113
|
+
def update(self, vector_id, vector=None, payload=None):
|
|
114
|
+
"""
|
|
115
|
+
Update a vector and its payload.
|
|
116
|
+
"""
|
|
117
|
+
self.delete(vector_id)
|
|
118
|
+
self.insert(vector, payload, [vector_id])
|
|
119
|
+
|
|
120
|
+
def get(self, vector_id):
|
|
121
|
+
"""
|
|
122
|
+
Retrieve a vector by ID.
|
|
123
|
+
"""
|
|
124
|
+
docs = self.client.get_by_ids([vector_id])
|
|
125
|
+
if docs and len(docs) > 0:
|
|
126
|
+
doc = docs[0]
|
|
127
|
+
return self._parse_output([doc])[0]
|
|
128
|
+
return None
|
|
129
|
+
|
|
130
|
+
def list_cols(self):
|
|
131
|
+
"""
|
|
132
|
+
List all collections.
|
|
133
|
+
"""
|
|
134
|
+
# LangChain doesn't have collections
|
|
135
|
+
return [self.collection_name]
|
|
136
|
+
|
|
137
|
+
def delete_col(self):
|
|
138
|
+
"""
|
|
139
|
+
Delete a collection.
|
|
140
|
+
"""
|
|
141
|
+
logger.warning("Deleting collection")
|
|
142
|
+
if hasattr(self.client, "delete_collection"):
|
|
143
|
+
self.client.delete_collection()
|
|
144
|
+
elif hasattr(self.client, "reset_collection"):
|
|
145
|
+
self.client.reset_collection()
|
|
146
|
+
else:
|
|
147
|
+
self.client.delete(ids=None)
|
|
148
|
+
|
|
149
|
+
def col_info(self):
|
|
150
|
+
"""
|
|
151
|
+
Get information about a collection.
|
|
152
|
+
"""
|
|
153
|
+
return {"name": self.collection_name}
|
|
154
|
+
|
|
155
|
+
def list(self, filters=None, limit=None):
|
|
156
|
+
"""
|
|
157
|
+
List all vectors in a collection.
|
|
158
|
+
"""
|
|
159
|
+
try:
|
|
160
|
+
if hasattr(self.client, "_collection") and hasattr(self.client._collection, "get"):
|
|
161
|
+
# Convert mem0 filters to Chroma where clause if needed
|
|
162
|
+
where_clause = None
|
|
163
|
+
if filters:
|
|
164
|
+
# Handle all filters, not just user_id
|
|
165
|
+
where_clause = filters
|
|
166
|
+
|
|
167
|
+
result = self.client._collection.get(where=where_clause, limit=limit)
|
|
168
|
+
|
|
169
|
+
# Convert the result to the expected format
|
|
170
|
+
if result and isinstance(result, dict):
|
|
171
|
+
return [self._parse_output(result)]
|
|
172
|
+
return []
|
|
173
|
+
except Exception as e:
|
|
174
|
+
logger.error(f"Error listing vectors from Chroma: {e}")
|
|
175
|
+
return []
|
|
176
|
+
|
|
177
|
+
def reset(self):
|
|
178
|
+
"""Reset the index by deleting and recreating it."""
|
|
179
|
+
logger.warning(f"Resetting collection: {self.collection_name}")
|
|
180
|
+
self.delete_col()
|