agentrun-mem0ai 0.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. agentrun_mem0/__init__.py +6 -0
  2. agentrun_mem0/client/__init__.py +0 -0
  3. agentrun_mem0/client/main.py +1747 -0
  4. agentrun_mem0/client/project.py +931 -0
  5. agentrun_mem0/client/utils.py +115 -0
  6. agentrun_mem0/configs/__init__.py +0 -0
  7. agentrun_mem0/configs/base.py +90 -0
  8. agentrun_mem0/configs/embeddings/__init__.py +0 -0
  9. agentrun_mem0/configs/embeddings/base.py +110 -0
  10. agentrun_mem0/configs/enums.py +7 -0
  11. agentrun_mem0/configs/llms/__init__.py +0 -0
  12. agentrun_mem0/configs/llms/anthropic.py +56 -0
  13. agentrun_mem0/configs/llms/aws_bedrock.py +192 -0
  14. agentrun_mem0/configs/llms/azure.py +57 -0
  15. agentrun_mem0/configs/llms/base.py +62 -0
  16. agentrun_mem0/configs/llms/deepseek.py +56 -0
  17. agentrun_mem0/configs/llms/lmstudio.py +59 -0
  18. agentrun_mem0/configs/llms/ollama.py +56 -0
  19. agentrun_mem0/configs/llms/openai.py +79 -0
  20. agentrun_mem0/configs/llms/vllm.py +56 -0
  21. agentrun_mem0/configs/prompts.py +459 -0
  22. agentrun_mem0/configs/rerankers/__init__.py +0 -0
  23. agentrun_mem0/configs/rerankers/base.py +17 -0
  24. agentrun_mem0/configs/rerankers/cohere.py +15 -0
  25. agentrun_mem0/configs/rerankers/config.py +12 -0
  26. agentrun_mem0/configs/rerankers/huggingface.py +17 -0
  27. agentrun_mem0/configs/rerankers/llm.py +48 -0
  28. agentrun_mem0/configs/rerankers/sentence_transformer.py +16 -0
  29. agentrun_mem0/configs/rerankers/zero_entropy.py +28 -0
  30. agentrun_mem0/configs/vector_stores/__init__.py +0 -0
  31. agentrun_mem0/configs/vector_stores/alibabacloud_mysql.py +64 -0
  32. agentrun_mem0/configs/vector_stores/aliyun_tablestore.py +32 -0
  33. agentrun_mem0/configs/vector_stores/azure_ai_search.py +57 -0
  34. agentrun_mem0/configs/vector_stores/azure_mysql.py +84 -0
  35. agentrun_mem0/configs/vector_stores/baidu.py +27 -0
  36. agentrun_mem0/configs/vector_stores/chroma.py +58 -0
  37. agentrun_mem0/configs/vector_stores/databricks.py +61 -0
  38. agentrun_mem0/configs/vector_stores/elasticsearch.py +65 -0
  39. agentrun_mem0/configs/vector_stores/faiss.py +37 -0
  40. agentrun_mem0/configs/vector_stores/langchain.py +30 -0
  41. agentrun_mem0/configs/vector_stores/milvus.py +42 -0
  42. agentrun_mem0/configs/vector_stores/mongodb.py +25 -0
  43. agentrun_mem0/configs/vector_stores/neptune.py +27 -0
  44. agentrun_mem0/configs/vector_stores/opensearch.py +41 -0
  45. agentrun_mem0/configs/vector_stores/pgvector.py +52 -0
  46. agentrun_mem0/configs/vector_stores/pinecone.py +55 -0
  47. agentrun_mem0/configs/vector_stores/qdrant.py +47 -0
  48. agentrun_mem0/configs/vector_stores/redis.py +24 -0
  49. agentrun_mem0/configs/vector_stores/s3_vectors.py +28 -0
  50. agentrun_mem0/configs/vector_stores/supabase.py +44 -0
  51. agentrun_mem0/configs/vector_stores/upstash_vector.py +34 -0
  52. agentrun_mem0/configs/vector_stores/valkey.py +15 -0
  53. agentrun_mem0/configs/vector_stores/vertex_ai_vector_search.py +28 -0
  54. agentrun_mem0/configs/vector_stores/weaviate.py +41 -0
  55. agentrun_mem0/embeddings/__init__.py +0 -0
  56. agentrun_mem0/embeddings/aws_bedrock.py +100 -0
  57. agentrun_mem0/embeddings/azure_openai.py +55 -0
  58. agentrun_mem0/embeddings/base.py +31 -0
  59. agentrun_mem0/embeddings/configs.py +30 -0
  60. agentrun_mem0/embeddings/gemini.py +39 -0
  61. agentrun_mem0/embeddings/huggingface.py +44 -0
  62. agentrun_mem0/embeddings/langchain.py +35 -0
  63. agentrun_mem0/embeddings/lmstudio.py +29 -0
  64. agentrun_mem0/embeddings/mock.py +11 -0
  65. agentrun_mem0/embeddings/ollama.py +53 -0
  66. agentrun_mem0/embeddings/openai.py +49 -0
  67. agentrun_mem0/embeddings/together.py +31 -0
  68. agentrun_mem0/embeddings/vertexai.py +64 -0
  69. agentrun_mem0/exceptions.py +503 -0
  70. agentrun_mem0/graphs/__init__.py +0 -0
  71. agentrun_mem0/graphs/configs.py +105 -0
  72. agentrun_mem0/graphs/neptune/__init__.py +0 -0
  73. agentrun_mem0/graphs/neptune/base.py +497 -0
  74. agentrun_mem0/graphs/neptune/neptunedb.py +511 -0
  75. agentrun_mem0/graphs/neptune/neptunegraph.py +474 -0
  76. agentrun_mem0/graphs/tools.py +371 -0
  77. agentrun_mem0/graphs/utils.py +97 -0
  78. agentrun_mem0/llms/__init__.py +0 -0
  79. agentrun_mem0/llms/anthropic.py +87 -0
  80. agentrun_mem0/llms/aws_bedrock.py +665 -0
  81. agentrun_mem0/llms/azure_openai.py +141 -0
  82. agentrun_mem0/llms/azure_openai_structured.py +91 -0
  83. agentrun_mem0/llms/base.py +131 -0
  84. agentrun_mem0/llms/configs.py +34 -0
  85. agentrun_mem0/llms/deepseek.py +107 -0
  86. agentrun_mem0/llms/gemini.py +201 -0
  87. agentrun_mem0/llms/groq.py +88 -0
  88. agentrun_mem0/llms/langchain.py +94 -0
  89. agentrun_mem0/llms/litellm.py +87 -0
  90. agentrun_mem0/llms/lmstudio.py +114 -0
  91. agentrun_mem0/llms/ollama.py +117 -0
  92. agentrun_mem0/llms/openai.py +147 -0
  93. agentrun_mem0/llms/openai_structured.py +52 -0
  94. agentrun_mem0/llms/sarvam.py +89 -0
  95. agentrun_mem0/llms/together.py +88 -0
  96. agentrun_mem0/llms/vllm.py +107 -0
  97. agentrun_mem0/llms/xai.py +52 -0
  98. agentrun_mem0/memory/__init__.py +0 -0
  99. agentrun_mem0/memory/base.py +63 -0
  100. agentrun_mem0/memory/graph_memory.py +698 -0
  101. agentrun_mem0/memory/kuzu_memory.py +713 -0
  102. agentrun_mem0/memory/main.py +2229 -0
  103. agentrun_mem0/memory/memgraph_memory.py +689 -0
  104. agentrun_mem0/memory/setup.py +56 -0
  105. agentrun_mem0/memory/storage.py +218 -0
  106. agentrun_mem0/memory/telemetry.py +90 -0
  107. agentrun_mem0/memory/utils.py +208 -0
  108. agentrun_mem0/proxy/__init__.py +0 -0
  109. agentrun_mem0/proxy/main.py +189 -0
  110. agentrun_mem0/reranker/__init__.py +9 -0
  111. agentrun_mem0/reranker/base.py +20 -0
  112. agentrun_mem0/reranker/cohere_reranker.py +85 -0
  113. agentrun_mem0/reranker/huggingface_reranker.py +147 -0
  114. agentrun_mem0/reranker/llm_reranker.py +142 -0
  115. agentrun_mem0/reranker/sentence_transformer_reranker.py +107 -0
  116. agentrun_mem0/reranker/zero_entropy_reranker.py +96 -0
  117. agentrun_mem0/utils/factory.py +283 -0
  118. agentrun_mem0/utils/gcp_auth.py +167 -0
  119. agentrun_mem0/vector_stores/__init__.py +0 -0
  120. agentrun_mem0/vector_stores/alibabacloud_mysql.py +547 -0
  121. agentrun_mem0/vector_stores/aliyun_tablestore.py +252 -0
  122. agentrun_mem0/vector_stores/azure_ai_search.py +396 -0
  123. agentrun_mem0/vector_stores/azure_mysql.py +463 -0
  124. agentrun_mem0/vector_stores/baidu.py +368 -0
  125. agentrun_mem0/vector_stores/base.py +58 -0
  126. agentrun_mem0/vector_stores/chroma.py +332 -0
  127. agentrun_mem0/vector_stores/configs.py +67 -0
  128. agentrun_mem0/vector_stores/databricks.py +761 -0
  129. agentrun_mem0/vector_stores/elasticsearch.py +237 -0
  130. agentrun_mem0/vector_stores/faiss.py +479 -0
  131. agentrun_mem0/vector_stores/langchain.py +180 -0
  132. agentrun_mem0/vector_stores/milvus.py +250 -0
  133. agentrun_mem0/vector_stores/mongodb.py +310 -0
  134. agentrun_mem0/vector_stores/neptune_analytics.py +467 -0
  135. agentrun_mem0/vector_stores/opensearch.py +292 -0
  136. agentrun_mem0/vector_stores/pgvector.py +404 -0
  137. agentrun_mem0/vector_stores/pinecone.py +382 -0
  138. agentrun_mem0/vector_stores/qdrant.py +270 -0
  139. agentrun_mem0/vector_stores/redis.py +295 -0
  140. agentrun_mem0/vector_stores/s3_vectors.py +176 -0
  141. agentrun_mem0/vector_stores/supabase.py +237 -0
  142. agentrun_mem0/vector_stores/upstash_vector.py +293 -0
  143. agentrun_mem0/vector_stores/valkey.py +824 -0
  144. agentrun_mem0/vector_stores/vertex_ai_vector_search.py +635 -0
  145. agentrun_mem0/vector_stores/weaviate.py +343 -0
  146. agentrun_mem0ai-0.0.11.data/data/README.md +205 -0
  147. agentrun_mem0ai-0.0.11.dist-info/METADATA +277 -0
  148. agentrun_mem0ai-0.0.11.dist-info/RECORD +150 -0
  149. agentrun_mem0ai-0.0.11.dist-info/WHEEL +4 -0
  150. agentrun_mem0ai-0.0.11.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,463 @@
