agentrun-mem0ai 0.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. agentrun_mem0/__init__.py +6 -0
  2. agentrun_mem0/client/__init__.py +0 -0
  3. agentrun_mem0/client/main.py +1747 -0
  4. agentrun_mem0/client/project.py +931 -0
  5. agentrun_mem0/client/utils.py +115 -0
  6. agentrun_mem0/configs/__init__.py +0 -0
  7. agentrun_mem0/configs/base.py +90 -0
  8. agentrun_mem0/configs/embeddings/__init__.py +0 -0
  9. agentrun_mem0/configs/embeddings/base.py +110 -0
  10. agentrun_mem0/configs/enums.py +7 -0
  11. agentrun_mem0/configs/llms/__init__.py +0 -0
  12. agentrun_mem0/configs/llms/anthropic.py +56 -0
  13. agentrun_mem0/configs/llms/aws_bedrock.py +192 -0
  14. agentrun_mem0/configs/llms/azure.py +57 -0
  15. agentrun_mem0/configs/llms/base.py +62 -0
  16. agentrun_mem0/configs/llms/deepseek.py +56 -0
  17. agentrun_mem0/configs/llms/lmstudio.py +59 -0
  18. agentrun_mem0/configs/llms/ollama.py +56 -0
  19. agentrun_mem0/configs/llms/openai.py +79 -0
  20. agentrun_mem0/configs/llms/vllm.py +56 -0
  21. agentrun_mem0/configs/prompts.py +459 -0
  22. agentrun_mem0/configs/rerankers/__init__.py +0 -0
  23. agentrun_mem0/configs/rerankers/base.py +17 -0
  24. agentrun_mem0/configs/rerankers/cohere.py +15 -0
  25. agentrun_mem0/configs/rerankers/config.py +12 -0
  26. agentrun_mem0/configs/rerankers/huggingface.py +17 -0
  27. agentrun_mem0/configs/rerankers/llm.py +48 -0
  28. agentrun_mem0/configs/rerankers/sentence_transformer.py +16 -0
  29. agentrun_mem0/configs/rerankers/zero_entropy.py +28 -0
  30. agentrun_mem0/configs/vector_stores/__init__.py +0 -0
  31. agentrun_mem0/configs/vector_stores/alibabacloud_mysql.py +64 -0
  32. agentrun_mem0/configs/vector_stores/aliyun_tablestore.py +32 -0
  33. agentrun_mem0/configs/vector_stores/azure_ai_search.py +57 -0
  34. agentrun_mem0/configs/vector_stores/azure_mysql.py +84 -0
  35. agentrun_mem0/configs/vector_stores/baidu.py +27 -0
  36. agentrun_mem0/configs/vector_stores/chroma.py +58 -0
  37. agentrun_mem0/configs/vector_stores/databricks.py +61 -0
  38. agentrun_mem0/configs/vector_stores/elasticsearch.py +65 -0
  39. agentrun_mem0/configs/vector_stores/faiss.py +37 -0
  40. agentrun_mem0/configs/vector_stores/langchain.py +30 -0
  41. agentrun_mem0/configs/vector_stores/milvus.py +42 -0
  42. agentrun_mem0/configs/vector_stores/mongodb.py +25 -0
  43. agentrun_mem0/configs/vector_stores/neptune.py +27 -0
  44. agentrun_mem0/configs/vector_stores/opensearch.py +41 -0
  45. agentrun_mem0/configs/vector_stores/pgvector.py +52 -0
  46. agentrun_mem0/configs/vector_stores/pinecone.py +55 -0
  47. agentrun_mem0/configs/vector_stores/qdrant.py +47 -0
  48. agentrun_mem0/configs/vector_stores/redis.py +24 -0
  49. agentrun_mem0/configs/vector_stores/s3_vectors.py +28 -0
  50. agentrun_mem0/configs/vector_stores/supabase.py +44 -0
  51. agentrun_mem0/configs/vector_stores/upstash_vector.py +34 -0
  52. agentrun_mem0/configs/vector_stores/valkey.py +15 -0
  53. agentrun_mem0/configs/vector_stores/vertex_ai_vector_search.py +28 -0
  54. agentrun_mem0/configs/vector_stores/weaviate.py +41 -0
  55. agentrun_mem0/embeddings/__init__.py +0 -0
  56. agentrun_mem0/embeddings/aws_bedrock.py +100 -0
  57. agentrun_mem0/embeddings/azure_openai.py +55 -0
  58. agentrun_mem0/embeddings/base.py +31 -0
  59. agentrun_mem0/embeddings/configs.py +30 -0
  60. agentrun_mem0/embeddings/gemini.py +39 -0
  61. agentrun_mem0/embeddings/huggingface.py +44 -0
  62. agentrun_mem0/embeddings/langchain.py +35 -0
  63. agentrun_mem0/embeddings/lmstudio.py +29 -0
  64. agentrun_mem0/embeddings/mock.py +11 -0
  65. agentrun_mem0/embeddings/ollama.py +53 -0
  66. agentrun_mem0/embeddings/openai.py +49 -0
  67. agentrun_mem0/embeddings/together.py +31 -0
  68. agentrun_mem0/embeddings/vertexai.py +64 -0
  69. agentrun_mem0/exceptions.py +503 -0
  70. agentrun_mem0/graphs/__init__.py +0 -0
  71. agentrun_mem0/graphs/configs.py +105 -0
  72. agentrun_mem0/graphs/neptune/__init__.py +0 -0
  73. agentrun_mem0/graphs/neptune/base.py +497 -0
  74. agentrun_mem0/graphs/neptune/neptunedb.py +511 -0
  75. agentrun_mem0/graphs/neptune/neptunegraph.py +474 -0
  76. agentrun_mem0/graphs/tools.py +371 -0
  77. agentrun_mem0/graphs/utils.py +97 -0
  78. agentrun_mem0/llms/__init__.py +0 -0
  79. agentrun_mem0/llms/anthropic.py +87 -0
  80. agentrun_mem0/llms/aws_bedrock.py +665 -0
  81. agentrun_mem0/llms/azure_openai.py +141 -0
  82. agentrun_mem0/llms/azure_openai_structured.py +91 -0
  83. agentrun_mem0/llms/base.py +131 -0
  84. agentrun_mem0/llms/configs.py +34 -0
  85. agentrun_mem0/llms/deepseek.py +107 -0
  86. agentrun_mem0/llms/gemini.py +201 -0
  87. agentrun_mem0/llms/groq.py +88 -0
  88. agentrun_mem0/llms/langchain.py +94 -0
  89. agentrun_mem0/llms/litellm.py +87 -0
  90. agentrun_mem0/llms/lmstudio.py +114 -0
  91. agentrun_mem0/llms/ollama.py +117 -0
  92. agentrun_mem0/llms/openai.py +147 -0
  93. agentrun_mem0/llms/openai_structured.py +52 -0
  94. agentrun_mem0/llms/sarvam.py +89 -0
  95. agentrun_mem0/llms/together.py +88 -0
  96. agentrun_mem0/llms/vllm.py +107 -0
  97. agentrun_mem0/llms/xai.py +52 -0
  98. agentrun_mem0/memory/__init__.py +0 -0
  99. agentrun_mem0/memory/base.py +63 -0
  100. agentrun_mem0/memory/graph_memory.py +698 -0
  101. agentrun_mem0/memory/kuzu_memory.py +713 -0
  102. agentrun_mem0/memory/main.py +2229 -0
  103. agentrun_mem0/memory/memgraph_memory.py +689 -0
  104. agentrun_mem0/memory/setup.py +56 -0
  105. agentrun_mem0/memory/storage.py +218 -0
  106. agentrun_mem0/memory/telemetry.py +90 -0
  107. agentrun_mem0/memory/utils.py +208 -0
  108. agentrun_mem0/proxy/__init__.py +0 -0
  109. agentrun_mem0/proxy/main.py +189 -0
  110. agentrun_mem0/reranker/__init__.py +9 -0
  111. agentrun_mem0/reranker/base.py +20 -0
  112. agentrun_mem0/reranker/cohere_reranker.py +85 -0
  113. agentrun_mem0/reranker/huggingface_reranker.py +147 -0
  114. agentrun_mem0/reranker/llm_reranker.py +142 -0
  115. agentrun_mem0/reranker/sentence_transformer_reranker.py +107 -0
  116. agentrun_mem0/reranker/zero_entropy_reranker.py +96 -0
  117. agentrun_mem0/utils/factory.py +283 -0
  118. agentrun_mem0/utils/gcp_auth.py +167 -0
  119. agentrun_mem0/vector_stores/__init__.py +0 -0
  120. agentrun_mem0/vector_stores/alibabacloud_mysql.py +547 -0
  121. agentrun_mem0/vector_stores/aliyun_tablestore.py +252 -0
  122. agentrun_mem0/vector_stores/azure_ai_search.py +396 -0
  123. agentrun_mem0/vector_stores/azure_mysql.py +463 -0
  124. agentrun_mem0/vector_stores/baidu.py +368 -0
  125. agentrun_mem0/vector_stores/base.py +58 -0
  126. agentrun_mem0/vector_stores/chroma.py +332 -0
  127. agentrun_mem0/vector_stores/configs.py +67 -0
  128. agentrun_mem0/vector_stores/databricks.py +761 -0
  129. agentrun_mem0/vector_stores/elasticsearch.py +237 -0
  130. agentrun_mem0/vector_stores/faiss.py +479 -0
  131. agentrun_mem0/vector_stores/langchain.py +180 -0
  132. agentrun_mem0/vector_stores/milvus.py +250 -0
  133. agentrun_mem0/vector_stores/mongodb.py +310 -0
  134. agentrun_mem0/vector_stores/neptune_analytics.py +467 -0
  135. agentrun_mem0/vector_stores/opensearch.py +292 -0
  136. agentrun_mem0/vector_stores/pgvector.py +404 -0
  137. agentrun_mem0/vector_stores/pinecone.py +382 -0
  138. agentrun_mem0/vector_stores/qdrant.py +270 -0
  139. agentrun_mem0/vector_stores/redis.py +295 -0
  140. agentrun_mem0/vector_stores/s3_vectors.py +176 -0
  141. agentrun_mem0/vector_stores/supabase.py +237 -0
  142. agentrun_mem0/vector_stores/upstash_vector.py +293 -0
  143. agentrun_mem0/vector_stores/valkey.py +824 -0
  144. agentrun_mem0/vector_stores/vertex_ai_vector_search.py +635 -0
  145. agentrun_mem0/vector_stores/weaviate.py +343 -0
  146. agentrun_mem0ai-0.0.11.data/data/README.md +205 -0
  147. agentrun_mem0ai-0.0.11.dist-info/METADATA +277 -0
  148. agentrun_mem0ai-0.0.11.dist-info/RECORD +150 -0
  149. agentrun_mem0ai-0.0.11.dist-info/WHEEL +4 -0
  150. agentrun_mem0ai-0.0.11.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,250 @@
