agentrun-mem0ai 0.0.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentrun_mem0/__init__.py +6 -0
- agentrun_mem0/client/__init__.py +0 -0
- agentrun_mem0/client/main.py +1747 -0
- agentrun_mem0/client/project.py +931 -0
- agentrun_mem0/client/utils.py +115 -0
- agentrun_mem0/configs/__init__.py +0 -0
- agentrun_mem0/configs/base.py +90 -0
- agentrun_mem0/configs/embeddings/__init__.py +0 -0
- agentrun_mem0/configs/embeddings/base.py +110 -0
- agentrun_mem0/configs/enums.py +7 -0
- agentrun_mem0/configs/llms/__init__.py +0 -0
- agentrun_mem0/configs/llms/anthropic.py +56 -0
- agentrun_mem0/configs/llms/aws_bedrock.py +192 -0
- agentrun_mem0/configs/llms/azure.py +57 -0
- agentrun_mem0/configs/llms/base.py +62 -0
- agentrun_mem0/configs/llms/deepseek.py +56 -0
- agentrun_mem0/configs/llms/lmstudio.py +59 -0
- agentrun_mem0/configs/llms/ollama.py +56 -0
- agentrun_mem0/configs/llms/openai.py +79 -0
- agentrun_mem0/configs/llms/vllm.py +56 -0
- agentrun_mem0/configs/prompts.py +459 -0
- agentrun_mem0/configs/rerankers/__init__.py +0 -0
- agentrun_mem0/configs/rerankers/base.py +17 -0
- agentrun_mem0/configs/rerankers/cohere.py +15 -0
- agentrun_mem0/configs/rerankers/config.py +12 -0
- agentrun_mem0/configs/rerankers/huggingface.py +17 -0
- agentrun_mem0/configs/rerankers/llm.py +48 -0
- agentrun_mem0/configs/rerankers/sentence_transformer.py +16 -0
- agentrun_mem0/configs/rerankers/zero_entropy.py +28 -0
- agentrun_mem0/configs/vector_stores/__init__.py +0 -0
- agentrun_mem0/configs/vector_stores/alibabacloud_mysql.py +64 -0
- agentrun_mem0/configs/vector_stores/aliyun_tablestore.py +32 -0
- agentrun_mem0/configs/vector_stores/azure_ai_search.py +57 -0
- agentrun_mem0/configs/vector_stores/azure_mysql.py +84 -0
- agentrun_mem0/configs/vector_stores/baidu.py +27 -0
- agentrun_mem0/configs/vector_stores/chroma.py +58 -0
- agentrun_mem0/configs/vector_stores/databricks.py +61 -0
- agentrun_mem0/configs/vector_stores/elasticsearch.py +65 -0
- agentrun_mem0/configs/vector_stores/faiss.py +37 -0
- agentrun_mem0/configs/vector_stores/langchain.py +30 -0
- agentrun_mem0/configs/vector_stores/milvus.py +42 -0
- agentrun_mem0/configs/vector_stores/mongodb.py +25 -0
- agentrun_mem0/configs/vector_stores/neptune.py +27 -0
- agentrun_mem0/configs/vector_stores/opensearch.py +41 -0
- agentrun_mem0/configs/vector_stores/pgvector.py +52 -0
- agentrun_mem0/configs/vector_stores/pinecone.py +55 -0
- agentrun_mem0/configs/vector_stores/qdrant.py +47 -0
- agentrun_mem0/configs/vector_stores/redis.py +24 -0
- agentrun_mem0/configs/vector_stores/s3_vectors.py +28 -0
- agentrun_mem0/configs/vector_stores/supabase.py +44 -0
- agentrun_mem0/configs/vector_stores/upstash_vector.py +34 -0
- agentrun_mem0/configs/vector_stores/valkey.py +15 -0
- agentrun_mem0/configs/vector_stores/vertex_ai_vector_search.py +28 -0
- agentrun_mem0/configs/vector_stores/weaviate.py +41 -0
- agentrun_mem0/embeddings/__init__.py +0 -0
- agentrun_mem0/embeddings/aws_bedrock.py +100 -0
- agentrun_mem0/embeddings/azure_openai.py +55 -0
- agentrun_mem0/embeddings/base.py +31 -0
- agentrun_mem0/embeddings/configs.py +30 -0
- agentrun_mem0/embeddings/gemini.py +39 -0
- agentrun_mem0/embeddings/huggingface.py +44 -0
- agentrun_mem0/embeddings/langchain.py +35 -0
- agentrun_mem0/embeddings/lmstudio.py +29 -0
- agentrun_mem0/embeddings/mock.py +11 -0
- agentrun_mem0/embeddings/ollama.py +53 -0
- agentrun_mem0/embeddings/openai.py +49 -0
- agentrun_mem0/embeddings/together.py +31 -0
- agentrun_mem0/embeddings/vertexai.py +64 -0
- agentrun_mem0/exceptions.py +503 -0
- agentrun_mem0/graphs/__init__.py +0 -0
- agentrun_mem0/graphs/configs.py +105 -0
- agentrun_mem0/graphs/neptune/__init__.py +0 -0
- agentrun_mem0/graphs/neptune/base.py +497 -0
- agentrun_mem0/graphs/neptune/neptunedb.py +511 -0
- agentrun_mem0/graphs/neptune/neptunegraph.py +474 -0
- agentrun_mem0/graphs/tools.py +371 -0
- agentrun_mem0/graphs/utils.py +97 -0
- agentrun_mem0/llms/__init__.py +0 -0
- agentrun_mem0/llms/anthropic.py +87 -0
- agentrun_mem0/llms/aws_bedrock.py +665 -0
- agentrun_mem0/llms/azure_openai.py +141 -0
- agentrun_mem0/llms/azure_openai_structured.py +91 -0
- agentrun_mem0/llms/base.py +131 -0
- agentrun_mem0/llms/configs.py +34 -0
- agentrun_mem0/llms/deepseek.py +107 -0
- agentrun_mem0/llms/gemini.py +201 -0
- agentrun_mem0/llms/groq.py +88 -0
- agentrun_mem0/llms/langchain.py +94 -0
- agentrun_mem0/llms/litellm.py +87 -0
- agentrun_mem0/llms/lmstudio.py +114 -0
- agentrun_mem0/llms/ollama.py +117 -0
- agentrun_mem0/llms/openai.py +147 -0
- agentrun_mem0/llms/openai_structured.py +52 -0
- agentrun_mem0/llms/sarvam.py +89 -0
- agentrun_mem0/llms/together.py +88 -0
- agentrun_mem0/llms/vllm.py +107 -0
- agentrun_mem0/llms/xai.py +52 -0
- agentrun_mem0/memory/__init__.py +0 -0
- agentrun_mem0/memory/base.py +63 -0
- agentrun_mem0/memory/graph_memory.py +698 -0
- agentrun_mem0/memory/kuzu_memory.py +713 -0
- agentrun_mem0/memory/main.py +2229 -0
- agentrun_mem0/memory/memgraph_memory.py +689 -0
- agentrun_mem0/memory/setup.py +56 -0
- agentrun_mem0/memory/storage.py +218 -0
- agentrun_mem0/memory/telemetry.py +90 -0
- agentrun_mem0/memory/utils.py +208 -0
- agentrun_mem0/proxy/__init__.py +0 -0
- agentrun_mem0/proxy/main.py +189 -0
- agentrun_mem0/reranker/__init__.py +9 -0
- agentrun_mem0/reranker/base.py +20 -0
- agentrun_mem0/reranker/cohere_reranker.py +85 -0
- agentrun_mem0/reranker/huggingface_reranker.py +147 -0
- agentrun_mem0/reranker/llm_reranker.py +142 -0
- agentrun_mem0/reranker/sentence_transformer_reranker.py +107 -0
- agentrun_mem0/reranker/zero_entropy_reranker.py +96 -0
- agentrun_mem0/utils/factory.py +283 -0
- agentrun_mem0/utils/gcp_auth.py +167 -0
- agentrun_mem0/vector_stores/__init__.py +0 -0
- agentrun_mem0/vector_stores/alibabacloud_mysql.py +547 -0
- agentrun_mem0/vector_stores/aliyun_tablestore.py +252 -0
- agentrun_mem0/vector_stores/azure_ai_search.py +396 -0
- agentrun_mem0/vector_stores/azure_mysql.py +463 -0
- agentrun_mem0/vector_stores/baidu.py +368 -0
- agentrun_mem0/vector_stores/base.py +58 -0
- agentrun_mem0/vector_stores/chroma.py +332 -0
- agentrun_mem0/vector_stores/configs.py +67 -0
- agentrun_mem0/vector_stores/databricks.py +761 -0
- agentrun_mem0/vector_stores/elasticsearch.py +237 -0
- agentrun_mem0/vector_stores/faiss.py +479 -0
- agentrun_mem0/vector_stores/langchain.py +180 -0
- agentrun_mem0/vector_stores/milvus.py +250 -0
- agentrun_mem0/vector_stores/mongodb.py +310 -0
- agentrun_mem0/vector_stores/neptune_analytics.py +467 -0
- agentrun_mem0/vector_stores/opensearch.py +292 -0
- agentrun_mem0/vector_stores/pgvector.py +404 -0
- agentrun_mem0/vector_stores/pinecone.py +382 -0
- agentrun_mem0/vector_stores/qdrant.py +270 -0
- agentrun_mem0/vector_stores/redis.py +295 -0
- agentrun_mem0/vector_stores/s3_vectors.py +176 -0
- agentrun_mem0/vector_stores/supabase.py +237 -0
- agentrun_mem0/vector_stores/upstash_vector.py +293 -0
- agentrun_mem0/vector_stores/valkey.py +824 -0
- agentrun_mem0/vector_stores/vertex_ai_vector_search.py +635 -0
- agentrun_mem0/vector_stores/weaviate.py +343 -0
- agentrun_mem0ai-0.0.11.data/data/README.md +205 -0
- agentrun_mem0ai-0.0.11.dist-info/METADATA +277 -0
- agentrun_mem0ai-0.0.11.dist-info/RECORD +150 -0
- agentrun_mem0ai-0.0.11.dist-info/WHEEL +4 -0
- agentrun_mem0ai-0.0.11.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,463 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
from contextlib import contextmanager
|
|
4
|
+
from typing import Any, Dict, List, Optional
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel
|
|
7
|
+
|
|
8
|
+
try:
|
|
9
|
+
import pymysql
|
|
10
|
+
from pymysql.cursors import DictCursor
|
|
11
|
+
from dbutils.pooled_db import PooledDB
|
|
12
|
+
except ImportError:
|
|
13
|
+
raise ImportError(
|
|
14
|
+
"Azure MySQL vector store requires PyMySQL and DBUtils. "
|
|
15
|
+
"Please install them using 'pip install pymysql dbutils'"
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
try:
|
|
19
|
+
from azure.identity import DefaultAzureCredential
|
|
20
|
+
AZURE_IDENTITY_AVAILABLE = True
|
|
21
|
+
except ImportError:
|
|
22
|
+
AZURE_IDENTITY_AVAILABLE = False
|
|
23
|
+
|
|
24
|
+
from agentrun_mem0.vector_stores.base import VectorStoreBase
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class OutputData(BaseModel):
|
|
30
|
+
id: Optional[str]
|
|
31
|
+
score: Optional[float]
|
|
32
|
+
payload: Optional[dict]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class AzureMySQL(VectorStoreBase):
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
host: str,
|
|
39
|
+
port: int,
|
|
40
|
+
user: str,
|
|
41
|
+
password: Optional[str],
|
|
42
|
+
database: str,
|
|
43
|
+
collection_name: str,
|
|
44
|
+
embedding_model_dims: int,
|
|
45
|
+
use_azure_credential: bool = False,
|
|
46
|
+
ssl_ca: Optional[str] = None,
|
|
47
|
+
ssl_disabled: bool = False,
|
|
48
|
+
minconn: int = 1,
|
|
49
|
+
maxconn: int = 5,
|
|
50
|
+
connection_pool: Optional[Any] = None,
|
|
51
|
+
):
|
|
52
|
+
"""
|
|
53
|
+
Initialize the Azure MySQL vector store.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
host (str): MySQL server host
|
|
57
|
+
port (int): MySQL server port
|
|
58
|
+
user (str): Database user
|
|
59
|
+
password (str, optional): Database password (not required if using Azure credential)
|
|
60
|
+
database (str): Database name
|
|
61
|
+
collection_name (str): Collection/table name
|
|
62
|
+
embedding_model_dims (int): Dimension of the embedding vector
|
|
63
|
+
use_azure_credential (bool): Use Azure DefaultAzureCredential for authentication
|
|
64
|
+
ssl_ca (str, optional): Path to SSL CA certificate
|
|
65
|
+
ssl_disabled (bool): Disable SSL connection
|
|
66
|
+
minconn (int): Minimum number of connections in the pool
|
|
67
|
+
maxconn (int): Maximum number of connections in the pool
|
|
68
|
+
connection_pool (Any, optional): Pre-configured connection pool
|
|
69
|
+
"""
|
|
70
|
+
self.host = host
|
|
71
|
+
self.port = port
|
|
72
|
+
self.user = user
|
|
73
|
+
self.password = password
|
|
74
|
+
self.database = database
|
|
75
|
+
self.collection_name = collection_name
|
|
76
|
+
self.embedding_model_dims = embedding_model_dims
|
|
77
|
+
self.use_azure_credential = use_azure_credential
|
|
78
|
+
self.ssl_ca = ssl_ca
|
|
79
|
+
self.ssl_disabled = ssl_disabled
|
|
80
|
+
self.connection_pool = connection_pool
|
|
81
|
+
|
|
82
|
+
# Handle Azure authentication
|
|
83
|
+
if use_azure_credential:
|
|
84
|
+
if not AZURE_IDENTITY_AVAILABLE:
|
|
85
|
+
raise ImportError(
|
|
86
|
+
"Azure Identity is required for Azure credential authentication. "
|
|
87
|
+
"Please install it using 'pip install azure-identity'"
|
|
88
|
+
)
|
|
89
|
+
self._setup_azure_auth()
|
|
90
|
+
|
|
91
|
+
# Setup connection pool
|
|
92
|
+
if self.connection_pool is None:
|
|
93
|
+
self._setup_connection_pool(minconn, maxconn)
|
|
94
|
+
|
|
95
|
+
# Create collection if it doesn't exist
|
|
96
|
+
collections = self.list_cols()
|
|
97
|
+
if collection_name not in collections:
|
|
98
|
+
self.create_col(name=collection_name, vector_size=embedding_model_dims, distance="cosine")
|
|
99
|
+
|
|
100
|
+
def _setup_azure_auth(self):
|
|
101
|
+
"""Setup Azure authentication using DefaultAzureCredential."""
|
|
102
|
+
try:
|
|
103
|
+
credential = DefaultAzureCredential()
|
|
104
|
+
# Get access token for Azure Database for MySQL
|
|
105
|
+
token = credential.get_token("https://ossrdbms-aad.database.windows.net/.default")
|
|
106
|
+
# Use token as password
|
|
107
|
+
self.password = token.token
|
|
108
|
+
logger.info("Successfully authenticated using Azure DefaultAzureCredential")
|
|
109
|
+
except Exception as e:
|
|
110
|
+
logger.error(f"Failed to authenticate with Azure: {e}")
|
|
111
|
+
raise
|
|
112
|
+
|
|
113
|
+
def _setup_connection_pool(self, minconn: int, maxconn: int):
|
|
114
|
+
"""Setup MySQL connection pool."""
|
|
115
|
+
connect_kwargs = {
|
|
116
|
+
"host": self.host,
|
|
117
|
+
"port": self.port,
|
|
118
|
+
"user": self.user,
|
|
119
|
+
"password": self.password,
|
|
120
|
+
"database": self.database,
|
|
121
|
+
"charset": "utf8mb4",
|
|
122
|
+
"cursorclass": DictCursor,
|
|
123
|
+
"autocommit": False,
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
# SSL configuration
|
|
127
|
+
if not self.ssl_disabled:
|
|
128
|
+
ssl_config = {"ssl_verify_cert": True}
|
|
129
|
+
if self.ssl_ca:
|
|
130
|
+
ssl_config["ssl_ca"] = self.ssl_ca
|
|
131
|
+
connect_kwargs["ssl"] = ssl_config
|
|
132
|
+
|
|
133
|
+
try:
|
|
134
|
+
self.connection_pool = PooledDB(
|
|
135
|
+
creator=pymysql,
|
|
136
|
+
mincached=minconn,
|
|
137
|
+
maxcached=maxconn,
|
|
138
|
+
maxconnections=maxconn,
|
|
139
|
+
blocking=True,
|
|
140
|
+
**connect_kwargs
|
|
141
|
+
)
|
|
142
|
+
logger.info("Successfully created MySQL connection pool")
|
|
143
|
+
except Exception as e:
|
|
144
|
+
logger.error(f"Failed to create connection pool: {e}")
|
|
145
|
+
raise
|
|
146
|
+
|
|
147
|
+
@contextmanager
|
|
148
|
+
def _get_cursor(self, commit: bool = False):
|
|
149
|
+
"""
|
|
150
|
+
Context manager to get a cursor from the connection pool.
|
|
151
|
+
Auto-commits or rolls back based on exception.
|
|
152
|
+
"""
|
|
153
|
+
conn = self.connection_pool.connection()
|
|
154
|
+
cur = conn.cursor()
|
|
155
|
+
try:
|
|
156
|
+
yield cur
|
|
157
|
+
if commit:
|
|
158
|
+
conn.commit()
|
|
159
|
+
except Exception as exc:
|
|
160
|
+
conn.rollback()
|
|
161
|
+
logger.error(f"Database error: {exc}", exc_info=True)
|
|
162
|
+
raise
|
|
163
|
+
finally:
|
|
164
|
+
cur.close()
|
|
165
|
+
conn.close()
|
|
166
|
+
|
|
167
|
+
def create_col(self, name: str = None, vector_size: int = None, distance: str = "cosine"):
|
|
168
|
+
"""
|
|
169
|
+
Create a new collection (table in MySQL).
|
|
170
|
+
Enables vector extension and creates appropriate indexes.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
name (str, optional): Collection name (uses self.collection_name if not provided)
|
|
174
|
+
vector_size (int, optional): Vector dimension (uses self.embedding_model_dims if not provided)
|
|
175
|
+
distance (str): Distance metric (cosine, euclidean, dot_product)
|
|
176
|
+
"""
|
|
177
|
+
table_name = name or self.collection_name
|
|
178
|
+
dims = vector_size or self.embedding_model_dims
|
|
179
|
+
|
|
180
|
+
with self._get_cursor(commit=True) as cur:
|
|
181
|
+
# Create table with vector column
|
|
182
|
+
cur.execute(f"""
|
|
183
|
+
CREATE TABLE IF NOT EXISTS `{table_name}` (
|
|
184
|
+
id VARCHAR(255) PRIMARY KEY,
|
|
185
|
+
vector JSON,
|
|
186
|
+
payload JSON,
|
|
187
|
+
INDEX idx_payload_keys ((CAST(payload AS CHAR(255)) ARRAY))
|
|
188
|
+
)
|
|
189
|
+
""")
|
|
190
|
+
logger.info(f"Created collection '{table_name}' with vector dimension {dims}")
|
|
191
|
+
|
|
192
|
+
def insert(self, vectors: List[List[float]], payloads: Optional[List[Dict]] = None, ids: Optional[List[str]] = None):
|
|
193
|
+
"""
|
|
194
|
+
Insert vectors into the collection.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
vectors (List[List[float]]): List of vectors to insert
|
|
198
|
+
payloads (List[Dict], optional): List of payloads corresponding to vectors
|
|
199
|
+
ids (List[str], optional): List of IDs corresponding to vectors
|
|
200
|
+
"""
|
|
201
|
+
logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}")
|
|
202
|
+
|
|
203
|
+
if payloads is None:
|
|
204
|
+
payloads = [{}] * len(vectors)
|
|
205
|
+
if ids is None:
|
|
206
|
+
import uuid
|
|
207
|
+
ids = [str(uuid.uuid4()) for _ in range(len(vectors))]
|
|
208
|
+
|
|
209
|
+
data = []
|
|
210
|
+
for vector, payload, vec_id in zip(vectors, payloads, ids):
|
|
211
|
+
data.append((vec_id, json.dumps(vector), json.dumps(payload)))
|
|
212
|
+
|
|
213
|
+
with self._get_cursor(commit=True) as cur:
|
|
214
|
+
cur.executemany(
|
|
215
|
+
f"INSERT INTO `{self.collection_name}` (id, vector, payload) VALUES (%s, %s, %s) "
|
|
216
|
+
f"ON DUPLICATE KEY UPDATE vector = VALUES(vector), payload = VALUES(payload)",
|
|
217
|
+
data
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
def _cosine_distance(self, vec1_json: str, vec2: List[float]) -> str:
|
|
221
|
+
"""Generate SQL for cosine distance calculation."""
|
|
222
|
+
# For MySQL, we need to calculate cosine similarity manually
|
|
223
|
+
# This is a simplified version - in production, you'd use stored procedures or UDFs
|
|
224
|
+
return """
|
|
225
|
+
1 - (
|
|
226
|
+
(SELECT SUM(a.val * b.val) /
|
|
227
|
+
(SQRT(SUM(a.val * a.val)) * SQRT(SUM(b.val * b.val))))
|
|
228
|
+
FROM (
|
|
229
|
+
SELECT JSON_EXTRACT(vector, CONCAT('$[', idx, ']')) as val
|
|
230
|
+
FROM (SELECT @row := @row + 1 as idx FROM (SELECT 0 UNION ALL SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3) t1, (SELECT 0 UNION ALL SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3) t2) indices
|
|
231
|
+
WHERE idx < JSON_LENGTH(vector)
|
|
232
|
+
) a,
|
|
233
|
+
(
|
|
234
|
+
SELECT JSON_EXTRACT(%s, CONCAT('$[', idx, ']')) as val
|
|
235
|
+
FROM (SELECT @row := @row + 1 as idx FROM (SELECT 0 UNION ALL SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3) t1, (SELECT 0 UNION ALL SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3) t2) indices
|
|
236
|
+
WHERE idx < JSON_LENGTH(%s)
|
|
237
|
+
) b
|
|
238
|
+
WHERE a.idx = b.idx
|
|
239
|
+
)
|
|
240
|
+
"""
|
|
241
|
+
|
|
242
|
+
def search(
|
|
243
|
+
self,
|
|
244
|
+
query: str,
|
|
245
|
+
vectors: List[float],
|
|
246
|
+
limit: int = 5,
|
|
247
|
+
filters: Optional[Dict] = None,
|
|
248
|
+
) -> List[OutputData]:
|
|
249
|
+
"""
|
|
250
|
+
Search for similar vectors using cosine similarity.
|
|
251
|
+
|
|
252
|
+
Args:
|
|
253
|
+
query (str): Query string (not used in vector search)
|
|
254
|
+
vectors (List[float]): Query vector
|
|
255
|
+
limit (int): Number of results to return
|
|
256
|
+
filters (Dict, optional): Filters to apply to the search
|
|
257
|
+
|
|
258
|
+
Returns:
|
|
259
|
+
List[OutputData]: Search results
|
|
260
|
+
"""
|
|
261
|
+
filter_conditions = []
|
|
262
|
+
filter_params = []
|
|
263
|
+
|
|
264
|
+
if filters:
|
|
265
|
+
for k, v in filters.items():
|
|
266
|
+
filter_conditions.append("JSON_EXTRACT(payload, %s) = %s")
|
|
267
|
+
filter_params.extend([f"$.{k}", json.dumps(v)])
|
|
268
|
+
|
|
269
|
+
filter_clause = "WHERE " + " AND ".join(filter_conditions) if filter_conditions else ""
|
|
270
|
+
|
|
271
|
+
# For simplicity, we'll compute cosine similarity in Python
|
|
272
|
+
# In production, you'd want to use MySQL stored procedures or UDFs
|
|
273
|
+
with self._get_cursor() as cur:
|
|
274
|
+
query_sql = f"""
|
|
275
|
+
SELECT id, vector, payload
|
|
276
|
+
FROM `{self.collection_name}`
|
|
277
|
+
{filter_clause}
|
|
278
|
+
"""
|
|
279
|
+
cur.execute(query_sql, filter_params)
|
|
280
|
+
results = cur.fetchall()
|
|
281
|
+
|
|
282
|
+
# Calculate cosine similarity in Python
|
|
283
|
+
import numpy as np
|
|
284
|
+
query_vec = np.array(vectors)
|
|
285
|
+
scored_results = []
|
|
286
|
+
|
|
287
|
+
for row in results:
|
|
288
|
+
vec = np.array(json.loads(row['vector']))
|
|
289
|
+
# Cosine similarity
|
|
290
|
+
similarity = np.dot(query_vec, vec) / (np.linalg.norm(query_vec) * np.linalg.norm(vec))
|
|
291
|
+
distance = 1 - similarity
|
|
292
|
+
scored_results.append((row['id'], distance, row['payload']))
|
|
293
|
+
|
|
294
|
+
# Sort by distance and limit
|
|
295
|
+
scored_results.sort(key=lambda x: x[1])
|
|
296
|
+
scored_results = scored_results[:limit]
|
|
297
|
+
|
|
298
|
+
return [
|
|
299
|
+
OutputData(id=r[0], score=float(r[1]), payload=json.loads(r[2]) if isinstance(r[2], str) else r[2])
|
|
300
|
+
for r in scored_results
|
|
301
|
+
]
|
|
302
|
+
|
|
303
|
+
def delete(self, vector_id: str):
|
|
304
|
+
"""
|
|
305
|
+
Delete a vector by ID.
|
|
306
|
+
|
|
307
|
+
Args:
|
|
308
|
+
vector_id (str): ID of the vector to delete
|
|
309
|
+
"""
|
|
310
|
+
with self._get_cursor(commit=True) as cur:
|
|
311
|
+
cur.execute(f"DELETE FROM `{self.collection_name}` WHERE id = %s", (vector_id,))
|
|
312
|
+
|
|
313
|
+
def update(
|
|
314
|
+
self,
|
|
315
|
+
vector_id: str,
|
|
316
|
+
vector: Optional[List[float]] = None,
|
|
317
|
+
payload: Optional[Dict] = None,
|
|
318
|
+
):
|
|
319
|
+
"""
|
|
320
|
+
Update a vector and its payload.
|
|
321
|
+
|
|
322
|
+
Args:
|
|
323
|
+
vector_id (str): ID of the vector to update
|
|
324
|
+
vector (List[float], optional): Updated vector
|
|
325
|
+
payload (Dict, optional): Updated payload
|
|
326
|
+
"""
|
|
327
|
+
with self._get_cursor(commit=True) as cur:
|
|
328
|
+
if vector is not None:
|
|
329
|
+
cur.execute(
|
|
330
|
+
f"UPDATE `{self.collection_name}` SET vector = %s WHERE id = %s",
|
|
331
|
+
(json.dumps(vector), vector_id),
|
|
332
|
+
)
|
|
333
|
+
if payload is not None:
|
|
334
|
+
cur.execute(
|
|
335
|
+
f"UPDATE `{self.collection_name}` SET payload = %s WHERE id = %s",
|
|
336
|
+
(json.dumps(payload), vector_id),
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
def get(self, vector_id: str) -> Optional[OutputData]:
|
|
340
|
+
"""
|
|
341
|
+
Retrieve a vector by ID.
|
|
342
|
+
|
|
343
|
+
Args:
|
|
344
|
+
vector_id (str): ID of the vector to retrieve
|
|
345
|
+
|
|
346
|
+
Returns:
|
|
347
|
+
OutputData: Retrieved vector or None if not found
|
|
348
|
+
"""
|
|
349
|
+
with self._get_cursor() as cur:
|
|
350
|
+
cur.execute(
|
|
351
|
+
f"SELECT id, vector, payload FROM `{self.collection_name}` WHERE id = %s",
|
|
352
|
+
(vector_id,),
|
|
353
|
+
)
|
|
354
|
+
result = cur.fetchone()
|
|
355
|
+
if not result:
|
|
356
|
+
return None
|
|
357
|
+
return OutputData(
|
|
358
|
+
id=result['id'],
|
|
359
|
+
score=None,
|
|
360
|
+
payload=json.loads(result['payload']) if isinstance(result['payload'], str) else result['payload']
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
def list_cols(self) -> List[str]:
|
|
364
|
+
"""
|
|
365
|
+
List all collections (tables).
|
|
366
|
+
|
|
367
|
+
Returns:
|
|
368
|
+
List[str]: List of collection names
|
|
369
|
+
"""
|
|
370
|
+
with self._get_cursor() as cur:
|
|
371
|
+
cur.execute("SHOW TABLES")
|
|
372
|
+
return [row[f"Tables_in_{self.database}"] for row in cur.fetchall()]
|
|
373
|
+
|
|
374
|
+
def delete_col(self):
|
|
375
|
+
"""Delete the collection (table)."""
|
|
376
|
+
with self._get_cursor(commit=True) as cur:
|
|
377
|
+
cur.execute(f"DROP TABLE IF EXISTS `{self.collection_name}`")
|
|
378
|
+
logger.info(f"Deleted collection '{self.collection_name}'")
|
|
379
|
+
|
|
380
|
+
def col_info(self) -> Dict[str, Any]:
|
|
381
|
+
"""
|
|
382
|
+
Get information about the collection.
|
|
383
|
+
|
|
384
|
+
Returns:
|
|
385
|
+
Dict[str, Any]: Collection information
|
|
386
|
+
"""
|
|
387
|
+
with self._get_cursor() as cur:
|
|
388
|
+
cur.execute("""
|
|
389
|
+
SELECT
|
|
390
|
+
TABLE_NAME as name,
|
|
391
|
+
TABLE_ROWS as count,
|
|
392
|
+
ROUND(((DATA_LENGTH + INDEX_LENGTH) / 1024 / 1024), 2) as size_mb
|
|
393
|
+
FROM information_schema.TABLES
|
|
394
|
+
WHERE TABLE_SCHEMA = %s AND TABLE_NAME = %s
|
|
395
|
+
""", (self.database, self.collection_name))
|
|
396
|
+
result = cur.fetchone()
|
|
397
|
+
|
|
398
|
+
if result:
|
|
399
|
+
return {
|
|
400
|
+
"name": result['name'],
|
|
401
|
+
"count": result['count'],
|
|
402
|
+
"size": f"{result['size_mb']} MB"
|
|
403
|
+
}
|
|
404
|
+
return {}
|
|
405
|
+
|
|
406
|
+
def list(
|
|
407
|
+
self,
|
|
408
|
+
filters: Optional[Dict] = None,
|
|
409
|
+
limit: int = 100
|
|
410
|
+
) -> List[List[OutputData]]:
|
|
411
|
+
"""
|
|
412
|
+
List all vectors in the collection.
|
|
413
|
+
|
|
414
|
+
Args:
|
|
415
|
+
filters (Dict, optional): Filters to apply
|
|
416
|
+
limit (int): Number of vectors to return
|
|
417
|
+
|
|
418
|
+
Returns:
|
|
419
|
+
List[List[OutputData]]: List of vectors
|
|
420
|
+
"""
|
|
421
|
+
filter_conditions = []
|
|
422
|
+
filter_params = []
|
|
423
|
+
|
|
424
|
+
if filters:
|
|
425
|
+
for k, v in filters.items():
|
|
426
|
+
filter_conditions.append("JSON_EXTRACT(payload, %s) = %s")
|
|
427
|
+
filter_params.extend([f"$.{k}", json.dumps(v)])
|
|
428
|
+
|
|
429
|
+
filter_clause = "WHERE " + " AND ".join(filter_conditions) if filter_conditions else ""
|
|
430
|
+
|
|
431
|
+
with self._get_cursor() as cur:
|
|
432
|
+
cur.execute(
|
|
433
|
+
f"""
|
|
434
|
+
SELECT id, vector, payload
|
|
435
|
+
FROM `{self.collection_name}`
|
|
436
|
+
{filter_clause}
|
|
437
|
+
LIMIT %s
|
|
438
|
+
""",
|
|
439
|
+
(*filter_params, limit)
|
|
440
|
+
)
|
|
441
|
+
results = cur.fetchall()
|
|
442
|
+
|
|
443
|
+
return [[
|
|
444
|
+
OutputData(
|
|
445
|
+
id=r['id'],
|
|
446
|
+
score=None,
|
|
447
|
+
payload=json.loads(r['payload']) if isinstance(r['payload'], str) else r['payload']
|
|
448
|
+
) for r in results
|
|
449
|
+
]]
|
|
450
|
+
|
|
451
|
+
def reset(self):
|
|
452
|
+
"""Reset the collection by deleting and recreating it."""
|
|
453
|
+
logger.warning(f"Resetting collection {self.collection_name}...")
|
|
454
|
+
self.delete_col()
|
|
455
|
+
self.create_col(name=self.collection_name, vector_size=self.embedding_model_dims)
|
|
456
|
+
|
|
457
|
+
def __del__(self):
|
|
458
|
+
"""Close the connection pool when the object is deleted."""
|
|
459
|
+
try:
|
|
460
|
+
if hasattr(self, 'connection_pool') and self.connection_pool:
|
|
461
|
+
self.connection_pool.close()
|
|
462
|
+
except Exception:
|
|
463
|
+
pass
|