agentrun-mem0ai 0.0.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentrun_mem0/__init__.py +6 -0
- agentrun_mem0/client/__init__.py +0 -0
- agentrun_mem0/client/main.py +1747 -0
- agentrun_mem0/client/project.py +931 -0
- agentrun_mem0/client/utils.py +115 -0
- agentrun_mem0/configs/__init__.py +0 -0
- agentrun_mem0/configs/base.py +90 -0
- agentrun_mem0/configs/embeddings/__init__.py +0 -0
- agentrun_mem0/configs/embeddings/base.py +110 -0
- agentrun_mem0/configs/enums.py +7 -0
- agentrun_mem0/configs/llms/__init__.py +0 -0
- agentrun_mem0/configs/llms/anthropic.py +56 -0
- agentrun_mem0/configs/llms/aws_bedrock.py +192 -0
- agentrun_mem0/configs/llms/azure.py +57 -0
- agentrun_mem0/configs/llms/base.py +62 -0
- agentrun_mem0/configs/llms/deepseek.py +56 -0
- agentrun_mem0/configs/llms/lmstudio.py +59 -0
- agentrun_mem0/configs/llms/ollama.py +56 -0
- agentrun_mem0/configs/llms/openai.py +79 -0
- agentrun_mem0/configs/llms/vllm.py +56 -0
- agentrun_mem0/configs/prompts.py +459 -0
- agentrun_mem0/configs/rerankers/__init__.py +0 -0
- agentrun_mem0/configs/rerankers/base.py +17 -0
- agentrun_mem0/configs/rerankers/cohere.py +15 -0
- agentrun_mem0/configs/rerankers/config.py +12 -0
- agentrun_mem0/configs/rerankers/huggingface.py +17 -0
- agentrun_mem0/configs/rerankers/llm.py +48 -0
- agentrun_mem0/configs/rerankers/sentence_transformer.py +16 -0
- agentrun_mem0/configs/rerankers/zero_entropy.py +28 -0
- agentrun_mem0/configs/vector_stores/__init__.py +0 -0
- agentrun_mem0/configs/vector_stores/alibabacloud_mysql.py +64 -0
- agentrun_mem0/configs/vector_stores/aliyun_tablestore.py +32 -0
- agentrun_mem0/configs/vector_stores/azure_ai_search.py +57 -0
- agentrun_mem0/configs/vector_stores/azure_mysql.py +84 -0
- agentrun_mem0/configs/vector_stores/baidu.py +27 -0
- agentrun_mem0/configs/vector_stores/chroma.py +58 -0
- agentrun_mem0/configs/vector_stores/databricks.py +61 -0
- agentrun_mem0/configs/vector_stores/elasticsearch.py +65 -0
- agentrun_mem0/configs/vector_stores/faiss.py +37 -0
- agentrun_mem0/configs/vector_stores/langchain.py +30 -0
- agentrun_mem0/configs/vector_stores/milvus.py +42 -0
- agentrun_mem0/configs/vector_stores/mongodb.py +25 -0
- agentrun_mem0/configs/vector_stores/neptune.py +27 -0
- agentrun_mem0/configs/vector_stores/opensearch.py +41 -0
- agentrun_mem0/configs/vector_stores/pgvector.py +52 -0
- agentrun_mem0/configs/vector_stores/pinecone.py +55 -0
- agentrun_mem0/configs/vector_stores/qdrant.py +47 -0
- agentrun_mem0/configs/vector_stores/redis.py +24 -0
- agentrun_mem0/configs/vector_stores/s3_vectors.py +28 -0
- agentrun_mem0/configs/vector_stores/supabase.py +44 -0
- agentrun_mem0/configs/vector_stores/upstash_vector.py +34 -0
- agentrun_mem0/configs/vector_stores/valkey.py +15 -0
- agentrun_mem0/configs/vector_stores/vertex_ai_vector_search.py +28 -0
- agentrun_mem0/configs/vector_stores/weaviate.py +41 -0
- agentrun_mem0/embeddings/__init__.py +0 -0
- agentrun_mem0/embeddings/aws_bedrock.py +100 -0
- agentrun_mem0/embeddings/azure_openai.py +55 -0
- agentrun_mem0/embeddings/base.py +31 -0
- agentrun_mem0/embeddings/configs.py +30 -0
- agentrun_mem0/embeddings/gemini.py +39 -0
- agentrun_mem0/embeddings/huggingface.py +44 -0
- agentrun_mem0/embeddings/langchain.py +35 -0
- agentrun_mem0/embeddings/lmstudio.py +29 -0
- agentrun_mem0/embeddings/mock.py +11 -0
- agentrun_mem0/embeddings/ollama.py +53 -0
- agentrun_mem0/embeddings/openai.py +49 -0
- agentrun_mem0/embeddings/together.py +31 -0
- agentrun_mem0/embeddings/vertexai.py +64 -0
- agentrun_mem0/exceptions.py +503 -0
- agentrun_mem0/graphs/__init__.py +0 -0
- agentrun_mem0/graphs/configs.py +105 -0
- agentrun_mem0/graphs/neptune/__init__.py +0 -0
- agentrun_mem0/graphs/neptune/base.py +497 -0
- agentrun_mem0/graphs/neptune/neptunedb.py +511 -0
- agentrun_mem0/graphs/neptune/neptunegraph.py +474 -0
- agentrun_mem0/graphs/tools.py +371 -0
- agentrun_mem0/graphs/utils.py +97 -0
- agentrun_mem0/llms/__init__.py +0 -0
- agentrun_mem0/llms/anthropic.py +87 -0
- agentrun_mem0/llms/aws_bedrock.py +665 -0
- agentrun_mem0/llms/azure_openai.py +141 -0
- agentrun_mem0/llms/azure_openai_structured.py +91 -0
- agentrun_mem0/llms/base.py +131 -0
- agentrun_mem0/llms/configs.py +34 -0
- agentrun_mem0/llms/deepseek.py +107 -0
- agentrun_mem0/llms/gemini.py +201 -0
- agentrun_mem0/llms/groq.py +88 -0
- agentrun_mem0/llms/langchain.py +94 -0
- agentrun_mem0/llms/litellm.py +87 -0
- agentrun_mem0/llms/lmstudio.py +114 -0
- agentrun_mem0/llms/ollama.py +117 -0
- agentrun_mem0/llms/openai.py +147 -0
- agentrun_mem0/llms/openai_structured.py +52 -0
- agentrun_mem0/llms/sarvam.py +89 -0
- agentrun_mem0/llms/together.py +88 -0
- agentrun_mem0/llms/vllm.py +107 -0
- agentrun_mem0/llms/xai.py +52 -0
- agentrun_mem0/memory/__init__.py +0 -0
- agentrun_mem0/memory/base.py +63 -0
- agentrun_mem0/memory/graph_memory.py +698 -0
- agentrun_mem0/memory/kuzu_memory.py +713 -0
- agentrun_mem0/memory/main.py +2229 -0
- agentrun_mem0/memory/memgraph_memory.py +689 -0
- agentrun_mem0/memory/setup.py +56 -0
- agentrun_mem0/memory/storage.py +218 -0
- agentrun_mem0/memory/telemetry.py +90 -0
- agentrun_mem0/memory/utils.py +208 -0
- agentrun_mem0/proxy/__init__.py +0 -0
- agentrun_mem0/proxy/main.py +189 -0
- agentrun_mem0/reranker/__init__.py +9 -0
- agentrun_mem0/reranker/base.py +20 -0
- agentrun_mem0/reranker/cohere_reranker.py +85 -0
- agentrun_mem0/reranker/huggingface_reranker.py +147 -0
- agentrun_mem0/reranker/llm_reranker.py +142 -0
- agentrun_mem0/reranker/sentence_transformer_reranker.py +107 -0
- agentrun_mem0/reranker/zero_entropy_reranker.py +96 -0
- agentrun_mem0/utils/factory.py +283 -0
- agentrun_mem0/utils/gcp_auth.py +167 -0
- agentrun_mem0/vector_stores/__init__.py +0 -0
- agentrun_mem0/vector_stores/alibabacloud_mysql.py +547 -0
- agentrun_mem0/vector_stores/aliyun_tablestore.py +252 -0
- agentrun_mem0/vector_stores/azure_ai_search.py +396 -0
- agentrun_mem0/vector_stores/azure_mysql.py +463 -0
- agentrun_mem0/vector_stores/baidu.py +368 -0
- agentrun_mem0/vector_stores/base.py +58 -0
- agentrun_mem0/vector_stores/chroma.py +332 -0
- agentrun_mem0/vector_stores/configs.py +67 -0
- agentrun_mem0/vector_stores/databricks.py +761 -0
- agentrun_mem0/vector_stores/elasticsearch.py +237 -0
- agentrun_mem0/vector_stores/faiss.py +479 -0
- agentrun_mem0/vector_stores/langchain.py +180 -0
- agentrun_mem0/vector_stores/milvus.py +250 -0
- agentrun_mem0/vector_stores/mongodb.py +310 -0
- agentrun_mem0/vector_stores/neptune_analytics.py +467 -0
- agentrun_mem0/vector_stores/opensearch.py +292 -0
- agentrun_mem0/vector_stores/pgvector.py +404 -0
- agentrun_mem0/vector_stores/pinecone.py +382 -0
- agentrun_mem0/vector_stores/qdrant.py +270 -0
- agentrun_mem0/vector_stores/redis.py +295 -0
- agentrun_mem0/vector_stores/s3_vectors.py +176 -0
- agentrun_mem0/vector_stores/supabase.py +237 -0
- agentrun_mem0/vector_stores/upstash_vector.py +293 -0
- agentrun_mem0/vector_stores/valkey.py +824 -0
- agentrun_mem0/vector_stores/vertex_ai_vector_search.py +635 -0
- agentrun_mem0/vector_stores/weaviate.py +343 -0
- agentrun_mem0ai-0.0.11.data/data/README.md +205 -0
- agentrun_mem0ai-0.0.11.dist-info/METADATA +277 -0
- agentrun_mem0ai-0.0.11.dist-info/RECORD +150 -0
- agentrun_mem0ai-0.0.11.dist-info/WHEEL +4 -0
- agentrun_mem0ai-0.0.11.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,547 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
from contextlib import contextmanager
|
|
4
|
+
from typing import Any, List, Optional
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel
|
|
7
|
+
|
|
8
|
+
try:
|
|
9
|
+
import mysql.connector
|
|
10
|
+
except ImportError as e:
|
|
11
|
+
raise ImportError(
|
|
12
|
+
"mysql.connector is not available. Please install it using 'pip install mysql-connector-python'"
|
|
13
|
+
) from e
|
|
14
|
+
|
|
15
|
+
from agentrun_mem0.vector_stores.base import VectorStoreBase
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class OutputData(BaseModel):
|
|
21
|
+
id: Optional[str]
|
|
22
|
+
score: Optional[float]
|
|
23
|
+
payload: Optional[dict]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class MySQLVector(VectorStoreBase):
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
dbname,
|
|
30
|
+
collection_name,
|
|
31
|
+
embedding_model_dims,
|
|
32
|
+
user,
|
|
33
|
+
password,
|
|
34
|
+
host,
|
|
35
|
+
port,
|
|
36
|
+
distance_function="euclidean",
|
|
37
|
+
m_value=16,
|
|
38
|
+
ssl_disabled=False,
|
|
39
|
+
ssl_ca=None,
|
|
40
|
+
ssl_cert=None,
|
|
41
|
+
ssl_key=None,
|
|
42
|
+
connection_string=None,
|
|
43
|
+
charset="utf8mb4",
|
|
44
|
+
autocommit=True,
|
|
45
|
+
):
|
|
46
|
+
"""
|
|
47
|
+
Initialize the AlibabaCloud MySQL Vector database.
|
|
48
|
+
Args:
|
|
49
|
+
dbname (str): Database name
|
|
50
|
+
collection_name (str): Collection name
|
|
51
|
+
embedding_model_dims (int): Dimension of the embedding vector
|
|
52
|
+
user (str): Database user
|
|
53
|
+
password (str): Database password
|
|
54
|
+
host (str): Database host
|
|
55
|
+
port (int): Database port
|
|
56
|
+
distance_function (str): Distance function for vector index ('euclidean' or 'cosine')
|
|
57
|
+
m_value (int): M parameter for HNSW index (3-200). Higher values = more accurate but slower
|
|
58
|
+
ssl_disabled (bool): Disable SSL connection
|
|
59
|
+
ssl_ca (str, optional): SSL CA certificate file path
|
|
60
|
+
ssl_cert (str, optional): SSL certificate file path
|
|
61
|
+
ssl_key (str, optional): SSL key file path
|
|
62
|
+
connection_string (str, optional): AlibabaCloud MySQL connection string (overrides individual connection parameters)
|
|
63
|
+
charset (str): Character set for the connection
|
|
64
|
+
autocommit (bool): Enable autocommit mode
|
|
65
|
+
"""
|
|
66
|
+
self.collection_name = collection_name
|
|
67
|
+
self.embedding_model_dims = embedding_model_dims
|
|
68
|
+
self.distance_function = distance_function
|
|
69
|
+
self.m_value = m_value
|
|
70
|
+
|
|
71
|
+
# Connection parameters
|
|
72
|
+
if connection_string:
|
|
73
|
+
# Parse connection string (simplified parsing)
|
|
74
|
+
# Format: mysql://user:password@host:port/database
|
|
75
|
+
import urllib.parse
|
|
76
|
+
parsed = urllib.parse.urlparse(connection_string)
|
|
77
|
+
self.connection_params = {
|
|
78
|
+
'user': parsed.username,
|
|
79
|
+
'password': parsed.password,
|
|
80
|
+
'host': parsed.hostname,
|
|
81
|
+
'port': parsed.port or 3306,
|
|
82
|
+
'database': parsed.path.lstrip('/') or dbname,
|
|
83
|
+
'charset': charset,
|
|
84
|
+
'autocommit': autocommit,
|
|
85
|
+
}
|
|
86
|
+
else:
|
|
87
|
+
self.connection_params = {
|
|
88
|
+
'user': user,
|
|
89
|
+
'password': password,
|
|
90
|
+
'host': host,
|
|
91
|
+
'port': port or 3306,
|
|
92
|
+
'database': dbname,
|
|
93
|
+
'charset': charset,
|
|
94
|
+
'autocommit': autocommit,
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
# SSL configuration
|
|
98
|
+
if not ssl_disabled:
|
|
99
|
+
ssl_config = {}
|
|
100
|
+
if ssl_ca:
|
|
101
|
+
ssl_config['ca'] = ssl_ca
|
|
102
|
+
if ssl_cert:
|
|
103
|
+
ssl_config['cert'] = ssl_cert
|
|
104
|
+
if ssl_key:
|
|
105
|
+
ssl_config['key'] = ssl_key
|
|
106
|
+
if ssl_config:
|
|
107
|
+
self.connection_params['ssl'] = ssl_config
|
|
108
|
+
|
|
109
|
+
# Test connection and create collection if needed
|
|
110
|
+
collections = self.list_cols()
|
|
111
|
+
if collection_name not in collections:
|
|
112
|
+
self.create_col()
|
|
113
|
+
|
|
114
|
+
@contextmanager
|
|
115
|
+
def _get_connection(self):
|
|
116
|
+
"""
|
|
117
|
+
Context manager to get a database connection.
|
|
118
|
+
"""
|
|
119
|
+
conn = None
|
|
120
|
+
try:
|
|
121
|
+
conn = mysql.connector.connect(**self.connection_params)
|
|
122
|
+
yield conn
|
|
123
|
+
except Exception as e:
|
|
124
|
+
logger.error(f"Database connection error: {e}")
|
|
125
|
+
if conn:
|
|
126
|
+
conn.rollback()
|
|
127
|
+
raise
|
|
128
|
+
finally:
|
|
129
|
+
if conn:
|
|
130
|
+
conn.close()
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def create_col(self) -> None:
|
|
134
|
+
"""
|
|
135
|
+
Create a new collection (table in AlibabaCloud MySQL).
|
|
136
|
+
Will also initialize vector search index.
|
|
137
|
+
"""
|
|
138
|
+
with self._get_connection() as conn:
|
|
139
|
+
cursor = conn.cursor()
|
|
140
|
+
try:
|
|
141
|
+
# Create table with VECTOR column
|
|
142
|
+
cursor.execute(f"""
|
|
143
|
+
CREATE TABLE IF NOT EXISTS {self.collection_name} (
|
|
144
|
+
id VARCHAR(255) PRIMARY KEY,
|
|
145
|
+
embedding VECTOR({self.embedding_model_dims}) NOT NULL,
|
|
146
|
+
payload JSON,
|
|
147
|
+
VECTOR INDEX (embedding) M={self.m_value} DISTANCE={self.distance_function}
|
|
148
|
+
)
|
|
149
|
+
""")
|
|
150
|
+
conn.commit()
|
|
151
|
+
logger.info(f"Created collection {self.collection_name} with vector index")
|
|
152
|
+
except Exception as e:
|
|
153
|
+
logger.error(f"Error creating collection: {e}")
|
|
154
|
+
raise
|
|
155
|
+
finally:
|
|
156
|
+
cursor.close()
|
|
157
|
+
|
|
158
|
+
def insert(self, vectors: List[List[float]], payloads=None, ids=None) -> None:
|
|
159
|
+
"""
|
|
160
|
+
Insert vectors into the collection.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
vectors: List of vectors to insert
|
|
164
|
+
payloads: List of payload dictionaries
|
|
165
|
+
ids: List of IDs for the vectors
|
|
166
|
+
"""
|
|
167
|
+
logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}")
|
|
168
|
+
|
|
169
|
+
if payloads is None:
|
|
170
|
+
payloads = [{}] * len(vectors)
|
|
171
|
+
if ids is None:
|
|
172
|
+
import uuid
|
|
173
|
+
ids = [str(uuid.uuid4()) for _ in vectors]
|
|
174
|
+
|
|
175
|
+
with self._get_connection() as conn:
|
|
176
|
+
cursor = conn.cursor()
|
|
177
|
+
try:
|
|
178
|
+
# Insert vectors one by one using VEC_FromText function
|
|
179
|
+
for vector_id, vector, payload in zip(ids, vectors, payloads):
|
|
180
|
+
# Convert vector to string format for VEC_FromText
|
|
181
|
+
vector_str = '[' + ','.join(map(str, vector)) + ']'
|
|
182
|
+
payload_json = json.dumps(payload) if payload else None
|
|
183
|
+
|
|
184
|
+
cursor.execute(f"""
|
|
185
|
+
INSERT INTO {self.collection_name} (id, embedding, payload)
|
|
186
|
+
VALUES (%s, VEC_FromText(%s), %s)
|
|
187
|
+
""", (vector_id, vector_str, payload_json))
|
|
188
|
+
|
|
189
|
+
conn.commit()
|
|
190
|
+
except Exception as e:
|
|
191
|
+
logger.error(f"Error inserting vectors: {e}")
|
|
192
|
+
conn.rollback()
|
|
193
|
+
raise
|
|
194
|
+
finally:
|
|
195
|
+
cursor.close()
|
|
196
|
+
|
|
197
|
+
def search(
|
|
198
|
+
self,
|
|
199
|
+
query: str,
|
|
200
|
+
vectors: List[float],
|
|
201
|
+
limit: Optional[int] = 5,
|
|
202
|
+
filters: Optional[dict] = None,
|
|
203
|
+
) -> List[OutputData]:
|
|
204
|
+
"""
|
|
205
|
+
Search for similar vectors using AlibabaCloud MySQL Vector distance functions.
|
|
206
|
+
Args:
|
|
207
|
+
query (str): Query string (for logging)
|
|
208
|
+
vectors (List[float]): Query vector
|
|
209
|
+
limit (int, optional): Number of results to return. Defaults to 5.
|
|
210
|
+
filters (Dict, optional): Filters to apply to the search. Defaults to None.
|
|
211
|
+
Returns:
|
|
212
|
+
List[OutputData]: Search results.
|
|
213
|
+
"""
|
|
214
|
+
filter_conditions = []
|
|
215
|
+
filter_params = []
|
|
216
|
+
|
|
217
|
+
if filters:
|
|
218
|
+
for k, v in filters.items():
|
|
219
|
+
filter_conditions.append("JSON_EXTRACT(payload, %s) = %s")
|
|
220
|
+
filter_params.extend([f"$.{k}", str(v)])
|
|
221
|
+
|
|
222
|
+
filter_clause = "WHERE " + " AND ".join(filter_conditions) if filter_conditions else ""
|
|
223
|
+
|
|
224
|
+
# Convert query vector to string format for VEC_FromText
|
|
225
|
+
query_vector_str = '[' + ','.join(map(str, vectors)) + ']'
|
|
226
|
+
|
|
227
|
+
with self._get_connection() as conn:
|
|
228
|
+
cursor = conn.cursor()
|
|
229
|
+
try:
|
|
230
|
+
# Use VEC_DISTANCE function which automatically uses the appropriate distance function
|
|
231
|
+
if filter_conditions:
|
|
232
|
+
query_params = [query_vector_str] + filter_params + [limit]
|
|
233
|
+
else:
|
|
234
|
+
query_params = [query_vector_str, limit]
|
|
235
|
+
|
|
236
|
+
distance_func = f"VEC_DISTANCE_{self.distance_function.upper()}"
|
|
237
|
+
|
|
238
|
+
logger.debug(f"SQL query: SELECT id, {distance_func}(embedding, VEC_FromText(%s)) AS distance, payload FROM {self.collection_name} {filter_clause} ORDER BY distance LIMIT %s")
|
|
239
|
+
logger.debug(f"Query params: {query_params}")
|
|
240
|
+
|
|
241
|
+
cursor.execute(f"""
|
|
242
|
+
SELECT id, {distance_func}(embedding, VEC_FromText(%s)) AS distance, payload
|
|
243
|
+
FROM {self.collection_name}
|
|
244
|
+
{filter_clause}
|
|
245
|
+
ORDER BY distance
|
|
246
|
+
LIMIT %s
|
|
247
|
+
""", query_params)
|
|
248
|
+
|
|
249
|
+
results = cursor.fetchall()
|
|
250
|
+
return [
|
|
251
|
+
OutputData(
|
|
252
|
+
id=str(r[0]),
|
|
253
|
+
score=float(r[1]),
|
|
254
|
+
payload=json.loads(r[2]) if r[2] else {}
|
|
255
|
+
)
|
|
256
|
+
for r in results
|
|
257
|
+
]
|
|
258
|
+
except Exception as e:
|
|
259
|
+
logger.error(f"Error searching vectors: {e}")
|
|
260
|
+
raise
|
|
261
|
+
finally:
|
|
262
|
+
cursor.close()
|
|
263
|
+
|
|
264
|
+
def delete(self, vector_id: str) -> None:
|
|
265
|
+
"""
|
|
266
|
+
Delete a vector by ID.
|
|
267
|
+
Args:
|
|
268
|
+
vector_id (str): ID of the vector to delete.
|
|
269
|
+
"""
|
|
270
|
+
with self._get_connection() as conn:
|
|
271
|
+
cursor = conn.cursor()
|
|
272
|
+
try:
|
|
273
|
+
cursor.execute(f"DELETE FROM {self.collection_name} WHERE id = %s", (vector_id,))
|
|
274
|
+
conn.commit()
|
|
275
|
+
except Exception as e:
|
|
276
|
+
logger.error(f"Error deleting vector: {e}")
|
|
277
|
+
raise
|
|
278
|
+
finally:
|
|
279
|
+
cursor.close()
|
|
280
|
+
|
|
281
|
+
def update(
|
|
282
|
+
self,
|
|
283
|
+
vector_id: str,
|
|
284
|
+
vector: Optional[List[float]] = None,
|
|
285
|
+
payload: Optional[dict] = None,
|
|
286
|
+
) -> None:
|
|
287
|
+
"""
|
|
288
|
+
Update a vector and its payload.
|
|
289
|
+
Args:
|
|
290
|
+
vector_id (str): ID of the vector to update.
|
|
291
|
+
vector (List[float], optional): Updated vector.
|
|
292
|
+
payload (Dict, optional): Updated payload.
|
|
293
|
+
"""
|
|
294
|
+
with self._get_connection() as conn:
|
|
295
|
+
cursor = conn.cursor()
|
|
296
|
+
try:
|
|
297
|
+
if vector:
|
|
298
|
+
vector_str = '[' + ','.join(map(str, vector)) + ']'
|
|
299
|
+
cursor.execute(
|
|
300
|
+
f"UPDATE {self.collection_name} SET embedding = VEC_FromText(%s) WHERE id = %s",
|
|
301
|
+
(vector_str, vector_id),
|
|
302
|
+
)
|
|
303
|
+
if payload:
|
|
304
|
+
cursor.execute(
|
|
305
|
+
f"UPDATE {self.collection_name} SET payload = %s WHERE id = %s",
|
|
306
|
+
(json.dumps(payload), vector_id),
|
|
307
|
+
)
|
|
308
|
+
conn.commit()
|
|
309
|
+
except Exception as e:
|
|
310
|
+
logger.error(f"Error updating vector: {e}")
|
|
311
|
+
conn.rollback()
|
|
312
|
+
raise
|
|
313
|
+
finally:
|
|
314
|
+
cursor.close()
|
|
315
|
+
|
|
316
|
+
def get(self, vector_id: str) -> Optional[OutputData]:
|
|
317
|
+
"""
|
|
318
|
+
Retrieve a vector by ID.
|
|
319
|
+
Args:
|
|
320
|
+
vector_id (str): ID of the vector to retrieve.
|
|
321
|
+
Returns:
|
|
322
|
+
OutputData: Retrieved vector data or None if not found.
|
|
323
|
+
"""
|
|
324
|
+
with self._get_connection() as conn:
|
|
325
|
+
cursor = conn.cursor()
|
|
326
|
+
try:
|
|
327
|
+
cursor.execute(
|
|
328
|
+
f"SELECT id, embedding, payload FROM {self.collection_name} WHERE id = %s",
|
|
329
|
+
(vector_id,),
|
|
330
|
+
)
|
|
331
|
+
result = cursor.fetchone()
|
|
332
|
+
if not result:
|
|
333
|
+
return None
|
|
334
|
+
|
|
335
|
+
payload = json.loads(result[2]) if result[2] else {}
|
|
336
|
+
return OutputData(id=str(result[0]), score=None, payload=payload)
|
|
337
|
+
except Exception as e:
|
|
338
|
+
logger.error(f"Error retrieving vector: {e}")
|
|
339
|
+
raise
|
|
340
|
+
finally:
|
|
341
|
+
cursor.close()
|
|
342
|
+
|
|
343
|
+
def list_cols(self) -> List[str]:
|
|
344
|
+
"""
|
|
345
|
+
List all collections (tables).
|
|
346
|
+
Returns:
|
|
347
|
+
List[str]: List of collection names.
|
|
348
|
+
"""
|
|
349
|
+
with self._get_connection() as conn:
|
|
350
|
+
cursor = conn.cursor()
|
|
351
|
+
try:
|
|
352
|
+
cursor.execute("SHOW TABLES")
|
|
353
|
+
return [row[0] for row in cursor.fetchall()]
|
|
354
|
+
except Exception as e:
|
|
355
|
+
logger.error(f"Error listing collections: {e}")
|
|
356
|
+
raise
|
|
357
|
+
finally:
|
|
358
|
+
cursor.close()
|
|
359
|
+
|
|
360
|
+
def delete_col(self) -> None:
|
|
361
|
+
"""Delete the collection (table)."""
|
|
362
|
+
with self._get_connection() as conn:
|
|
363
|
+
cursor = conn.cursor()
|
|
364
|
+
try:
|
|
365
|
+
cursor.execute(f"DROP TABLE IF EXISTS {self.collection_name}")
|
|
366
|
+
conn.commit()
|
|
367
|
+
logger.info(f"Deleted collection {self.collection_name}")
|
|
368
|
+
except Exception as e:
|
|
369
|
+
logger.error(f"Error deleting collection: {e}")
|
|
370
|
+
raise
|
|
371
|
+
finally:
|
|
372
|
+
cursor.close()
|
|
373
|
+
|
|
374
|
+
def col_info(self) -> dict[str, Any]:
|
|
375
|
+
"""
|
|
376
|
+
Get information about the collection.
|
|
377
|
+
Returns:
|
|
378
|
+
Dict[str, Any]: Collection information.
|
|
379
|
+
"""
|
|
380
|
+
with self._get_connection() as conn:
|
|
381
|
+
cursor = conn.cursor()
|
|
382
|
+
try:
|
|
383
|
+
# Get row count
|
|
384
|
+
cursor.execute(f"SELECT COUNT(*) FROM {self.collection_name}")
|
|
385
|
+
row_count = cursor.fetchone()[0]
|
|
386
|
+
|
|
387
|
+
# Get table size information
|
|
388
|
+
cursor.execute("""
|
|
389
|
+
SELECT
|
|
390
|
+
table_name,
|
|
391
|
+
ROUND(((data_length + index_length) / 1024 / 1024), 2) AS total_size_mb
|
|
392
|
+
FROM information_schema.tables
|
|
393
|
+
WHERE table_schema = DATABASE() AND table_name = %s
|
|
394
|
+
""", (self.collection_name,))
|
|
395
|
+
|
|
396
|
+
result = cursor.fetchone()
|
|
397
|
+
size_mb = result[1] if result else 0
|
|
398
|
+
|
|
399
|
+
return {
|
|
400
|
+
"name": self.collection_name,
|
|
401
|
+
"count": row_count,
|
|
402
|
+
"size": f"{size_mb} MB"
|
|
403
|
+
}
|
|
404
|
+
except Exception as e:
|
|
405
|
+
logger.error(f"Error getting collection info: {e}")
|
|
406
|
+
raise
|
|
407
|
+
finally:
|
|
408
|
+
cursor.close()
|
|
409
|
+
|
|
410
|
+
def list(
|
|
411
|
+
self,
|
|
412
|
+
filters: Optional[dict] = None,
|
|
413
|
+
limit: Optional[int] = 100
|
|
414
|
+
) -> List[OutputData]:
|
|
415
|
+
"""
|
|
416
|
+
List all vectors in the collection.
|
|
417
|
+
Args:
|
|
418
|
+
filters (Dict, optional): Filters to apply to the list.
|
|
419
|
+
limit (int, optional): Number of vectors to return. Defaults to 100.
|
|
420
|
+
Returns:
|
|
421
|
+
List[OutputData]: List of vectors.
|
|
422
|
+
"""
|
|
423
|
+
filter_conditions = []
|
|
424
|
+
filter_params = []
|
|
425
|
+
|
|
426
|
+
if filters:
|
|
427
|
+
for k, v in filters.items():
|
|
428
|
+
filter_conditions.append("JSON_EXTRACT(payload, %s) = %s")
|
|
429
|
+
filter_params.extend([f"$.{k}", str(v)])
|
|
430
|
+
|
|
431
|
+
filter_clause = "WHERE " + " AND ".join(filter_conditions) if filter_conditions else ""
|
|
432
|
+
|
|
433
|
+
with self._get_connection() as conn:
|
|
434
|
+
cursor = conn.cursor()
|
|
435
|
+
try:
|
|
436
|
+
cursor.execute(f"""
|
|
437
|
+
SELECT id, embedding, payload
|
|
438
|
+
FROM {self.collection_name}
|
|
439
|
+
{filter_clause}
|
|
440
|
+
LIMIT %s
|
|
441
|
+
""", (*filter_params, limit))
|
|
442
|
+
|
|
443
|
+
results = cursor.fetchall()
|
|
444
|
+
return [
|
|
445
|
+
OutputData(
|
|
446
|
+
id=str(r[0]),
|
|
447
|
+
score=None,
|
|
448
|
+
payload=json.loads(r[2]) if r[2] else {}
|
|
449
|
+
)
|
|
450
|
+
for r in results
|
|
451
|
+
]
|
|
452
|
+
except Exception as e:
|
|
453
|
+
logger.error(f"Error listing vectors: {e}")
|
|
454
|
+
raise
|
|
455
|
+
finally:
|
|
456
|
+
cursor.close()
|
|
457
|
+
|
|
458
|
+
def list_paginated(
|
|
459
|
+
self,
|
|
460
|
+
filters: Optional[dict] = None,
|
|
461
|
+
limit: int = 100,
|
|
462
|
+
next_token: Optional[str] = None
|
|
463
|
+
) -> dict:
|
|
464
|
+
"""
|
|
465
|
+
List vectors with pagination support.
|
|
466
|
+
|
|
467
|
+
Args:
|
|
468
|
+
filters (Dict, optional): Filters to apply to the list.
|
|
469
|
+
limit (int): Maximum number of vectors to return (max 1000). Defaults to 100.
|
|
470
|
+
next_token (str, optional): Token for pagination, pass the token from previous
|
|
471
|
+
response to get next page. This is the last id from
|
|
472
|
+
the previous page.
|
|
473
|
+
|
|
474
|
+
Returns:
|
|
475
|
+
dict: {
|
|
476
|
+
"memories": list of OutputData objects,
|
|
477
|
+
"next_token": token for next page (None if no more pages),
|
|
478
|
+
"has_more": boolean indicating if there are more pages
|
|
479
|
+
}
|
|
480
|
+
"""
|
|
481
|
+
# Limit max to 1000
|
|
482
|
+
limit = min(limit, 1000)
|
|
483
|
+
|
|
484
|
+
filter_conditions = []
|
|
485
|
+
filter_params = []
|
|
486
|
+
|
|
487
|
+
if filters:
|
|
488
|
+
for k, v in filters.items():
|
|
489
|
+
filter_conditions.append("JSON_EXTRACT(payload, %s) = %s")
|
|
490
|
+
filter_params.extend([f"$.{k}", str(v)])
|
|
491
|
+
|
|
492
|
+
# Add pagination condition using id as cursor
|
|
493
|
+
if next_token:
|
|
494
|
+
filter_conditions.append("id > %s")
|
|
495
|
+
filter_params.append(next_token)
|
|
496
|
+
|
|
497
|
+
filter_clause = "WHERE " + " AND ".join(filter_conditions) if filter_conditions else ""
|
|
498
|
+
|
|
499
|
+
with self._get_connection() as conn:
|
|
500
|
+
cursor = conn.cursor()
|
|
501
|
+
try:
|
|
502
|
+
# Fetch one extra record to check if there are more pages
|
|
503
|
+
cursor.execute(f"""
|
|
504
|
+
SELECT id, embedding, payload
|
|
505
|
+
FROM {self.collection_name}
|
|
506
|
+
{filter_clause}
|
|
507
|
+
ORDER BY id
|
|
508
|
+
LIMIT %s
|
|
509
|
+
""", (*filter_params, limit + 1))
|
|
510
|
+
|
|
511
|
+
results = cursor.fetchall()
|
|
512
|
+
|
|
513
|
+
# Check if there are more pages
|
|
514
|
+
has_more = len(results) > limit
|
|
515
|
+
|
|
516
|
+
# Only return up to limit records
|
|
517
|
+
if has_more:
|
|
518
|
+
results = results[:limit]
|
|
519
|
+
|
|
520
|
+
memories = [
|
|
521
|
+
OutputData(
|
|
522
|
+
id=str(r[0]),
|
|
523
|
+
score=None,
|
|
524
|
+
payload=json.loads(r[2]) if r[2] else {}
|
|
525
|
+
)
|
|
526
|
+
for r in results
|
|
527
|
+
]
|
|
528
|
+
|
|
529
|
+
# The next_token is the last id in the current page
|
|
530
|
+
new_next_token = memories[-1].id if memories and has_more else None
|
|
531
|
+
|
|
532
|
+
return {
|
|
533
|
+
"memories": memories,
|
|
534
|
+
"next_token": new_next_token,
|
|
535
|
+
"has_more": has_more,
|
|
536
|
+
}
|
|
537
|
+
except Exception as e:
|
|
538
|
+
logger.error(f"Error listing vectors with pagination: {e}")
|
|
539
|
+
raise
|
|
540
|
+
finally:
|
|
541
|
+
cursor.close()
|
|
542
|
+
|
|
543
|
+
def reset(self) -> None:
|
|
544
|
+
"""Reset the collection by deleting and recreating it."""
|
|
545
|
+
logger.warning(f"Resetting collection {self.collection_name}...")
|
|
546
|
+
self.delete_col()
|
|
547
|
+
self.create_col()
|