1
+ import json
2
+ import logging
3
+ from contextlib import contextmanager
4
+ from typing import Any, Dict, List, Optional
5
+
6
+ from pydantic import BaseModel
7
+
8
+ try:
9
+ import pymysql
10
+ from pymysql.cursors import DictCursor
11
+ from dbutils.pooled_db import PooledDB
12
+ except ImportError:
13
+ raise ImportError(
14
+ "Azure MySQL vector store requires PyMySQL and DBUtils. "
15
+ "Please install them using 'pip install pymysql dbutils'"
16
+ )
17
+
18
+ try:
19
+ from azure.identity import DefaultAzureCredential
20
+ AZURE_IDENTITY_AVAILABLE = True
21
+ except ImportError:
22
+ AZURE_IDENTITY_AVAILABLE = False
23
+
24
+ from agentrun_mem0.vector_stores.base import VectorStoreBase
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ class OutputData(BaseModel):
30
+ id: Optional[str]
31
+ score: Optional[float]
32
+ payload: Optional[dict]
33
+
34
+
35
+ class AzureMySQL(VectorStoreBase):
36
+ def __init__(
37
+ self,
38
+ host: str,
39
+ port: int,
40
+ user: str,
41
+ password: Optional[str],
42
+ database: str,
43
+ collection_name: str,
44
+ embedding_model_dims: int,
45
+ use_azure_credential: bool = False,
46
+ ssl_ca: Optional[str] = None,
47
+ ssl_disabled: bool = False,
48
+ minconn: int = 1,
49
+ maxconn: int = 5,
50
+ connection_pool: Optional[Any] = None,
51
+ ):
52
+ """
53
+ Initialize the Azure MySQL vector store.
54
+
55
+ Args:
56
+ host (str): MySQL server host
57
+ port (int): MySQL server port
58
+ user (str): Database user
59
+ password (str, optional): Database password (not required if using Azure credential)
60
+ database (str): Database name
61
+ collection_name (str): Collection/table name
62
+ embedding_model_dims (int): Dimension of the embedding vector
63
+ use_azure_credential (bool): Use Azure DefaultAzureCredential for authentication
64
+ ssl_ca (str, optional): Path to SSL CA certificate
65
+ ssl_disabled (bool): Disable SSL connection
66
+ minconn (int): Minimum number of connections in the pool
67
+ maxconn (int): Maximum number of connections in the pool
68
+ connection_pool (Any, optional): Pre-configured connection pool
69
+ """
70
+ self.host = host
71
+ self.port = port
72
+ self.user = user
73
+ self.password = password
74
+ self.database = database
75
+ self.collection_name = collection_name
76
+ self.embedding_model_dims = embedding_model_dims
77
+ self.use_azure_credential = use_azure_credential
78
+ self.ssl_ca = ssl_ca
79
+ self.ssl_disabled = ssl_disabled
80
+ self.connection_pool = connection_pool
81
+
82
+ # Handle Azure authentication
83
+ if use_azure_credential:
84
+ if not AZURE_IDENTITY_AVAILABLE:
85
+ raise ImportError(
86
+ "Azure Identity is required for Azure credential authentication. "
87
+ "Please install it using 'pip install azure-identity'"
88
+ )
89
+ self._setup_azure_auth()
90
+
91
+ # Setup connection pool
92
+ if self.connection_pool is None:
93
+ self._setup_connection_pool(minconn, maxconn)
94
+
95
+ # Create collection if it doesn't exist
96
+ collections = self.list_cols()
97
+ if collection_name not in collections:
98
+ self.create_col(name=collection_name, vector_size=embedding_model_dims, distance="cosine")
99
+
100
+ def _setup_azure_auth(self):
101
+ """Setup Azure authentication using DefaultAzureCredential."""
102
+ try:
103
+ credential = DefaultAzureCredential()
104
+ # Get access token for Azure Database for MySQL
105
+ token = credential.get_token("https://ossrdbms-aad.database.windows.net/.default")
106
+ # Use token as password
107
+ self.password = token.token
108
+ logger.info("Successfully authenticated using Azure DefaultAzureCredential")
109
+ except Exception as e:
110
+ logger.error(f"Failed to authenticate with Azure: {e}")
111
+ raise
112
+
113
+ def _setup_connection_pool(self, minconn: int, maxconn: int):
114
+ """Setup MySQL connection pool."""
115
+ connect_kwargs = {
116
+ "host": self.host,
117
+ "port": self.port,
118
+ "user": self.user,
119
+ "password": self.password,
120
+ "database": self.database,
121
+ "charset": "utf8mb4",
122
+ "cursorclass": DictCursor,
123
+ "autocommit": False,
124
+ }
125
+
126
+ # SSL configuration
127
+ if not self.ssl_disabled:
128
+ ssl_config = {"ssl_verify_cert": True}
129
+ if self.ssl_ca:
130
+ ssl_config["ssl_ca"] = self.ssl_ca
131
+ connect_kwargs["ssl"] = ssl_config
132
+
133
+ try:
134
+ self.connection_pool = PooledDB(
135
+ creator=pymysql,
136
+ mincached=minconn,
137
+ maxcached=maxconn,
138
+ maxconnections=maxconn,
139
+ blocking=True,
140
+ **connect_kwargs
141
+ )
142
+ logger.info("Successfully created MySQL connection pool")
143
+ except Exception as e:
144
+ logger.error(f"Failed to create connection pool: {e}")
145
+ raise
146
+
147
+ @contextmanager
148
+ def _get_cursor(self, commit: bool = False):
149
+ """
150
+ Context manager to get a cursor from the connection pool.
151
+ Auto-commits or rolls back based on exception.
152
+ """
153
+ conn = self.connection_pool.connection()
154
+ cur = conn.cursor()
155
+ try:
156
+ yield cur
157
+ if commit:
158
+ conn.commit()
159
+ except Exception as exc:
160
+ conn.rollback()
161
+ logger.error(f"Database error: {exc}", exc_info=True)
162
+ raise
163
+ finally:
164
+ cur.close()
165
+ conn.close()
166
+
167
+ def create_col(self, name: str = None, vector_size: int = None, distance: str = "cosine"):
168
+ """
169
+ Create a new collection (table in MySQL).
170
+ Enables vector extension and creates appropriate indexes.
171
+
172
+ Args:
173
+ name (str, optional): Collection name (uses self.collection_name if not provided)
174
+ vector_size (int, optional): Vector dimension (uses self.embedding_model_dims if not provided)
175
+ distance (str): Distance metric (cosine, euclidean, dot_product)
176
+ """
177
+ table_name = name or self.collection_name
178
+ dims = vector_size or self.embedding_model_dims
179
+
180
+ with self._get_cursor(commit=True) as cur:
181
+ # Create table with vector column
182
+ cur.execute(f"""
183
+ CREATE TABLE IF NOT EXISTS `{table_name}` (
184
+ id VARCHAR(255) PRIMARY KEY,
185
+ vector JSON,
186
+ payload JSON,
187
+ INDEX idx_payload_keys ((CAST(payload AS CHAR(255)) ARRAY))
188
+ )
189
+ """)
190
+ logger.info(f"Created collection '{table_name}' with vector dimension {dims}")
191
+
192
+ def insert(self, vectors: List[List[float]], payloads: Optional[List[Dict]] = None, ids: Optional[List[str]] = None):
193
+ """
194
+ Insert vectors into the collection.
195
+
196
+ Args:
197
+ vectors (List[List[float]]): List of vectors to insert
198
+ payloads (List[Dict], optional): List of payloads corresponding to vectors
199
+ ids (List[str], optional): List of IDs corresponding to vectors
200
+ """
201
+ logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}")
202
+
203
+ if payloads is None:
204
+ payloads = [{}] * len(vectors)
205
+ if ids is None:
206
+ import uuid
207
+ ids = [str(uuid.uuid4()) for _ in range(len(vectors))]
208
+
209
+ data = []
210
+ for vector, payload, vec_id in zip(vectors, payloads, ids):
211
+ data.append((vec_id, json.dumps(vector), json.dumps(payload)))
212
+
213
+ with self._get_cursor(commit=True) as cur:
214
+ cur.executemany(
215
+ f"INSERT INTO `{self.collection_name}` (id, vector, payload) VALUES (%s, %s, %s) "
216
+ f"ON DUPLICATE KEY UPDATE vector = VALUES(vector), payload = VALUES(payload)",
217
+ data
218
+ )
219
+
220
+ def _cosine_distance(self, vec1_json: str, vec2: List[float]) -> str:
221
+ """Generate SQL for cosine distance calculation."""
222
+ # For MySQL, we need to calculate cosine similarity manually
223
+ # This is a simplified version - in production, you'd use stored procedures or UDFs
224
+ return """
225
+ 1 - (
226
+ (SELECT SUM(a.val * b.val) /
227
+ (SQRT(SUM(a.val * a.val)) * SQRT(SUM(b.val * b.val))))
228
+ FROM (
229
+ SELECT JSON_EXTRACT(vector, CONCAT('$[', idx, ']')) as val
230
+ FROM (SELECT @row := @row + 1 as idx FROM (SELECT 0 UNION ALL SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3) t1, (SELECT 0 UNION ALL SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3) t2) indices
231
+ WHERE idx < JSON_LENGTH(vector)
232
+ ) a,
233
+ (
234
+ SELECT JSON_EXTRACT(%s, CONCAT('$[', idx, ']')) as val
235
+ FROM (SELECT @row := @row + 1 as idx FROM (SELECT 0 UNION ALL SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3) t1, (SELECT 0 UNION ALL SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3) t2) indices
236
+ WHERE idx < JSON_LENGTH(%s)
237
+ ) b
238
+ WHERE a.idx = b.idx
239
+ )
240
+ """
241
+
242
+ def search(
243
+ self,
244
+ query: str,
245
+ vectors: List[float],
246
+ limit: int = 5,
247
+ filters: Optional[Dict] = None,
248
+ ) -> List[OutputData]:
249
+ """
250
+ Search for similar vectors using cosine similarity.
251
+
252
+ Args:
253
+ query (str): Query string (not used in vector search)
254
+ vectors (List[float]): Query vector
255
+ limit (int): Number of results to return
256
+ filters (Dict, optional): Filters to apply to the search
257
+
258
+ Returns:
259
+ List[OutputData]: Search results
260
+ """
261
+ filter_conditions = []
262
+ filter_params = []
263
+
264
+ if filters:
265
+ for k, v in filters.items():
266
+ filter_conditions.append("JSON_EXTRACT(payload, %s) = %s")
267
+ filter_params.extend([f"$.{k}", json.dumps(v)])
268
+
269
+ filter_clause = "WHERE " + " AND ".join(filter_conditions) if filter_conditions else ""
270
+
271
+ # For simplicity, we'll compute cosine similarity in Python
272
+ # In production, you'd want to use MySQL stored procedures or UDFs
273
+ with self._get_cursor() as cur:
274
+ query_sql = f"""
275
+ SELECT id, vector, payload
276
+ FROM `{self.collection_name}`
277
+ {filter_clause}
278
+ """
279
+ cur.execute(query_sql, filter_params)
280
+ results = cur.fetchall()
281
+
282
+ # Calculate cosine similarity in Python
283
+ import numpy as np
284
+ query_vec = np.array(vectors)
285
+ scored_results = []
286
+
287
+ for row in results:
288
+ vec = np.array(json.loads(row['vector']))
289
+ # Cosine similarity
290
+ similarity = np.dot(query_vec, vec) / (np.linalg.norm(query_vec) * np.linalg.norm(vec))
291
+ distance = 1 - similarity
292
+ scored_results.append((row['id'], distance, row['payload']))
293
+
294
+ # Sort by distance and limit
295
+ scored_results.sort(key=lambda x: x[1])
296
+ scored_results = scored_results[:limit]
297
+
298
+ return [
299
+ OutputData(id=r[0], score=float(r[1]), payload=json.loads(r[2]) if isinstance(r[2], str) else r[2])
300
+ for r in scored_results
301
+ ]
302
+
303
+ def delete(self, vector_id: str):
304
+ """
305
+ Delete a vector by ID.
306
+
307
+ Args:
308
+ vector_id (str): ID of the vector to delete
309
+ """
310
+ with self._get_cursor(commit=True) as cur:
311
+ cur.execute(f"DELETE FROM `{self.collection_name}` WHERE id = %s", (vector_id,))
312
+
313
+ def update(
314
+ self,
315
+ vector_id: str,
316
+ vector: Optional[List[float]] = None,
317
+ payload: Optional[Dict] = None,
318
+ ):
319
+ """
320
+ Update a vector and its payload.
321
+
322
+ Args:
323
+ vector_id (str): ID of the vector to update
324
+ vector (List[float], optional): Updated vector
325
+ payload (Dict, optional): Updated payload
326
+ """
327
+ with self._get_cursor(commit=True) as cur:
328
+ if vector is not None:
329
+ cur.execute(
330
+ f"UPDATE `{self.collection_name}` SET vector = %s WHERE id = %s",
331
+ (json.dumps(vector), vector_id),
332
+ )
333
+ if payload is not None:
334
+ cur.execute(
335
+ f"UPDATE `{self.collection_name}` SET payload = %s WHERE id = %s",
336
+ (json.dumps(payload), vector_id),
337
+ )
338
+
339
+ def get(self, vector_id: str) -> Optional[OutputData]:
340
+ """
341
+ Retrieve a vector by ID.
342
+
343
+ Args:
344
+ vector_id (str): ID of the vector to retrieve
345
+
346
+ Returns:
347
+ OutputData: Retrieved vector or None if not found
348
+ """
349
+ with self._get_cursor() as cur:
350
+ cur.execute(
351
+ f"SELECT id, vector, payload FROM `{self.collection_name}` WHERE id = %s",
352
+ (vector_id,),
353
+ )
354
+ result = cur.fetchone()
355
+ if not result:
356
+ return None
357
+ return OutputData(
358
+ id=result['id'],
359
+ score=None,
360
+ payload=json.loads(result['payload']) if isinstance(result['payload'], str) else result['payload']
361
+ )
362
+
363
+ def list_cols(self) -> List[str]:
364
+ """
365
+ List all collections (tables).
366
+
367
+ Returns:
368
+ List[str]: List of collection names
369
+ """
370
+ with self._get_cursor() as cur:
371
+ cur.execute("SHOW TABLES")
372
+ return [row[f"Tables_in_{self.database}"] for row in cur.fetchall()]
373
+
374
+ def delete_col(self):
375
+ """Delete the collection (table)."""
376
+ with self._get_cursor(commit=True) as cur:
377
+ cur.execute(f"DROP TABLE IF EXISTS `{self.collection_name}`")
378
+ logger.info(f"Deleted collection '{self.collection_name}'")
379
+
380
+ def col_info(self) -> Dict[str, Any]:
381
+ """
382
+ Get information about the collection.
383
+
384
+ Returns:
385
+ Dict[str, Any]: Collection information
386
+ """
387
+ with self._get_cursor() as cur:
388
+ cur.execute("""
389
+ SELECT
390
+ TABLE_NAME as name,
391
+ TABLE_ROWS as count,
392
+ ROUND(((DATA_LENGTH + INDEX_LENGTH) / 1024 / 1024), 2) as size_mb
393
+ FROM information_schema.TABLES
394
+ WHERE TABLE_SCHEMA = %s AND TABLE_NAME = %s
395
+ """, (self.database, self.collection_name))
396
+ result = cur.fetchone()
397
+
398
+ if result:
399
+ return {
400
+ "name": result['name'],
401
+ "count": result['count'],
402
+ "size": f"{result['size_mb']} MB"
403
+ }
404
+ return {}
405
+
406
+ def list(
407
+ self,
408
+ filters: Optional[Dict] = None,
409
+ limit: int = 100
410
+ ) -> List[List[OutputData]]:
411
+ """
412
+ List all vectors in the collection.
413
+
414
+ Args:
415
+ filters (Dict, optional): Filters to apply
416
+ limit (int): Number of vectors to return
417
+
418
+ Returns:
419
+ List[List[OutputData]]: List of vectors
420
+ """
421
+ filter_conditions = []
422
+ filter_params = []
423
+
424
+ if filters:
425
+ for k, v in filters.items():
426
+ filter_conditions.append("JSON_EXTRACT(payload, %s) = %s")
427
+ filter_params.extend([f"$.{k}", json.dumps(v)])
428
+
429
+ filter_clause = "WHERE " + " AND ".join(filter_conditions) if filter_conditions else ""
430
+
431
+ with self._get_cursor() as cur:
432
+ cur.execute(
433
+ f"""
434
+ SELECT id, vector, payload
435
+ FROM `{self.collection_name}`
436
+ {filter_clause}
437
+ LIMIT %s
438
+ """,
439
+ (*filter_params, limit)
440
+ )
441
+ results = cur.fetchall()
442
+
443
+ return [[
444
+ OutputData(
445
+ id=r['id'],
446
+ score=None,
447
+ payload=json.loads(r['payload']) if isinstance(r['payload'], str) else r['payload']
448
+ ) for r in results
449
+ ]]
450
+
451
+ def reset(self):
452
+ """Reset the collection by deleting and recreating it."""
453
+ logger.warning(f"Resetting collection {self.collection_name}...")
454
+ self.delete_col()
455
+ self.create_col(name=self.collection_name, vector_size=self.embedding_model_dims)
456
+
457
+ def __del__(self):
458
+ """Close the connection pool when the object is deleted."""
459
+ try:
460
+ if hasattr(self, 'connection_pool') and self.connection_pool:
461
+ self.connection_pool.close()
462
+ except Exception:
463
+ pass