1
+ import logging
2
+ from typing import Dict, Optional
3
+
4
+ from pydantic import BaseModel
5
+
6
+ from agentrun_mem0.configs.vector_stores.milvus import MetricType
7
+ from agentrun_mem0.vector_stores.base import VectorStoreBase
8
+
9
+ try:
10
+ import pymilvus # noqa: F401
11
+ except ImportError:
12
+ raise ImportError("The 'pymilvus' library is required. Please install it using 'pip install pymilvus'.")
13
+
14
+ from pymilvus import CollectionSchema, DataType, FieldSchema, MilvusClient
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class OutputData(BaseModel):
20
+ id: Optional[str] # memory id
21
+ score: Optional[float] # distance
22
+ payload: Optional[Dict] # metadata
23
+
24
+
25
+ class MilvusDB(VectorStoreBase):
26
+ def __init__(
27
+ self,
28
+ url: str,
29
+ token: str,
30
+ collection_name: str,
31
+ embedding_model_dims: int,
32
+ metric_type: MetricType,
33
+ db_name: str,
34
+ ) -> None:
35
+ """Initialize the MilvusDB database.
36
+
37
+ Args:
38
+ url (str): Full URL for Milvus/Zilliz server.
39
+ token (str): Token/api_key for Zilliz server / for local setup defaults to None.
40
+ collection_name (str): Name of the collection (defaults to mem0).
41
+ embedding_model_dims (int): Dimensions of the embedding model (defaults to 1536).
42
+ metric_type (MetricType): Metric type for similarity search (defaults to L2).
43
+ db_name (str): Name of the database (defaults to "").
44
+ """
45
+ self.collection_name = collection_name
46
+ self.embedding_model_dims = embedding_model_dims
47
+ self.metric_type = metric_type
48
+ self.client = MilvusClient(uri=url, token=token, db_name=db_name)
49
+ self.create_col(
50
+ collection_name=self.collection_name,
51
+ vector_size=self.embedding_model_dims,
52
+ metric_type=self.metric_type,
53
+ )
54
+
55
+ def create_col(
56
+ self,
57
+ collection_name: str,
58
+ vector_size: int,
59
+ metric_type: MetricType = MetricType.COSINE,
60
+ ) -> None:
61
+ """Create a new collection with index_type AUTOINDEX.
62
+
63
+ Args:
64
+ collection_name (str): Name of the collection (defaults to mem0).
65
+ vector_size (int): Dimensions of the embedding model (defaults to 1536).
66
+ metric_type (MetricType, optional): etric type for similarity search. Defaults to MetricType.COSINE.
67
+ """
68
+
69
+ if self.client.has_collection(collection_name):
70
+ logger.info(f"Collection {collection_name} already exists. Skipping creation.")
71
+ else:
72
+ fields = [
73
+ FieldSchema(name="id", dtype=DataType.VARCHAR, is_primary=True, max_length=512),
74
+ FieldSchema(name="vectors", dtype=DataType.FLOAT_VECTOR, dim=vector_size),
75
+ FieldSchema(name="metadata", dtype=DataType.JSON),
76
+ ]
77
+
78
+ schema = CollectionSchema(fields, enable_dynamic_field=True)
79
+
80
+ index = self.client.prepare_index_params(
81
+ field_name="vectors", metric_type=metric_type, index_type="AUTOINDEX", index_name="vector_index"
82
+ )
83
+ self.client.create_collection(collection_name=collection_name, schema=schema, index_params=index)
84
+
85
+ def insert(self, ids, vectors, payloads, **kwargs: Optional[dict[str, any]]):
86
+ """Insert vectors into a collection.
87
+
88
+ Args:
89
+ vectors (List[List[float]]): List of vectors to insert.
90
+ payloads (List[Dict], optional): List of payloads corresponding to vectors.
91
+ ids (List[str], optional): List of IDs corresponding to vectors.
92
+ """
93
+ # Batch insert all records at once for better performance and consistency
94
+ data = [
95
+ {"id": idx, "vectors": embedding, "metadata": metadata}
96
+ for idx, embedding, metadata in zip(ids, vectors, payloads)
97
+ ]
98
+ self.client.insert(collection_name=self.collection_name, data=data, **kwargs)
99
+
100
+ def _create_filter(self, filters: dict):
101
+ """Prepare filters for efficient query.
102
+
103
+ Args:
104
+ filters (dict): filters [user_id, agent_id, run_id]
105
+
106
+ Returns:
107
+ str: formated filter.
108
+ """
109
+ operands = []
110
+ for key, value in filters.items():
111
+ if isinstance(value, str):
112
+ operands.append(f'(metadata["{key}"] == "{value}")')
113
+ else:
114
+ operands.append(f'(metadata["{key}"] == {value})')
115
+
116
+ return " and ".join(operands)
117
+
118
+ def _parse_output(self, data: list):
119
+ """
120
+ Parse the output data.
121
+
122
+ Args:
123
+ data (Dict): Output data.
124
+
125
+ Returns:
126
+ List[OutputData]: Parsed output data.
127
+ """
128
+ memory = []
129
+
130
+ for value in data:
131
+ uid, score, metadata = (
132
+ value.get("id"),
133
+ value.get("distance"),
134
+ value.get("entity", {}).get("metadata"),
135
+ )
136
+
137
+ memory_obj = OutputData(id=uid, score=score, payload=metadata)
138
+ memory.append(memory_obj)
139
+
140
+ return memory
141
+
142
+ def search(self, query: str, vectors: list, limit: int = 5, filters: dict = None) -> list:
143
+ """
144
+ Search for similar vectors.
145
+
146
+ Args:
147
+ query (str): Query.
148
+ vectors (List[float]): Query vector.
149
+ limit (int, optional): Number of results to return. Defaults to 5.
150
+ filters (Dict, optional): Filters to apply to the search. Defaults to None.
151
+
152
+ Returns:
153
+ list: Search results.
154
+ """
155
+ query_filter = self._create_filter(filters) if filters else None
156
+ hits = self.client.search(
157
+ collection_name=self.collection_name,
158
+ data=[vectors],
159
+ limit=limit,
160
+ filter=query_filter,
161
+ output_fields=["*"],
162
+ )
163
+ result = self._parse_output(data=hits[0])
164
+ return result
165
+
166
+ def delete(self, vector_id):
167
+ """
168
+ Delete a vector by ID.
169
+
170
+ Args:
171
+ vector_id (str): ID of the vector to delete.
172
+ """
173
+ self.client.delete(collection_name=self.collection_name, ids=vector_id)
174
+
175
+ def update(self, vector_id=None, vector=None, payload=None):
176
+ """
177
+ Update a vector and its payload.
178
+
179
+ Args:
180
+ vector_id (str): ID of the vector to update.
181
+ vector (List[float], optional): Updated vector.
182
+ payload (Dict, optional): Updated payload.
183
+ """
184
+ schema = {"id": vector_id, "vectors": vector, "metadata": payload}
185
+ self.client.upsert(collection_name=self.collection_name, data=schema)
186
+
187
+ def get(self, vector_id):
188
+ """
189
+ Retrieve a vector by ID.
190
+
191
+ Args:
192
+ vector_id (str): ID of the vector to retrieve.
193
+
194
+ Returns:
195
+ OutputData: Retrieved vector.
196
+ """
197
+ result = self.client.get(collection_name=self.collection_name, ids=vector_id)
198
+ output = OutputData(
199
+ id=result[0].get("id", None),
200
+ score=None,
201
+ payload=result[0].get("metadata", None),
202
+ )
203
+ return output
204
+
205
+ def list_cols(self):
206
+ """
207
+ List all collections.
208
+
209
+ Returns:
210
+ List[str]: List of collection names.
211
+ """
212
+ return self.client.list_collections()
213
+
214
+ def delete_col(self):
215
+ """Delete a collection."""
216
+ return self.client.drop_collection(collection_name=self.collection_name)
217
+
218
+ def col_info(self):
219
+ """
220
+ Get information about a collection.
221
+
222
+ Returns:
223
+ Dict[str, Any]: Collection information.
224
+ """
225
+ return self.client.get_collection_stats(collection_name=self.collection_name)
226
+
227
+ def list(self, filters: dict = None, limit: int = 100) -> list:
228
+ """
229
+ List all vectors in a collection.
230
+
231
+ Args:
232
+ filters (Dict, optional): Filters to apply to the list.
233
+ limit (int, optional): Number of vectors to return. Defaults to 100.
234
+
235
+ Returns:
236
+ List[OutputData]: List of vectors.
237
+ """
238
+ query_filter = self._create_filter(filters) if filters else None
239
+ result = self.client.query(collection_name=self.collection_name, filter=query_filter, limit=limit)
240
+ memories = []
241
+ for data in result:
242
+ obj = OutputData(id=data.get("id"), score=None, payload=data.get("metadata"))
243
+ memories.append(obj)
244
+ return [memories]
245
+
246
+ def reset(self):
247
+ """Reset the index by deleting and recreating it."""
248
+ logger.warning(f"Resetting index {self.collection_name}...")
249
+ self.delete_col()
250
+ self.create_col(self.collection_name, self.embedding_model_dims, self.metric_type)
@@ -0,0 +1,310 @@
1
+ import logging
2
+ from typing import Any, Dict, List, Optional
3
+
4
+ from pydantic import BaseModel
5
+
6
+ try:
7
+ from pymongo import MongoClient
8
+ from pymongo.errors import PyMongoError
9
+ from pymongo.operations import SearchIndexModel
10
+ except ImportError:
11
+ raise ImportError("The 'pymongo' library is required. Please install it using 'pip install pymongo'.")
12
+
13
+ from agentrun_mem0.vector_stores.base import VectorStoreBase
14
+
15
+ logger = logging.getLogger(__name__)
16
+ logging.basicConfig(level=logging.INFO)
17
+
18
+
19
+ class OutputData(BaseModel):
20
+ id: Optional[str]
21
+ score: Optional[float]
22
+ payload: Optional[dict]
23
+
24
+
25
+ class MongoDB(VectorStoreBase):
26
+ VECTOR_TYPE = "knnVector"
27
+ SIMILARITY_METRIC = "cosine"
28
+
29
+ def __init__(self, db_name: str, collection_name: str, embedding_model_dims: int, mongo_uri: str):
30
+ """
31
+ Initialize the MongoDB vector store with vector search capabilities.
32
+
33
+ Args:
34
+ db_name (str): Database name
35
+ collection_name (str): Collection name
36
+ embedding_model_dims (int): Dimension of the embedding vector
37
+ mongo_uri (str): MongoDB connection URI
38
+ """
39
+ self.collection_name = collection_name
40
+ self.embedding_model_dims = embedding_model_dims
41
+ self.db_name = db_name
42
+
43
+ self.client = MongoClient(mongo_uri)
44
+ self.db = self.client[db_name]
45
+ self.collection = self.create_col()
46
+
47
+ def create_col(self):
48
+ """Create new collection with vector search index."""
49
+ try:
50
+ database = self.client[self.db_name]
51
+ collection_names = database.list_collection_names()
52
+ if self.collection_name not in collection_names:
53
+ logger.info(f"Collection '{self.collection_name}' does not exist. Creating it now.")
54
+ collection = database[self.collection_name]
55
+ # Insert and remove a placeholder document to create the collection
56
+ collection.insert_one({"_id": 0, "placeholder": True})
57
+ collection.delete_one({"_id": 0})
58
+ logger.info(f"Collection '{self.collection_name}' created successfully.")
59
+ else:
60
+ collection = database[self.collection_name]
61
+
62
+ self.index_name = f"{self.collection_name}_vector_index"
63
+ found_indexes = list(collection.list_search_indexes(name=self.index_name))
64
+ if found_indexes:
65
+ logger.info(f"Search index '{self.index_name}' already exists in collection '{self.collection_name}'.")
66
+ else:
67
+ search_index_model = SearchIndexModel(
68
+ name=self.index_name,
69
+ definition={
70
+ "mappings": {
71
+ "dynamic": False,
72
+ "fields": {
73
+ "embedding": {
74
+ "type": self.VECTOR_TYPE,
75
+ "dimensions": self.embedding_model_dims,
76
+ "similarity": self.SIMILARITY_METRIC,
77
+ }
78
+ },
79
+ }
80
+ },
81
+ )
82
+ collection.create_search_index(search_index_model)
83
+ logger.info(
84
+ f"Search index '{self.index_name}' created successfully for collection '{self.collection_name}'."
85
+ )
86
+ return collection
87
+ except PyMongoError as e:
88
+ logger.error(f"Error creating collection and search index: {e}")
89
+ return None
90
+
91
+ def insert(
92
+ self, vectors: List[List[float]], payloads: Optional[List[Dict]] = None, ids: Optional[List[str]] = None
93
+ ) -> None:
94
+ """
95
+ Insert vectors into the collection.
96
+
97
+ Args:
98
+ vectors (List[List[float]]): List of vectors to insert.
99
+ payloads (List[Dict], optional): List of payloads corresponding to vectors.
100
+ ids (List[str], optional): List of IDs corresponding to vectors.
101
+ """
102
+ logger.info(f"Inserting {len(vectors)} vectors into collection '{self.collection_name}'.")
103
+
104
+ data = []
105
+ for vector, payload, _id in zip(vectors, payloads or [{}] * len(vectors), ids or [None] * len(vectors)):
106
+ document = {"_id": _id, "embedding": vector, "payload": payload}
107
+ data.append(document)
108
+ try:
109
+ self.collection.insert_many(data)
110
+ logger.info(f"Inserted {len(data)} documents into '{self.collection_name}'.")
111
+ except PyMongoError as e:
112
+ logger.error(f"Error inserting data: {e}")
113
+
114
+ def search(self, query: str, vectors: List[float], limit=5, filters: Optional[Dict] = None) -> List[OutputData]:
115
+ """
116
+ Search for similar vectors using the vector search index.
117
+
118
+ Args:
119
+ query (str): Query string
120
+ vectors (List[float]): Query vector.
121
+ limit (int, optional): Number of results to return. Defaults to 5.
122
+ filters (Dict, optional): Filters to apply to the search.
123
+
124
+ Returns:
125
+ List[OutputData]: Search results.
126
+ """
127
+
128
+ found_indexes = list(self.collection.list_search_indexes(name=self.index_name))
129
+ if not found_indexes:
130
+ logger.error(f"Index '{self.index_name}' does not exist.")
131
+ return []
132
+
133
+ results = []
134
+ try:
135
+ collection = self.client[self.db_name][self.collection_name]
136
+ pipeline = [
137
+ {
138
+ "$vectorSearch": {
139
+ "index": self.index_name,
140
+ "limit": limit,
141
+ "numCandidates": limit,
142
+ "queryVector": vectors,
143
+ "path": "embedding",
144
+ }
145
+ },
146
+ {"$set": {"score": {"$meta": "vectorSearchScore"}}},
147
+ {"$project": {"embedding": 0}},
148
+ ]
149
+
150
+ # Add filter stage if filters are provided
151
+ if filters:
152
+ filter_conditions = []
153
+ for key, value in filters.items():
154
+ filter_conditions.append({"payload." + key: value})
155
+
156
+ if filter_conditions:
157
+ # Add a $match stage after vector search to apply filters
158
+ pipeline.insert(1, {"$match": {"$and": filter_conditions}})
159
+
160
+ results = list(collection.aggregate(pipeline))
161
+ logger.info(f"Vector search completed. Found {len(results)} documents.")
162
+ except Exception as e:
163
+ logger.error(f"Error during vector search for query {query}: {e}")
164
+ return []
165
+
166
+ output = [OutputData(id=str(doc["_id"]), score=doc.get("score"), payload=doc.get("payload")) for doc in results]
167
+ return output
168
+
169
+ def delete(self, vector_id: str) -> None:
170
+ """
171
+ Delete a vector by ID.
172
+
173
+ Args:
174
+ vector_id (str): ID of the vector to delete.
175
+ """
176
+ try:
177
+ result = self.collection.delete_one({"_id": vector_id})
178
+ if result.deleted_count > 0:
179
+ logger.info(f"Deleted document with ID '{vector_id}'.")
180
+ else:
181
+ logger.warning(f"No document found with ID '{vector_id}' to delete.")
182
+ except PyMongoError as e:
183
+ logger.error(f"Error deleting document: {e}")
184
+
185
+ def update(self, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[Dict] = None) -> None:
186
+ """
187
+ Update a vector and its payload.
188
+
189
+ Args:
190
+ vector_id (str): ID of the vector to update.
191
+ vector (List[float], optional): Updated vector.
192
+ payload (Dict, optional): Updated payload.
193
+ """
194
+ update_fields = {}
195
+ if vector is not None:
196
+ update_fields["embedding"] = vector
197
+ if payload is not None:
198
+ update_fields["payload"] = payload
199
+
200
+ if update_fields:
201
+ try:
202
+ result = self.collection.update_one({"_id": vector_id}, {"$set": update_fields})
203
+ if result.matched_count > 0:
204
+ logger.info(f"Updated document with ID '{vector_id}'.")
205
+ else:
206
+ logger.warning(f"No document found with ID '{vector_id}' to update.")
207
+ except PyMongoError as e:
208
+ logger.error(f"Error updating document: {e}")
209
+
210
+ def get(self, vector_id: str) -> Optional[OutputData]:
211
+ """
212
+ Retrieve a vector by ID.
213
+
214
+ Args:
215
+ vector_id (str): ID of the vector to retrieve.
216
+
217
+ Returns:
218
+ Optional[OutputData]: Retrieved vector or None if not found.
219
+ """
220
+ try:
221
+ doc = self.collection.find_one({"_id": vector_id})
222
+ if doc:
223
+ logger.info(f"Retrieved document with ID '{vector_id}'.")
224
+ return OutputData(id=str(doc["_id"]), score=None, payload=doc.get("payload"))
225
+ else:
226
+ logger.warning(f"Document with ID '{vector_id}' not found.")
227
+ return None
228
+ except PyMongoError as e:
229
+ logger.error(f"Error retrieving document: {e}")
230
+ return None
231
+
232
+ def list_cols(self) -> List[str]:
233
+ """
234
+ List all collections in the database.
235
+
236
+ Returns:
237
+ List[str]: List of collection names.
238
+ """
239
+ try:
240
+ collections = self.db.list_collection_names()
241
+ logger.info(f"Listing collections in database '{self.db_name}': {collections}")
242
+ return collections
243
+ except PyMongoError as e:
244
+ logger.error(f"Error listing collections: {e}")
245
+ return []
246
+
247
+ def delete_col(self) -> None:
248
+ """Delete the collection."""
249
+ try:
250
+ self.collection.drop()
251
+ logger.info(f"Deleted collection '{self.collection_name}'.")
252
+ except PyMongoError as e:
253
+ logger.error(f"Error deleting collection: {e}")
254
+
255
+ def col_info(self) -> Dict[str, Any]:
256
+ """
257
+ Get information about the collection.
258
+
259
+ Returns:
260
+ Dict[str, Any]: Collection information.
261
+ """
262
+ try:
263
+ stats = self.db.command("collstats", self.collection_name)
264
+ info = {"name": self.collection_name, "count": stats.get("count"), "size": stats.get("size")}
265
+ logger.info(f"Collection info: {info}")
266
+ return info
267
+ except PyMongoError as e:
268
+ logger.error(f"Error getting collection info: {e}")
269
+ return {}
270
+
271
+ def list(self, filters: Optional[Dict] = None, limit: int = 100) -> List[OutputData]:
272
+ """
273
+ List vectors in the collection.
274
+
275
+ Args:
276
+ filters (Dict, optional): Filters to apply to the list.
277
+ limit (int, optional): Number of vectors to return.
278
+
279
+ Returns:
280
+ List[OutputData]: List of vectors.
281
+ """
282
+ try:
283
+ query = {}
284
+ if filters:
285
+ # Apply filters to the payload field
286
+ filter_conditions = []
287
+ for key, value in filters.items():
288
+ filter_conditions.append({"payload." + key: value})
289
+ if filter_conditions:
290
+ query = {"$and": filter_conditions}
291
+
292
+ cursor = self.collection.find(query).limit(limit)
293
+ results = [OutputData(id=str(doc["_id"]), score=None, payload=doc.get("payload")) for doc in cursor]
294
+ logger.info(f"Retrieved {len(results)} documents from collection '{self.collection_name}'.")
295
+ return results
296
+ except PyMongoError as e:
297
+ logger.error(f"Error listing documents: {e}")
298
+ return []
299
+
300
+ def reset(self):
301
+ """Reset the index by deleting and recreating it."""
302
+ logger.warning(f"Resetting index {self.collection_name}...")
303
+ self.delete_col()
304
+ self.collection = self.create_col(self.collection_name)
305
+
306
+ def __del__(self) -> None:
307
+ """Close the database connection when the object is deleted."""
308
+ if hasattr(self, "client"):
309
+ self.client.close()
310
+ logger.info("MongoClient connection closed.")