matrixone-python-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- matrixone/__init__.py +155 -0
- matrixone/account.py +723 -0
- matrixone/async_client.py +3913 -0
- matrixone/async_metadata_manager.py +311 -0
- matrixone/async_orm.py +123 -0
- matrixone/async_vector_index_manager.py +633 -0
- matrixone/base_client.py +208 -0
- matrixone/client.py +4672 -0
- matrixone/config.py +452 -0
- matrixone/connection_hooks.py +286 -0
- matrixone/exceptions.py +89 -0
- matrixone/logger.py +782 -0
- matrixone/metadata.py +820 -0
- matrixone/moctl.py +219 -0
- matrixone/orm.py +2277 -0
- matrixone/pitr.py +646 -0
- matrixone/pubsub.py +771 -0
- matrixone/restore.py +411 -0
- matrixone/search_vector_index.py +1176 -0
- matrixone/snapshot.py +550 -0
- matrixone/sql_builder.py +844 -0
- matrixone/sqlalchemy_ext/__init__.py +161 -0
- matrixone/sqlalchemy_ext/adapters.py +163 -0
- matrixone/sqlalchemy_ext/dialect.py +534 -0
- matrixone/sqlalchemy_ext/fulltext_index.py +895 -0
- matrixone/sqlalchemy_ext/fulltext_search.py +1686 -0
- matrixone/sqlalchemy_ext/hnsw_config.py +194 -0
- matrixone/sqlalchemy_ext/ivf_config.py +252 -0
- matrixone/sqlalchemy_ext/table_builder.py +351 -0
- matrixone/sqlalchemy_ext/vector_index.py +1721 -0
- matrixone/sqlalchemy_ext/vector_type.py +948 -0
- matrixone/version.py +580 -0
- matrixone_python_sdk-0.1.0.dist-info/METADATA +706 -0
- matrixone_python_sdk-0.1.0.dist-info/RECORD +122 -0
- matrixone_python_sdk-0.1.0.dist-info/WHEEL +5 -0
- matrixone_python_sdk-0.1.0.dist-info/entry_points.txt +5 -0
- matrixone_python_sdk-0.1.0.dist-info/licenses/LICENSE +200 -0
- matrixone_python_sdk-0.1.0.dist-info/top_level.txt +2 -0
- tests/__init__.py +19 -0
- tests/offline/__init__.py +20 -0
- tests/offline/conftest.py +77 -0
- tests/offline/test_account.py +703 -0
- tests/offline/test_async_client_query_comprehensive.py +1218 -0
- tests/offline/test_basic.py +54 -0
- tests/offline/test_case_sensitivity.py +227 -0
- tests/offline/test_connection_hooks_offline.py +287 -0
- tests/offline/test_dialect_schema_handling.py +609 -0
- tests/offline/test_explain_methods.py +346 -0
- tests/offline/test_filter_logical_in.py +237 -0
- tests/offline/test_fulltext_search_comprehensive.py +795 -0
- tests/offline/test_ivf_config.py +249 -0
- tests/offline/test_join_methods.py +281 -0
- tests/offline/test_join_sqlalchemy_compatibility.py +276 -0
- tests/offline/test_logical_in_method.py +237 -0
- tests/offline/test_matrixone_version_parsing.py +264 -0
- tests/offline/test_metadata_offline.py +557 -0
- tests/offline/test_moctl.py +300 -0
- tests/offline/test_moctl_simple.py +251 -0
- tests/offline/test_model_support_offline.py +359 -0
- tests/offline/test_model_support_simple.py +225 -0
- tests/offline/test_pinecone_filter_offline.py +377 -0
- tests/offline/test_pitr.py +585 -0
- tests/offline/test_pubsub.py +712 -0
- tests/offline/test_query_update.py +283 -0
- tests/offline/test_restore.py +445 -0
- tests/offline/test_snapshot_comprehensive.py +384 -0
- tests/offline/test_sql_escaping_edge_cases.py +551 -0
- tests/offline/test_sqlalchemy_integration.py +382 -0
- tests/offline/test_sqlalchemy_vector_integration.py +434 -0
- tests/offline/test_table_builder.py +198 -0
- tests/offline/test_unified_filter.py +398 -0
- tests/offline/test_unified_transaction.py +495 -0
- tests/offline/test_vector_index.py +238 -0
- tests/offline/test_vector_operations.py +688 -0
- tests/offline/test_vector_type.py +174 -0
- tests/offline/test_version_core.py +328 -0
- tests/offline/test_version_management.py +372 -0
- tests/offline/test_version_standalone.py +652 -0
- tests/online/__init__.py +20 -0
- tests/online/conftest.py +216 -0
- tests/online/test_account_management.py +194 -0
- tests/online/test_advanced_features.py +344 -0
- tests/online/test_async_client_interfaces.py +330 -0
- tests/online/test_async_client_online.py +285 -0
- tests/online/test_async_model_insert_online.py +293 -0
- tests/online/test_async_orm_online.py +300 -0
- tests/online/test_async_simple_query_online.py +802 -0
- tests/online/test_async_transaction_simple_query.py +300 -0
- tests/online/test_basic_connection.py +130 -0
- tests/online/test_client_online.py +238 -0
- tests/online/test_config.py +90 -0
- tests/online/test_config_validation.py +123 -0
- tests/online/test_connection_hooks_new_online.py +217 -0
- tests/online/test_dialect_schema_handling_online.py +331 -0
- tests/online/test_filter_logical_in_online.py +374 -0
- tests/online/test_fulltext_comprehensive.py +1773 -0
- tests/online/test_fulltext_label_online.py +433 -0
- tests/online/test_fulltext_search_online.py +842 -0
- tests/online/test_ivf_stats_online.py +506 -0
- tests/online/test_logger_integration.py +311 -0
- tests/online/test_matrixone_query_orm.py +540 -0
- tests/online/test_metadata_online.py +579 -0
- tests/online/test_model_insert_online.py +255 -0
- tests/online/test_mysql_driver_validation.py +213 -0
- tests/online/test_orm_advanced_features.py +2022 -0
- tests/online/test_orm_cte_integration.py +269 -0
- tests/online/test_orm_online.py +270 -0
- tests/online/test_pinecone_filter.py +708 -0
- tests/online/test_pubsub_operations.py +352 -0
- tests/online/test_query_methods.py +225 -0
- tests/online/test_query_update_online.py +433 -0
- tests/online/test_search_vector_index.py +557 -0
- tests/online/test_simple_fulltext_online.py +915 -0
- tests/online/test_snapshot_comprehensive.py +998 -0
- tests/online/test_sqlalchemy_engine_integration.py +336 -0
- tests/online/test_sqlalchemy_integration.py +425 -0
- tests/online/test_transaction_contexts.py +1219 -0
- tests/online/test_transaction_insert_methods.py +356 -0
- tests/online/test_transaction_query_methods.py +288 -0
- tests/online/test_unified_filter_online.py +529 -0
- tests/online/test_vector_comprehensive.py +706 -0
- tests/online/test_version_management.py +291 -0
@@ -0,0 +1,633 @@
|
|
1
|
+
# Copyright 2021 - 2022 Matrix Origin
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""
|
16
|
+
Async vector index manager for MatrixOne async client.
|
17
|
+
"""
|
18
|
+
|
19
|
+
from __future__ import annotations
|
20
|
+
|
21
|
+
from typing import TYPE_CHECKING, Any, Dict, List
|
22
|
+
|
23
|
+
from sqlalchemy import text
|
24
|
+
|
25
|
+
if TYPE_CHECKING:
|
26
|
+
from .sqlalchemy_ext import VectorOpType
|
27
|
+
|
28
|
+
|
29
|
+
async def _exec_sql_safe_async(connection, sql: str):
|
30
|
+
"""Execute SQL safely for async connections, bypassing bind parameter parsing when possible."""
|
31
|
+
if hasattr(connection, 'exec_driver_sql'):
|
32
|
+
# Escape % to %% for pymysql's format string handling
|
33
|
+
escaped_sql = sql.replace('%', '%%')
|
34
|
+
return await connection.exec_driver_sql(escaped_sql)
|
35
|
+
else:
|
36
|
+
return await connection.execute(text(sql))
|
37
|
+
|
38
|
+
|
39
|
+
class AsyncVectorManager:
|
40
|
+
"""
|
41
|
+
Unified async vector manager for MatrixOne vector operations and chain operations.
|
42
|
+
|
43
|
+
This class provides comprehensive asynchronous vector functionality including vector table
|
44
|
+
creation, vector indexing, vector data operations, and vector similarity search.
|
45
|
+
It supports both IVF (Inverted File) and HNSW (Hierarchical Navigable Small World)
|
46
|
+
indexing algorithms for efficient vector similarity search.
|
47
|
+
|
48
|
+
Key Features:
|
49
|
+
|
50
|
+
- Async vector table creation with configurable dimensions and precision
|
51
|
+
- Async vector index creation and management (IVF, HNSW)
|
52
|
+
- Async vector data insertion and batch operations
|
53
|
+
- Async vector similarity search with multiple distance metrics
|
54
|
+
- Async vector range search for distance-based filtering
|
55
|
+
- Integration with MatrixOne's vector capabilities
|
56
|
+
- Support for both f32 and f64 vector precision
|
57
|
+
|
58
|
+
Supported Index Types:
|
59
|
+
- IVF (Inverted File): Good for large datasets, requires training
|
60
|
+
- HNSW: Good for high-dimensional vectors, no training required
|
61
|
+
|
62
|
+
Supported Distance Metrics:
|
63
|
+
- L2 (Euclidean) distance: Standard Euclidean distance
|
64
|
+
- Cosine similarity: Cosine of the angle between vectors
|
65
|
+
- Inner product: Dot product of vectors
|
66
|
+
|
67
|
+
Usage Examples:
|
68
|
+
|
69
|
+
# Initialize async vector manager
|
70
|
+
vector_ops = client.vector_ops
|
71
|
+
|
72
|
+
# Create vector table
|
73
|
+
await vector_ops.create_table("documents", {
|
74
|
+
"id": "int primary key",
|
75
|
+
"content": "text",
|
76
|
+
"embedding": "vecf32(384)"
|
77
|
+
})
|
78
|
+
|
79
|
+
# Create vector index
|
80
|
+
await vector_ops.create_ivf("documents", "idx_embedding", "embedding", lists=100)
|
81
|
+
|
82
|
+
# Vector similarity search
|
83
|
+
results = await vector_ops.similarity_search(
|
84
|
+
table_name="documents",
|
85
|
+
vector_column="embedding",
|
86
|
+
query_vector=[0.1, 0.2, 0.3, ...], # 384-dimensional vector
|
87
|
+
limit=10,
|
88
|
+
distance_type="l2"
|
89
|
+
)
|
90
|
+
"""
|
91
|
+
|
92
|
+
def __init__(self, client):
|
93
|
+
self.client = client
|
94
|
+
|
95
|
+
async def create_ivf(
|
96
|
+
self,
|
97
|
+
table_name: str,
|
98
|
+
name: str,
|
99
|
+
column: str,
|
100
|
+
lists: int = 100,
|
101
|
+
op_type: VectorOpType = None,
|
102
|
+
) -> "AsyncVectorManager":
|
103
|
+
"""
|
104
|
+
Create an IVFFLAT vector index using chain operations.
|
105
|
+
|
106
|
+
Args:
|
107
|
+
|
108
|
+
table_name: Name of the table
|
109
|
+
name: Name of the index
|
110
|
+
column: Vector column to index
|
111
|
+
lists: Number of lists for IVFFLAT (default: 100)
|
112
|
+
op_type: Vector operation type (VectorOpType enum, default: VectorOpType.VECTOR_L2_OPS)
|
113
|
+
|
114
|
+
Returns:
|
115
|
+
|
116
|
+
AsyncVectorManager: Self for chaining
|
117
|
+
"""
|
118
|
+
from .sqlalchemy_ext import IVFVectorIndex, VectorOpType
|
119
|
+
|
120
|
+
# Use default if not provided
|
121
|
+
if op_type is None:
|
122
|
+
op_type = VectorOpType.VECTOR_L2_OPS
|
123
|
+
|
124
|
+
try:
|
125
|
+
index = IVFVectorIndex(name, column, lists, op_type)
|
126
|
+
sql = index.create_sql(table_name)
|
127
|
+
|
128
|
+
async with self.client._engine.begin() as conn:
|
129
|
+
# Enable IVF indexing in the same connection
|
130
|
+
await _exec_sql_safe_async(conn, "SET experimental_ivf_index = 1")
|
131
|
+
await _exec_sql_safe_async(conn, "SET probe_limit = 1")
|
132
|
+
await _exec_sql_safe_async(conn, sql)
|
133
|
+
return self
|
134
|
+
except Exception as e:
|
135
|
+
raise Exception(f"Failed to create IVFFLAT vector index {name} on table {table_name}: {e}")
|
136
|
+
|
137
|
+
async def create_hnsw(
|
138
|
+
self,
|
139
|
+
table_name: str,
|
140
|
+
name: str,
|
141
|
+
column: str,
|
142
|
+
m: int = 16,
|
143
|
+
ef_construction: int = 200,
|
144
|
+
ef_search: int = 50,
|
145
|
+
op_type: VectorOpType = None,
|
146
|
+
) -> "AsyncVectorManager":
|
147
|
+
"""
|
148
|
+
Create an HNSW vector index using chain operations.
|
149
|
+
|
150
|
+
Args:
|
151
|
+
|
152
|
+
table_name: Name of the table
|
153
|
+
name: Name of the index
|
154
|
+
column: Vector column to index
|
155
|
+
m: Number of bi-directional links for HNSW (default: 16)
|
156
|
+
ef_construction: Size of dynamic candidate list for HNSW construction (default: 200)
|
157
|
+
ef_search: Size of dynamic candidate list for HNSW search (default: 50)
|
158
|
+
op_type: Vector operation type (VectorOpType enum, default: VectorOpType.VECTOR_L2_OPS)
|
159
|
+
|
160
|
+
Returns:
|
161
|
+
|
162
|
+
AsyncVectorManager: Self for chaining
|
163
|
+
"""
|
164
|
+
from .sqlalchemy_ext import HnswVectorIndex, VectorOpType
|
165
|
+
|
166
|
+
# Use default if not provided
|
167
|
+
if op_type is None:
|
168
|
+
op_type = VectorOpType.VECTOR_L2_OPS
|
169
|
+
|
170
|
+
try:
|
171
|
+
index = HnswVectorIndex(name, column, m, ef_construction, ef_search, op_type)
|
172
|
+
sql = index.create_sql(table_name)
|
173
|
+
|
174
|
+
async with self.client._engine.begin() as conn:
|
175
|
+
# Enable HNSW indexing in the same connection
|
176
|
+
await _exec_sql_safe_async(conn, "SET experimental_hnsw_index = 1")
|
177
|
+
await _exec_sql_safe_async(conn, sql)
|
178
|
+
return self
|
179
|
+
except Exception as e:
|
180
|
+
raise Exception(f"Failed to create HNSW vector index {name} on table {table_name}: {e}")
|
181
|
+
|
182
|
+
async def drop(self, table_name: str, name: str) -> "AsyncVectorManager":
|
183
|
+
"""
|
184
|
+
Drop a vector index using chain operations.
|
185
|
+
|
186
|
+
Args:
|
187
|
+
|
188
|
+
table_name: Name of the table
|
189
|
+
name: Name of the index
|
190
|
+
|
191
|
+
Returns:
|
192
|
+
|
193
|
+
AsyncVectorManager: Self for chaining
|
194
|
+
"""
|
195
|
+
try:
|
196
|
+
async with self.client._engine.begin() as conn:
|
197
|
+
await _exec_sql_safe_async(conn, f"DROP INDEX IF EXISTS {name} ON {table_name}")
|
198
|
+
return self
|
199
|
+
except Exception as e:
|
200
|
+
raise Exception(f"Failed to drop vector index {name} from table {table_name}: {e}")
|
201
|
+
|
202
|
+
async def enable_ivf(self, probe_limit: int = 1) -> "AsyncVectorManager":
|
203
|
+
"""
|
204
|
+
Enable IVF indexing with probe limit.
|
205
|
+
|
206
|
+
Args:
|
207
|
+
|
208
|
+
probe_limit: Probe limit for IVF indexing
|
209
|
+
|
210
|
+
Returns:
|
211
|
+
|
212
|
+
AsyncVectorManager: Self for chaining
|
213
|
+
"""
|
214
|
+
try:
|
215
|
+
await self.client.execute("SET experimental_ivf_index = 1")
|
216
|
+
await self.client.execute(f"SET probe_limit = {probe_limit}")
|
217
|
+
return self
|
218
|
+
except Exception as e:
|
219
|
+
raise Exception(f"Failed to enable IVF indexing: {e}")
|
220
|
+
|
221
|
+
async def disable_ivf(self) -> "AsyncVectorManager":
|
222
|
+
"""
|
223
|
+
Disable IVF indexing.
|
224
|
+
|
225
|
+
Returns:
|
226
|
+
|
227
|
+
AsyncVectorManager: Self for chaining
|
228
|
+
"""
|
229
|
+
try:
|
230
|
+
await self.client.execute("SET experimental_ivf_index = 0")
|
231
|
+
return self
|
232
|
+
except Exception as e:
|
233
|
+
raise Exception(f"Failed to disable IVF indexing: {e}")
|
234
|
+
|
235
|
+
async def enable_hnsw(self) -> "AsyncVectorManager":
|
236
|
+
"""
|
237
|
+
Enable HNSW indexing.
|
238
|
+
|
239
|
+
Returns:
|
240
|
+
|
241
|
+
AsyncVectorManager: Self for chaining
|
242
|
+
"""
|
243
|
+
try:
|
244
|
+
await self.client.execute("SET experimental_hnsw_index = 1")
|
245
|
+
return self
|
246
|
+
except Exception as e:
|
247
|
+
raise Exception(f"Failed to enable HNSW indexing: {e}")
|
248
|
+
|
249
|
+
async def disable_hnsw(self) -> "AsyncVectorManager":
|
250
|
+
"""
|
251
|
+
Disable HNSW indexing.
|
252
|
+
|
253
|
+
Returns:
|
254
|
+
|
255
|
+
AsyncVectorManager: Self for chaining
|
256
|
+
"""
|
257
|
+
try:
|
258
|
+
await self.client.execute("SET experimental_hnsw_index = 0")
|
259
|
+
return self
|
260
|
+
except Exception as e:
|
261
|
+
raise Exception(f"Failed to disable HNSW indexing: {e}")
|
262
|
+
|
263
|
+
# Data operations
|
264
|
+
async def insert(self, table_name_or_model, data: dict) -> "AsyncVectorManager":
|
265
|
+
"""
|
266
|
+
Insert vector data using chain operations asynchronously.
|
267
|
+
|
268
|
+
Args:
|
269
|
+
|
270
|
+
table_name_or_model: Either a table name (str) or a SQLAlchemy model class
|
271
|
+
data: Data to insert (dict with column names as keys)
|
272
|
+
|
273
|
+
Returns:
|
274
|
+
|
275
|
+
AsyncVectorManager: Self for chaining
|
276
|
+
"""
|
277
|
+
# Handle model class input
|
278
|
+
if hasattr(table_name_or_model, '__tablename__'):
|
279
|
+
# It's a model class
|
280
|
+
table_name = table_name_or_model.__tablename__
|
281
|
+
else:
|
282
|
+
# It's a table name string
|
283
|
+
table_name = table_name_or_model
|
284
|
+
|
285
|
+
await self.client.insert(table_name, data)
|
286
|
+
return self
|
287
|
+
|
288
|
+
async def insert_in_transaction(self, table_name_or_model, data: dict, connection) -> "AsyncVectorManager":
|
289
|
+
"""
|
290
|
+
Insert vector data within an existing SQLAlchemy transaction asynchronously.
|
291
|
+
|
292
|
+
Args:
|
293
|
+
|
294
|
+
table_name_or_model: Either a table name (str) or a SQLAlchemy model class
|
295
|
+
data: Data to insert (dict with column names as keys)
|
296
|
+
connection: SQLAlchemy connection object (required for transaction support)
|
297
|
+
|
298
|
+
Returns:
|
299
|
+
|
300
|
+
AsyncVectorManager: Self for chaining
|
301
|
+
|
302
|
+
Raises:
|
303
|
+
|
304
|
+
ValueError: If connection is not provided
|
305
|
+
"""
|
306
|
+
if connection is None:
|
307
|
+
raise ValueError("connection parameter is required for transaction operations")
|
308
|
+
|
309
|
+
# Handle model class input
|
310
|
+
if hasattr(table_name_or_model, '__tablename__'):
|
311
|
+
# It's a model class
|
312
|
+
table_name = table_name_or_model.__tablename__
|
313
|
+
else:
|
314
|
+
# It's a table name string
|
315
|
+
table_name = table_name_or_model
|
316
|
+
|
317
|
+
# Build INSERT statement
|
318
|
+
columns = list(data.keys())
|
319
|
+
values = list(data.values())
|
320
|
+
|
321
|
+
# Convert vectors to string format
|
322
|
+
formatted_values = []
|
323
|
+
for value in values:
|
324
|
+
if value is None:
|
325
|
+
formatted_values.append("NULL")
|
326
|
+
elif isinstance(value, list):
|
327
|
+
formatted_values.append("'" + "[" + ",".join(map(str, value)) + "]" + "'")
|
328
|
+
else:
|
329
|
+
formatted_values.append(f"'{str(value)}'")
|
330
|
+
|
331
|
+
columns_str = ", ".join(columns)
|
332
|
+
values_str = ", ".join(formatted_values)
|
333
|
+
|
334
|
+
sql = f"INSERT INTO {table_name} ({columns_str}) VALUES ({values_str})"
|
335
|
+
await self.client._execute_with_logging(connection, sql, context="Vector insert")
|
336
|
+
|
337
|
+
return self
|
338
|
+
|
339
|
+
async def batch_insert(self, table_name_or_model, data_list: list) -> "AsyncVectorManager":
|
340
|
+
"""
|
341
|
+
Batch insert vector data using chain operations asynchronously.
|
342
|
+
|
343
|
+
Args:
|
344
|
+
|
345
|
+
table_name_or_model: Either a table name (str) or a SQLAlchemy model class
|
346
|
+
data_list: List of data dictionaries to insert
|
347
|
+
|
348
|
+
Returns:
|
349
|
+
|
350
|
+
AsyncVectorManager: Self for chaining
|
351
|
+
"""
|
352
|
+
# Handle model class input
|
353
|
+
if hasattr(table_name_or_model, '__tablename__'):
|
354
|
+
# It's a model class
|
355
|
+
table_name = table_name_or_model.__tablename__
|
356
|
+
else:
|
357
|
+
# It's a table name string
|
358
|
+
table_name = table_name_or_model
|
359
|
+
|
360
|
+
await self.client.batch_insert(table_name, data_list)
|
361
|
+
return self
|
362
|
+
|
363
|
+
async def similarity_search(
|
364
|
+
self,
|
365
|
+
table_name: str,
|
366
|
+
vector_column: str,
|
367
|
+
query_vector: list,
|
368
|
+
limit: int = 10,
|
369
|
+
distance_type: str = "l2",
|
370
|
+
select_columns: list = None,
|
371
|
+
where_conditions: list = None,
|
372
|
+
where_params: list = None,
|
373
|
+
connection=None,
|
374
|
+
) -> list:
|
375
|
+
"""
|
376
|
+
Perform similarity search using chain operations.
|
377
|
+
|
378
|
+
Args:
|
379
|
+
|
380
|
+
table_name: Name of the table
|
381
|
+
vector_column: Name of the vector column
|
382
|
+
query_vector: Query vector as list
|
383
|
+
limit: Number of results to return
|
384
|
+
distance_type: Type of distance calculation (l2, cosine, inner_product)
|
385
|
+
select_columns: List of columns to select (None means all columns)
|
386
|
+
where_conditions: List of WHERE conditions
|
387
|
+
where_params: List of parameters for WHERE conditions
|
388
|
+
connection: Optional existing database connection (for transaction support)
|
389
|
+
|
390
|
+
Returns:
|
391
|
+
|
392
|
+
List of search results
|
393
|
+
"""
|
394
|
+
from .sql_builder import DistanceFunction, build_vector_similarity_query
|
395
|
+
|
396
|
+
# Convert distance type to enum
|
397
|
+
if distance_type == "l2":
|
398
|
+
distance_func = DistanceFunction.L2_SQ
|
399
|
+
elif distance_type == "cosine":
|
400
|
+
distance_func = DistanceFunction.COSINE
|
401
|
+
elif distance_type == "inner_product":
|
402
|
+
distance_func = DistanceFunction.INNER_PRODUCT
|
403
|
+
else:
|
404
|
+
raise ValueError(f"Unsupported distance type: {distance_type}")
|
405
|
+
|
406
|
+
# Build query using unified SQL builder
|
407
|
+
sql = build_vector_similarity_query(
|
408
|
+
table_name=table_name,
|
409
|
+
vector_column=vector_column,
|
410
|
+
query_vector=query_vector,
|
411
|
+
distance_func=distance_func,
|
412
|
+
limit=limit,
|
413
|
+
select_columns=select_columns,
|
414
|
+
where_conditions=where_conditions,
|
415
|
+
where_params=where_params,
|
416
|
+
)
|
417
|
+
|
418
|
+
if connection is not None:
|
419
|
+
# Use existing connection (for transaction support)
|
420
|
+
result = await self.client._execute_with_logging(connection, sql, context="Async vector similarity search")
|
421
|
+
return result.fetchall()
|
422
|
+
else:
|
423
|
+
# Create new connection
|
424
|
+
async with self.client._engine.begin() as conn:
|
425
|
+
result = await self.client._execute_with_logging(conn, sql, context="Async vector similarity search")
|
426
|
+
return result.fetchall()
|
427
|
+
|
428
|
+
async def range_search(
|
429
|
+
self,
|
430
|
+
table_name: str,
|
431
|
+
vector_column: str,
|
432
|
+
query_vector: list,
|
433
|
+
max_distance: float,
|
434
|
+
distance_type: str = "l2",
|
435
|
+
select_columns: list = None,
|
436
|
+
connection=None,
|
437
|
+
) -> list:
|
438
|
+
"""
|
439
|
+
Perform range search using chain operations.
|
440
|
+
|
441
|
+
Args:
|
442
|
+
|
443
|
+
table_name: Name of the table
|
444
|
+
vector_column: Name of the vector column
|
445
|
+
query_vector: Query vector as list
|
446
|
+
max_distance: Maximum distance threshold
|
447
|
+
distance_type: Type of distance calculation
|
448
|
+
select_columns: List of columns to select (None means all columns)
|
449
|
+
connection: Optional existing database connection (for transaction support)
|
450
|
+
|
451
|
+
Returns:
|
452
|
+
|
453
|
+
List of search results within range
|
454
|
+
"""
|
455
|
+
# Convert vector to string format
|
456
|
+
vector_str = "[" + ",".join(map(str, query_vector)) + "]"
|
457
|
+
|
458
|
+
# Build distance function based on type
|
459
|
+
if distance_type == "l2":
|
460
|
+
distance_func = "l2_distance"
|
461
|
+
elif distance_type == "cosine":
|
462
|
+
distance_func = "cosine_distance"
|
463
|
+
elif distance_type == "inner_product":
|
464
|
+
distance_func = "inner_product"
|
465
|
+
else:
|
466
|
+
raise ValueError(f"Unsupported distance type: {distance_type}")
|
467
|
+
|
468
|
+
# Build SELECT clause
|
469
|
+
if select_columns is None:
|
470
|
+
select_clause = "*"
|
471
|
+
else:
|
472
|
+
# Ensure vector_column is included for distance calculation
|
473
|
+
columns_to_select = list(select_columns)
|
474
|
+
if vector_column not in columns_to_select:
|
475
|
+
columns_to_select.append(vector_column)
|
476
|
+
select_clause = ", ".join(columns_to_select)
|
477
|
+
|
478
|
+
# Build query
|
479
|
+
sql = f"""
|
480
|
+
SELECT {select_clause}, {distance_func}({vector_column}, '{vector_str}') as distance
|
481
|
+
FROM {table_name}
|
482
|
+
WHERE {distance_func}({vector_column}, '{vector_str}') <= {max_distance}
|
483
|
+
ORDER BY distance
|
484
|
+
"""
|
485
|
+
|
486
|
+
if connection is not None:
|
487
|
+
# Use existing connection (for transaction support)
|
488
|
+
result = await self.client._execute_with_logging(connection, sql, context="Async vector range search")
|
489
|
+
return result.fetchall()
|
490
|
+
else:
|
491
|
+
# Create new connection
|
492
|
+
async with self.client._engine.begin() as conn:
|
493
|
+
result = await self.client._execute_with_logging(conn, sql, context="Async vector range search")
|
494
|
+
return result.fetchall()
|
495
|
+
|
496
|
+
async def get_ivf_stats(self, table_name_or_model, column_name: str = None) -> Dict[str, Any]:
|
497
|
+
"""
|
498
|
+
Get IVF index statistics for a table.
|
499
|
+
|
500
|
+
Args:
|
501
|
+
|
502
|
+
table_name_or_model: Either a table name (str) or a SQLAlchemy model class
|
503
|
+
column_name: Name of the vector column (optional, will be inferred if not provided)
|
504
|
+
|
505
|
+
Returns:
|
506
|
+
|
507
|
+
Dict containing IVF index statistics including:
|
508
|
+
- index_tables: Dictionary mapping table types to table names
|
509
|
+
- distribution: Dictionary containing bucket distribution data
|
510
|
+
- database: Database name
|
511
|
+
- table_name: Table name
|
512
|
+
- column_name: Vector column name
|
513
|
+
|
514
|
+
Raises:
|
515
|
+
|
516
|
+
Exception: If IVF index is not found or if there are errors retrieving stats
|
517
|
+
|
518
|
+
Examples:
|
519
|
+
|
520
|
+
# Get stats for a table with vector column
|
521
|
+
stats = await client.vector_ops.get_ivf_stats("my_table", "embedding")
|
522
|
+
print(f"Index tables: {stats['index_tables']}")
|
523
|
+
print(f"Distribution: {stats['distribution']}")
|
524
|
+
|
525
|
+
# Get stats using model class
|
526
|
+
stats = await client.vector_ops.get_ivf_stats(MyModel, "vector_col")
|
527
|
+
"""
|
528
|
+
# Handle model class input
|
529
|
+
if hasattr(table_name_or_model, '__tablename__'):
|
530
|
+
table_name = table_name_or_model.__tablename__
|
531
|
+
else:
|
532
|
+
table_name = table_name_or_model
|
533
|
+
|
534
|
+
# Get database name from connection params
|
535
|
+
database = self.client._connection_params.get('database')
|
536
|
+
if not database:
|
537
|
+
raise Exception("No database connection found. Please connect to a database first.")
|
538
|
+
|
539
|
+
# If column_name is not provided, try to infer it
|
540
|
+
if not column_name:
|
541
|
+
# Query the table schema to find vector columns
|
542
|
+
async with self.client._engine.begin() as conn:
|
543
|
+
schema_sql = (
|
544
|
+
f"SELECT column_name, data_type "
|
545
|
+
f"FROM information_schema.columns "
|
546
|
+
f"WHERE table_schema = '{database}' "
|
547
|
+
f"AND table_name = '{table_name}' "
|
548
|
+
f"AND (data_type LIKE '%VEC%' OR data_type LIKE '%vec%')"
|
549
|
+
)
|
550
|
+
result = await self.client._execute_with_logging(conn, schema_sql, context="Auto-detect vector column")
|
551
|
+
vector_columns = result.fetchall()
|
552
|
+
|
553
|
+
if not vector_columns:
|
554
|
+
raise Exception(f"No vector columns found in table {table_name}")
|
555
|
+
elif len(vector_columns) == 1:
|
556
|
+
column_name = vector_columns[0][0]
|
557
|
+
else:
|
558
|
+
# Multiple vector columns found, raise error asking user to specify
|
559
|
+
column_names = [col[0] for col in vector_columns]
|
560
|
+
raise Exception(
|
561
|
+
f"Multiple vector columns found in table {table_name}: {column_names}. "
|
562
|
+
f"Please specify the column_name parameter."
|
563
|
+
)
|
564
|
+
|
565
|
+
# Get IVF index table names
|
566
|
+
async with self.client._engine.begin() as conn:
|
567
|
+
index_tables = await self._get_ivf_index_table_names(database, table_name, column_name, conn)
|
568
|
+
|
569
|
+
if not index_tables:
|
570
|
+
raise Exception(f"No IVF index found for table {table_name}, column {column_name}")
|
571
|
+
|
572
|
+
# Get the entries table name for distribution analysis
|
573
|
+
entries_table = index_tables.get('entries')
|
574
|
+
if not entries_table:
|
575
|
+
raise Exception("No entries table found in IVF index")
|
576
|
+
|
577
|
+
# Get bucket distribution
|
578
|
+
distribution = await self._get_ivf_buckets_distribution(database, entries_table, conn)
|
579
|
+
|
580
|
+
return {
|
581
|
+
'index_tables': index_tables,
|
582
|
+
'distribution': distribution,
|
583
|
+
'database': database,
|
584
|
+
'table_name': table_name,
|
585
|
+
'column_name': column_name,
|
586
|
+
}
|
587
|
+
|
588
|
+
async def _get_ivf_index_table_names(
|
589
|
+
self,
|
590
|
+
database: str,
|
591
|
+
table_name: str,
|
592
|
+
column_name: str,
|
593
|
+
connection,
|
594
|
+
) -> Dict[str, str]:
|
595
|
+
"""
|
596
|
+
Get the table names of the IVF index tables.
|
597
|
+
"""
|
598
|
+
sql = (
|
599
|
+
f"SELECT i.algo_table_type, i.index_table_name "
|
600
|
+
f"FROM `mo_catalog`.`mo_indexes` AS i "
|
601
|
+
f"JOIN `mo_catalog`.`mo_tables` AS t ON i.table_id = t.rel_id "
|
602
|
+
f"AND i.column_name = '{column_name}' "
|
603
|
+
f"AND t.relname = '{table_name}' "
|
604
|
+
f"AND t.reldatabase = '{database}' "
|
605
|
+
f"AND i.algo='ivfflat'"
|
606
|
+
)
|
607
|
+
result = await self.client._execute_with_logging(connection, sql, context="Get IVF index table names")
|
608
|
+
return {row[0]: row[1] for row in result}
|
609
|
+
|
610
|
+
async def _get_ivf_buckets_distribution(
|
611
|
+
self,
|
612
|
+
database: str,
|
613
|
+
table_name: str,
|
614
|
+
connection,
|
615
|
+
) -> Dict[str, List[int]]:
|
616
|
+
"""
|
617
|
+
Get the buckets distribution of the IVF index tables.
|
618
|
+
"""
|
619
|
+
sql = (
|
620
|
+
f"SELECT "
|
621
|
+
f" COUNT(*) AS centroid_count, "
|
622
|
+
f" __mo_index_centroid_fk_id AS centroid_id, "
|
623
|
+
f" __mo_index_centroid_fk_version AS centroid_version "
|
624
|
+
f"FROM `{database}`.`{table_name}` "
|
625
|
+
f"GROUP BY `__mo_index_centroid_fk_id`, `__mo_index_centroid_fk_version`"
|
626
|
+
)
|
627
|
+
result = await self.client._execute_with_logging(connection, sql, context="Get IVF buckets distribution")
|
628
|
+
rows = result.fetchall()
|
629
|
+
return {
|
630
|
+
"centroid_count": [row[0] for row in rows],
|
631
|
+
"centroid_id": [row[1] for row in rows],
|
632
|
+
"centroid_version": [row[2] for row in rows],
|
633
|
+
}
|