agentrun-mem0ai 0.0.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentrun_mem0/__init__.py +6 -0
- agentrun_mem0/client/__init__.py +0 -0
- agentrun_mem0/client/main.py +1747 -0
- agentrun_mem0/client/project.py +931 -0
- agentrun_mem0/client/utils.py +115 -0
- agentrun_mem0/configs/__init__.py +0 -0
- agentrun_mem0/configs/base.py +90 -0
- agentrun_mem0/configs/embeddings/__init__.py +0 -0
- agentrun_mem0/configs/embeddings/base.py +110 -0
- agentrun_mem0/configs/enums.py +7 -0
- agentrun_mem0/configs/llms/__init__.py +0 -0
- agentrun_mem0/configs/llms/anthropic.py +56 -0
- agentrun_mem0/configs/llms/aws_bedrock.py +192 -0
- agentrun_mem0/configs/llms/azure.py +57 -0
- agentrun_mem0/configs/llms/base.py +62 -0
- agentrun_mem0/configs/llms/deepseek.py +56 -0
- agentrun_mem0/configs/llms/lmstudio.py +59 -0
- agentrun_mem0/configs/llms/ollama.py +56 -0
- agentrun_mem0/configs/llms/openai.py +79 -0
- agentrun_mem0/configs/llms/vllm.py +56 -0
- agentrun_mem0/configs/prompts.py +459 -0
- agentrun_mem0/configs/rerankers/__init__.py +0 -0
- agentrun_mem0/configs/rerankers/base.py +17 -0
- agentrun_mem0/configs/rerankers/cohere.py +15 -0
- agentrun_mem0/configs/rerankers/config.py +12 -0
- agentrun_mem0/configs/rerankers/huggingface.py +17 -0
- agentrun_mem0/configs/rerankers/llm.py +48 -0
- agentrun_mem0/configs/rerankers/sentence_transformer.py +16 -0
- agentrun_mem0/configs/rerankers/zero_entropy.py +28 -0
- agentrun_mem0/configs/vector_stores/__init__.py +0 -0
- agentrun_mem0/configs/vector_stores/alibabacloud_mysql.py +64 -0
- agentrun_mem0/configs/vector_stores/aliyun_tablestore.py +32 -0
- agentrun_mem0/configs/vector_stores/azure_ai_search.py +57 -0
- agentrun_mem0/configs/vector_stores/azure_mysql.py +84 -0
- agentrun_mem0/configs/vector_stores/baidu.py +27 -0
- agentrun_mem0/configs/vector_stores/chroma.py +58 -0
- agentrun_mem0/configs/vector_stores/databricks.py +61 -0
- agentrun_mem0/configs/vector_stores/elasticsearch.py +65 -0
- agentrun_mem0/configs/vector_stores/faiss.py +37 -0
- agentrun_mem0/configs/vector_stores/langchain.py +30 -0
- agentrun_mem0/configs/vector_stores/milvus.py +42 -0
- agentrun_mem0/configs/vector_stores/mongodb.py +25 -0
- agentrun_mem0/configs/vector_stores/neptune.py +27 -0
- agentrun_mem0/configs/vector_stores/opensearch.py +41 -0
- agentrun_mem0/configs/vector_stores/pgvector.py +52 -0
- agentrun_mem0/configs/vector_stores/pinecone.py +55 -0
- agentrun_mem0/configs/vector_stores/qdrant.py +47 -0
- agentrun_mem0/configs/vector_stores/redis.py +24 -0
- agentrun_mem0/configs/vector_stores/s3_vectors.py +28 -0
- agentrun_mem0/configs/vector_stores/supabase.py +44 -0
- agentrun_mem0/configs/vector_stores/upstash_vector.py +34 -0
- agentrun_mem0/configs/vector_stores/valkey.py +15 -0
- agentrun_mem0/configs/vector_stores/vertex_ai_vector_search.py +28 -0
- agentrun_mem0/configs/vector_stores/weaviate.py +41 -0
- agentrun_mem0/embeddings/__init__.py +0 -0
- agentrun_mem0/embeddings/aws_bedrock.py +100 -0
- agentrun_mem0/embeddings/azure_openai.py +55 -0
- agentrun_mem0/embeddings/base.py +31 -0
- agentrun_mem0/embeddings/configs.py +30 -0
- agentrun_mem0/embeddings/gemini.py +39 -0
- agentrun_mem0/embeddings/huggingface.py +44 -0
- agentrun_mem0/embeddings/langchain.py +35 -0
- agentrun_mem0/embeddings/lmstudio.py +29 -0
- agentrun_mem0/embeddings/mock.py +11 -0
- agentrun_mem0/embeddings/ollama.py +53 -0
- agentrun_mem0/embeddings/openai.py +49 -0
- agentrun_mem0/embeddings/together.py +31 -0
- agentrun_mem0/embeddings/vertexai.py +64 -0
- agentrun_mem0/exceptions.py +503 -0
- agentrun_mem0/graphs/__init__.py +0 -0
- agentrun_mem0/graphs/configs.py +105 -0
- agentrun_mem0/graphs/neptune/__init__.py +0 -0
- agentrun_mem0/graphs/neptune/base.py +497 -0
- agentrun_mem0/graphs/neptune/neptunedb.py +511 -0
- agentrun_mem0/graphs/neptune/neptunegraph.py +474 -0
- agentrun_mem0/graphs/tools.py +371 -0
- agentrun_mem0/graphs/utils.py +97 -0
- agentrun_mem0/llms/__init__.py +0 -0
- agentrun_mem0/llms/anthropic.py +87 -0
- agentrun_mem0/llms/aws_bedrock.py +665 -0
- agentrun_mem0/llms/azure_openai.py +141 -0
- agentrun_mem0/llms/azure_openai_structured.py +91 -0
- agentrun_mem0/llms/base.py +131 -0
- agentrun_mem0/llms/configs.py +34 -0
- agentrun_mem0/llms/deepseek.py +107 -0
- agentrun_mem0/llms/gemini.py +201 -0
- agentrun_mem0/llms/groq.py +88 -0
- agentrun_mem0/llms/langchain.py +94 -0
- agentrun_mem0/llms/litellm.py +87 -0
- agentrun_mem0/llms/lmstudio.py +114 -0
- agentrun_mem0/llms/ollama.py +117 -0
- agentrun_mem0/llms/openai.py +147 -0
- agentrun_mem0/llms/openai_structured.py +52 -0
- agentrun_mem0/llms/sarvam.py +89 -0
- agentrun_mem0/llms/together.py +88 -0
- agentrun_mem0/llms/vllm.py +107 -0
- agentrun_mem0/llms/xai.py +52 -0
- agentrun_mem0/memory/__init__.py +0 -0
- agentrun_mem0/memory/base.py +63 -0
- agentrun_mem0/memory/graph_memory.py +698 -0
- agentrun_mem0/memory/kuzu_memory.py +713 -0
- agentrun_mem0/memory/main.py +2229 -0
- agentrun_mem0/memory/memgraph_memory.py +689 -0
- agentrun_mem0/memory/setup.py +56 -0
- agentrun_mem0/memory/storage.py +218 -0
- agentrun_mem0/memory/telemetry.py +90 -0
- agentrun_mem0/memory/utils.py +208 -0
- agentrun_mem0/proxy/__init__.py +0 -0
- agentrun_mem0/proxy/main.py +189 -0
- agentrun_mem0/reranker/__init__.py +9 -0
- agentrun_mem0/reranker/base.py +20 -0
- agentrun_mem0/reranker/cohere_reranker.py +85 -0
- agentrun_mem0/reranker/huggingface_reranker.py +147 -0
- agentrun_mem0/reranker/llm_reranker.py +142 -0
- agentrun_mem0/reranker/sentence_transformer_reranker.py +107 -0
- agentrun_mem0/reranker/zero_entropy_reranker.py +96 -0
- agentrun_mem0/utils/factory.py +283 -0
- agentrun_mem0/utils/gcp_auth.py +167 -0
- agentrun_mem0/vector_stores/__init__.py +0 -0
- agentrun_mem0/vector_stores/alibabacloud_mysql.py +547 -0
- agentrun_mem0/vector_stores/aliyun_tablestore.py +252 -0
- agentrun_mem0/vector_stores/azure_ai_search.py +396 -0
- agentrun_mem0/vector_stores/azure_mysql.py +463 -0
- agentrun_mem0/vector_stores/baidu.py +368 -0
- agentrun_mem0/vector_stores/base.py +58 -0
- agentrun_mem0/vector_stores/chroma.py +332 -0
- agentrun_mem0/vector_stores/configs.py +67 -0
- agentrun_mem0/vector_stores/databricks.py +761 -0
- agentrun_mem0/vector_stores/elasticsearch.py +237 -0
- agentrun_mem0/vector_stores/faiss.py +479 -0
- agentrun_mem0/vector_stores/langchain.py +180 -0
- agentrun_mem0/vector_stores/milvus.py +250 -0
- agentrun_mem0/vector_stores/mongodb.py +310 -0
- agentrun_mem0/vector_stores/neptune_analytics.py +467 -0
- agentrun_mem0/vector_stores/opensearch.py +292 -0
- agentrun_mem0/vector_stores/pgvector.py +404 -0
- agentrun_mem0/vector_stores/pinecone.py +382 -0
- agentrun_mem0/vector_stores/qdrant.py +270 -0
- agentrun_mem0/vector_stores/redis.py +295 -0
- agentrun_mem0/vector_stores/s3_vectors.py +176 -0
- agentrun_mem0/vector_stores/supabase.py +237 -0
- agentrun_mem0/vector_stores/upstash_vector.py +293 -0
- agentrun_mem0/vector_stores/valkey.py +824 -0
- agentrun_mem0/vector_stores/vertex_ai_vector_search.py +635 -0
- agentrun_mem0/vector_stores/weaviate.py +343 -0
- agentrun_mem0ai-0.0.11.data/data/README.md +205 -0
- agentrun_mem0ai-0.0.11.dist-info/METADATA +277 -0
- agentrun_mem0ai-0.0.11.dist-info/RECORD +150 -0
- agentrun_mem0ai-0.0.11.dist-info/WHEEL +4 -0
- agentrun_mem0ai-0.0.11.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,250 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Dict, Optional
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel
|
|
5
|
+
|
|
6
|
+
from agentrun_mem0.configs.vector_stores.milvus import MetricType
|
|
7
|
+
from agentrun_mem0.vector_stores.base import VectorStoreBase
|
|
8
|
+
|
|
9
|
+
try:
|
|
10
|
+
import pymilvus # noqa: F401
|
|
11
|
+
except ImportError:
|
|
12
|
+
raise ImportError("The 'pymilvus' library is required. Please install it using 'pip install pymilvus'.")
|
|
13
|
+
|
|
14
|
+
from pymilvus import CollectionSchema, DataType, FieldSchema, MilvusClient
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class OutputData(BaseModel):
|
|
20
|
+
id: Optional[str] # memory id
|
|
21
|
+
score: Optional[float] # distance
|
|
22
|
+
payload: Optional[Dict] # metadata
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class MilvusDB(VectorStoreBase):
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
url: str,
|
|
29
|
+
token: str,
|
|
30
|
+
collection_name: str,
|
|
31
|
+
embedding_model_dims: int,
|
|
32
|
+
metric_type: MetricType,
|
|
33
|
+
db_name: str,
|
|
34
|
+
) -> None:
|
|
35
|
+
"""Initialize the MilvusDB database.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
url (str): Full URL for Milvus/Zilliz server.
|
|
39
|
+
token (str): Token/api_key for Zilliz server / for local setup defaults to None.
|
|
40
|
+
collection_name (str): Name of the collection (defaults to mem0).
|
|
41
|
+
embedding_model_dims (int): Dimensions of the embedding model (defaults to 1536).
|
|
42
|
+
metric_type (MetricType): Metric type for similarity search (defaults to L2).
|
|
43
|
+
db_name (str): Name of the database (defaults to "").
|
|
44
|
+
"""
|
|
45
|
+
self.collection_name = collection_name
|
|
46
|
+
self.embedding_model_dims = embedding_model_dims
|
|
47
|
+
self.metric_type = metric_type
|
|
48
|
+
self.client = MilvusClient(uri=url, token=token, db_name=db_name)
|
|
49
|
+
self.create_col(
|
|
50
|
+
collection_name=self.collection_name,
|
|
51
|
+
vector_size=self.embedding_model_dims,
|
|
52
|
+
metric_type=self.metric_type,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
def create_col(
|
|
56
|
+
self,
|
|
57
|
+
collection_name: str,
|
|
58
|
+
vector_size: int,
|
|
59
|
+
metric_type: MetricType = MetricType.COSINE,
|
|
60
|
+
) -> None:
|
|
61
|
+
"""Create a new collection with index_type AUTOINDEX.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
collection_name (str): Name of the collection (defaults to mem0).
|
|
65
|
+
vector_size (int): Dimensions of the embedding model (defaults to 1536).
|
|
66
|
+
metric_type (MetricType, optional): etric type for similarity search. Defaults to MetricType.COSINE.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
if self.client.has_collection(collection_name):
|
|
70
|
+
logger.info(f"Collection {collection_name} already exists. Skipping creation.")
|
|
71
|
+
else:
|
|
72
|
+
fields = [
|
|
73
|
+
FieldSchema(name="id", dtype=DataType.VARCHAR, is_primary=True, max_length=512),
|
|
74
|
+
FieldSchema(name="vectors", dtype=DataType.FLOAT_VECTOR, dim=vector_size),
|
|
75
|
+
FieldSchema(name="metadata", dtype=DataType.JSON),
|
|
76
|
+
]
|
|
77
|
+
|
|
78
|
+
schema = CollectionSchema(fields, enable_dynamic_field=True)
|
|
79
|
+
|
|
80
|
+
index = self.client.prepare_index_params(
|
|
81
|
+
field_name="vectors", metric_type=metric_type, index_type="AUTOINDEX", index_name="vector_index"
|
|
82
|
+
)
|
|
83
|
+
self.client.create_collection(collection_name=collection_name, schema=schema, index_params=index)
|
|
84
|
+
|
|
85
|
+
def insert(self, ids, vectors, payloads, **kwargs: Optional[dict[str, any]]):
|
|
86
|
+
"""Insert vectors into a collection.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
vectors (List[List[float]]): List of vectors to insert.
|
|
90
|
+
payloads (List[Dict], optional): List of payloads corresponding to vectors.
|
|
91
|
+
ids (List[str], optional): List of IDs corresponding to vectors.
|
|
92
|
+
"""
|
|
93
|
+
# Batch insert all records at once for better performance and consistency
|
|
94
|
+
data = [
|
|
95
|
+
{"id": idx, "vectors": embedding, "metadata": metadata}
|
|
96
|
+
for idx, embedding, metadata in zip(ids, vectors, payloads)
|
|
97
|
+
]
|
|
98
|
+
self.client.insert(collection_name=self.collection_name, data=data, **kwargs)
|
|
99
|
+
|
|
100
|
+
def _create_filter(self, filters: dict):
|
|
101
|
+
"""Prepare filters for efficient query.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
filters (dict): filters [user_id, agent_id, run_id]
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
str: formated filter.
|
|
108
|
+
"""
|
|
109
|
+
operands = []
|
|
110
|
+
for key, value in filters.items():
|
|
111
|
+
if isinstance(value, str):
|
|
112
|
+
operands.append(f'(metadata["{key}"] == "{value}")')
|
|
113
|
+
else:
|
|
114
|
+
operands.append(f'(metadata["{key}"] == {value})')
|
|
115
|
+
|
|
116
|
+
return " and ".join(operands)
|
|
117
|
+
|
|
118
|
+
def _parse_output(self, data: list):
|
|
119
|
+
"""
|
|
120
|
+
Parse the output data.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
data (Dict): Output data.
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
List[OutputData]: Parsed output data.
|
|
127
|
+
"""
|
|
128
|
+
memory = []
|
|
129
|
+
|
|
130
|
+
for value in data:
|
|
131
|
+
uid, score, metadata = (
|
|
132
|
+
value.get("id"),
|
|
133
|
+
value.get("distance"),
|
|
134
|
+
value.get("entity", {}).get("metadata"),
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
memory_obj = OutputData(id=uid, score=score, payload=metadata)
|
|
138
|
+
memory.append(memory_obj)
|
|
139
|
+
|
|
140
|
+
return memory
|
|
141
|
+
|
|
142
|
+
def search(self, query: str, vectors: list, limit: int = 5, filters: dict = None) -> list:
|
|
143
|
+
"""
|
|
144
|
+
Search for similar vectors.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
query (str): Query.
|
|
148
|
+
vectors (List[float]): Query vector.
|
|
149
|
+
limit (int, optional): Number of results to return. Defaults to 5.
|
|
150
|
+
filters (Dict, optional): Filters to apply to the search. Defaults to None.
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
list: Search results.
|
|
154
|
+
"""
|
|
155
|
+
query_filter = self._create_filter(filters) if filters else None
|
|
156
|
+
hits = self.client.search(
|
|
157
|
+
collection_name=self.collection_name,
|
|
158
|
+
data=[vectors],
|
|
159
|
+
limit=limit,
|
|
160
|
+
filter=query_filter,
|
|
161
|
+
output_fields=["*"],
|
|
162
|
+
)
|
|
163
|
+
result = self._parse_output(data=hits[0])
|
|
164
|
+
return result
|
|
165
|
+
|
|
166
|
+
def delete(self, vector_id):
|
|
167
|
+
"""
|
|
168
|
+
Delete a vector by ID.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
vector_id (str): ID of the vector to delete.
|
|
172
|
+
"""
|
|
173
|
+
self.client.delete(collection_name=self.collection_name, ids=vector_id)
|
|
174
|
+
|
|
175
|
+
def update(self, vector_id=None, vector=None, payload=None):
|
|
176
|
+
"""
|
|
177
|
+
Update a vector and its payload.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
vector_id (str): ID of the vector to update.
|
|
181
|
+
vector (List[float], optional): Updated vector.
|
|
182
|
+
payload (Dict, optional): Updated payload.
|
|
183
|
+
"""
|
|
184
|
+
schema = {"id": vector_id, "vectors": vector, "metadata": payload}
|
|
185
|
+
self.client.upsert(collection_name=self.collection_name, data=schema)
|
|
186
|
+
|
|
187
|
+
def get(self, vector_id):
|
|
188
|
+
"""
|
|
189
|
+
Retrieve a vector by ID.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
vector_id (str): ID of the vector to retrieve.
|
|
193
|
+
|
|
194
|
+
Returns:
|
|
195
|
+
OutputData: Retrieved vector.
|
|
196
|
+
"""
|
|
197
|
+
result = self.client.get(collection_name=self.collection_name, ids=vector_id)
|
|
198
|
+
output = OutputData(
|
|
199
|
+
id=result[0].get("id", None),
|
|
200
|
+
score=None,
|
|
201
|
+
payload=result[0].get("metadata", None),
|
|
202
|
+
)
|
|
203
|
+
return output
|
|
204
|
+
|
|
205
|
+
def list_cols(self):
|
|
206
|
+
"""
|
|
207
|
+
List all collections.
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
List[str]: List of collection names.
|
|
211
|
+
"""
|
|
212
|
+
return self.client.list_collections()
|
|
213
|
+
|
|
214
|
+
def delete_col(self):
|
|
215
|
+
"""Delete a collection."""
|
|
216
|
+
return self.client.drop_collection(collection_name=self.collection_name)
|
|
217
|
+
|
|
218
|
+
def col_info(self):
|
|
219
|
+
"""
|
|
220
|
+
Get information about a collection.
|
|
221
|
+
|
|
222
|
+
Returns:
|
|
223
|
+
Dict[str, Any]: Collection information.
|
|
224
|
+
"""
|
|
225
|
+
return self.client.get_collection_stats(collection_name=self.collection_name)
|
|
226
|
+
|
|
227
|
+
def list(self, filters: dict = None, limit: int = 100) -> list:
|
|
228
|
+
"""
|
|
229
|
+
List all vectors in a collection.
|
|
230
|
+
|
|
231
|
+
Args:
|
|
232
|
+
filters (Dict, optional): Filters to apply to the list.
|
|
233
|
+
limit (int, optional): Number of vectors to return. Defaults to 100.
|
|
234
|
+
|
|
235
|
+
Returns:
|
|
236
|
+
List[OutputData]: List of vectors.
|
|
237
|
+
"""
|
|
238
|
+
query_filter = self._create_filter(filters) if filters else None
|
|
239
|
+
result = self.client.query(collection_name=self.collection_name, filter=query_filter, limit=limit)
|
|
240
|
+
memories = []
|
|
241
|
+
for data in result:
|
|
242
|
+
obj = OutputData(id=data.get("id"), score=None, payload=data.get("metadata"))
|
|
243
|
+
memories.append(obj)
|
|
244
|
+
return [memories]
|
|
245
|
+
|
|
246
|
+
def reset(self):
|
|
247
|
+
"""Reset the index by deleting and recreating it."""
|
|
248
|
+
logger.warning(f"Resetting index {self.collection_name}...")
|
|
249
|
+
self.delete_col()
|
|
250
|
+
self.create_col(self.collection_name, self.embedding_model_dims, self.metric_type)
|
|
@@ -0,0 +1,310 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Any, Dict, List, Optional
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel
|
|
5
|
+
|
|
6
|
+
try:
|
|
7
|
+
from pymongo import MongoClient
|
|
8
|
+
from pymongo.errors import PyMongoError
|
|
9
|
+
from pymongo.operations import SearchIndexModel
|
|
10
|
+
except ImportError:
|
|
11
|
+
raise ImportError("The 'pymongo' library is required. Please install it using 'pip install pymongo'.")
|
|
12
|
+
|
|
13
|
+
from agentrun_mem0.vector_stores.base import VectorStoreBase
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
logging.basicConfig(level=logging.INFO)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class OutputData(BaseModel):
|
|
20
|
+
id: Optional[str]
|
|
21
|
+
score: Optional[float]
|
|
22
|
+
payload: Optional[dict]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class MongoDB(VectorStoreBase):
|
|
26
|
+
VECTOR_TYPE = "knnVector"
|
|
27
|
+
SIMILARITY_METRIC = "cosine"
|
|
28
|
+
|
|
29
|
+
def __init__(self, db_name: str, collection_name: str, embedding_model_dims: int, mongo_uri: str):
|
|
30
|
+
"""
|
|
31
|
+
Initialize the MongoDB vector store with vector search capabilities.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
db_name (str): Database name
|
|
35
|
+
collection_name (str): Collection name
|
|
36
|
+
embedding_model_dims (int): Dimension of the embedding vector
|
|
37
|
+
mongo_uri (str): MongoDB connection URI
|
|
38
|
+
"""
|
|
39
|
+
self.collection_name = collection_name
|
|
40
|
+
self.embedding_model_dims = embedding_model_dims
|
|
41
|
+
self.db_name = db_name
|
|
42
|
+
|
|
43
|
+
self.client = MongoClient(mongo_uri)
|
|
44
|
+
self.db = self.client[db_name]
|
|
45
|
+
self.collection = self.create_col()
|
|
46
|
+
|
|
47
|
+
def create_col(self):
|
|
48
|
+
"""Create new collection with vector search index."""
|
|
49
|
+
try:
|
|
50
|
+
database = self.client[self.db_name]
|
|
51
|
+
collection_names = database.list_collection_names()
|
|
52
|
+
if self.collection_name not in collection_names:
|
|
53
|
+
logger.info(f"Collection '{self.collection_name}' does not exist. Creating it now.")
|
|
54
|
+
collection = database[self.collection_name]
|
|
55
|
+
# Insert and remove a placeholder document to create the collection
|
|
56
|
+
collection.insert_one({"_id": 0, "placeholder": True})
|
|
57
|
+
collection.delete_one({"_id": 0})
|
|
58
|
+
logger.info(f"Collection '{self.collection_name}' created successfully.")
|
|
59
|
+
else:
|
|
60
|
+
collection = database[self.collection_name]
|
|
61
|
+
|
|
62
|
+
self.index_name = f"{self.collection_name}_vector_index"
|
|
63
|
+
found_indexes = list(collection.list_search_indexes(name=self.index_name))
|
|
64
|
+
if found_indexes:
|
|
65
|
+
logger.info(f"Search index '{self.index_name}' already exists in collection '{self.collection_name}'.")
|
|
66
|
+
else:
|
|
67
|
+
search_index_model = SearchIndexModel(
|
|
68
|
+
name=self.index_name,
|
|
69
|
+
definition={
|
|
70
|
+
"mappings": {
|
|
71
|
+
"dynamic": False,
|
|
72
|
+
"fields": {
|
|
73
|
+
"embedding": {
|
|
74
|
+
"type": self.VECTOR_TYPE,
|
|
75
|
+
"dimensions": self.embedding_model_dims,
|
|
76
|
+
"similarity": self.SIMILARITY_METRIC,
|
|
77
|
+
}
|
|
78
|
+
},
|
|
79
|
+
}
|
|
80
|
+
},
|
|
81
|
+
)
|
|
82
|
+
collection.create_search_index(search_index_model)
|
|
83
|
+
logger.info(
|
|
84
|
+
f"Search index '{self.index_name}' created successfully for collection '{self.collection_name}'."
|
|
85
|
+
)
|
|
86
|
+
return collection
|
|
87
|
+
except PyMongoError as e:
|
|
88
|
+
logger.error(f"Error creating collection and search index: {e}")
|
|
89
|
+
return None
|
|
90
|
+
|
|
91
|
+
def insert(
|
|
92
|
+
self, vectors: List[List[float]], payloads: Optional[List[Dict]] = None, ids: Optional[List[str]] = None
|
|
93
|
+
) -> None:
|
|
94
|
+
"""
|
|
95
|
+
Insert vectors into the collection.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
vectors (List[List[float]]): List of vectors to insert.
|
|
99
|
+
payloads (List[Dict], optional): List of payloads corresponding to vectors.
|
|
100
|
+
ids (List[str], optional): List of IDs corresponding to vectors.
|
|
101
|
+
"""
|
|
102
|
+
logger.info(f"Inserting {len(vectors)} vectors into collection '{self.collection_name}'.")
|
|
103
|
+
|
|
104
|
+
data = []
|
|
105
|
+
for vector, payload, _id in zip(vectors, payloads or [{}] * len(vectors), ids or [None] * len(vectors)):
|
|
106
|
+
document = {"_id": _id, "embedding": vector, "payload": payload}
|
|
107
|
+
data.append(document)
|
|
108
|
+
try:
|
|
109
|
+
self.collection.insert_many(data)
|
|
110
|
+
logger.info(f"Inserted {len(data)} documents into '{self.collection_name}'.")
|
|
111
|
+
except PyMongoError as e:
|
|
112
|
+
logger.error(f"Error inserting data: {e}")
|
|
113
|
+
|
|
114
|
+
def search(self, query: str, vectors: List[float], limit=5, filters: Optional[Dict] = None) -> List[OutputData]:
|
|
115
|
+
"""
|
|
116
|
+
Search for similar vectors using the vector search index.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
query (str): Query string
|
|
120
|
+
vectors (List[float]): Query vector.
|
|
121
|
+
limit (int, optional): Number of results to return. Defaults to 5.
|
|
122
|
+
filters (Dict, optional): Filters to apply to the search.
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
List[OutputData]: Search results.
|
|
126
|
+
"""
|
|
127
|
+
|
|
128
|
+
found_indexes = list(self.collection.list_search_indexes(name=self.index_name))
|
|
129
|
+
if not found_indexes:
|
|
130
|
+
logger.error(f"Index '{self.index_name}' does not exist.")
|
|
131
|
+
return []
|
|
132
|
+
|
|
133
|
+
results = []
|
|
134
|
+
try:
|
|
135
|
+
collection = self.client[self.db_name][self.collection_name]
|
|
136
|
+
pipeline = [
|
|
137
|
+
{
|
|
138
|
+
"$vectorSearch": {
|
|
139
|
+
"index": self.index_name,
|
|
140
|
+
"limit": limit,
|
|
141
|
+
"numCandidates": limit,
|
|
142
|
+
"queryVector": vectors,
|
|
143
|
+
"path": "embedding",
|
|
144
|
+
}
|
|
145
|
+
},
|
|
146
|
+
{"$set": {"score": {"$meta": "vectorSearchScore"}}},
|
|
147
|
+
{"$project": {"embedding": 0}},
|
|
148
|
+
]
|
|
149
|
+
|
|
150
|
+
# Add filter stage if filters are provided
|
|
151
|
+
if filters:
|
|
152
|
+
filter_conditions = []
|
|
153
|
+
for key, value in filters.items():
|
|
154
|
+
filter_conditions.append({"payload." + key: value})
|
|
155
|
+
|
|
156
|
+
if filter_conditions:
|
|
157
|
+
# Add a $match stage after vector search to apply filters
|
|
158
|
+
pipeline.insert(1, {"$match": {"$and": filter_conditions}})
|
|
159
|
+
|
|
160
|
+
results = list(collection.aggregate(pipeline))
|
|
161
|
+
logger.info(f"Vector search completed. Found {len(results)} documents.")
|
|
162
|
+
except Exception as e:
|
|
163
|
+
logger.error(f"Error during vector search for query {query}: {e}")
|
|
164
|
+
return []
|
|
165
|
+
|
|
166
|
+
output = [OutputData(id=str(doc["_id"]), score=doc.get("score"), payload=doc.get("payload")) for doc in results]
|
|
167
|
+
return output
|
|
168
|
+
|
|
169
|
+
def delete(self, vector_id: str) -> None:
|
|
170
|
+
"""
|
|
171
|
+
Delete a vector by ID.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
vector_id (str): ID of the vector to delete.
|
|
175
|
+
"""
|
|
176
|
+
try:
|
|
177
|
+
result = self.collection.delete_one({"_id": vector_id})
|
|
178
|
+
if result.deleted_count > 0:
|
|
179
|
+
logger.info(f"Deleted document with ID '{vector_id}'.")
|
|
180
|
+
else:
|
|
181
|
+
logger.warning(f"No document found with ID '{vector_id}' to delete.")
|
|
182
|
+
except PyMongoError as e:
|
|
183
|
+
logger.error(f"Error deleting document: {e}")
|
|
184
|
+
|
|
185
|
+
def update(self, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[Dict] = None) -> None:
|
|
186
|
+
"""
|
|
187
|
+
Update a vector and its payload.
|
|
188
|
+
|
|
189
|
+
Args:
|
|
190
|
+
vector_id (str): ID of the vector to update.
|
|
191
|
+
vector (List[float], optional): Updated vector.
|
|
192
|
+
payload (Dict, optional): Updated payload.
|
|
193
|
+
"""
|
|
194
|
+
update_fields = {}
|
|
195
|
+
if vector is not None:
|
|
196
|
+
update_fields["embedding"] = vector
|
|
197
|
+
if payload is not None:
|
|
198
|
+
update_fields["payload"] = payload
|
|
199
|
+
|
|
200
|
+
if update_fields:
|
|
201
|
+
try:
|
|
202
|
+
result = self.collection.update_one({"_id": vector_id}, {"$set": update_fields})
|
|
203
|
+
if result.matched_count > 0:
|
|
204
|
+
logger.info(f"Updated document with ID '{vector_id}'.")
|
|
205
|
+
else:
|
|
206
|
+
logger.warning(f"No document found with ID '{vector_id}' to update.")
|
|
207
|
+
except PyMongoError as e:
|
|
208
|
+
logger.error(f"Error updating document: {e}")
|
|
209
|
+
|
|
210
|
+
def get(self, vector_id: str) -> Optional[OutputData]:
|
|
211
|
+
"""
|
|
212
|
+
Retrieve a vector by ID.
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
vector_id (str): ID of the vector to retrieve.
|
|
216
|
+
|
|
217
|
+
Returns:
|
|
218
|
+
Optional[OutputData]: Retrieved vector or None if not found.
|
|
219
|
+
"""
|
|
220
|
+
try:
|
|
221
|
+
doc = self.collection.find_one({"_id": vector_id})
|
|
222
|
+
if doc:
|
|
223
|
+
logger.info(f"Retrieved document with ID '{vector_id}'.")
|
|
224
|
+
return OutputData(id=str(doc["_id"]), score=None, payload=doc.get("payload"))
|
|
225
|
+
else:
|
|
226
|
+
logger.warning(f"Document with ID '{vector_id}' not found.")
|
|
227
|
+
return None
|
|
228
|
+
except PyMongoError as e:
|
|
229
|
+
logger.error(f"Error retrieving document: {e}")
|
|
230
|
+
return None
|
|
231
|
+
|
|
232
|
+
def list_cols(self) -> List[str]:
|
|
233
|
+
"""
|
|
234
|
+
List all collections in the database.
|
|
235
|
+
|
|
236
|
+
Returns:
|
|
237
|
+
List[str]: List of collection names.
|
|
238
|
+
"""
|
|
239
|
+
try:
|
|
240
|
+
collections = self.db.list_collection_names()
|
|
241
|
+
logger.info(f"Listing collections in database '{self.db_name}': {collections}")
|
|
242
|
+
return collections
|
|
243
|
+
except PyMongoError as e:
|
|
244
|
+
logger.error(f"Error listing collections: {e}")
|
|
245
|
+
return []
|
|
246
|
+
|
|
247
|
+
def delete_col(self) -> None:
|
|
248
|
+
"""Delete the collection."""
|
|
249
|
+
try:
|
|
250
|
+
self.collection.drop()
|
|
251
|
+
logger.info(f"Deleted collection '{self.collection_name}'.")
|
|
252
|
+
except PyMongoError as e:
|
|
253
|
+
logger.error(f"Error deleting collection: {e}")
|
|
254
|
+
|
|
255
|
+
def col_info(self) -> Dict[str, Any]:
|
|
256
|
+
"""
|
|
257
|
+
Get information about the collection.
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
Dict[str, Any]: Collection information.
|
|
261
|
+
"""
|
|
262
|
+
try:
|
|
263
|
+
stats = self.db.command("collstats", self.collection_name)
|
|
264
|
+
info = {"name": self.collection_name, "count": stats.get("count"), "size": stats.get("size")}
|
|
265
|
+
logger.info(f"Collection info: {info}")
|
|
266
|
+
return info
|
|
267
|
+
except PyMongoError as e:
|
|
268
|
+
logger.error(f"Error getting collection info: {e}")
|
|
269
|
+
return {}
|
|
270
|
+
|
|
271
|
+
def list(self, filters: Optional[Dict] = None, limit: int = 100) -> List[OutputData]:
|
|
272
|
+
"""
|
|
273
|
+
List vectors in the collection.
|
|
274
|
+
|
|
275
|
+
Args:
|
|
276
|
+
filters (Dict, optional): Filters to apply to the list.
|
|
277
|
+
limit (int, optional): Number of vectors to return.
|
|
278
|
+
|
|
279
|
+
Returns:
|
|
280
|
+
List[OutputData]: List of vectors.
|
|
281
|
+
"""
|
|
282
|
+
try:
|
|
283
|
+
query = {}
|
|
284
|
+
if filters:
|
|
285
|
+
# Apply filters to the payload field
|
|
286
|
+
filter_conditions = []
|
|
287
|
+
for key, value in filters.items():
|
|
288
|
+
filter_conditions.append({"payload." + key: value})
|
|
289
|
+
if filter_conditions:
|
|
290
|
+
query = {"$and": filter_conditions}
|
|
291
|
+
|
|
292
|
+
cursor = self.collection.find(query).limit(limit)
|
|
293
|
+
results = [OutputData(id=str(doc["_id"]), score=None, payload=doc.get("payload")) for doc in cursor]
|
|
294
|
+
logger.info(f"Retrieved {len(results)} documents from collection '{self.collection_name}'.")
|
|
295
|
+
return results
|
|
296
|
+
except PyMongoError as e:
|
|
297
|
+
logger.error(f"Error listing documents: {e}")
|
|
298
|
+
return []
|
|
299
|
+
|
|
300
|
+
def reset(self):
|
|
301
|
+
"""Reset the index by deleting and recreating it."""
|
|
302
|
+
logger.warning(f"Resetting index {self.collection_name}...")
|
|
303
|
+
self.delete_col()
|
|
304
|
+
self.collection = self.create_col(self.collection_name)
|
|
305
|
+
|
|
306
|
+
def __del__(self) -> None:
|
|
307
|
+
"""Close the database connection when the object is deleted."""
|
|
308
|
+
if hasattr(self, "client"):
|
|
309
|
+
self.client.close()
|
|
310
|
+
logger.info("MongoClient connection closed.")
|