matrixone-python-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- matrixone/__init__.py +155 -0
- matrixone/account.py +723 -0
- matrixone/async_client.py +3913 -0
- matrixone/async_metadata_manager.py +311 -0
- matrixone/async_orm.py +123 -0
- matrixone/async_vector_index_manager.py +633 -0
- matrixone/base_client.py +208 -0
- matrixone/client.py +4672 -0
- matrixone/config.py +452 -0
- matrixone/connection_hooks.py +286 -0
- matrixone/exceptions.py +89 -0
- matrixone/logger.py +782 -0
- matrixone/metadata.py +820 -0
- matrixone/moctl.py +219 -0
- matrixone/orm.py +2277 -0
- matrixone/pitr.py +646 -0
- matrixone/pubsub.py +771 -0
- matrixone/restore.py +411 -0
- matrixone/search_vector_index.py +1176 -0
- matrixone/snapshot.py +550 -0
- matrixone/sql_builder.py +844 -0
- matrixone/sqlalchemy_ext/__init__.py +161 -0
- matrixone/sqlalchemy_ext/adapters.py +163 -0
- matrixone/sqlalchemy_ext/dialect.py +534 -0
- matrixone/sqlalchemy_ext/fulltext_index.py +895 -0
- matrixone/sqlalchemy_ext/fulltext_search.py +1686 -0
- matrixone/sqlalchemy_ext/hnsw_config.py +194 -0
- matrixone/sqlalchemy_ext/ivf_config.py +252 -0
- matrixone/sqlalchemy_ext/table_builder.py +351 -0
- matrixone/sqlalchemy_ext/vector_index.py +1721 -0
- matrixone/sqlalchemy_ext/vector_type.py +948 -0
- matrixone/version.py +580 -0
- matrixone_python_sdk-0.1.0.dist-info/METADATA +706 -0
- matrixone_python_sdk-0.1.0.dist-info/RECORD +122 -0
- matrixone_python_sdk-0.1.0.dist-info/WHEEL +5 -0
- matrixone_python_sdk-0.1.0.dist-info/entry_points.txt +5 -0
- matrixone_python_sdk-0.1.0.dist-info/licenses/LICENSE +200 -0
- matrixone_python_sdk-0.1.0.dist-info/top_level.txt +2 -0
- tests/__init__.py +19 -0
- tests/offline/__init__.py +20 -0
- tests/offline/conftest.py +77 -0
- tests/offline/test_account.py +703 -0
- tests/offline/test_async_client_query_comprehensive.py +1218 -0
- tests/offline/test_basic.py +54 -0
- tests/offline/test_case_sensitivity.py +227 -0
- tests/offline/test_connection_hooks_offline.py +287 -0
- tests/offline/test_dialect_schema_handling.py +609 -0
- tests/offline/test_explain_methods.py +346 -0
- tests/offline/test_filter_logical_in.py +237 -0
- tests/offline/test_fulltext_search_comprehensive.py +795 -0
- tests/offline/test_ivf_config.py +249 -0
- tests/offline/test_join_methods.py +281 -0
- tests/offline/test_join_sqlalchemy_compatibility.py +276 -0
- tests/offline/test_logical_in_method.py +237 -0
- tests/offline/test_matrixone_version_parsing.py +264 -0
- tests/offline/test_metadata_offline.py +557 -0
- tests/offline/test_moctl.py +300 -0
- tests/offline/test_moctl_simple.py +251 -0
- tests/offline/test_model_support_offline.py +359 -0
- tests/offline/test_model_support_simple.py +225 -0
- tests/offline/test_pinecone_filter_offline.py +377 -0
- tests/offline/test_pitr.py +585 -0
- tests/offline/test_pubsub.py +712 -0
- tests/offline/test_query_update.py +283 -0
- tests/offline/test_restore.py +445 -0
- tests/offline/test_snapshot_comprehensive.py +384 -0
- tests/offline/test_sql_escaping_edge_cases.py +551 -0
- tests/offline/test_sqlalchemy_integration.py +382 -0
- tests/offline/test_sqlalchemy_vector_integration.py +434 -0
- tests/offline/test_table_builder.py +198 -0
- tests/offline/test_unified_filter.py +398 -0
- tests/offline/test_unified_transaction.py +495 -0
- tests/offline/test_vector_index.py +238 -0
- tests/offline/test_vector_operations.py +688 -0
- tests/offline/test_vector_type.py +174 -0
- tests/offline/test_version_core.py +328 -0
- tests/offline/test_version_management.py +372 -0
- tests/offline/test_version_standalone.py +652 -0
- tests/online/__init__.py +20 -0
- tests/online/conftest.py +216 -0
- tests/online/test_account_management.py +194 -0
- tests/online/test_advanced_features.py +344 -0
- tests/online/test_async_client_interfaces.py +330 -0
- tests/online/test_async_client_online.py +285 -0
- tests/online/test_async_model_insert_online.py +293 -0
- tests/online/test_async_orm_online.py +300 -0
- tests/online/test_async_simple_query_online.py +802 -0
- tests/online/test_async_transaction_simple_query.py +300 -0
- tests/online/test_basic_connection.py +130 -0
- tests/online/test_client_online.py +238 -0
- tests/online/test_config.py +90 -0
- tests/online/test_config_validation.py +123 -0
- tests/online/test_connection_hooks_new_online.py +217 -0
- tests/online/test_dialect_schema_handling_online.py +331 -0
- tests/online/test_filter_logical_in_online.py +374 -0
- tests/online/test_fulltext_comprehensive.py +1773 -0
- tests/online/test_fulltext_label_online.py +433 -0
- tests/online/test_fulltext_search_online.py +842 -0
- tests/online/test_ivf_stats_online.py +506 -0
- tests/online/test_logger_integration.py +311 -0
- tests/online/test_matrixone_query_orm.py +540 -0
- tests/online/test_metadata_online.py +579 -0
- tests/online/test_model_insert_online.py +255 -0
- tests/online/test_mysql_driver_validation.py +213 -0
- tests/online/test_orm_advanced_features.py +2022 -0
- tests/online/test_orm_cte_integration.py +269 -0
- tests/online/test_orm_online.py +270 -0
- tests/online/test_pinecone_filter.py +708 -0
- tests/online/test_pubsub_operations.py +352 -0
- tests/online/test_query_methods.py +225 -0
- tests/online/test_query_update_online.py +433 -0
- tests/online/test_search_vector_index.py +557 -0
- tests/online/test_simple_fulltext_online.py +915 -0
- tests/online/test_snapshot_comprehensive.py +998 -0
- tests/online/test_sqlalchemy_engine_integration.py +336 -0
- tests/online/test_sqlalchemy_integration.py +425 -0
- tests/online/test_transaction_contexts.py +1219 -0
- tests/online/test_transaction_insert_methods.py +356 -0
- tests/online/test_transaction_query_methods.py +288 -0
- tests/online/test_unified_filter_online.py +529 -0
- tests/online/test_vector_comprehensive.py +706 -0
- tests/online/test_version_management.py +291 -0
matrixone/client.py
ADDED
@@ -0,0 +1,4672 @@
|
|
1
|
+
# Copyright 2021 - 2022 Matrix Origin
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""
|
16
|
+
MatrixOne Client - Basic implementation
|
17
|
+
"""
|
18
|
+
|
19
|
+
from __future__ import annotations
|
20
|
+
|
21
|
+
from contextlib import contextmanager
|
22
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional, Tuple, Union
|
23
|
+
|
24
|
+
if TYPE_CHECKING:
|
25
|
+
from sqlalchemy.engine import Connection
|
26
|
+
from .sqlalchemy_ext import VectorOpType
|
27
|
+
|
28
|
+
from sqlalchemy import create_engine
|
29
|
+
from sqlalchemy.engine import Engine
|
30
|
+
|
31
|
+
from .account import AccountManager, TransactionAccountManager
|
32
|
+
from .base_client import BaseMatrixOneClient, BaseMatrixOneExecutor
|
33
|
+
from .connection_hooks import ConnectionHook, ConnectionAction, create_connection_hook
|
34
|
+
from .exceptions import ConnectionError, QueryError
|
35
|
+
from .logger import MatrixOneLogger, create_default_logger
|
36
|
+
from .metadata import MetadataManager, TransactionMetadataManager
|
37
|
+
from .moctl import MoCtlManager
|
38
|
+
from .pitr import PitrManager, TransactionPitrManager
|
39
|
+
from .pubsub import PubSubManager, TransactionPubSubManager
|
40
|
+
from .restore import RestoreManager, TransactionRestoreManager
|
41
|
+
from .snapshot import CloneManager, Snapshot, SnapshotLevel, SnapshotManager
|
42
|
+
from .sqlalchemy_ext import MatrixOneDialect
|
43
|
+
from .version import get_version_manager
|
44
|
+
|
45
|
+
|
46
|
+
class ClientExecutor(BaseMatrixOneExecutor):
|
47
|
+
"""Client executor that uses Client's execute method"""
|
48
|
+
|
49
|
+
def __init__(self, client):
|
50
|
+
super().__init__(client)
|
51
|
+
self.client = client
|
52
|
+
|
53
|
+
def _execute(self, sql: str):
|
54
|
+
return self.client.execute(sql)
|
55
|
+
|
56
|
+
def _get_empty_result(self):
|
57
|
+
return ResultSet([], [], affected_rows=0)
|
58
|
+
|
59
|
+
|
60
|
+
class Client(BaseMatrixOneClient):
|
61
|
+
"""
|
62
|
+
MatrixOne Client - High-level interface for MatrixOne database operations.
|
63
|
+
|
64
|
+
This class provides a comprehensive interface for connecting to and interacting
|
65
|
+
with MatrixOne databases. It supports modern API patterns including table creation,
|
66
|
+
data insertion, querying, vector operations, and transaction management.
|
67
|
+
|
68
|
+
Key Features:
|
69
|
+
|
70
|
+
- High-level table operations (create_table, drop_table, insert, batch_insert)
|
71
|
+
- Query builder interface for complex queries
|
72
|
+
- Vector operations (similarity search, range search, indexing)
|
73
|
+
- Transaction management with context managers
|
74
|
+
- Snapshot and restore operations
|
75
|
+
- Account and user management
|
76
|
+
- Fulltext search capabilities
|
77
|
+
- Connection pooling and SSL support
|
78
|
+
|
79
|
+
Examples::
|
80
|
+
|
81
|
+
from matrixone import Client
|
82
|
+
|
83
|
+
# Basic usage
|
84
|
+
client = Client(
|
85
|
+
host='localhost',
|
86
|
+
port=6001,
|
87
|
+
user='root',
|
88
|
+
password='111',
|
89
|
+
database='test'
|
90
|
+
)
|
91
|
+
|
92
|
+
# Create table
|
93
|
+
client.create_table("users", {
|
94
|
+
"id": "int primary key",
|
95
|
+
"name": "varchar(100)",
|
96
|
+
"email": "varchar(255)"
|
97
|
+
})
|
98
|
+
|
99
|
+
# Insert and query data
|
100
|
+
client.insert("users", {"id": 1, "name": "John", "email": "john@example.com"})
|
101
|
+
result = client.query("users").where("id = ?", 1).all()
|
102
|
+
|
103
|
+
# Vector operations
|
104
|
+
client.create_table("documents", {
|
105
|
+
"id": "int primary key",
|
106
|
+
"content": "text",
|
107
|
+
"embedding": "vecf32(384)"
|
108
|
+
})
|
109
|
+
|
110
|
+
results = client.vector_ops.similarity_search(
|
111
|
+
"documents",
|
112
|
+
vector_column="embedding",
|
113
|
+
query_vector=[0.1, 0.2, 0.3, ...],
|
114
|
+
limit=10
|
115
|
+
)
|
116
|
+
|
117
|
+
# Transaction
|
118
|
+
with client.transaction() as tx:
|
119
|
+
tx.execute("INSERT INTO users (name) VALUES ('John')")
|
120
|
+
|
121
|
+
Attributes::
|
122
|
+
|
123
|
+
engine (Engine): SQLAlchemy engine instance
|
124
|
+
connected (bool): Connection status
|
125
|
+
backend_version (str): Detected backend version
|
126
|
+
vector_ops (VectorManager): Vector operations manager
|
127
|
+
snapshots (SnapshotManager): Snapshot operations manager
|
128
|
+
query (QueryBuilder): Query builder for complex queries
|
129
|
+
"""
|
130
|
+
|
131
|
+
def __init__(
|
132
|
+
self,
|
133
|
+
host: str = None,
|
134
|
+
port: int = None,
|
135
|
+
user: str = None,
|
136
|
+
password: str = None,
|
137
|
+
database: str = None,
|
138
|
+
ssl_mode: str = "preferred",
|
139
|
+
ssl_ca: Optional[str] = None,
|
140
|
+
ssl_cert: Optional[str] = None,
|
141
|
+
ssl_key: Optional[str] = None,
|
142
|
+
account: Optional[str] = None,
|
143
|
+
role: Optional[str] = None,
|
144
|
+
pool_size: int = 10,
|
145
|
+
max_overflow: int = 20,
|
146
|
+
pool_timeout: int = 30,
|
147
|
+
pool_recycle: int = 3600,
|
148
|
+
connection_timeout: int = 30,
|
149
|
+
query_timeout: int = 300,
|
150
|
+
auto_commit: bool = True,
|
151
|
+
charset: str = "utf8mb4",
|
152
|
+
logger: Optional[MatrixOneLogger] = None,
|
153
|
+
sql_log_mode: str = "auto",
|
154
|
+
slow_query_threshold: float = 1.0,
|
155
|
+
max_sql_display_length: int = 500,
|
156
|
+
):
|
157
|
+
"""
|
158
|
+
Initialize MatrixOne client
|
159
|
+
|
160
|
+
Args::
|
161
|
+
|
162
|
+
host: Database host (optional, can be set later via connect)
|
163
|
+
port: Database port (optional, can be set later via connect)
|
164
|
+
user: Username (optional, can be set later via connect)
|
165
|
+
password: Password (optional, can be set later via connect)
|
166
|
+
database: Database name (optional, can be set later via connect)
|
167
|
+
ssl_mode: SSL mode (disabled, preferred, required)
|
168
|
+
ssl_ca: SSL CA certificate path
|
169
|
+
ssl_cert: SSL client certificate path
|
170
|
+
ssl_key: SSL client key path
|
171
|
+
account: Optional account name
|
172
|
+
role: Optional role name
|
173
|
+
pool_size: Connection pool size
|
174
|
+
max_overflow: Maximum overflow connections
|
175
|
+
pool_timeout: Pool timeout in seconds
|
176
|
+
pool_recycle: Connection recycle time in seconds
|
177
|
+
connection_timeout: Connection timeout in seconds
|
178
|
+
query_timeout: Query timeout in seconds
|
179
|
+
auto_commit: Enable auto-commit mode
|
180
|
+
charset: Character set for connection
|
181
|
+
logger: Custom logger instance. If None, creates a default logger
|
182
|
+
sql_log_mode: SQL logging mode ('off', 'auto', 'simple', 'full')
|
183
|
+
- 'off': No SQL logging
|
184
|
+
- 'auto': Smart logging - short SQL shown fully, long SQL summarized (default)
|
185
|
+
- 'simple': Show operation summary only
|
186
|
+
- 'full': Show complete SQL regardless of length
|
187
|
+
slow_query_threshold: Threshold in seconds for slow query warnings (default: 1.0)
|
188
|
+
max_sql_display_length: Maximum SQL length in auto mode before summarizing (default: 500)
|
189
|
+
"""
|
190
|
+
self.connection_timeout = connection_timeout
|
191
|
+
self.query_timeout = query_timeout
|
192
|
+
self.auto_commit = auto_commit
|
193
|
+
self.charset = charset
|
194
|
+
self.pool_size = pool_size
|
195
|
+
self.max_overflow = max_overflow
|
196
|
+
self.pool_timeout = pool_timeout
|
197
|
+
self.pool_recycle = pool_recycle
|
198
|
+
|
199
|
+
# Initialize logger
|
200
|
+
if logger is not None:
|
201
|
+
self.logger = logger
|
202
|
+
else:
|
203
|
+
self.logger = create_default_logger(
|
204
|
+
sql_log_mode=sql_log_mode,
|
205
|
+
slow_query_threshold=slow_query_threshold,
|
206
|
+
max_sql_display_length=max_sql_display_length,
|
207
|
+
)
|
208
|
+
|
209
|
+
self._engine = None
|
210
|
+
self._connection_params = {}
|
211
|
+
self._login_info = None
|
212
|
+
self._snapshots = None
|
213
|
+
self._clone = None
|
214
|
+
self._moctl = None
|
215
|
+
self._restore = None
|
216
|
+
self._pitr = None
|
217
|
+
self._pubsub = None
|
218
|
+
self._account = None
|
219
|
+
self._vector_index = None
|
220
|
+
# self._vector_query = None # Removed - functionality moved to vector_ops
|
221
|
+
self._vector_data = None
|
222
|
+
self._fulltext_index = None
|
223
|
+
self._metadata = None
|
224
|
+
|
225
|
+
# Initialize version manager
|
226
|
+
self._version_manager = get_version_manager()
|
227
|
+
self._backend_version = None
|
228
|
+
|
229
|
+
# Auto-connect if connection parameters are provided
|
230
|
+
if all([host, port, user, password, database]):
|
231
|
+
self.connect(
|
232
|
+
host=host,
|
233
|
+
port=port,
|
234
|
+
user=user,
|
235
|
+
password=password,
|
236
|
+
database=database,
|
237
|
+
ssl_mode=ssl_mode,
|
238
|
+
ssl_ca=ssl_ca,
|
239
|
+
ssl_cert=ssl_cert,
|
240
|
+
ssl_key=ssl_key,
|
241
|
+
account=account,
|
242
|
+
role=role,
|
243
|
+
)
|
244
|
+
|
245
|
+
def connect(
|
246
|
+
self,
|
247
|
+
host: str,
|
248
|
+
port: int,
|
249
|
+
user: str,
|
250
|
+
password: str,
|
251
|
+
database: str,
|
252
|
+
ssl_mode: str = "preferred",
|
253
|
+
ssl_ca: Optional[str] = None,
|
254
|
+
ssl_cert: Optional[str] = None,
|
255
|
+
ssl_key: Optional[str] = None,
|
256
|
+
account: Optional[str] = None,
|
257
|
+
role: Optional[str] = None,
|
258
|
+
charset: str = "utf8mb4",
|
259
|
+
connection_timeout: int = 30,
|
260
|
+
auto_commit: bool = True,
|
261
|
+
on_connect: Optional[Union[ConnectionHook, List[Union[ConnectionAction, str]], Callable]] = None,
|
262
|
+
) -> None:
|
263
|
+
"""
|
264
|
+
Connect to MatrixOne database using SQLAlchemy engine
|
265
|
+
|
266
|
+
Args::
|
267
|
+
|
268
|
+
host: Database host
|
269
|
+
port: Database port
|
270
|
+
user: Username or login info in format "user", "account#user", or "account#user#role"
|
271
|
+
password: Password
|
272
|
+
database: Database name
|
273
|
+
ssl_mode: SSL mode (disabled, preferred, required)
|
274
|
+
ssl_ca: SSL CA certificate path
|
275
|
+
ssl_cert: SSL client certificate path
|
276
|
+
ssl_key: SSL client key path
|
277
|
+
account: Optional account name (will be combined with user if user doesn't contain '#')
|
278
|
+
role: Optional role name (will be combined with user if user doesn't contain '#')
|
279
|
+
charset: Character set for the connection (default: utf8mb4)
|
280
|
+
connection_timeout: Connection timeout in seconds (default: 30)
|
281
|
+
auto_commit: Enable autocommit (default: True)
|
282
|
+
on_connect: Connection hook to execute after successful connection.
|
283
|
+
Can be:
|
284
|
+
- ConnectionHook instance
|
285
|
+
- List of ConnectionAction or string action names
|
286
|
+
- Custom callback function
|
287
|
+
|
288
|
+
Examples::
|
289
|
+
|
290
|
+
# Enable all features after connection
|
291
|
+
client.connect(host, port, user, password, database,
|
292
|
+
on_connect=[ConnectionAction.ENABLE_ALL])
|
293
|
+
|
294
|
+
# Enable only vector operations with custom charset
|
295
|
+
client.connect(host, port, user, password, database,
|
296
|
+
charset="utf8mb4",
|
297
|
+
on_connect=[ConnectionAction.ENABLE_VECTOR])
|
298
|
+
|
299
|
+
# Custom callback
|
300
|
+
def my_callback(client):
|
301
|
+
print(f"Connected to {client._connection_params['host']}")
|
302
|
+
|
303
|
+
client.connect(host, port, user, password, database,
|
304
|
+
on_connect=my_callback)
|
305
|
+
"""
|
306
|
+
# Build final login info based on user parameter and optional account/role
|
307
|
+
final_user, parsed_info = self._build_login_info(user, account, role)
|
308
|
+
|
309
|
+
# Store parsed info for later use
|
310
|
+
self._login_info = parsed_info
|
311
|
+
|
312
|
+
self._connection_params = {
|
313
|
+
"host": host,
|
314
|
+
"port": port,
|
315
|
+
"user": final_user,
|
316
|
+
"password": password,
|
317
|
+
"database": database,
|
318
|
+
"charset": charset,
|
319
|
+
"connect_timeout": connection_timeout,
|
320
|
+
"autocommit": auto_commit,
|
321
|
+
"ssl_disabled": ssl_mode == "disabled",
|
322
|
+
"ssl_verify_cert": ssl_mode == "required",
|
323
|
+
"ssl_verify_identity": ssl_mode == "required",
|
324
|
+
}
|
325
|
+
|
326
|
+
# Add SSL parameters if provided
|
327
|
+
if ssl_ca:
|
328
|
+
self._connection_params["ssl_ca"] = ssl_ca
|
329
|
+
if ssl_cert:
|
330
|
+
self._connection_params["ssl_cert"] = ssl_cert
|
331
|
+
if ssl_key:
|
332
|
+
self._connection_params["ssl_key"] = ssl_key
|
333
|
+
|
334
|
+
try:
|
335
|
+
# Create SQLAlchemy engine with connection pooling
|
336
|
+
self._engine = self._create_engine()
|
337
|
+
self.logger.log_connection(host, port, final_user, database, success=True)
|
338
|
+
|
339
|
+
# Initialize managers after engine is created
|
340
|
+
self._initialize_managers()
|
341
|
+
|
342
|
+
# Try to detect backend version after successful connection
|
343
|
+
try:
|
344
|
+
self._detect_backend_version()
|
345
|
+
except Exception as e:
|
346
|
+
self.logger.warning(f"Failed to detect backend version: {e}")
|
347
|
+
|
348
|
+
# Setup connection hook if provided
|
349
|
+
if on_connect:
|
350
|
+
self._setup_connection_hook(on_connect)
|
351
|
+
# Execute the hook once immediately for the initial connection
|
352
|
+
self._execute_connection_hook_immediately(on_connect)
|
353
|
+
|
354
|
+
except Exception as e:
|
355
|
+
self.logger.log_connection(host, port, final_user, database, success=False)
|
356
|
+
self.logger.log_error(e, context="Connection")
|
357
|
+
raise ConnectionError(f"Failed to connect to MatrixOne: {e}")
|
358
|
+
|
359
|
+
def _setup_connection_hook(
|
360
|
+
self, on_connect: Union[ConnectionHook, List[Union[ConnectionAction, str]], Callable]
|
361
|
+
) -> None:
|
362
|
+
"""Setup connection hook to be executed on each new connection"""
|
363
|
+
try:
|
364
|
+
if isinstance(on_connect, ConnectionHook):
|
365
|
+
# Direct ConnectionHook instance
|
366
|
+
hook = on_connect
|
367
|
+
elif isinstance(on_connect, list):
|
368
|
+
# List of actions - create a hook
|
369
|
+
hook = create_connection_hook(actions=on_connect)
|
370
|
+
elif callable(on_connect):
|
371
|
+
# Custom callback function
|
372
|
+
hook = create_connection_hook(custom_hook=on_connect)
|
373
|
+
else:
|
374
|
+
self.logger.warning(f"Invalid on_connect parameter type: {type(on_connect)}")
|
375
|
+
return
|
376
|
+
|
377
|
+
# Set the client reference and attach to engine
|
378
|
+
hook.set_client(self)
|
379
|
+
hook.attach_to_engine(self._engine)
|
380
|
+
|
381
|
+
except Exception as e:
|
382
|
+
self.logger.warning(f"Connection hook setup failed: {e}")
|
383
|
+
|
384
|
+
def _execute_connection_hook_immediately(
|
385
|
+
self, on_connect: Union[ConnectionHook, List[Union[ConnectionAction, str]], Callable]
|
386
|
+
) -> None:
|
387
|
+
"""Execute connection hook immediately for the initial connection"""
|
388
|
+
try:
|
389
|
+
if isinstance(on_connect, ConnectionHook):
|
390
|
+
# Direct ConnectionHook instance
|
391
|
+
hook = on_connect
|
392
|
+
elif isinstance(on_connect, list):
|
393
|
+
# List of actions - create a hook
|
394
|
+
hook = create_connection_hook(actions=on_connect)
|
395
|
+
elif callable(on_connect):
|
396
|
+
# Custom callback function
|
397
|
+
hook = create_connection_hook(custom_hook=on_connect)
|
398
|
+
else:
|
399
|
+
self.logger.warning(f"Invalid on_connect parameter type: {type(on_connect)}")
|
400
|
+
return
|
401
|
+
|
402
|
+
# Execute the hook immediately
|
403
|
+
hook.execute_sync(self)
|
404
|
+
|
405
|
+
except Exception as e:
|
406
|
+
self.logger.warning(f"Immediate connection hook execution failed: {e}")
|
407
|
+
|
408
|
+
@classmethod
|
409
|
+
def from_engine(cls, engine: Engine, **kwargs) -> "Client":
|
410
|
+
"""
|
411
|
+
Create Client instance from existing SQLAlchemy Engine
|
412
|
+
|
413
|
+
Args::
|
414
|
+
|
415
|
+
engine: SQLAlchemy Engine instance (must use MySQL driver)
|
416
|
+
**kwargs: Additional client configuration options
|
417
|
+
|
418
|
+
Returns::
|
419
|
+
|
420
|
+
Client: Configured client instance
|
421
|
+
|
422
|
+
Raises::
|
423
|
+
|
424
|
+
ConnectionError: If engine doesn't use MySQL driver
|
425
|
+
|
426
|
+
Examples
|
427
|
+
|
428
|
+
Basic usage::
|
429
|
+
|
430
|
+
from sqlalchemy import create_engine
|
431
|
+
from matrixone import Client
|
432
|
+
|
433
|
+
engine = create_engine("mysql+pymysql://user:pass@host:port/db")
|
434
|
+
client = Client.from_engine(engine)
|
435
|
+
|
436
|
+
With custom configuration::
|
437
|
+
|
438
|
+
engine = create_engine("mysql+pymysql://user:pass@host:port/db")
|
439
|
+
client = Client.from_engine(
|
440
|
+
engine,
|
441
|
+
sql_log_mode='auto',
|
442
|
+
slow_query_threshold=0.5
|
443
|
+
)
|
444
|
+
"""
|
445
|
+
# Check if engine uses MySQL driver
|
446
|
+
if not cls._is_mysql_engine(engine):
|
447
|
+
raise ConnectionError(
|
448
|
+
"MatrixOne Client only supports MySQL drivers. "
|
449
|
+
"Please use mysql+pymysql:// or mysql+mysqlconnector:// connection strings. "
|
450
|
+
f"Current engine uses: {engine.dialect.name}"
|
451
|
+
)
|
452
|
+
|
453
|
+
# Create client instance with default parameters
|
454
|
+
client = cls(**kwargs)
|
455
|
+
|
456
|
+
# Set the provided engine
|
457
|
+
client._engine = engine
|
458
|
+
|
459
|
+
# Replace the dialect with MatrixOne dialect for proper vector type support
|
460
|
+
original_dbapi = engine.dialect.dbapi
|
461
|
+
engine.dialect = MatrixOneDialect()
|
462
|
+
engine.dialect.dbapi = original_dbapi
|
463
|
+
|
464
|
+
# Initialize managers after engine is set
|
465
|
+
client._initialize_managers()
|
466
|
+
|
467
|
+
# Try to detect backend version
|
468
|
+
try:
|
469
|
+
client._detect_backend_version()
|
470
|
+
except Exception as e:
|
471
|
+
client.logger.warning(f"Failed to detect backend version: {e}")
|
472
|
+
|
473
|
+
return client
|
474
|
+
|
475
|
+
@staticmethod
|
476
|
+
def _is_mysql_engine(engine: Engine) -> bool:
|
477
|
+
"""
|
478
|
+
Check if the engine uses a MySQL driver
|
479
|
+
|
480
|
+
Args::
|
481
|
+
|
482
|
+
engine: SQLAlchemy Engine instance
|
483
|
+
|
484
|
+
Returns::
|
485
|
+
|
486
|
+
bool: True if engine uses MySQL driver, False otherwise
|
487
|
+
"""
|
488
|
+
# Check dialect name
|
489
|
+
dialect_name = engine.dialect.name.lower()
|
490
|
+
|
491
|
+
# Check if it's a MySQL dialect
|
492
|
+
if dialect_name == "mysql":
|
493
|
+
return True
|
494
|
+
|
495
|
+
# Check connection string for MySQL drivers
|
496
|
+
url = str(engine.url)
|
497
|
+
mysql_drivers = [
|
498
|
+
"mysql+pymysql",
|
499
|
+
"mysql+mysqlconnector",
|
500
|
+
"mysql+cymysql",
|
501
|
+
"mysql+oursql",
|
502
|
+
"mysql+gaerdbms",
|
503
|
+
"mysql+pyodbc",
|
504
|
+
]
|
505
|
+
|
506
|
+
return any(driver in url.lower() for driver in mysql_drivers)
|
507
|
+
|
508
|
+
def _create_engine(self) -> Engine:
|
509
|
+
"""Create SQLAlchemy engine with connection pooling"""
|
510
|
+
# Build connection string
|
511
|
+
connection_string = (
|
512
|
+
f"mysql+pymysql://{self._connection_params['user']}:"
|
513
|
+
f"{self._connection_params['password']}@"
|
514
|
+
f"{self._connection_params['host']}:"
|
515
|
+
f"{self._connection_params['port']}/"
|
516
|
+
f"{self._connection_params['database']}"
|
517
|
+
)
|
518
|
+
|
519
|
+
# Add SSL parameters if needed
|
520
|
+
if "ssl_ca" in self._connection_params:
|
521
|
+
connection_string += f"?ssl_ca={self._connection_params['ssl_ca']}"
|
522
|
+
|
523
|
+
# Create engine with connection pooling
|
524
|
+
engine = create_engine(
|
525
|
+
connection_string,
|
526
|
+
pool_size=self.pool_size,
|
527
|
+
max_overflow=self.max_overflow,
|
528
|
+
pool_timeout=self.pool_timeout,
|
529
|
+
pool_recycle=self.pool_recycle,
|
530
|
+
pool_pre_ping=True, # Enable connection health checks
|
531
|
+
)
|
532
|
+
|
533
|
+
# Replace the dialect with MatrixOne dialect for proper vector type support
|
534
|
+
original_dbapi = engine.dialect.dbapi
|
535
|
+
engine.dialect = MatrixOneDialect()
|
536
|
+
engine.dialect.dbapi = original_dbapi
|
537
|
+
|
538
|
+
return engine
|
539
|
+
|
540
|
+
def _initialize_managers(self) -> None:
|
541
|
+
"""Initialize all manager instances after engine is created"""
|
542
|
+
self._snapshots = SnapshotManager(self)
|
543
|
+
self._clone = CloneManager(self)
|
544
|
+
self._moctl = MoCtlManager(self)
|
545
|
+
self._restore = RestoreManager(self)
|
546
|
+
self._pitr = PitrManager(self)
|
547
|
+
self._pubsub = PubSubManager(self)
|
548
|
+
self._account = AccountManager(self)
|
549
|
+
self._vector = VectorManager(self)
|
550
|
+
# self._vector_query = VectorQueryManager(self) # Removed - functionality moved to vector_ops
|
551
|
+
self._fulltext_index = FulltextIndexManager(self)
|
552
|
+
self._metadata = MetadataManager(self)
|
553
|
+
|
554
|
+
def disconnect(self) -> None:
|
555
|
+
"""
|
556
|
+
Disconnect from MatrixOne database and dispose engine.
|
557
|
+
|
558
|
+
This method properly closes all database connections and disposes of the
|
559
|
+
SQLAlchemy engine. It should be called when the client is no longer needed
|
560
|
+
to free up resources.
|
561
|
+
|
562
|
+
After calling this method, the client will need to be reconnected using
|
563
|
+
the connect() method before any database operations can be performed.
|
564
|
+
|
565
|
+
Raises::
|
566
|
+
|
567
|
+
Exception: If disconnection fails (logged but re-raised)
|
568
|
+
|
569
|
+
Example
|
570
|
+
|
571
|
+
>>> client = Client('localhost', 6001, 'root', '111', 'test')
|
572
|
+
>>> client.connect()
|
573
|
+
>>> # ... perform database operations ...
|
574
|
+
>>> client.disconnect() # Clean up resources
|
575
|
+
"""
|
576
|
+
if self._engine:
|
577
|
+
try:
|
578
|
+
self._engine.dispose()
|
579
|
+
self._engine = None
|
580
|
+
self.logger.log_disconnection(success=True)
|
581
|
+
except Exception as e:
|
582
|
+
self.logger.log_disconnection(success=False)
|
583
|
+
self.logger.log_error(e, context="Disconnection")
|
584
|
+
raise
|
585
|
+
|
586
|
+
def get_login_info(self) -> Optional[dict]:
|
587
|
+
"""
|
588
|
+
Get parsed login information used for database connection.
|
589
|
+
|
590
|
+
Returns the login information dictionary that was used to establish
|
591
|
+
the database connection. This includes user, account, role, and other
|
592
|
+
authentication details.
|
593
|
+
|
594
|
+
Returns::
|
595
|
+
|
596
|
+
Optional[dict]: Dictionary containing login information with keys:
|
597
|
+
- user: Username
|
598
|
+
- account: Account name (if specified)
|
599
|
+
- role: Role name (if specified)
|
600
|
+
- host: Database host
|
601
|
+
- port: Database port
|
602
|
+
- database: Database name
|
603
|
+
Returns None if not connected or no login info available.
|
604
|
+
|
605
|
+
Example
|
606
|
+
|
607
|
+
>>> client = Client('localhost', 6001, 'root', '111', 'test')
|
608
|
+
>>> client.connect()
|
609
|
+
>>> login_info = client.get_login_info()
|
610
|
+
>>> print(f"Connected as {login_info['user']} to {login_info['database']}")
|
611
|
+
"""
|
612
|
+
return self._login_info
|
613
|
+
|
614
|
+
def _escape_identifier(self, identifier: str) -> str:
|
615
|
+
"""Escapes an identifier to prevent SQL injection."""
|
616
|
+
return f"`{identifier}`"
|
617
|
+
|
618
|
+
def _escape_string(self, value: str) -> str:
|
619
|
+
"""Escapes a string value for SQL queries."""
|
620
|
+
return f"'{value}'"
|
621
|
+
|
622
|
+
def _build_login_info(self, user: str, account: Optional[str] = None, role: Optional[str] = None) -> tuple[str, dict]:
|
623
|
+
"""
|
624
|
+
Build final login info based on user parameter and optional account/role
|
625
|
+
|
626
|
+
Args::
|
627
|
+
|
628
|
+
user: Username or login info in format "user", "account#user", or "account#user#role"
|
629
|
+
account: Optional account name
|
630
|
+
role: Optional role name
|
631
|
+
|
632
|
+
Returns::
|
633
|
+
|
634
|
+
tuple: (final_user_string, parsed_info_dict)
|
635
|
+
|
636
|
+
Rules:
|
637
|
+
1. If user contains '#', it's already in format "account#user" or "account#user#role"
|
638
|
+
- If account or role is also provided, raise error (conflict)
|
639
|
+
2. If user doesn't contain '#', combine with optional account/role:
|
640
|
+
- No account/role: use user as-is
|
641
|
+
- Only role: use "sys#user#role"
|
642
|
+
- Only account: use "account#user"
|
643
|
+
- Both: use "account#user#role"
|
644
|
+
"""
|
645
|
+
# Check if user already contains login format
|
646
|
+
if "#" in user:
|
647
|
+
# User is already in format "account#user" or "account#user#role"
|
648
|
+
if account is not None or role is not None:
|
649
|
+
raise ValueError(
|
650
|
+
f"Conflict: user parameter '{user}' already contains account/role info, "
|
651
|
+
f"but account='{account}' and role='{role}' are also provided. "
|
652
|
+
f"Use either user format or separate account/role parameters, not both."
|
653
|
+
)
|
654
|
+
|
655
|
+
# Parse the existing format
|
656
|
+
parts = user.split("#")
|
657
|
+
if len(parts) == 2:
|
658
|
+
# "account#user" format
|
659
|
+
final_account, final_user, final_role = parts[0], parts[1], None
|
660
|
+
elif len(parts) == 3:
|
661
|
+
# "account#user#role" format
|
662
|
+
final_account, final_user, final_role = parts[0], parts[1], parts[2]
|
663
|
+
else:
|
664
|
+
raise ValueError(f"Invalid user format: '{user}'. Expected 'user', 'account#user', or 'account#user#role'")
|
665
|
+
|
666
|
+
final_user_string = user
|
667
|
+
|
668
|
+
else:
|
669
|
+
# User is just a username, combine with optional account/role
|
670
|
+
if account is None and role is None:
|
671
|
+
# No account/role provided, use user as-is
|
672
|
+
final_account, final_user, final_role = "sys", user, None
|
673
|
+
final_user_string = user
|
674
|
+
elif account is None and role is not None:
|
675
|
+
# Only role provided, use sys account
|
676
|
+
final_account, final_user, final_role = "sys", user, role
|
677
|
+
final_user_string = f"sys#{user}#{role}"
|
678
|
+
elif account is not None and role is None:
|
679
|
+
# Only account provided, no role
|
680
|
+
final_account, final_user, final_role = account, user, None
|
681
|
+
final_user_string = f"{account}#{user}"
|
682
|
+
else:
|
683
|
+
# Both account and role provided
|
684
|
+
final_account, final_user, final_role = account, user, role
|
685
|
+
final_user_string = f"{account}#{user}#{role}"
|
686
|
+
|
687
|
+
parsed_info = {"account": final_account, "user": final_user, "role": final_role}
|
688
|
+
|
689
|
+
return final_user_string, parsed_info
|
690
|
+
|
691
|
+
def _execute_with_logging(
|
692
|
+
self, connection: "Connection", sql: str, context: str = "SQL execution", override_sql_log_mode: str = None
|
693
|
+
):
|
694
|
+
"""
|
695
|
+
Execute SQL with proper logging through the client's logger.
|
696
|
+
|
697
|
+
This is an internal helper method used by all SDK components to ensure
|
698
|
+
consistent SQL logging across vector operations, transactions, and other features.
|
699
|
+
|
700
|
+
Args::
|
701
|
+
|
702
|
+
connection: SQLAlchemy connection object
|
703
|
+
sql: SQL query string
|
704
|
+
context: Context description for error logging (default: "SQL execution")
|
705
|
+
override_sql_log_mode: Temporarily override sql_log_mode for this query only
|
706
|
+
|
707
|
+
Returns::
|
708
|
+
|
709
|
+
SQLAlchemy result object
|
710
|
+
|
711
|
+
Note:
|
712
|
+
|
713
|
+
This method is used internally by VectorManager, TransactionWrapper,
|
714
|
+
and other SDK components. External users should use execute() instead.
|
715
|
+
"""
|
716
|
+
import time
|
717
|
+
|
718
|
+
start_time = time.time()
|
719
|
+
try:
|
720
|
+
# Use exec_driver_sql() to bypass SQLAlchemy's bind parameter parsing
|
721
|
+
if hasattr(connection, 'exec_driver_sql'):
|
722
|
+
# Escape % to %% for pymysql's format string handling
|
723
|
+
escaped_sql = sql.replace('%', '%%')
|
724
|
+
result = connection.exec_driver_sql(escaped_sql)
|
725
|
+
else:
|
726
|
+
# Fallback for testing or older SQLAlchemy versions
|
727
|
+
from sqlalchemy import text
|
728
|
+
|
729
|
+
result = connection.execute(text(sql))
|
730
|
+
execution_time = time.time() - start_time
|
731
|
+
|
732
|
+
# Try to get row count if available
|
733
|
+
try:
|
734
|
+
if result.returns_rows:
|
735
|
+
# For SELECT queries, we can't consume the result to count rows
|
736
|
+
# So we just log without row count
|
737
|
+
self.logger.log_query(
|
738
|
+
sql, execution_time, None, success=True, override_sql_log_mode=override_sql_log_mode
|
739
|
+
)
|
740
|
+
else:
|
741
|
+
# For DML queries (INSERT/UPDATE/DELETE), we can get rowcount
|
742
|
+
self.logger.log_query(
|
743
|
+
sql, execution_time, result.rowcount, success=True, override_sql_log_mode=override_sql_log_mode
|
744
|
+
)
|
745
|
+
except Exception:
|
746
|
+
# Fallback: just log the query without row count
|
747
|
+
self.logger.log_query(sql, execution_time, None, success=True, override_sql_log_mode=override_sql_log_mode)
|
748
|
+
|
749
|
+
return result
|
750
|
+
except Exception as e:
|
751
|
+
execution_time = time.time() - start_time
|
752
|
+
self.logger.log_query(sql, execution_time, success=False, override_sql_log_mode=override_sql_log_mode)
|
753
|
+
self.logger.log_error(e, context=context)
|
754
|
+
raise
|
755
|
+
|
756
|
+
def execute(self, sql: str, params: Optional[Tuple] = None) -> "ResultSet":
|
757
|
+
"""
|
758
|
+
Execute SQL query using connection pool.
|
759
|
+
|
760
|
+
This is the primary method for executing SQL statements against the MatrixOne
|
761
|
+
database. It supports both SELECT queries (returning data) and DML operations
|
762
|
+
(INSERT, UPDATE, DELETE) that modify data.
|
763
|
+
|
764
|
+
Args::
|
765
|
+
|
766
|
+
sql (str): SQL query string. Can include '?' placeholders for parameter binding.
|
767
|
+
params (Optional[Tuple]): Query parameters to replace '?' placeholders in the SQL.
|
768
|
+
Parameters are automatically escaped to prevent SQL injection.
|
769
|
+
|
770
|
+
Returns::
|
771
|
+
|
772
|
+
ResultSet: Object containing query results with the following attributes:
|
773
|
+
- columns: List of column names
|
774
|
+
- rows: List of tuples containing row data
|
775
|
+
- affected_rows: Number of rows affected (for DML operations)
|
776
|
+
- fetchall(): Method to get all rows as a list
|
777
|
+
- fetchone(): Method to get the next row
|
778
|
+
- fetchmany(size): Method to get multiple rows
|
779
|
+
|
780
|
+
Raises::
|
781
|
+
|
782
|
+
ConnectionError: If not connected to database
|
783
|
+
QueryError: If query execution fails
|
784
|
+
|
785
|
+
Examples
|
786
|
+
|
787
|
+
# SELECT query
|
788
|
+
>>> result = client.execute("SELECT * FROM users WHERE age > ?", (25,))
|
789
|
+
>>> for row in result.fetchall():
|
790
|
+
... print(row)
|
791
|
+
|
792
|
+
# INSERT query
|
793
|
+
>>> result = client.execute("INSERT INTO users (name, age) VALUES (?, ?)", ("John", 30))
|
794
|
+
>>> print(f"Inserted {result.affected_rows} rows")
|
795
|
+
|
796
|
+
# UPDATE query
|
797
|
+
>>> result = client.execute("UPDATE users SET age = ? WHERE name = ?", (31, "John"))
|
798
|
+
>>> print(f"Updated {result.affected_rows} rows")
|
799
|
+
"""
|
800
|
+
if not self._engine:
|
801
|
+
raise ConnectionError("Not connected to database")
|
802
|
+
|
803
|
+
import time
|
804
|
+
|
805
|
+
start_time = time.time()
|
806
|
+
|
807
|
+
try:
|
808
|
+
# Handle parameter substitution for MatrixOne compatibility
|
809
|
+
final_sql = self._substitute_parameters(sql, params)
|
810
|
+
|
811
|
+
with self._engine.begin() as conn:
|
812
|
+
# Use exec_driver_sql() to bypass SQLAlchemy's bind parameter parsing
|
813
|
+
# This prevents JSON strings like {"a":1} from being parsed as :1 bind params
|
814
|
+
if hasattr(conn, 'exec_driver_sql'):
|
815
|
+
# Escape % to %% for pymysql's format string handling in exec_driver_sql
|
816
|
+
escaped_sql = final_sql.replace('%', '%%')
|
817
|
+
result = conn.exec_driver_sql(escaped_sql)
|
818
|
+
else:
|
819
|
+
# Fallback for testing or older SQLAlchemy versions
|
820
|
+
from sqlalchemy import text
|
821
|
+
|
822
|
+
result = conn.execute(text(final_sql))
|
823
|
+
|
824
|
+
execution_time = time.time() - start_time
|
825
|
+
|
826
|
+
if result.returns_rows:
|
827
|
+
# SELECT query
|
828
|
+
columns = list(result.keys())
|
829
|
+
rows = result.fetchall()
|
830
|
+
result_set = ResultSet(columns, rows)
|
831
|
+
self.logger.log_query(sql, execution_time, len(rows), success=True)
|
832
|
+
return result_set
|
833
|
+
else:
|
834
|
+
# INSERT/UPDATE/DELETE query
|
835
|
+
result_set = ResultSet([], [], affected_rows=result.rowcount)
|
836
|
+
self.logger.log_query(sql, execution_time, result.rowcount, success=True)
|
837
|
+
return result_set
|
838
|
+
|
839
|
+
except Exception as e:
|
840
|
+
execution_time = time.time() - start_time
|
841
|
+
|
842
|
+
# Log error FIRST, before any error processing
|
843
|
+
# Wrap in try-except to ensure logging failure doesn't hide the real error
|
844
|
+
try:
|
845
|
+
self.logger.log_query(sql, execution_time, success=False)
|
846
|
+
self.logger.log_error(e, context="Query execution")
|
847
|
+
except Exception as log_err:
|
848
|
+
# If logging fails, print to stderr as fallback but continue with error handling
|
849
|
+
import sys
|
850
|
+
|
851
|
+
print(f"Warning: Error logging failed: {log_err}", file=sys.stderr)
|
852
|
+
|
853
|
+
# Extract user-friendly error message
|
854
|
+
error_msg = str(e)
|
855
|
+
|
856
|
+
# Handle common database errors with helpful messages
|
857
|
+
# Check for "does not exist" first before "syntax error"
|
858
|
+
if (
|
859
|
+
'does not exist' in error_msg.lower()
|
860
|
+
or 'no such table' in error_msg.lower()
|
861
|
+
or 'doesn\'t exist' in error_msg.lower()
|
862
|
+
):
|
863
|
+
# Table doesn't exist
|
864
|
+
import re
|
865
|
+
|
866
|
+
match = re.search(r"(?:table|database)\s+[\"']?(\w+)[\"']?\s+does not exist", error_msg, re.IGNORECASE)
|
867
|
+
if match:
|
868
|
+
obj_name = match.group(1)
|
869
|
+
raise QueryError(
|
870
|
+
f"Table or database '{obj_name}' does not exist. "
|
871
|
+
f"Create it first using client.create_table() or CREATE TABLE/DATABASE statement."
|
872
|
+
) from None
|
873
|
+
else:
|
874
|
+
raise QueryError(f"Object not found: {error_msg}") from None
|
875
|
+
|
876
|
+
elif 'already exists' in error_msg.lower() and '1050' in error_msg:
|
877
|
+
# Table already exists
|
878
|
+
match = None
|
879
|
+
if 'table' in error_msg.lower():
|
880
|
+
import re
|
881
|
+
|
882
|
+
match = re.search(r"table\s+(\w+)\s+already\s+exists", error_msg, re.IGNORECASE)
|
883
|
+
if match:
|
884
|
+
table_name = match.group(1)
|
885
|
+
raise QueryError(
|
886
|
+
f"Table '{table_name}' already exists. "
|
887
|
+
f"Use DROP TABLE {table_name} or client.drop_table() to remove it first."
|
888
|
+
) from None
|
889
|
+
else:
|
890
|
+
raise QueryError(f"Object already exists: {error_msg}") from None
|
891
|
+
|
892
|
+
elif 'duplicate' in error_msg.lower() and ('1062' in error_msg or '1061' in error_msg):
|
893
|
+
# Duplicate key/entry
|
894
|
+
raise QueryError(
|
895
|
+
f"Duplicate entry error: {error_msg}. "
|
896
|
+
f"Check for duplicate primary key or unique constraint violations."
|
897
|
+
) from None
|
898
|
+
|
899
|
+
elif 'syntax error' in error_msg.lower() or '1064' in error_msg:
|
900
|
+
# SQL syntax error
|
901
|
+
sql_preview = sql[:200] + '...' if len(sql) > 200 else sql
|
902
|
+
raise QueryError(f"SQL syntax error: {error_msg}\n" f"Query: {sql_preview}") from None
|
903
|
+
|
904
|
+
elif 'column' in error_msg.lower() and ('unknown' in error_msg.lower() or 'not found' in error_msg.lower()):
|
905
|
+
# Column doesn't exist
|
906
|
+
raise QueryError(f"Column not found: {error_msg}. " f"Check your column names and table schema.") from None
|
907
|
+
|
908
|
+
elif 'cannot be null' in error_msg.lower() or '1048' in error_msg:
|
909
|
+
# NULL constraint violation
|
910
|
+
raise QueryError(
|
911
|
+
f"NULL constraint violation: {error_msg}. " f"Some columns require non-NULL values."
|
912
|
+
) from None
|
913
|
+
|
914
|
+
elif 'not supported' in error_msg.lower() and '20105' in error_msg:
|
915
|
+
# MatrixOne-specific: feature not supported
|
916
|
+
raise QueryError(
|
917
|
+
f"MatrixOne feature limitation: {error_msg}. "
|
918
|
+
f"This feature may require additional configuration or is not yet supported."
|
919
|
+
) from None
|
920
|
+
|
921
|
+
elif 'bind parameter' in error_msg.lower() or 'InvalidRequestError' in error_msg:
|
922
|
+
# SQLAlchemy bind parameter error (from JSON colons, etc.)
|
923
|
+
raise QueryError(
|
924
|
+
f"Parameter binding error: {error_msg}. "
|
925
|
+
f"This might be caused by special characters in your data (colons in JSON, etc.)"
|
926
|
+
) from None
|
927
|
+
|
928
|
+
else:
|
929
|
+
# Generic error - still cleaner than full SQLAlchemy stack
|
930
|
+
raise QueryError(f"Query execution failed: {error_msg}") from None
|
931
|
+
|
932
|
+
def insert(self, table_name_or_model, data: dict) -> "ResultSet":
|
933
|
+
"""
|
934
|
+
Insert a single row of data into a table.
|
935
|
+
|
936
|
+
This method provides a convenient way to insert data using a dictionary
|
937
|
+
where keys are column names and values are the data to insert. The method
|
938
|
+
automatically handles SQL generation and parameter binding.
|
939
|
+
|
940
|
+
Args::
|
941
|
+
|
942
|
+
table_name_or_model: Either a table name (str) or a SQLAlchemy model class
|
943
|
+
data (dict): Dictionary mapping column names to values. Example:
|
944
|
+
|
945
|
+
{'name': 'John', 'age': 30, 'email': 'john@example.com'}
|
946
|
+
|
947
|
+
Returns::
|
948
|
+
|
949
|
+
ResultSet: Object containing insertion results with:
|
950
|
+
- affected_rows: Number of rows inserted (should be 1)
|
951
|
+
- columns: Empty list (no columns returned for INSERT)
|
952
|
+
- rows: Empty list (no rows returned for INSERT)
|
953
|
+
|
954
|
+
Raises::
|
955
|
+
|
956
|
+
ConnectionError: If not connected to database
|
957
|
+
QueryError: If insertion fails
|
958
|
+
|
959
|
+
Examples
|
960
|
+
|
961
|
+
# Insert a single user using table name
|
962
|
+
>>> result = client.insert('users', {
|
963
|
+
... 'name': 'John Doe',
|
964
|
+
... 'age': 30,
|
965
|
+
... 'email': 'john@example.com'
|
966
|
+
... })
|
967
|
+
>>> print(f"Inserted {result.affected_rows} row")
|
968
|
+
|
969
|
+
# Insert using model class
|
970
|
+
>>> from sqlalchemy import Column, Integer, String
|
971
|
+
>>> from matrixone.orm import declarative_base
|
972
|
+
>>> Base = declarative_base()
|
973
|
+
>>> class User(Base):
|
974
|
+
... __tablename__ = 'users'
|
975
|
+
... id = Column(Integer, primary_key=True)
|
976
|
+
... name = Column(String(50))
|
977
|
+
... age = Column(Integer)
|
978
|
+
>>> result = client.insert(User, {
|
979
|
+
... 'name': 'Jane Doe',
|
980
|
+
... 'age': 25
|
981
|
+
... })
|
982
|
+
|
983
|
+
# Insert with NULL values
|
984
|
+
>>> result = client.insert('products', {
|
985
|
+
... 'name': 'Product A',
|
986
|
+
... 'price': 99.99,
|
987
|
+
... 'description': None # NULL value
|
988
|
+
... })
|
989
|
+
"""
|
990
|
+
# Handle model class input
|
991
|
+
if hasattr(table_name_or_model, '__tablename__'):
|
992
|
+
# It's a model class
|
993
|
+
table_name = table_name_or_model.__tablename__
|
994
|
+
else:
|
995
|
+
# It's a table name string
|
996
|
+
table_name = table_name_or_model
|
997
|
+
|
998
|
+
executor = ClientExecutor(self)
|
999
|
+
return executor.insert(table_name, data)
|
1000
|
+
|
1001
|
+
def batch_insert(self, table_name_or_model, data_list: list) -> "ResultSet":
|
1002
|
+
"""
|
1003
|
+
Batch insert multiple rows of data into a table.
|
1004
|
+
|
1005
|
+
This method efficiently inserts multiple rows in a single operation,
|
1006
|
+
which is much faster than calling insert() multiple times. All rows
|
1007
|
+
must have the same column structure.
|
1008
|
+
|
1009
|
+
Args::
|
1010
|
+
|
1011
|
+
table_name_or_model: Either a table name (str) or a SQLAlchemy model class
|
1012
|
+
data_list (list): List of dictionaries, where each dictionary represents
|
1013
|
+
a row to insert. All dictionaries must have the same keys.
|
1014
|
+
Example: [
|
1015
|
+
{'name': 'John', 'age': 30},
|
1016
|
+
{'name': 'Jane', 'age': 25},
|
1017
|
+
{'name': 'Bob', 'age': 35}
|
1018
|
+
]
|
1019
|
+
|
1020
|
+
Returns::
|
1021
|
+
|
1022
|
+
ResultSet: Object containing insertion results with:
|
1023
|
+
- affected_rows: Number of rows inserted
|
1024
|
+
- columns: Empty list (no columns returned for INSERT)
|
1025
|
+
- rows: Empty list (no rows returned for INSERT)
|
1026
|
+
|
1027
|
+
Raises::
|
1028
|
+
|
1029
|
+
ConnectionError: If not connected to database
|
1030
|
+
QueryError: If batch insertion fails
|
1031
|
+
ValueError: If data_list is empty or has inconsistent column structure
|
1032
|
+
|
1033
|
+
Examples
|
1034
|
+
|
1035
|
+
# Insert multiple users
|
1036
|
+
>>> users = [
|
1037
|
+
... {'name': 'John Doe', 'age': 30, 'email': 'john@example.com'},
|
1038
|
+
... {'name': 'Jane Smith', 'age': 25, 'email': 'jane@example.com'},
|
1039
|
+
... {'name': 'Bob Johnson', 'age': 35, 'email': 'bob@example.com'}
|
1040
|
+
... ]
|
1041
|
+
>>> result = client.batch_insert('users', users)
|
1042
|
+
>>> print(f"Inserted {result.affected_rows} rows")
|
1043
|
+
|
1044
|
+
# Insert with some NULL values
|
1045
|
+
>>> products = [
|
1046
|
+
... {'name': 'Product A', 'price': 99.99, 'description': 'Great product'},
|
1047
|
+
... {'name': 'Product B', 'price': 149.99, 'description': None}
|
1048
|
+
... ]
|
1049
|
+
>>> result = client.batch_insert('products', products)
|
1050
|
+
"""
|
1051
|
+
# Handle model class input
|
1052
|
+
if hasattr(table_name_or_model, '__tablename__'):
|
1053
|
+
# It's a model class
|
1054
|
+
table_name = table_name_or_model.__tablename__
|
1055
|
+
else:
|
1056
|
+
# It's a table name string
|
1057
|
+
table_name = table_name_or_model
|
1058
|
+
|
1059
|
+
executor = ClientExecutor(self)
|
1060
|
+
return executor.batch_insert(table_name, data_list)
|
1061
|
+
|
1062
|
+
def _substitute_parameters(self, sql: str, params=None) -> str:
|
1063
|
+
"""
|
1064
|
+
Substitute ? placeholders or named parameters with actual values since MatrixOne
|
1065
|
+
doesn't support prepared statements
|
1066
|
+
|
1067
|
+
Args::
|
1068
|
+
|
1069
|
+
sql: SQL query string with ? placeholders or named parameters (:name)
|
1070
|
+
params: Tuple of parameter values or dict of named parameters
|
1071
|
+
|
1072
|
+
Returns::
|
1073
|
+
|
1074
|
+
SQL string with parameters substituted
|
1075
|
+
"""
|
1076
|
+
if not params:
|
1077
|
+
return sql
|
1078
|
+
|
1079
|
+
final_sql = sql
|
1080
|
+
|
1081
|
+
# Handle named parameters (dict)
|
1082
|
+
if isinstance(params, dict):
|
1083
|
+
for key, value in params.items():
|
1084
|
+
placeholder = f":{key}"
|
1085
|
+
if placeholder in final_sql:
|
1086
|
+
if isinstance(value, str):
|
1087
|
+
# Escape single quotes in string values
|
1088
|
+
escaped_value = value.replace("'", "''")
|
1089
|
+
final_sql = final_sql.replace(placeholder, f"'{escaped_value}'")
|
1090
|
+
elif value is None:
|
1091
|
+
final_sql = final_sql.replace(placeholder, "NULL")
|
1092
|
+
else:
|
1093
|
+
final_sql = final_sql.replace(placeholder, str(value))
|
1094
|
+
# Handle positional parameters (tuple/list)
|
1095
|
+
elif isinstance(params, (tuple, list)):
|
1096
|
+
for param in params:
|
1097
|
+
# Skip empty lists that might come from CTE queries
|
1098
|
+
if isinstance(param, list) and len(param) == 0:
|
1099
|
+
continue
|
1100
|
+
elif isinstance(param, str):
|
1101
|
+
# Escape single quotes in string values
|
1102
|
+
escaped_param = param.replace("'", "''")
|
1103
|
+
# Handle both ? and %s placeholders
|
1104
|
+
if "?" in final_sql:
|
1105
|
+
final_sql = final_sql.replace("?", f"'{escaped_param}'", 1)
|
1106
|
+
elif "%s" in final_sql:
|
1107
|
+
final_sql = final_sql.replace("%s", f"'{escaped_param}'", 1)
|
1108
|
+
elif param is None:
|
1109
|
+
# Handle both ? and %s placeholders
|
1110
|
+
if "?" in final_sql:
|
1111
|
+
final_sql = final_sql.replace("?", "NULL", 1)
|
1112
|
+
elif "%s" in final_sql:
|
1113
|
+
final_sql = final_sql.replace("%s", "NULL", 1)
|
1114
|
+
else:
|
1115
|
+
# Handle both ? and %s placeholders
|
1116
|
+
if "?" in final_sql:
|
1117
|
+
final_sql = final_sql.replace("?", str(param), 1)
|
1118
|
+
elif "%s" in final_sql:
|
1119
|
+
final_sql = final_sql.replace("%s", str(param), 1)
|
1120
|
+
|
1121
|
+
return final_sql
|
1122
|
+
|
1123
|
+
def get_sqlalchemy_engine(self) -> Engine:
|
1124
|
+
"""
|
1125
|
+
Get SQLAlchemy engine
|
1126
|
+
|
1127
|
+
Returns::
|
1128
|
+
|
1129
|
+
SQLAlchemy Engine
|
1130
|
+
"""
|
1131
|
+
if not self._engine:
|
1132
|
+
raise ConnectionError("Not connected to database")
|
1133
|
+
return self._engine
|
1134
|
+
|
1135
|
+
def query(self, *columns, snapshot: str = None):
|
1136
|
+
"""Get MatrixOne query builder - SQLAlchemy style
|
1137
|
+
|
1138
|
+
Args::
|
1139
|
+
|
1140
|
+
*columns: Can be:
|
1141
|
+
- Single model class: query(Article) - returns all columns from model
|
1142
|
+
- Multiple columns: query(Article.id, Article.title) - returns specific columns
|
1143
|
+
- Mixed: query(Article, Article.id, some_expression.label('alias')) - model + additional columns
|
1144
|
+
snapshot: Optional snapshot name for snapshot queries
|
1145
|
+
|
1146
|
+
Examples
|
1147
|
+
|
1148
|
+
# Traditional model query (all columns)
|
1149
|
+
client.query(Article).filter(...).all()
|
1150
|
+
|
1151
|
+
# Column-specific query
|
1152
|
+
client.query(Article.id, Article.title).filter(...).all()
|
1153
|
+
|
1154
|
+
# With fulltext score
|
1155
|
+
client.query(Article.id, boolean_match("title", "content").must("python").label("score"))
|
1156
|
+
|
1157
|
+
# Snapshot query
|
1158
|
+
client.query(Article, snapshot="my_snapshot").filter(...).all()
|
1159
|
+
|
1160
|
+
Returns::
|
1161
|
+
|
1162
|
+
MatrixOneQuery instance configured for the specified columns
|
1163
|
+
"""
|
1164
|
+
from .orm import MatrixOneQuery
|
1165
|
+
|
1166
|
+
if len(columns) == 1:
|
1167
|
+
# Traditional single model class usage
|
1168
|
+
column = columns[0]
|
1169
|
+
if isinstance(column, str):
|
1170
|
+
# String table name
|
1171
|
+
return MatrixOneQuery(column, self, snapshot=snapshot)
|
1172
|
+
elif hasattr(column, '__tablename__'):
|
1173
|
+
# This is a model class
|
1174
|
+
return MatrixOneQuery(column, self, snapshot=snapshot)
|
1175
|
+
elif hasattr(column, 'name') and hasattr(column, 'as_sql'):
|
1176
|
+
# This is a CTE object
|
1177
|
+
from .orm import CTE
|
1178
|
+
|
1179
|
+
if isinstance(column, CTE):
|
1180
|
+
query = MatrixOneQuery(None, self, snapshot=snapshot)
|
1181
|
+
query._table_name = column.name
|
1182
|
+
query._select_columns = ["*"] # Default to select all from CTE
|
1183
|
+
query._ctes = [column] # Add the CTE to the query
|
1184
|
+
return query
|
1185
|
+
else:
|
1186
|
+
# This is a single column/expression - need to handle specially
|
1187
|
+
# For now, we'll create a query that can handle column selections
|
1188
|
+
query = MatrixOneQuery(None, self, snapshot=snapshot)
|
1189
|
+
query._select_columns = [column]
|
1190
|
+
# Try to infer table name from column
|
1191
|
+
if hasattr(column, 'table') and hasattr(column.table, 'name'):
|
1192
|
+
query._table_name = column.table.name
|
1193
|
+
return query
|
1194
|
+
else:
|
1195
|
+
# Multiple columns/expressions
|
1196
|
+
model_class = None
|
1197
|
+
select_columns = []
|
1198
|
+
|
1199
|
+
for column in columns:
|
1200
|
+
if hasattr(column, '__tablename__'):
|
1201
|
+
# This is a model class - use its table
|
1202
|
+
model_class = column
|
1203
|
+
else:
|
1204
|
+
# This is a column or expression
|
1205
|
+
select_columns.append(column)
|
1206
|
+
|
1207
|
+
if model_class:
|
1208
|
+
query = MatrixOneQuery(model_class, self, snapshot=snapshot)
|
1209
|
+
if select_columns:
|
1210
|
+
# Add additional columns to the model's default columns
|
1211
|
+
query._select_columns = select_columns
|
1212
|
+
return query
|
1213
|
+
else:
|
1214
|
+
# No model class provided, need to infer table from columns
|
1215
|
+
query = MatrixOneQuery(None, self, snapshot=snapshot)
|
1216
|
+
query._select_columns = select_columns
|
1217
|
+
|
1218
|
+
# Try to infer table name from first column that has table info
|
1219
|
+
for col in select_columns:
|
1220
|
+
if hasattr(col, 'table') and hasattr(col.table, 'name'):
|
1221
|
+
query._table_name = col.table.name
|
1222
|
+
break
|
1223
|
+
elif isinstance(col, str) and '.' in col:
|
1224
|
+
# String column like "table.column" - extract table name
|
1225
|
+
parts = col.split('.')
|
1226
|
+
if len(parts) >= 2:
|
1227
|
+
# For "db.table.column" format, use "db.table"
|
1228
|
+
# For "table.column" format, use "table"
|
1229
|
+
table_name = '.'.join(parts[:-1])
|
1230
|
+
query._table_name = table_name
|
1231
|
+
break
|
1232
|
+
|
1233
|
+
return query
|
1234
|
+
|
1235
|
+
@contextmanager
|
1236
|
+
def snapshot(self, snapshot_name: str) -> Generator[SnapshotClient, None, None]:
|
1237
|
+
"""
|
1238
|
+
Snapshot context manager
|
1239
|
+
|
1240
|
+
Usage
|
1241
|
+
|
1242
|
+
with client.snapshot("daily_backup") as snapshot_client:
|
1243
|
+
result = snapshot_client.execute("SELECT * FROM users")
|
1244
|
+
"""
|
1245
|
+
if not self._engine:
|
1246
|
+
raise ConnectionError("Not connected to database")
|
1247
|
+
|
1248
|
+
# Create a snapshot client wrapper
|
1249
|
+
snapshot_client = SnapshotClient(self, snapshot_name)
|
1250
|
+
yield snapshot_client
|
1251
|
+
|
1252
|
+
@contextmanager
|
1253
|
+
def transaction(self) -> Generator[TransactionWrapper, None, None]:
|
1254
|
+
"""
|
1255
|
+
Transaction context manager
|
1256
|
+
|
1257
|
+
Usage
|
1258
|
+
|
1259
|
+
with client.transaction() as tx:
|
1260
|
+
# MatrixOne operations
|
1261
|
+
tx.execute("INSERT INTO users ...")
|
1262
|
+
tx.execute("UPDATE users ...")
|
1263
|
+
|
1264
|
+
# Snapshot and clone operations within transaction
|
1265
|
+
tx.snapshots.create("snap1", "table", database="db1", table="t1")
|
1266
|
+
tx.clone.clone_database("target_db", "source_db")
|
1267
|
+
|
1268
|
+
# SQLAlchemy operations in the same transaction
|
1269
|
+
session = tx.get_sqlalchemy_session()
|
1270
|
+
session.add(user)
|
1271
|
+
session.commit()
|
1272
|
+
"""
|
1273
|
+
if not self._engine:
|
1274
|
+
raise ConnectionError("Not connected to database")
|
1275
|
+
|
1276
|
+
tx_wrapper = None
|
1277
|
+
try:
|
1278
|
+
# Use engine's connection pool for transaction
|
1279
|
+
with self._engine.begin() as conn:
|
1280
|
+
tx_wrapper = TransactionWrapper(conn, self)
|
1281
|
+
yield tx_wrapper
|
1282
|
+
|
1283
|
+
# Commit SQLAlchemy session first
|
1284
|
+
tx_wrapper.commit_sqlalchemy()
|
1285
|
+
|
1286
|
+
except Exception as e:
|
1287
|
+
# Rollback SQLAlchemy session first
|
1288
|
+
if tx_wrapper:
|
1289
|
+
tx_wrapper.rollback_sqlalchemy()
|
1290
|
+
raise e
|
1291
|
+
finally:
|
1292
|
+
# Clean up SQLAlchemy resources
|
1293
|
+
if tx_wrapper:
|
1294
|
+
tx_wrapper.close_sqlalchemy()
|
1295
|
+
|
1296
|
+
@property
|
1297
|
+
def snapshots(self) -> Optional[SnapshotManager]:
|
1298
|
+
"""Get snapshot manager"""
|
1299
|
+
return self._snapshots
|
1300
|
+
|
1301
|
+
@property
|
1302
|
+
def clone(self) -> Optional[CloneManager]:
|
1303
|
+
"""Get clone manager"""
|
1304
|
+
return self._clone
|
1305
|
+
|
1306
|
+
@property
|
1307
|
+
def moctl(self) -> Optional[MoCtlManager]:
|
1308
|
+
"""Get mo_ctl manager"""
|
1309
|
+
return self._moctl
|
1310
|
+
|
1311
|
+
@property
|
1312
|
+
def restore(self) -> Optional[RestoreManager]:
|
1313
|
+
"""Get restore manager"""
|
1314
|
+
return self._restore
|
1315
|
+
|
1316
|
+
@property
|
1317
|
+
def pitr(self) -> Optional[PitrManager]:
|
1318
|
+
"""Get PITR manager"""
|
1319
|
+
return self._pitr
|
1320
|
+
|
1321
|
+
@property
|
1322
|
+
def pubsub(self) -> Optional[PubSubManager]:
|
1323
|
+
"""Get publish-subscribe manager"""
|
1324
|
+
return self._pubsub
|
1325
|
+
|
1326
|
+
@property
|
1327
|
+
def account(self) -> Optional[AccountManager]:
|
1328
|
+
"""Get account manager"""
|
1329
|
+
return self._account
|
1330
|
+
|
1331
|
+
@property
|
1332
|
+
def vector_ops(self) -> Optional["VectorManager"]:
|
1333
|
+
"""Get unified vector operations manager for vector operations (index and data)"""
|
1334
|
+
return self._vector
|
1335
|
+
|
1336
|
+
def get_pinecone_index(self, table_name_or_model, vector_column: str):
|
1337
|
+
"""
|
1338
|
+
Get a PineconeCompatibleIndex object for vector search operations.
|
1339
|
+
|
1340
|
+
This method creates a Pinecone-compatible vector search interface
|
1341
|
+
that automatically parses the table schema and vector index configuration.
|
1342
|
+
The primary key column is automatically detected, and all other columns
|
1343
|
+
except the vector column will be included as metadata.
|
1344
|
+
|
1345
|
+
Args::
|
1346
|
+
|
1347
|
+
table_name_or_model: Either a table name (str) or a SQLAlchemy model class
|
1348
|
+
vector_column: Name of the vector column
|
1349
|
+
|
1350
|
+
Returns::
|
1351
|
+
|
1352
|
+
PineconeCompatibleIndex object with Pinecone-compatible API
|
1353
|
+
|
1354
|
+
Example
|
1355
|
+
|
1356
|
+
>>> index = client.get_pinecone_index("documents", "embedding")
|
1357
|
+
>>> index = client.get_pinecone_index(DocumentModel, "embedding")
|
1358
|
+
>>> results = index.query([0.1, 0.2, 0.3], top_k=5)
|
1359
|
+
>>> for match in results.matches:
|
1360
|
+
... print(f"ID: {match.id}, Score: {match.score}")
|
1361
|
+
... print(f"Metadata: {match.metadata}")
|
1362
|
+
"""
|
1363
|
+
from .search_vector_index import PineconeCompatibleIndex
|
1364
|
+
|
1365
|
+
# Handle model class input
|
1366
|
+
if hasattr(table_name_or_model, '__tablename__'):
|
1367
|
+
table_name = table_name_or_model.__tablename__
|
1368
|
+
else:
|
1369
|
+
table_name = table_name_or_model
|
1370
|
+
|
1371
|
+
return PineconeCompatibleIndex(
|
1372
|
+
client=self,
|
1373
|
+
table_name=table_name,
|
1374
|
+
vector_column=vector_column,
|
1375
|
+
)
|
1376
|
+
|
1377
|
+
@property
|
1378
|
+
def fulltext_index(self) -> Optional["FulltextIndexManager"]:
|
1379
|
+
"""Get fulltext index manager for fulltext index operations"""
|
1380
|
+
return self._fulltext_index
|
1381
|
+
|
1382
|
+
@property
|
1383
|
+
def metadata(self) -> Optional["MetadataManager"]:
|
1384
|
+
"""Get metadata manager for table metadata operations"""
|
1385
|
+
return self._metadata
|
1386
|
+
|
1387
|
+
# @property
|
1388
|
+
# def vector_query(self) -> Optional["VectorQueryManager"]:
|
1389
|
+
# """Get vector query manager for vector query operations"""
|
1390
|
+
# return self._vector_query
|
1391
|
+
# Removed - functionality moved to vector_ops
|
1392
|
+
|
1393
|
+
def connected(self) -> bool:
|
1394
|
+
"""Check if client is connected to database"""
|
1395
|
+
return self._engine is not None
|
1396
|
+
|
1397
|
+
def version(self) -> str:
|
1398
|
+
"""
|
1399
|
+
Get MatrixOne server version
|
1400
|
+
|
1401
|
+
Returns::
|
1402
|
+
|
1403
|
+
str: MatrixOne server version string
|
1404
|
+
|
1405
|
+
Raises::
|
1406
|
+
|
1407
|
+
ConnectionError: If not connected to MatrixOne
|
1408
|
+
QueryError: If version query fails
|
1409
|
+
|
1410
|
+
Example
|
1411
|
+
|
1412
|
+
>>> client = Client('localhost', 6001, 'root', '111', 'test')
|
1413
|
+
>>> version = client.version()
|
1414
|
+
>>> print(f"MatrixOne version: {version}")
|
1415
|
+
"""
|
1416
|
+
if not self._engine:
|
1417
|
+
raise ConnectionError("Not connected to MatrixOne")
|
1418
|
+
|
1419
|
+
try:
|
1420
|
+
result = self.execute("SELECT VERSION()")
|
1421
|
+
if result.rows:
|
1422
|
+
return result.rows[0][0]
|
1423
|
+
else:
|
1424
|
+
raise QueryError("Failed to get version information")
|
1425
|
+
except Exception as e:
|
1426
|
+
raise QueryError(f"Failed to get version: {e}")
|
1427
|
+
|
1428
|
+
def git_version(self) -> str:
|
1429
|
+
"""
|
1430
|
+
Get MatrixOne git version information
|
1431
|
+
|
1432
|
+
Returns::
|
1433
|
+
|
1434
|
+
str: MatrixOne git version string
|
1435
|
+
|
1436
|
+
Raises::
|
1437
|
+
|
1438
|
+
ConnectionError: If not connected to MatrixOne
|
1439
|
+
QueryError: If git version query fails
|
1440
|
+
|
1441
|
+
Example
|
1442
|
+
|
1443
|
+
>>> client = Client('localhost', 6001, 'root', '111', 'test')
|
1444
|
+
>>> git_version = client.git_version()
|
1445
|
+
>>> print(f"MatrixOne git version: {git_version}")
|
1446
|
+
"""
|
1447
|
+
if not self._engine:
|
1448
|
+
raise ConnectionError("Not connected to MatrixOne")
|
1449
|
+
|
1450
|
+
try:
|
1451
|
+
# Use MatrixOne's built-in git_version() function
|
1452
|
+
result = self.execute("SELECT git_version()")
|
1453
|
+
if result.rows:
|
1454
|
+
return result.rows[0][0]
|
1455
|
+
else:
|
1456
|
+
raise QueryError("Failed to get git version information")
|
1457
|
+
except Exception as e:
|
1458
|
+
raise QueryError(f"Failed to get git version: {e}")
|
1459
|
+
|
1460
|
+
def _detect_backend_version(self) -> None:
|
1461
|
+
"""
|
1462
|
+
Detect backend version and set it in version manager
|
1463
|
+
|
1464
|
+
This method attempts to get the MatrixOne version from the backend
|
1465
|
+
and sets it in the version manager for compatibility checking.
|
1466
|
+
|
1467
|
+
Handles two version formats:
|
1468
|
+
1. "8.0.30-MatrixOne-v" (development version, highest priority)
|
1469
|
+
2. "8.0.30-MatrixOne-v3.0.0" (release version)
|
1470
|
+
"""
|
1471
|
+
try:
|
1472
|
+
# Try to get version using version() function
|
1473
|
+
result = self.execute("SELECT version()")
|
1474
|
+
if result.rows:
|
1475
|
+
version_string = result.rows[0][0]
|
1476
|
+
version = self._parse_matrixone_version(version_string)
|
1477
|
+
if version:
|
1478
|
+
self.set_backend_version(version)
|
1479
|
+
self.logger.info(f"Detected backend version: {version} (from: {version_string})")
|
1480
|
+
return
|
1481
|
+
|
1482
|
+
# Fallback: try git_version()
|
1483
|
+
result = self.execute("SELECT git_version()")
|
1484
|
+
if result.rows:
|
1485
|
+
git_version = result.rows[0][0]
|
1486
|
+
version = self._parse_matrixone_version(git_version)
|
1487
|
+
if version:
|
1488
|
+
self.set_backend_version(version)
|
1489
|
+
self.logger.info(f"Detected backend version from git: {version} (from: {git_version})")
|
1490
|
+
return
|
1491
|
+
|
1492
|
+
except Exception as e:
|
1493
|
+
self.logger.warning(f"Could not detect backend version: {e}")
|
1494
|
+
|
1495
|
+
def _parse_matrixone_version(self, version_string: str) -> Optional[str]:
|
1496
|
+
"""
|
1497
|
+
Parse MatrixOne version string to extract semantic version
|
1498
|
+
|
1499
|
+
Handles formats:
|
1500
|
+
1. "8.0.30-MatrixOne-v" -> "999.0.0" (development version, highest)
|
1501
|
+
2. "8.0.30-MatrixOne-v3.0.0" -> "3.0.0" (release version)
|
1502
|
+
3. "MatrixOne 3.0.1" -> "3.0.1" (fallback format)
|
1503
|
+
|
1504
|
+
Args::
|
1505
|
+
|
1506
|
+
version_string: Raw version string from MatrixOne
|
1507
|
+
|
1508
|
+
Returns::
|
1509
|
+
|
1510
|
+
Semantic version string or None if parsing fails
|
1511
|
+
"""
|
1512
|
+
import re
|
1513
|
+
|
1514
|
+
if not version_string:
|
1515
|
+
return None
|
1516
|
+
|
1517
|
+
# Pattern 1: Development version "8.0.30-MatrixOne-v" (v后面为空)
|
1518
|
+
dev_pattern = r"(\d+\.\d+\.\d+)-MatrixOne-v$"
|
1519
|
+
dev_match = re.search(dev_pattern, version_string.strip())
|
1520
|
+
if dev_match:
|
1521
|
+
# Development version - assign highest version number
|
1522
|
+
return "999.0.0"
|
1523
|
+
|
1524
|
+
# Pattern 2: Release version "8.0.30-MatrixOne-v3.0.0" (v后面有版本号)
|
1525
|
+
release_pattern = r"(\d+\.\d+\.\d+)-MatrixOne-v(\d+\.\d+\.\d+)"
|
1526
|
+
release_match = re.search(release_pattern, version_string.strip())
|
1527
|
+
if release_match:
|
1528
|
+
# Extract the semantic version part
|
1529
|
+
semantic_version = release_match.group(2)
|
1530
|
+
return semantic_version
|
1531
|
+
|
1532
|
+
# Pattern 3: Fallback format "MatrixOne 3.0.1"
|
1533
|
+
fallback_pattern = r"(\d+\.\d+\.\d+)"
|
1534
|
+
fallback_match = re.search(fallback_pattern, version_string)
|
1535
|
+
if fallback_match:
|
1536
|
+
return fallback_match.group(1)
|
1537
|
+
|
1538
|
+
self.logger.warning(f"Could not parse version string: {version_string}")
|
1539
|
+
return None
|
1540
|
+
|
1541
|
+
def set_backend_version(self, version: str) -> None:
|
1542
|
+
"""
|
1543
|
+
Manually set the backend version
|
1544
|
+
|
1545
|
+
Args::
|
1546
|
+
|
1547
|
+
version: Version string in format "major.minor.patch" (e.g., "3.0.1")
|
1548
|
+
"""
|
1549
|
+
self._version_manager.set_backend_version(version)
|
1550
|
+
self._backend_version = version
|
1551
|
+
self.logger.info(f"Backend version set to: {version}")
|
1552
|
+
|
1553
|
+
def get_backend_version(self) -> Optional[str]:
|
1554
|
+
"""
|
1555
|
+
Get current backend version
|
1556
|
+
|
1557
|
+
Returns::
|
1558
|
+
|
1559
|
+
Version string or None if not set
|
1560
|
+
"""
|
1561
|
+
backend_version = self._version_manager.get_backend_version()
|
1562
|
+
return str(backend_version) if backend_version else None
|
1563
|
+
|
1564
|
+
def is_feature_available(self, feature_name: str) -> bool:
|
1565
|
+
"""
|
1566
|
+
Check if a feature is available in current backend version
|
1567
|
+
|
1568
|
+
Args::
|
1569
|
+
|
1570
|
+
feature_name: Name of the feature to check
|
1571
|
+
|
1572
|
+
Returns::
|
1573
|
+
|
1574
|
+
True if feature is available, False otherwise
|
1575
|
+
"""
|
1576
|
+
return self._version_manager.is_feature_available(feature_name)
|
1577
|
+
|
1578
|
+
def get_feature_info(self, feature_name: str) -> Optional[Dict[str, Any]]:
|
1579
|
+
"""
|
1580
|
+
Get feature requirement information
|
1581
|
+
|
1582
|
+
Args::
|
1583
|
+
|
1584
|
+
feature_name: Name of the feature
|
1585
|
+
|
1586
|
+
Returns::
|
1587
|
+
|
1588
|
+
Feature information dictionary or None if not found
|
1589
|
+
"""
|
1590
|
+
requirement = self._version_manager.get_feature_info(feature_name)
|
1591
|
+
if requirement:
|
1592
|
+
return {
|
1593
|
+
"feature_name": requirement.feature_name,
|
1594
|
+
"min_version": str(requirement.min_version) if requirement.min_version else None,
|
1595
|
+
"max_version": str(requirement.max_version) if requirement.max_version else None,
|
1596
|
+
"description": requirement.description,
|
1597
|
+
"alternative": requirement.alternative,
|
1598
|
+
}
|
1599
|
+
return None
|
1600
|
+
|
1601
|
+
def check_version_compatibility(self, required_version: str, operator: str = ">=") -> bool:
|
1602
|
+
"""
|
1603
|
+
Check if current backend version is compatible with required version
|
1604
|
+
|
1605
|
+
Args::
|
1606
|
+
|
1607
|
+
required_version: Required version string (e.g., "3.0.1")
|
1608
|
+
operator: Comparison operator (">=", ">", "<=", "<", "==", "!=")
|
1609
|
+
|
1610
|
+
Returns::
|
1611
|
+
|
1612
|
+
True if compatible, False otherwise
|
1613
|
+
"""
|
1614
|
+
return self._version_manager.is_version_compatible(required_version, operator=operator)
|
1615
|
+
|
1616
|
+
def get_version_hint(self, feature_name: str, error_context: str = "") -> str:
|
1617
|
+
"""
|
1618
|
+
Get helpful hint message for version-related errors
|
1619
|
+
|
1620
|
+
Args::
|
1621
|
+
|
1622
|
+
feature_name: Name of the feature
|
1623
|
+
error_context: Additional context for the error
|
1624
|
+
|
1625
|
+
Returns::
|
1626
|
+
|
1627
|
+
Helpful hint message
|
1628
|
+
"""
|
1629
|
+
return self._version_manager.get_version_hint(feature_name, error_context)
|
1630
|
+
|
1631
|
+
def is_development_version(self) -> bool:
|
1632
|
+
"""
|
1633
|
+
Check if current backend is a development version
|
1634
|
+
|
1635
|
+
Returns::
|
1636
|
+
|
1637
|
+
True if backend is development version (999.x.x), False otherwise
|
1638
|
+
"""
|
1639
|
+
return self._version_manager.is_development_version()
|
1640
|
+
|
1641
|
+
def __enter__(self):
|
1642
|
+
return self
|
1643
|
+
|
1644
|
+
def create_table(self, table_name_or_model, columns: dict = None, **kwargs) -> "Client":
|
1645
|
+
"""
|
1646
|
+
Create a table with a simplified interface.
|
1647
|
+
|
1648
|
+
Args::
|
1649
|
+
|
1650
|
+
table_name_or_model: Either a table name (str) or a SQLAlchemy model class
|
1651
|
+
columns: Dictionary mapping column names to their types (required if table_name_or_model is str)
|
1652
|
+
Supported formats:
|
1653
|
+
- 'id': 'bigint' (with primary_key=True if needed)
|
1654
|
+
- 'name': 'varchar(100)'
|
1655
|
+
- 'embedding': 'vecf32(128)' or 'vecf64(128)'
|
1656
|
+
- 'score': 'float'
|
1657
|
+
- 'created_at': 'datetime'
|
1658
|
+
- 'is_active': 'boolean'
|
1659
|
+
**kwargs: Additional table parameters
|
1660
|
+
|
1661
|
+
Returns::
|
1662
|
+
|
1663
|
+
Client: Self for chaining
|
1664
|
+
|
1665
|
+
Example
|
1666
|
+
|
1667
|
+
client.create_table("users", {
|
1668
|
+
'id': 'bigint',
|
1669
|
+
'name': 'varchar(100)',
|
1670
|
+
'email': 'varchar(255)',
|
1671
|
+
'embedding': 'vecf32(128)',
|
1672
|
+
'score': 'float',
|
1673
|
+
'created_at': 'datetime',
|
1674
|
+
'is_active': 'boolean'
|
1675
|
+
}, primary_key='id')
|
1676
|
+
"""
|
1677
|
+
from .sqlalchemy_ext import VectorTableBuilder
|
1678
|
+
|
1679
|
+
# Handle model class input
|
1680
|
+
if hasattr(table_name_or_model, '__tablename__'):
|
1681
|
+
# It's a model class
|
1682
|
+
model_class = table_name_or_model
|
1683
|
+
table_name = model_class.__tablename__
|
1684
|
+
table = model_class.__table__
|
1685
|
+
|
1686
|
+
from .sqlalchemy_ext import FulltextIndex, VectorIndex
|
1687
|
+
from sqlalchemy.schema import CreateTable, CreateIndex
|
1688
|
+
|
1689
|
+
with self.get_sqlalchemy_engine().begin() as conn:
|
1690
|
+
# Create table without indexes first
|
1691
|
+
# Build CREATE TABLE statement without indexes
|
1692
|
+
create_table_sql = str(CreateTable(table).compile(dialect=conn.dialect))
|
1693
|
+
|
1694
|
+
# Helper to execute SQL with fallback for testing
|
1695
|
+
def _exec_sql(sql_str):
|
1696
|
+
if hasattr(conn, 'exec_driver_sql'):
|
1697
|
+
return conn.exec_driver_sql(sql_str)
|
1698
|
+
else:
|
1699
|
+
from sqlalchemy import text
|
1700
|
+
|
1701
|
+
return conn.execute(text(sql_str))
|
1702
|
+
|
1703
|
+
# Execute CREATE TABLE with better error handling
|
1704
|
+
try:
|
1705
|
+
_exec_sql(create_table_sql)
|
1706
|
+
self.logger.info(f"✓ Created table '{table_name}'")
|
1707
|
+
except Exception as e:
|
1708
|
+
# Log the error before processing
|
1709
|
+
try:
|
1710
|
+
self.logger.log_error(e, context=f"Creating table '{table_name}'")
|
1711
|
+
except Exception as log_err:
|
1712
|
+
import sys
|
1713
|
+
|
1714
|
+
print(f"Warning: Error logging failed: {log_err}", file=sys.stderr)
|
1715
|
+
|
1716
|
+
# Extract user-friendly error message
|
1717
|
+
error_msg = str(e)
|
1718
|
+
|
1719
|
+
# Handle common errors with helpful messages
|
1720
|
+
if 'already exists' in error_msg.lower() or '1050' in error_msg:
|
1721
|
+
raise QueryError(
|
1722
|
+
f"Table '{table_name}' already exists. "
|
1723
|
+
f"Use client.drop_table({table_name_or_model.__name__}) to remove it first, "
|
1724
|
+
f"or check if you need to reuse the existing table."
|
1725
|
+
) from None
|
1726
|
+
elif 'duplicate' in error_msg.lower():
|
1727
|
+
raise QueryError(
|
1728
|
+
f"Duplicate key or index found when creating table '{table_name}'. "
|
1729
|
+
f"Check your table definition for duplicate column or index names."
|
1730
|
+
) from None
|
1731
|
+
elif 'syntax error' in error_msg.lower():
|
1732
|
+
raise QueryError(f"SQL syntax error when creating table '{table_name}': {error_msg}") from None
|
1733
|
+
else:
|
1734
|
+
# Generic error with cleaner message
|
1735
|
+
raise QueryError(f"Failed to create table '{table_name}': {error_msg}") from None
|
1736
|
+
|
1737
|
+
# Create indexes separately with proper types
|
1738
|
+
for index in table.indexes:
|
1739
|
+
if isinstance(index, FulltextIndex):
|
1740
|
+
# Create fulltext index using custom DDL
|
1741
|
+
columns_str = ", ".join(col.name for col in index.columns)
|
1742
|
+
fulltext_sql = f"CREATE FULLTEXT INDEX {index.name} ON {table_name} ({columns_str})"
|
1743
|
+
if hasattr(index, 'parser') and index.parser:
|
1744
|
+
fulltext_sql += f" WITH PARSER {index.parser}"
|
1745
|
+
try:
|
1746
|
+
_exec_sql(fulltext_sql)
|
1747
|
+
except Exception as e:
|
1748
|
+
self.logger.warning(f"Failed to create fulltext index {index.name}: {e}")
|
1749
|
+
elif isinstance(index, VectorIndex):
|
1750
|
+
# Create vector index using custom method
|
1751
|
+
try:
|
1752
|
+
vector_sql = str(CreateIndex(index).compile(dialect=conn.dialect))
|
1753
|
+
_exec_sql(vector_sql)
|
1754
|
+
except Exception as e:
|
1755
|
+
self.logger.warning(f"Failed to create vector index {index.name}: {e}")
|
1756
|
+
else:
|
1757
|
+
# Create regular index
|
1758
|
+
try:
|
1759
|
+
index_sql = str(CreateIndex(index).compile(dialect=conn.dialect))
|
1760
|
+
_exec_sql(index_sql)
|
1761
|
+
except Exception as e:
|
1762
|
+
self.logger.warning(f"Failed to create index {index.name}: {e}")
|
1763
|
+
|
1764
|
+
return self
|
1765
|
+
|
1766
|
+
# It's a table name string
|
1767
|
+
table_name = table_name_or_model
|
1768
|
+
if columns is None:
|
1769
|
+
raise ValueError("columns parameter is required when table_name_or_model is a string")
|
1770
|
+
|
1771
|
+
# Parse primary key from kwargs
|
1772
|
+
primary_key = kwargs.get("primary_key", None)
|
1773
|
+
|
1774
|
+
# Create table using VectorTableBuilder
|
1775
|
+
builder = VectorTableBuilder(table_name)
|
1776
|
+
|
1777
|
+
# Add columns based on simplified format
|
1778
|
+
for column_name, column_def in columns.items():
|
1779
|
+
is_primary = primary_key == column_name
|
1780
|
+
|
1781
|
+
if column_def.lower().startswith("vecf32(") or column_def.lower().startswith("vecf64("):
|
1782
|
+
# Parse vecf32(64) or vecf64(64) format (case insensitive)
|
1783
|
+
import re
|
1784
|
+
|
1785
|
+
match = re.match(r"vecf(\d+)\((\d+)\)", column_def.lower())
|
1786
|
+
if match:
|
1787
|
+
precision = f"f{match.group(1)}"
|
1788
|
+
dimension = int(match.group(2))
|
1789
|
+
builder.add_vector_column(column_name, dimension, precision)
|
1790
|
+
else:
|
1791
|
+
raise ValueError(f"Invalid vecf format: {column_def}")
|
1792
|
+
|
1793
|
+
elif column_def.startswith("varchar("):
|
1794
|
+
# Parse varchar type: varchar(100)
|
1795
|
+
import re
|
1796
|
+
|
1797
|
+
match = re.match(r"varchar\((\d+)\)", column_def)
|
1798
|
+
if match:
|
1799
|
+
length = int(match.group(1))
|
1800
|
+
builder.add_string_column(column_name, length)
|
1801
|
+
else:
|
1802
|
+
raise ValueError(f"Invalid varchar format: {column_def}")
|
1803
|
+
|
1804
|
+
elif column_def.startswith("char("):
|
1805
|
+
# Parse char type: char(10)
|
1806
|
+
import re
|
1807
|
+
|
1808
|
+
match = re.match(r"char\((\d+)\)", column_def)
|
1809
|
+
if match:
|
1810
|
+
length = int(match.group(1))
|
1811
|
+
builder.add_string_column(column_name, length)
|
1812
|
+
else:
|
1813
|
+
raise ValueError(f"Invalid char format: {column_def}")
|
1814
|
+
|
1815
|
+
elif column_def.startswith("decimal("):
|
1816
|
+
# Parse decimal type: decimal(10,2)
|
1817
|
+
import re
|
1818
|
+
|
1819
|
+
match = re.match(r"decimal\((\d+),(\d+)\)", column_def)
|
1820
|
+
if match:
|
1821
|
+
precision = int(match.group(1))
|
1822
|
+
scale = int(match.group(2))
|
1823
|
+
builder.add_numeric_column(column_name, "decimal", precision, scale)
|
1824
|
+
else:
|
1825
|
+
raise ValueError(f"Invalid decimal format: {column_def}")
|
1826
|
+
|
1827
|
+
elif column_def.startswith("float("):
|
1828
|
+
# Parse float type: float(10)
|
1829
|
+
import re
|
1830
|
+
|
1831
|
+
match = re.match(r"float\((\d+)\)", column_def)
|
1832
|
+
if match:
|
1833
|
+
precision = int(match.group(1))
|
1834
|
+
builder.add_numeric_column(column_name, "float", precision)
|
1835
|
+
else:
|
1836
|
+
raise ValueError(f"Invalid float format: {column_def}")
|
1837
|
+
|
1838
|
+
elif column_def in ("int", "integer"):
|
1839
|
+
builder.add_int_column(column_name, primary_key=is_primary)
|
1840
|
+
elif column_def in ("bigint", "bigint unsigned"):
|
1841
|
+
builder.add_bigint_column(column_name, primary_key=is_primary)
|
1842
|
+
elif column_def in ("smallint", "tinyint"):
|
1843
|
+
if column_def == "smallint":
|
1844
|
+
builder.add_smallint_column(column_name, primary_key=is_primary)
|
1845
|
+
else:
|
1846
|
+
builder.add_tinyint_column(column_name, primary_key=is_primary)
|
1847
|
+
elif column_def in ("text", "longtext", "mediumtext", "tinytext"):
|
1848
|
+
builder.add_text_column(column_name)
|
1849
|
+
elif column_def in ("float", "double"):
|
1850
|
+
builder.add_numeric_column(column_name, column_def)
|
1851
|
+
elif column_def in ("date", "datetime", "timestamp", "time"):
|
1852
|
+
builder.add_datetime_column(column_name, column_def)
|
1853
|
+
elif column_def in ("boolean", "bool"):
|
1854
|
+
builder.add_boolean_column(column_name)
|
1855
|
+
elif column_def in ("json", "jsonb"):
|
1856
|
+
builder.add_json_column(column_name)
|
1857
|
+
elif column_def in (
|
1858
|
+
"blob",
|
1859
|
+
"longblob",
|
1860
|
+
"mediumblob",
|
1861
|
+
"tinyblob",
|
1862
|
+
"binary",
|
1863
|
+
"varbinary",
|
1864
|
+
):
|
1865
|
+
builder.add_binary_column(column_name, column_def)
|
1866
|
+
else:
|
1867
|
+
raise ValueError(
|
1868
|
+
f"Unsupported column type '{column_def}' for column '{column_name}'. "
|
1869
|
+
f"Supported types: int, bigint, smallint, tinyint, varchar(n), char(n), "
|
1870
|
+
f"text, float, double, decimal(p,s), date, datetime, timestamp, time, "
|
1871
|
+
f"boolean, json, blob, vecf32(n), vecf64(n)"
|
1872
|
+
)
|
1873
|
+
|
1874
|
+
# Create table
|
1875
|
+
table = builder.build()
|
1876
|
+
table.create(self.get_sqlalchemy_engine())
|
1877
|
+
|
1878
|
+
return self
|
1879
|
+
|
1880
|
+
def create_table_in_transaction(self, table_name: str, columns: dict, connection, **kwargs) -> "Client":
|
1881
|
+
"""
|
1882
|
+
Create a table with a simplified interface within an existing SQLAlchemy transaction.
|
1883
|
+
|
1884
|
+
Args::
|
1885
|
+
|
1886
|
+
table_name: Name of the table
|
1887
|
+
columns: Dictionary mapping column names to their types (same format as create_table)
|
1888
|
+
connection: SQLAlchemy connection object (required for transaction support)
|
1889
|
+
**kwargs: Additional table parameters
|
1890
|
+
|
1891
|
+
Returns::
|
1892
|
+
|
1893
|
+
Client: Self for chaining
|
1894
|
+
"""
|
1895
|
+
if connection is None:
|
1896
|
+
raise ValueError("connection parameter is required for transaction operations")
|
1897
|
+
|
1898
|
+
from sqlalchemy.schema import CreateTable
|
1899
|
+
|
1900
|
+
from .sqlalchemy_ext import VectorTableBuilder
|
1901
|
+
|
1902
|
+
# Parse primary key from kwargs
|
1903
|
+
primary_key = kwargs.get("primary_key", None)
|
1904
|
+
|
1905
|
+
# Create table using VectorTableBuilder
|
1906
|
+
builder = VectorTableBuilder(table_name)
|
1907
|
+
|
1908
|
+
# Add columns based on simplified format (same logic as create_table)
|
1909
|
+
for column_name, column_def in columns.items():
|
1910
|
+
is_primary = primary_key == column_name
|
1911
|
+
|
1912
|
+
if column_def.startswith("vector("):
|
1913
|
+
# Parse vector type: vector(128,f32) or vector(128)
|
1914
|
+
import re
|
1915
|
+
|
1916
|
+
match = re.match(r"vector\((\d+)(?:,(\w+))?\)", column_def)
|
1917
|
+
if match:
|
1918
|
+
dimension = int(match.group(1))
|
1919
|
+
precision = match.group(2) or "f32"
|
1920
|
+
builder.add_vector_column(column_name, dimension, precision)
|
1921
|
+
else:
|
1922
|
+
raise ValueError(f"Invalid vector format: {column_def}")
|
1923
|
+
|
1924
|
+
elif column_def.startswith("varchar("):
|
1925
|
+
# Parse varchar type: varchar(100)
|
1926
|
+
import re
|
1927
|
+
|
1928
|
+
match = re.match(r"varchar\((\d+)\)", column_def)
|
1929
|
+
if match:
|
1930
|
+
length = int(match.group(1))
|
1931
|
+
builder.add_string_column(column_name, length)
|
1932
|
+
else:
|
1933
|
+
raise ValueError(f"Invalid varchar format: {column_def}")
|
1934
|
+
|
1935
|
+
elif column_def.startswith("char("):
|
1936
|
+
# Parse char type: char(10)
|
1937
|
+
import re
|
1938
|
+
|
1939
|
+
match = re.match(r"char\((\d+)\)", column_def)
|
1940
|
+
if match:
|
1941
|
+
length = int(match.group(1))
|
1942
|
+
builder.add_string_column(column_name, length)
|
1943
|
+
else:
|
1944
|
+
raise ValueError(f"Invalid char format: {column_def}")
|
1945
|
+
|
1946
|
+
elif column_def.startswith("decimal("):
|
1947
|
+
# Parse decimal type: decimal(10,2)
|
1948
|
+
import re
|
1949
|
+
|
1950
|
+
match = re.match(r"decimal\((\d+),(\d+)\)", column_def)
|
1951
|
+
if match:
|
1952
|
+
precision = int(match.group(1))
|
1953
|
+
scale = int(match.group(2))
|
1954
|
+
builder.add_numeric_column(column_name, "decimal", precision, scale)
|
1955
|
+
else:
|
1956
|
+
raise ValueError(f"Invalid decimal format: {column_def}")
|
1957
|
+
|
1958
|
+
elif column_def.startswith("float("):
|
1959
|
+
# Parse float type: float(10)
|
1960
|
+
import re
|
1961
|
+
|
1962
|
+
match = re.match(r"float\((\d+)\)", column_def)
|
1963
|
+
if match:
|
1964
|
+
precision = int(match.group(1))
|
1965
|
+
builder.add_numeric_column(column_name, "float", precision)
|
1966
|
+
else:
|
1967
|
+
raise ValueError(f"Invalid float format: {column_def}")
|
1968
|
+
|
1969
|
+
elif column_def in ("int", "integer"):
|
1970
|
+
builder.add_int_column(column_name, primary_key=is_primary)
|
1971
|
+
elif column_def in ("bigint", "bigint unsigned"):
|
1972
|
+
builder.add_bigint_column(column_name, primary_key=is_primary)
|
1973
|
+
elif column_def in ("smallint", "tinyint"):
|
1974
|
+
if column_def == "smallint":
|
1975
|
+
builder.add_smallint_column(column_name, primary_key=is_primary)
|
1976
|
+
else:
|
1977
|
+
builder.add_tinyint_column(column_name, primary_key=is_primary)
|
1978
|
+
elif column_def in ("text", "longtext", "mediumtext", "tinytext"):
|
1979
|
+
builder.add_text_column(column_name)
|
1980
|
+
elif column_def in ("float", "double"):
|
1981
|
+
builder.add_numeric_column(column_name, column_def)
|
1982
|
+
elif column_def in ("date", "datetime", "timestamp", "time"):
|
1983
|
+
builder.add_datetime_column(column_name, column_def)
|
1984
|
+
elif column_def in ("boolean", "bool"):
|
1985
|
+
builder.add_boolean_column(column_name)
|
1986
|
+
elif column_def in ("json", "jsonb"):
|
1987
|
+
builder.add_json_column(column_name)
|
1988
|
+
elif column_def in (
|
1989
|
+
"blob",
|
1990
|
+
"longblob",
|
1991
|
+
"mediumblob",
|
1992
|
+
"tinyblob",
|
1993
|
+
"binary",
|
1994
|
+
"varbinary",
|
1995
|
+
):
|
1996
|
+
builder.add_binary_column(column_name, column_def)
|
1997
|
+
else:
|
1998
|
+
raise ValueError(
|
1999
|
+
f"Unsupported column type '{column_def}' for column '{column_name}'. "
|
2000
|
+
f"Supported types: int, bigint, smallint, tinyint, varchar(n), char(n), "
|
2001
|
+
f"text, float, double, decimal(p,s), date, datetime, timestamp, time, "
|
2002
|
+
f"boolean, json, blob, vecf32(n), vecf64(n)"
|
2003
|
+
)
|
2004
|
+
|
2005
|
+
# Create table using the provided connection
|
2006
|
+
table = builder.build()
|
2007
|
+
create_sql = CreateTable(table)
|
2008
|
+
sql = str(create_sql.compile(dialect=connection.dialect))
|
2009
|
+
connection.execute(sql)
|
2010
|
+
|
2011
|
+
return self
|
2012
|
+
|
2013
|
+
def drop_table(self, table_name_or_model) -> "Client":
|
2014
|
+
"""
|
2015
|
+
Drop a table.
|
2016
|
+
|
2017
|
+
Args::
|
2018
|
+
|
2019
|
+
table_name_or_model: Either a table name (str) or a SQLAlchemy model class
|
2020
|
+
|
2021
|
+
Returns::
|
2022
|
+
|
2023
|
+
Client: Self for chaining
|
2024
|
+
|
2025
|
+
Example
|
2026
|
+
|
2027
|
+
# Drop table by name
|
2028
|
+
client.drop_table("users")
|
2029
|
+
|
2030
|
+
# Drop table by model class
|
2031
|
+
client.drop_table(UserModel)
|
2032
|
+
"""
|
2033
|
+
# Handle model class input
|
2034
|
+
if hasattr(table_name_or_model, '__tablename__'):
|
2035
|
+
# It's a model class
|
2036
|
+
table_name = table_name_or_model.__tablename__
|
2037
|
+
else:
|
2038
|
+
# It's a table name string
|
2039
|
+
table_name = table_name_or_model
|
2040
|
+
|
2041
|
+
with self.get_sqlalchemy_engine().begin() as conn:
|
2042
|
+
sql = f"DROP TABLE IF EXISTS {table_name}"
|
2043
|
+
if hasattr(conn, 'exec_driver_sql'):
|
2044
|
+
conn.exec_driver_sql(sql)
|
2045
|
+
else:
|
2046
|
+
from sqlalchemy import text
|
2047
|
+
|
2048
|
+
conn.execute(text(sql))
|
2049
|
+
|
2050
|
+
return self
|
2051
|
+
|
2052
|
+
def drop_table_in_transaction(self, table_name: str, connection) -> "Client":
|
2053
|
+
"""
|
2054
|
+
Drop a table within an existing SQLAlchemy transaction.
|
2055
|
+
|
2056
|
+
Args::
|
2057
|
+
|
2058
|
+
table_name: Name of the table to drop
|
2059
|
+
connection: SQLAlchemy connection object (required for transaction support)
|
2060
|
+
|
2061
|
+
Returns::
|
2062
|
+
|
2063
|
+
Client: Self for chaining
|
2064
|
+
"""
|
2065
|
+
if connection is None:
|
2066
|
+
raise ValueError("connection parameter is required for transaction operations")
|
2067
|
+
|
2068
|
+
sql = f"DROP TABLE IF EXISTS {table_name}"
|
2069
|
+
if hasattr(connection, 'exec_driver_sql'):
|
2070
|
+
connection.exec_driver_sql(sql)
|
2071
|
+
else:
|
2072
|
+
from sqlalchemy import text
|
2073
|
+
|
2074
|
+
connection.execute(text(sql))
|
2075
|
+
|
2076
|
+
return self
|
2077
|
+
|
2078
|
+
def create_table_with_index(self, table_name: str, columns: dict, indexes: list = None, **kwargs) -> "Client":
|
2079
|
+
"""
|
2080
|
+
Create a table with vector indexes using a simplified interface.
|
2081
|
+
|
2082
|
+
Args::
|
2083
|
+
|
2084
|
+
table_name: Name of the table
|
2085
|
+
columns: Dictionary mapping column names to their types (same format as create_table)
|
2086
|
+
indexes: List of index definitions, each containing:
|
2087
|
+
- 'name': Index name
|
2088
|
+
- 'column': Column name to index
|
2089
|
+
- 'type': Index type ('ivfflat' or 'hnsw')
|
2090
|
+
- 'params': Dictionary of index-specific parameters
|
2091
|
+
**kwargs: Additional table parameters
|
2092
|
+
|
2093
|
+
Returns::
|
2094
|
+
|
2095
|
+
Client: Self for chaining
|
2096
|
+
|
2097
|
+
Example
|
2098
|
+
|
2099
|
+
client.create_table_with_index("vector_docs", {
|
2100
|
+
'id': 'bigint',
|
2101
|
+
'title': 'varchar(200)',
|
2102
|
+
'embedding': 'vector(128,f32)'
|
2103
|
+
}, indexes=[
|
2104
|
+
{
|
2105
|
+
'name': 'idx_hnsw',
|
2106
|
+
'column': 'embedding',
|
2107
|
+
'type': 'hnsw',
|
2108
|
+
'params': {'m': 48, 'ef_construction': 64, 'ef_search': 64}
|
2109
|
+
}
|
2110
|
+
], primary_key='id')
|
2111
|
+
"""
|
2112
|
+
|
2113
|
+
from .sqlalchemy_ext import VectorTableBuilder
|
2114
|
+
|
2115
|
+
# Parse primary key from kwargs
|
2116
|
+
primary_key = kwargs.get("primary_key", None)
|
2117
|
+
|
2118
|
+
# Create table using VectorTableBuilder
|
2119
|
+
builder = VectorTableBuilder(table_name)
|
2120
|
+
|
2121
|
+
# Add columns based on simplified format (same logic as create_table)
|
2122
|
+
for column_name, column_def in columns.items():
|
2123
|
+
is_primary = primary_key == column_name
|
2124
|
+
|
2125
|
+
if column_def.startswith("vector("):
|
2126
|
+
# Parse vector type: vector(128,f32) or vector(128)
|
2127
|
+
import re
|
2128
|
+
|
2129
|
+
match = re.match(r"vector\((\d+)(?:,(\w+))?\)", column_def)
|
2130
|
+
if match:
|
2131
|
+
dimension = int(match.group(1))
|
2132
|
+
precision = match.group(2) or "f32"
|
2133
|
+
builder.add_vector_column(column_name, dimension, precision)
|
2134
|
+
else:
|
2135
|
+
raise ValueError(f"Invalid vector format: {column_def}")
|
2136
|
+
|
2137
|
+
elif column_def.startswith("varchar("):
|
2138
|
+
# Parse varchar type: varchar(100)
|
2139
|
+
import re
|
2140
|
+
|
2141
|
+
match = re.match(r"varchar\((\d+)\)", column_def)
|
2142
|
+
if match:
|
2143
|
+
length = int(match.group(1))
|
2144
|
+
builder.add_string_column(column_name, length)
|
2145
|
+
else:
|
2146
|
+
raise ValueError(f"Invalid varchar format: {column_def}")
|
2147
|
+
|
2148
|
+
elif column_def.startswith("char("):
|
2149
|
+
# Parse char type: char(10)
|
2150
|
+
import re
|
2151
|
+
|
2152
|
+
match = re.match(r"char\((\d+)\)", column_def)
|
2153
|
+
if match:
|
2154
|
+
length = int(match.group(1))
|
2155
|
+
builder.add_string_column(column_name, length)
|
2156
|
+
else:
|
2157
|
+
raise ValueError(f"Invalid char format: {column_def}")
|
2158
|
+
|
2159
|
+
elif column_def.startswith("decimal("):
|
2160
|
+
# Parse decimal type: decimal(10,2)
|
2161
|
+
import re
|
2162
|
+
|
2163
|
+
match = re.match(r"decimal\((\d+),(\d+)\)", column_def)
|
2164
|
+
if match:
|
2165
|
+
precision = int(match.group(1))
|
2166
|
+
scale = int(match.group(2))
|
2167
|
+
builder.add_numeric_column(column_name, "decimal", precision, scale)
|
2168
|
+
else:
|
2169
|
+
raise ValueError(f"Invalid decimal format: {column_def}")
|
2170
|
+
|
2171
|
+
elif column_def.startswith("float("):
|
2172
|
+
# Parse float type: float(10)
|
2173
|
+
import re
|
2174
|
+
|
2175
|
+
match = re.match(r"float\((\d+)\)", column_def)
|
2176
|
+
if match:
|
2177
|
+
precision = int(match.group(1))
|
2178
|
+
builder.add_numeric_column(column_name, "float", precision)
|
2179
|
+
else:
|
2180
|
+
raise ValueError(f"Invalid float format: {column_def}")
|
2181
|
+
|
2182
|
+
elif column_def in ("int", "integer"):
|
2183
|
+
builder.add_int_column(column_name, primary_key=is_primary)
|
2184
|
+
elif column_def in ("bigint", "bigint unsigned"):
|
2185
|
+
builder.add_bigint_column(column_name, primary_key=is_primary)
|
2186
|
+
elif column_def in ("smallint", "tinyint"):
|
2187
|
+
if column_def == "smallint":
|
2188
|
+
builder.add_smallint_column(column_name, primary_key=is_primary)
|
2189
|
+
else:
|
2190
|
+
builder.add_tinyint_column(column_name, primary_key=is_primary)
|
2191
|
+
elif column_def in ("text", "longtext", "mediumtext", "tinytext"):
|
2192
|
+
builder.add_text_column(column_name)
|
2193
|
+
elif column_def in ("float", "double"):
|
2194
|
+
builder.add_numeric_column(column_name, column_def)
|
2195
|
+
elif column_def in ("date", "datetime", "timestamp", "time"):
|
2196
|
+
builder.add_datetime_column(column_name, column_def)
|
2197
|
+
elif column_def in ("boolean", "bool"):
|
2198
|
+
builder.add_boolean_column(column_name)
|
2199
|
+
elif column_def in ("json", "jsonb"):
|
2200
|
+
builder.add_json_column(column_name)
|
2201
|
+
elif column_def in (
|
2202
|
+
"blob",
|
2203
|
+
"longblob",
|
2204
|
+
"mediumblob",
|
2205
|
+
"tinyblob",
|
2206
|
+
"binary",
|
2207
|
+
"varbinary",
|
2208
|
+
):
|
2209
|
+
builder.add_binary_column(column_name, column_def)
|
2210
|
+
else:
|
2211
|
+
raise ValueError(
|
2212
|
+
f"Unsupported column type '{column_def}' for column '{column_name}'. "
|
2213
|
+
f"Supported types: int, bigint, smallint, tinyint, varchar(n), char(n), "
|
2214
|
+
f"text, float, double, decimal(p,s), date, datetime, timestamp, time, "
|
2215
|
+
f"boolean, json, blob, vecf32(n), vecf64(n)"
|
2216
|
+
)
|
2217
|
+
|
2218
|
+
# Create table
|
2219
|
+
table = builder.build()
|
2220
|
+
table.create(self.get_sqlalchemy_engine())
|
2221
|
+
|
2222
|
+
# Create indexes if specified
|
2223
|
+
if indexes:
|
2224
|
+
for index_def in indexes:
|
2225
|
+
index_name = index_def["name"]
|
2226
|
+
column_name = index_def["column"]
|
2227
|
+
index_type = index_def["type"]
|
2228
|
+
params = index_def.get("params", {})
|
2229
|
+
|
2230
|
+
# Convert index type and enable indexing
|
2231
|
+
if index_type == "hnsw":
|
2232
|
+
# Enable HNSW indexing
|
2233
|
+
self.vector_index.enable_hnsw()
|
2234
|
+
elif index_type == "ivfflat":
|
2235
|
+
# Enable IVF indexing
|
2236
|
+
self.vector_index.enable_ivf()
|
2237
|
+
else:
|
2238
|
+
raise ValueError(f"Unsupported index type: {index_type}")
|
2239
|
+
|
2240
|
+
# Create the index using separated APIs
|
2241
|
+
if index_type == "ivfflat":
|
2242
|
+
self.vector_index.create_ivf(table_name=table_name, name=index_name, column=column_name, **params)
|
2243
|
+
elif index_type == "hnsw":
|
2244
|
+
self.vector_index.create_hnsw(table_name=table_name, name=index_name, column=column_name, **params)
|
2245
|
+
else:
|
2246
|
+
raise ValueError(f"Unsupported index type: {index_type}")
|
2247
|
+
|
2248
|
+
return self
|
2249
|
+
|
2250
|
+
def create_table_with_index_in_transaction(
|
2251
|
+
self, table_name: str, columns: dict, connection, indexes: list = None, **kwargs
|
2252
|
+
) -> "Client":
|
2253
|
+
"""
|
2254
|
+
Create a table with vector indexes within an existing SQLAlchemy transaction.
|
2255
|
+
|
2256
|
+
Args::
|
2257
|
+
|
2258
|
+
table_name: Name of the table
|
2259
|
+
columns: Dictionary mapping column names to their types (same format as create_table)
|
2260
|
+
connection: SQLAlchemy connection object (required for transaction support)
|
2261
|
+
indexes: List of index definitions (same format as create_table_with_index)
|
2262
|
+
**kwargs: Additional table parameters
|
2263
|
+
|
2264
|
+
Returns::
|
2265
|
+
|
2266
|
+
Client: Self for chaining
|
2267
|
+
"""
|
2268
|
+
if connection is None:
|
2269
|
+
raise ValueError("connection parameter is required for transaction operations")
|
2270
|
+
|
2271
|
+
from sqlalchemy.schema import CreateTable
|
2272
|
+
|
2273
|
+
from .sqlalchemy_ext import VectorTableBuilder
|
2274
|
+
|
2275
|
+
# Parse primary key from kwargs
|
2276
|
+
primary_key = kwargs.get("primary_key", None)
|
2277
|
+
|
2278
|
+
# Create table using VectorTableBuilder
|
2279
|
+
builder = VectorTableBuilder(table_name)
|
2280
|
+
|
2281
|
+
# Add columns based on simplified format (same logic as create_table)
|
2282
|
+
for column_name, column_def in columns.items():
|
2283
|
+
is_primary = primary_key == column_name
|
2284
|
+
|
2285
|
+
if column_def.startswith("vector("):
|
2286
|
+
# Parse vector type: vector(128,f32) or vector(128)
|
2287
|
+
import re
|
2288
|
+
|
2289
|
+
match = re.match(r"vector\((\d+)(?:,(\w+))?\)", column_def)
|
2290
|
+
if match:
|
2291
|
+
dimension = int(match.group(1))
|
2292
|
+
precision = match.group(2) or "f32"
|
2293
|
+
builder.add_vector_column(column_name, dimension, precision)
|
2294
|
+
else:
|
2295
|
+
raise ValueError(f"Invalid vector format: {column_def}")
|
2296
|
+
|
2297
|
+
elif column_def.startswith("varchar("):
|
2298
|
+
# Parse varchar type: varchar(100)
|
2299
|
+
import re
|
2300
|
+
|
2301
|
+
match = re.match(r"varchar\((\d+)\)", column_def)
|
2302
|
+
if match:
|
2303
|
+
length = int(match.group(1))
|
2304
|
+
builder.add_string_column(column_name, length)
|
2305
|
+
else:
|
2306
|
+
raise ValueError(f"Invalid varchar format: {column_def}")
|
2307
|
+
|
2308
|
+
elif column_def.startswith("char("):
|
2309
|
+
# Parse char type: char(10)
|
2310
|
+
import re
|
2311
|
+
|
2312
|
+
match = re.match(r"char\((\d+)\)", column_def)
|
2313
|
+
if match:
|
2314
|
+
length = int(match.group(1))
|
2315
|
+
builder.add_string_column(column_name, length)
|
2316
|
+
else:
|
2317
|
+
raise ValueError(f"Invalid char format: {column_def}")
|
2318
|
+
|
2319
|
+
elif column_def.startswith("decimal("):
|
2320
|
+
# Parse decimal type: decimal(10,2)
|
2321
|
+
import re
|
2322
|
+
|
2323
|
+
match = re.match(r"decimal\((\d+),(\d+)\)", column_def)
|
2324
|
+
if match:
|
2325
|
+
precision = int(match.group(1))
|
2326
|
+
scale = int(match.group(2))
|
2327
|
+
builder.add_numeric_column(column_name, "decimal", precision, scale)
|
2328
|
+
else:
|
2329
|
+
raise ValueError(f"Invalid decimal format: {column_def}")
|
2330
|
+
|
2331
|
+
elif column_def.startswith("float("):
|
2332
|
+
# Parse float type: float(10)
|
2333
|
+
import re
|
2334
|
+
|
2335
|
+
match = re.match(r"float\((\d+)\)", column_def)
|
2336
|
+
if match:
|
2337
|
+
precision = int(match.group(1))
|
2338
|
+
builder.add_numeric_column(column_name, "float", precision)
|
2339
|
+
else:
|
2340
|
+
raise ValueError(f"Invalid float format: {column_def}")
|
2341
|
+
|
2342
|
+
elif column_def in ("int", "integer"):
|
2343
|
+
builder.add_int_column(column_name, primary_key=is_primary)
|
2344
|
+
elif column_def in ("bigint", "bigint unsigned"):
|
2345
|
+
builder.add_bigint_column(column_name, primary_key=is_primary)
|
2346
|
+
elif column_def in ("smallint", "tinyint"):
|
2347
|
+
if column_def == "smallint":
|
2348
|
+
builder.add_smallint_column(column_name, primary_key=is_primary)
|
2349
|
+
else:
|
2350
|
+
builder.add_tinyint_column(column_name, primary_key=is_primary)
|
2351
|
+
elif column_def in ("text", "longtext", "mediumtext", "tinytext"):
|
2352
|
+
builder.add_text_column(column_name)
|
2353
|
+
elif column_def in ("float", "double"):
|
2354
|
+
builder.add_numeric_column(column_name, column_def)
|
2355
|
+
elif column_def in ("date", "datetime", "timestamp", "time"):
|
2356
|
+
builder.add_datetime_column(column_name, column_def)
|
2357
|
+
elif column_def in ("boolean", "bool"):
|
2358
|
+
builder.add_boolean_column(column_name)
|
2359
|
+
elif column_def in ("json", "jsonb"):
|
2360
|
+
builder.add_json_column(column_name)
|
2361
|
+
elif column_def in (
|
2362
|
+
"blob",
|
2363
|
+
"longblob",
|
2364
|
+
"mediumblob",
|
2365
|
+
"tinyblob",
|
2366
|
+
"binary",
|
2367
|
+
"varbinary",
|
2368
|
+
):
|
2369
|
+
builder.add_binary_column(column_name, column_def)
|
2370
|
+
else:
|
2371
|
+
raise ValueError(
|
2372
|
+
f"Unsupported column type '{column_def}' for column '{column_name}'. "
|
2373
|
+
f"Supported types: int, bigint, smallint, tinyint, varchar(n), char(n), "
|
2374
|
+
f"text, float, double, decimal(p,s), date, datetime, timestamp, time, "
|
2375
|
+
f"boolean, json, blob, vecf32(n), vecf64(n)"
|
2376
|
+
)
|
2377
|
+
|
2378
|
+
# Create table using the provided connection
|
2379
|
+
table = builder.build()
|
2380
|
+
create_sql = CreateTable(table)
|
2381
|
+
sql = str(create_sql.compile(dialect=connection.dialect))
|
2382
|
+
connection.execute(sql)
|
2383
|
+
|
2384
|
+
# Create indexes if specified
|
2385
|
+
if indexes:
|
2386
|
+
for index_def in indexes:
|
2387
|
+
index_name = index_def["name"]
|
2388
|
+
column_name = index_def["column"]
|
2389
|
+
index_type = index_def["type"]
|
2390
|
+
params = index_def.get("params", {})
|
2391
|
+
|
2392
|
+
# Create the index using separated APIs
|
2393
|
+
if index_type == "ivfflat":
|
2394
|
+
self.vector_index.create_ivf(table_name=table_name, name=index_name, column=column_name, **params)
|
2395
|
+
elif index_type == "hnsw":
|
2396
|
+
self.vector_index.create_hnsw(table_name=table_name, name=index_name, column=column_name, **params)
|
2397
|
+
else:
|
2398
|
+
raise ValueError(f"Unsupported index type: {index_type}")
|
2399
|
+
|
2400
|
+
return self
|
2401
|
+
|
2402
|
+
def create_table_orm(self, table_name: str, *columns, **kwargs) -> "Client":
|
2403
|
+
"""
|
2404
|
+
Create a table using SQLAlchemy ORM-style column definitions.
|
2405
|
+
Similar to SQLAlchemy Table() constructor but without metadata.
|
2406
|
+
|
2407
|
+
Args::
|
2408
|
+
|
2409
|
+
table_name: Name of the table
|
2410
|
+
*columns: SQLAlchemy Column objects and Index objects (including VectorIndex)
|
2411
|
+
**kwargs: Additional parameters (like enable_hnsw, enable_ivf)
|
2412
|
+
|
2413
|
+
Returns::
|
2414
|
+
|
2415
|
+
Client: Self for chaining
|
2416
|
+
|
2417
|
+
Example::
|
2418
|
+
|
2419
|
+
from sqlalchemy import Column, BigInteger, Integer
|
2420
|
+
from matrixone.sqlalchemy_ext import Vectorf32, VectorIndex, VectorIndexType, VectorOpType
|
2421
|
+
|
2422
|
+
client.create_table_orm(
|
2423
|
+
'vector_docs_hnsw_demo',
|
2424
|
+
Column('a', BigInteger, primary_key=True),
|
2425
|
+
Column('b', Vectorf32(128)),
|
2426
|
+
Column('c', Integer),
|
2427
|
+
VectorIndex('idx_hnsw', 'b', index_type=VectorIndexType.HNSW,
|
2428
|
+
m=48, ef_construction=64, ef_search=64,
|
2429
|
+
op_type=VectorOpType.VECTOR_L2_OPS)
|
2430
|
+
)
|
2431
|
+
"""
|
2432
|
+
from sqlalchemy import MetaData, Table
|
2433
|
+
|
2434
|
+
# Create metadata and table
|
2435
|
+
metadata = MetaData()
|
2436
|
+
table = Table(table_name, metadata, *columns)
|
2437
|
+
|
2438
|
+
# Check if we need to enable HNSW or IVF indexing
|
2439
|
+
enable_hnsw = kwargs.get("enable_hnsw", False)
|
2440
|
+
enable_ivf = kwargs.get("enable_ivf", False)
|
2441
|
+
|
2442
|
+
# Check if table has vector indexes that need special handling
|
2443
|
+
has_hnsw_index = False
|
2444
|
+
has_ivf_index = False
|
2445
|
+
|
2446
|
+
for item in table.indexes:
|
2447
|
+
if hasattr(item, "index_type"):
|
2448
|
+
# Check for HNSW index type (string comparison)
|
2449
|
+
if str(item.index_type).lower() == "hnsw":
|
2450
|
+
has_hnsw_index = True
|
2451
|
+
elif str(item.index_type).lower() == "ivfflat":
|
2452
|
+
has_ivf_index = True
|
2453
|
+
|
2454
|
+
# Create table using SQLAlchemy engine with proper session handling
|
2455
|
+
engine = self.get_sqlalchemy_engine()
|
2456
|
+
|
2457
|
+
# Enable appropriate indexing if needed and create table in same session
|
2458
|
+
if has_hnsw_index or enable_hnsw:
|
2459
|
+
with engine.begin() as conn:
|
2460
|
+
from .sqlalchemy_ext import create_hnsw_config
|
2461
|
+
|
2462
|
+
hnsw_config = create_hnsw_config(self._engine)
|
2463
|
+
hnsw_config.enable_hnsw_indexing(conn)
|
2464
|
+
# Create table and indexes in the same session
|
2465
|
+
table.create(conn)
|
2466
|
+
elif has_ivf_index or enable_ivf:
|
2467
|
+
with engine.begin() as conn:
|
2468
|
+
from .sqlalchemy_ext import create_ivf_config
|
2469
|
+
|
2470
|
+
ivf_config = create_ivf_config(self._engine)
|
2471
|
+
ivf_config.enable_ivf_indexing()
|
2472
|
+
# Create table and indexes in the same session
|
2473
|
+
table.create(conn)
|
2474
|
+
else:
|
2475
|
+
# No special indexing needed, create normally
|
2476
|
+
table.create(engine)
|
2477
|
+
|
2478
|
+
return self
|
2479
|
+
|
2480
|
+
def create_table_orm_in_transaction(self, table_name: str, connection, *columns, **kwargs) -> "Client":
|
2481
|
+
"""
|
2482
|
+
Create a table using SQLAlchemy ORM-style definitions within an existing SQLAlchemy transaction.
|
2483
|
+
|
2484
|
+
Args::
|
2485
|
+
|
2486
|
+
table_name: Name of the table
|
2487
|
+
connection: SQLAlchemy connection object (required for transaction support)
|
2488
|
+
*columns: SQLAlchemy Column objects and Index objects (including VectorIndex)
|
2489
|
+
**kwargs: Additional parameters (like enable_hnsw, enable_ivf)
|
2490
|
+
|
2491
|
+
Returns::
|
2492
|
+
|
2493
|
+
Client: Self for chaining
|
2494
|
+
"""
|
2495
|
+
if connection is None:
|
2496
|
+
raise ValueError("connection parameter is required for transaction operations")
|
2497
|
+
|
2498
|
+
from sqlalchemy import MetaData, Table
|
2499
|
+
from sqlalchemy.schema import CreateTable
|
2500
|
+
|
2501
|
+
# Create metadata and table
|
2502
|
+
metadata = MetaData()
|
2503
|
+
table = Table(table_name, metadata, *columns)
|
2504
|
+
|
2505
|
+
# Check if we need to enable HNSW or IVF indexing
|
2506
|
+
enable_hnsw = kwargs.get("enable_hnsw", False)
|
2507
|
+
enable_ivf = kwargs.get("enable_ivf", False)
|
2508
|
+
|
2509
|
+
# Check if table has vector indexes that need special handling
|
2510
|
+
has_hnsw_index = False
|
2511
|
+
has_ivf_index = False
|
2512
|
+
|
2513
|
+
for item in table.indexes:
|
2514
|
+
if hasattr(item, "index_type"):
|
2515
|
+
# Check for HNSW index type (string comparison)
|
2516
|
+
if str(item.index_type).lower() == "hnsw":
|
2517
|
+
has_hnsw_index = True
|
2518
|
+
elif str(item.index_type).lower() == "ivfflat":
|
2519
|
+
has_ivf_index = True
|
2520
|
+
|
2521
|
+
# Enable appropriate indexing if needed (within transaction)
|
2522
|
+
if has_hnsw_index or enable_hnsw:
|
2523
|
+
from .sqlalchemy_ext import create_hnsw_config
|
2524
|
+
|
2525
|
+
hnsw_config = create_hnsw_config(self._engine)
|
2526
|
+
hnsw_config.enable_hnsw_indexing(connection)
|
2527
|
+
if has_ivf_index or enable_ivf:
|
2528
|
+
from .sqlalchemy_ext import create_ivf_config
|
2529
|
+
|
2530
|
+
ivf_config = create_ivf_config(self._engine)
|
2531
|
+
ivf_config.enable_ivf_indexing()
|
2532
|
+
|
2533
|
+
# Create table using the provided connection
|
2534
|
+
create_sql = CreateTable(table)
|
2535
|
+
sql = str(create_sql.compile(dialect=connection.dialect))
|
2536
|
+
connection.execute(sql)
|
2537
|
+
|
2538
|
+
return self
|
2539
|
+
|
2540
|
+
def create_all(self, base_class=None):
|
2541
|
+
"""
|
2542
|
+
Create all tables defined in the given base class or default Base.
|
2543
|
+
|
2544
|
+
Args::
|
2545
|
+
|
2546
|
+
base_class: SQLAlchemy declarative base class. If None, uses the default Base.
|
2547
|
+
"""
|
2548
|
+
if base_class is None:
|
2549
|
+
from matrixone.orm import declarative_base
|
2550
|
+
|
2551
|
+
base_class = declarative_base()
|
2552
|
+
|
2553
|
+
base_class.metadata.create_all(self._engine)
|
2554
|
+
return self
|
2555
|
+
|
2556
|
+
def drop_all(self, base_class=None):
|
2557
|
+
"""
|
2558
|
+
Drop all tables defined in the given base class or default Base.
|
2559
|
+
|
2560
|
+
Args::
|
2561
|
+
|
2562
|
+
base_class: SQLAlchemy declarative base class. If None, uses the default Base.
|
2563
|
+
"""
|
2564
|
+
if base_class is None:
|
2565
|
+
from matrixone.orm import declarative_base
|
2566
|
+
|
2567
|
+
base_class = declarative_base()
|
2568
|
+
|
2569
|
+
# Get all table names from the metadata
|
2570
|
+
table_names = list(base_class.metadata.tables.keys())
|
2571
|
+
|
2572
|
+
# Drop each table individually using direct SQL for better compatibility
|
2573
|
+
for table_name in table_names:
|
2574
|
+
try:
|
2575
|
+
self.execute(f"DROP TABLE IF EXISTS {table_name}")
|
2576
|
+
except Exception as e:
|
2577
|
+
# Log the error but continue with other tables
|
2578
|
+
print(f"Warning: Failed to drop table {table_name}: {e}")
|
2579
|
+
|
2580
|
+
return self
|
2581
|
+
|
2582
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
2583
|
+
self.disconnect()
|
2584
|
+
|
2585
|
+
|
2586
|
+
class ResultSet:
|
2587
|
+
"""
|
2588
|
+
Result set wrapper for query results from MatrixOne database operations.
|
2589
|
+
|
2590
|
+
This class provides a convenient interface for accessing query results
|
2591
|
+
with methods similar to database cursor objects. It supports both
|
2592
|
+
SELECT queries (returning data) and DML operations (returning affected row counts).
|
2593
|
+
|
2594
|
+
Key Features:
|
2595
|
+
|
2596
|
+
- Iterator interface for row-by-row access
|
2597
|
+
- Bulk data access methods (fetchall, fetchmany)
|
2598
|
+
- Column name access and metadata
|
2599
|
+
- Affected row count for DML operations
|
2600
|
+
- Cursor-like positioning for result navigation
|
2601
|
+
|
2602
|
+
Attributes::
|
2603
|
+
|
2604
|
+
columns (List[str]): List of column names in the result set
|
2605
|
+
rows (List[Tuple[Any, ...]]): List of tuples containing row data
|
2606
|
+
affected_rows (int): Number of rows affected by DML operations
|
2607
|
+
|
2608
|
+
Usage Examples:
|
2609
|
+
|
2610
|
+
# SELECT query results
|
2611
|
+
>>> result = client.execute("SELECT id, name, age FROM users WHERE age > ?", (25,))
|
2612
|
+
>>> print(f"Found {len(result.rows)} users")
|
2613
|
+
>>> for row in result.fetchall():
|
2614
|
+
... print(f"ID: {row[0]}, Name: {row[1]}, Age: {row[2]}")
|
2615
|
+
|
2616
|
+
# Access by column name
|
2617
|
+
>>> for row in result.rows:
|
2618
|
+
... user_id = row[result.columns.index('id')]
|
2619
|
+
... user_name = row[result.columns.index('name')]
|
2620
|
+
|
2621
|
+
# DML operation results
|
2622
|
+
>>> result = client.execute("INSERT INTO users (name, age) VALUES (?, ?)", ("John", 30))
|
2623
|
+
>>> print(f"Inserted {result.affected_rows} rows")
|
2624
|
+
|
2625
|
+
# Iterator interface
|
2626
|
+
>>> for row in result:
|
2627
|
+
... print(row)
|
2628
|
+
|
2629
|
+
Note: This class is automatically created by the Client's execute() method
|
2630
|
+
and provides a consistent interface for all query results.
|
2631
|
+
"""
|
2632
|
+
|
2633
|
+
def __init__(self, columns: List[str], rows: List[Tuple[Any, ...]], affected_rows: int = 0):
|
2634
|
+
self.columns = columns
|
2635
|
+
self.rows = rows
|
2636
|
+
self.affected_rows = affected_rows
|
2637
|
+
self._cursor = 0 # Track current position in result set
|
2638
|
+
|
2639
|
+
def fetchall(self) -> List[Tuple[Any, ...]]:
|
2640
|
+
"""Fetch all remaining rows"""
|
2641
|
+
remaining_rows = self.rows[self._cursor :]
|
2642
|
+
self._cursor = len(self.rows)
|
2643
|
+
return remaining_rows
|
2644
|
+
|
2645
|
+
def fetchone(self) -> Optional[Tuple[Any, ...]]:
|
2646
|
+
"""Fetch one row"""
|
2647
|
+
if self._cursor < len(self.rows):
|
2648
|
+
row = self.rows[self._cursor]
|
2649
|
+
self._cursor += 1
|
2650
|
+
return row
|
2651
|
+
return None
|
2652
|
+
|
2653
|
+
def fetchmany(self, size: int = 1) -> List[Tuple[Any, ...]]:
|
2654
|
+
"""Fetch many rows"""
|
2655
|
+
start = self._cursor
|
2656
|
+
end = min(start + size, len(self.rows))
|
2657
|
+
rows = self.rows[start:end]
|
2658
|
+
self._cursor = end
|
2659
|
+
return rows
|
2660
|
+
|
2661
|
+
def scalar(self) -> Any:
|
2662
|
+
"""Get scalar value (first column of first row)"""
|
2663
|
+
if self.rows and self.columns:
|
2664
|
+
return self.rows[0][0]
|
2665
|
+
return None
|
2666
|
+
|
2667
|
+
def keys(self):
|
2668
|
+
"""Get column names"""
|
2669
|
+
return iter(self.columns)
|
2670
|
+
|
2671
|
+
def __iter__(self):
|
2672
|
+
return iter(self.rows)
|
2673
|
+
|
2674
|
+
def __len__(self):
|
2675
|
+
return len(self.rows)
|
2676
|
+
|
2677
|
+
|
2678
|
+
class SnapshotClient:
|
2679
|
+
"""Snapshot client wrapper for executing queries with snapshot"""
|
2680
|
+
|
2681
|
+
def __init__(self, client, snapshot_name: str):
|
2682
|
+
self.client = client
|
2683
|
+
self.snapshot_name = snapshot_name
|
2684
|
+
|
2685
|
+
def execute(self, sql: str, params: Optional[Tuple] = None) -> ResultSet:
|
2686
|
+
"""Execute SQL with snapshot"""
|
2687
|
+
# Insert snapshot hint after the first table name in FROM clause
|
2688
|
+
import re
|
2689
|
+
|
2690
|
+
# Find the first table name after FROM and insert snapshot hint
|
2691
|
+
pattern = r"(\bFROM\s+)(\w+)(\s|$)"
|
2692
|
+
|
2693
|
+
def replace_func(match):
|
2694
|
+
return f"{match.group(1)}{match.group(2)}{{snapshot = '{self.snapshot_name}'}}{match.group(3)}"
|
2695
|
+
|
2696
|
+
snapshot_sql = re.sub(pattern, replace_func, sql, count=1)
|
2697
|
+
|
2698
|
+
# Handle parameter substitution for MatrixOne compatibility
|
2699
|
+
final_sql = self.client._substitute_parameters(snapshot_sql, params)
|
2700
|
+
|
2701
|
+
return self.client.execute(final_sql)
|
2702
|
+
|
2703
|
+
|
2704
|
+
class TransactionWrapper:
|
2705
|
+
"""
|
2706
|
+
Transaction wrapper for executing queries within a MatrixOne transaction.
|
2707
|
+
|
2708
|
+
This class provides a transaction context for executing multiple database
|
2709
|
+
operations atomically. It wraps a SQLAlchemy connection and provides
|
2710
|
+
access to all MatrixOne managers (snapshots, clone, restore, PITR, etc.)
|
2711
|
+
within the transaction context.
|
2712
|
+
|
2713
|
+
Key Features:
|
2714
|
+
|
2715
|
+
- Atomic transaction execution with automatic rollback on errors
|
2716
|
+
- Access to all MatrixOne managers within transaction context
|
2717
|
+
- SQLAlchemy session integration
|
2718
|
+
- Automatic commit/rollback handling
|
2719
|
+
- Support for nested transactions
|
2720
|
+
|
2721
|
+
Available Managers:
|
2722
|
+
- snapshots: TransactionSnapshotManager for snapshot operations
|
2723
|
+
- clone: TransactionCloneManager for clone operations
|
2724
|
+
- restore: TransactionRestoreManager for restore operations
|
2725
|
+
- pitr: TransactionPitrManager for point-in-time recovery
|
2726
|
+
- pubsub: TransactionPubSubManager for pub/sub operations
|
2727
|
+
- account: TransactionAccountManager for account operations
|
2728
|
+
- vector: TransactionVectorManager for vector operations
|
2729
|
+
- fulltext_index: TransactionFulltextIndexManager for fulltext operations
|
2730
|
+
|
2731
|
+
Usage Examples
|
2732
|
+
|
2733
|
+
.. code-block:: python
|
2734
|
+
|
2735
|
+
# Basic transaction usage
|
2736
|
+
with client.transaction() as tx:
|
2737
|
+
tx.execute("INSERT INTO users (name) VALUES (?)", ("John",))
|
2738
|
+
tx.execute("INSERT INTO orders (user_id, amount) VALUES (?, ?)", (1, 100.0))
|
2739
|
+
# Transaction commits automatically on success
|
2740
|
+
|
2741
|
+
# Using managers within transaction
|
2742
|
+
with client.transaction() as tx:
|
2743
|
+
# Create snapshot within transaction
|
2744
|
+
tx.snapshots.create("backup", SnapshotLevel.DATABASE, database="mydb")
|
2745
|
+
|
2746
|
+
# Clone database within transaction
|
2747
|
+
tx.clone.clone_database("new_db", "source_db")
|
2748
|
+
|
2749
|
+
# Vector operations within transaction
|
2750
|
+
tx.vector.create_table("vectors", {"id": "int", "embedding": "vector(384,f32)"})
|
2751
|
+
|
2752
|
+
# SQLAlchemy session integration
|
2753
|
+
with client.transaction() as tx:
|
2754
|
+
session = tx.get_sqlalchemy_session()
|
2755
|
+
user = User(name="John")
|
2756
|
+
session.add(user)
|
2757
|
+
session.commit()
|
2758
|
+
|
2759
|
+
Note: This class is automatically created by the Client's transaction()
|
2760
|
+
context manager and should not be instantiated directly.
|
2761
|
+
"""
|
2762
|
+
|
2763
|
+
def __init__(self, connection, client):
|
2764
|
+
self.connection = connection
|
2765
|
+
self.client = client
|
2766
|
+
# Create snapshot, clone, restore, PITR, pubsub, account, and vector managers that use this transaction
|
2767
|
+
self.snapshots = TransactionSnapshotManager(client, self)
|
2768
|
+
self.clone = TransactionCloneManager(client, self)
|
2769
|
+
self.restore = TransactionRestoreManager(client, self)
|
2770
|
+
self.pitr = TransactionPitrManager(client, self)
|
2771
|
+
self.pubsub = TransactionPubSubManager(client, self)
|
2772
|
+
self.account = TransactionAccountManager(self)
|
2773
|
+
self.vector_ops = TransactionVectorIndexManager(client, self)
|
2774
|
+
self.fulltext_index = TransactionFulltextIndexManager(client, self)
|
2775
|
+
self.metadata = TransactionMetadataManager(client, self)
|
2776
|
+
# SQLAlchemy integration
|
2777
|
+
self._sqlalchemy_session = None
|
2778
|
+
|
2779
|
+
def execute(self, sql: str, params: Optional[Tuple] = None) -> ResultSet:
|
2780
|
+
"""Execute SQL within transaction"""
|
2781
|
+
import time
|
2782
|
+
|
2783
|
+
start_time = time.time()
|
2784
|
+
|
2785
|
+
try:
|
2786
|
+
# Use exec_driver_sql() to bypass SQLAlchemy's bind parameter parsing
|
2787
|
+
# This prevents JSON strings like {"a":1} from being parsed as :1 bind params
|
2788
|
+
if hasattr(self.connection, 'exec_driver_sql'):
|
2789
|
+
# Escape % to %% for pymysql's format string handling
|
2790
|
+
escaped_sql = sql.replace('%', '%%')
|
2791
|
+
result = self.connection.exec_driver_sql(escaped_sql)
|
2792
|
+
else:
|
2793
|
+
# Fallback for testing or older SQLAlchemy versions
|
2794
|
+
from sqlalchemy import text
|
2795
|
+
|
2796
|
+
result = self.connection.execute(text(sql), params or {})
|
2797
|
+
execution_time = time.time() - start_time
|
2798
|
+
|
2799
|
+
if result.returns_rows:
|
2800
|
+
columns = list(result.keys())
|
2801
|
+
rows = result.fetchall()
|
2802
|
+
self.client.logger.log_query(sql, execution_time, len(rows), success=True)
|
2803
|
+
return ResultSet(columns, rows)
|
2804
|
+
else:
|
2805
|
+
self.client.logger.log_query(sql, execution_time, result.rowcount, success=True)
|
2806
|
+
return ResultSet([], [], affected_rows=result.rowcount)
|
2807
|
+
|
2808
|
+
except Exception as e:
|
2809
|
+
execution_time = time.time() - start_time
|
2810
|
+
self.client.logger.log_query(sql, execution_time, success=False)
|
2811
|
+
self.client.logger.log_error(e, context="Transaction query execution")
|
2812
|
+
raise QueryError(f"Transaction query execution failed: {e}")
|
2813
|
+
|
2814
|
+
def get_connection(self):
|
2815
|
+
"""
|
2816
|
+
Get the underlying SQLAlchemy connection for direct use
|
2817
|
+
|
2818
|
+
Returns::
|
2819
|
+
|
2820
|
+
SQLAlchemy Connection instance bound to this transaction
|
2821
|
+
"""
|
2822
|
+
return self.connection
|
2823
|
+
|
2824
|
+
def get_sqlalchemy_session(self):
|
2825
|
+
"""
|
2826
|
+
Get SQLAlchemy session that uses the same transaction
|
2827
|
+
|
2828
|
+
Returns::
|
2829
|
+
|
2830
|
+
SQLAlchemy Session instance bound to this transaction
|
2831
|
+
"""
|
2832
|
+
if self._sqlalchemy_session is None:
|
2833
|
+
from sqlalchemy.orm import sessionmaker
|
2834
|
+
|
2835
|
+
# Create session factory using the client's engine
|
2836
|
+
Session = sessionmaker(bind=self.client._engine)
|
2837
|
+
self._sqlalchemy_session = Session(bind=self.connection)
|
2838
|
+
|
2839
|
+
return self._sqlalchemy_session
|
2840
|
+
|
2841
|
+
def commit_sqlalchemy(self) -> None:
|
2842
|
+
"""Commit SQLAlchemy session"""
|
2843
|
+
if self._sqlalchemy_session:
|
2844
|
+
self._sqlalchemy_session.commit()
|
2845
|
+
|
2846
|
+
def rollback_sqlalchemy(self) -> None:
|
2847
|
+
"""Rollback SQLAlchemy session"""
|
2848
|
+
if self._sqlalchemy_session:
|
2849
|
+
self._sqlalchemy_session.rollback()
|
2850
|
+
|
2851
|
+
def close_sqlalchemy(self) -> None:
|
2852
|
+
"""Close SQLAlchemy session"""
|
2853
|
+
if self._sqlalchemy_session:
|
2854
|
+
self._sqlalchemy_session.close()
|
2855
|
+
self._sqlalchemy_session = None
|
2856
|
+
|
2857
|
+
def insert(self, table_name: str, data: dict[str, Any]) -> ResultSet:
|
2858
|
+
"""
|
2859
|
+
Insert data into a table within transaction.
|
2860
|
+
|
2861
|
+
Args::
|
2862
|
+
|
2863
|
+
table_name: Name of the table
|
2864
|
+
data: Data to insert (dict with column names as keys)
|
2865
|
+
|
2866
|
+
Returns::
|
2867
|
+
|
2868
|
+
ResultSet object
|
2869
|
+
"""
|
2870
|
+
sql = self.client._build_insert_sql(table_name, data)
|
2871
|
+
return self.execute(sql)
|
2872
|
+
|
2873
|
+
def batch_insert(self, table_name: str, data_list: list[dict[str, Any]]) -> ResultSet:
|
2874
|
+
"""
|
2875
|
+
Batch insert data into a table within transaction.
|
2876
|
+
|
2877
|
+
Args::
|
2878
|
+
|
2879
|
+
table_name: Name of the table
|
2880
|
+
data_list: List of data dictionaries to insert
|
2881
|
+
|
2882
|
+
Returns::
|
2883
|
+
|
2884
|
+
ResultSet object
|
2885
|
+
"""
|
2886
|
+
if not data_list:
|
2887
|
+
return ResultSet([], [], affected_rows=0)
|
2888
|
+
|
2889
|
+
sql = self.client._build_batch_insert_sql(table_name, data_list)
|
2890
|
+
return self.execute(sql)
|
2891
|
+
|
2892
|
+
def query(self, *columns, snapshot: str = None):
|
2893
|
+
"""Get MatrixOne query builder within transaction - SQLAlchemy style
|
2894
|
+
|
2895
|
+
Args::
|
2896
|
+
|
2897
|
+
*columns: Can be:
|
2898
|
+
- Single model class: query(Article) - returns all columns from model
|
2899
|
+
- Multiple columns: query(Article.id, Article.title) - returns specific columns
|
2900
|
+
- Mixed: query(Article, Article.id, some_expression.label('alias')) - model + additional columns
|
2901
|
+
snapshot: Optional snapshot name for snapshot queries
|
2902
|
+
|
2903
|
+
Returns::
|
2904
|
+
|
2905
|
+
MatrixOneQuery instance configured for the specified columns within transaction
|
2906
|
+
"""
|
2907
|
+
from .orm import MatrixOneQuery
|
2908
|
+
|
2909
|
+
if len(columns) == 1:
|
2910
|
+
# Traditional single model class usage
|
2911
|
+
column = columns[0]
|
2912
|
+
if isinstance(column, str):
|
2913
|
+
# String table name
|
2914
|
+
return MatrixOneQuery(column, self.client, transaction_wrapper=self, snapshot=snapshot)
|
2915
|
+
elif hasattr(column, '__tablename__'):
|
2916
|
+
# This is a model class
|
2917
|
+
return MatrixOneQuery(column, self.client, transaction_wrapper=self, snapshot=snapshot)
|
2918
|
+
elif hasattr(column, 'name') and hasattr(column, 'as_sql'):
|
2919
|
+
# This is a CTE object
|
2920
|
+
from .orm import CTE
|
2921
|
+
|
2922
|
+
if isinstance(column, CTE):
|
2923
|
+
query = MatrixOneQuery(None, self.client, transaction_wrapper=self, snapshot=snapshot)
|
2924
|
+
query._table_name = column.name
|
2925
|
+
query._select_columns = ["*"] # Default to select all from CTE
|
2926
|
+
query._ctes = [column] # Add the CTE to the query
|
2927
|
+
return query
|
2928
|
+
else:
|
2929
|
+
# This is a single column/expression - need to handle specially
|
2930
|
+
query = MatrixOneQuery(None, self.client, transaction_wrapper=self, snapshot=snapshot)
|
2931
|
+
query._select_columns = [column]
|
2932
|
+
# Try to infer table name from column
|
2933
|
+
if hasattr(column, 'table') and hasattr(column.table, 'name'):
|
2934
|
+
query._table_name = column.table.name
|
2935
|
+
return query
|
2936
|
+
else:
|
2937
|
+
# Multiple columns/expressions
|
2938
|
+
model_class = None
|
2939
|
+
select_columns = []
|
2940
|
+
|
2941
|
+
for column in columns:
|
2942
|
+
if hasattr(column, '__tablename__'):
|
2943
|
+
# This is a model class - use its table
|
2944
|
+
model_class = column
|
2945
|
+
else:
|
2946
|
+
# This is a column or expression
|
2947
|
+
select_columns.append(column)
|
2948
|
+
|
2949
|
+
if model_class:
|
2950
|
+
query = MatrixOneQuery(model_class, self.client, transaction_wrapper=self, snapshot=snapshot)
|
2951
|
+
if select_columns:
|
2952
|
+
# Add additional columns to the model's default columns
|
2953
|
+
query._select_columns = select_columns
|
2954
|
+
return query
|
2955
|
+
else:
|
2956
|
+
# No model class provided, need to infer table from columns
|
2957
|
+
query = MatrixOneQuery(None, self.client, transaction_wrapper=self, snapshot=snapshot)
|
2958
|
+
query._select_columns = select_columns
|
2959
|
+
|
2960
|
+
# Try to infer table name from first column that has table info
|
2961
|
+
for col in select_columns:
|
2962
|
+
if hasattr(col, 'table') and hasattr(col.table, 'name'):
|
2963
|
+
query._table_name = col.table.name
|
2964
|
+
break
|
2965
|
+
elif isinstance(col, str) and '.' in col:
|
2966
|
+
# String column like "table.column" - extract table name
|
2967
|
+
parts = col.split('.')
|
2968
|
+
if len(parts) >= 2:
|
2969
|
+
# For "db.table.column" format, use "db.table"
|
2970
|
+
# For "table.column" format, use "table"
|
2971
|
+
table_name = '.'.join(parts[:-1])
|
2972
|
+
query._table_name = table_name
|
2973
|
+
break
|
2974
|
+
|
2975
|
+
return query
|
2976
|
+
|
2977
|
+
def create_table(self, table_name: str, columns: dict, **kwargs) -> "TransactionWrapper":
|
2978
|
+
"""
|
2979
|
+
Create a table within MatrixOne transaction.
|
2980
|
+
|
2981
|
+
Args::
|
2982
|
+
|
2983
|
+
table_name: Name of the table
|
2984
|
+
columns: Dictionary mapping column names to their types (same format as client.create_table)
|
2985
|
+
**kwargs: Additional table parameters
|
2986
|
+
|
2987
|
+
Returns::
|
2988
|
+
|
2989
|
+
TransactionWrapper: Self for chaining
|
2990
|
+
"""
|
2991
|
+
from sqlalchemy.schema import CreateTable
|
2992
|
+
|
2993
|
+
from .sqlalchemy_ext import VectorTableBuilder
|
2994
|
+
|
2995
|
+
# Parse primary key from kwargs
|
2996
|
+
primary_key = kwargs.get("primary_key", None)
|
2997
|
+
|
2998
|
+
# Create table using VectorTableBuilder
|
2999
|
+
builder = VectorTableBuilder(table_name)
|
3000
|
+
|
3001
|
+
# Add columns based on simplified format (same logic as client.create_table)
|
3002
|
+
for column_name, column_def in columns.items():
|
3003
|
+
is_primary = primary_key == column_name
|
3004
|
+
|
3005
|
+
if column_def.startswith("vector("):
|
3006
|
+
# Parse vector type: vector(128,f32) or vector(128)
|
3007
|
+
import re
|
3008
|
+
|
3009
|
+
match = re.match(r"vector\((\d+)(?:,(\w+))?\)", column_def)
|
3010
|
+
if match:
|
3011
|
+
dimension = int(match.group(1))
|
3012
|
+
precision = match.group(2) or "f32"
|
3013
|
+
builder.add_vector_column(column_name, dimension, precision)
|
3014
|
+
else:
|
3015
|
+
raise ValueError(f"Invalid vector format: {column_def}")
|
3016
|
+
|
3017
|
+
elif column_def.startswith("varchar("):
|
3018
|
+
# Parse varchar type: varchar(100)
|
3019
|
+
import re
|
3020
|
+
|
3021
|
+
match = re.match(r"varchar\((\d+)\)", column_def)
|
3022
|
+
if match:
|
3023
|
+
length = int(match.group(1))
|
3024
|
+
builder.add_string_column(column_name, length)
|
3025
|
+
else:
|
3026
|
+
raise ValueError(f"Invalid varchar format: {column_def}")
|
3027
|
+
|
3028
|
+
elif column_def.startswith("char("):
|
3029
|
+
# Parse char type: char(10)
|
3030
|
+
import re
|
3031
|
+
|
3032
|
+
match = re.match(r"char\((\d+)\)", column_def)
|
3033
|
+
if match:
|
3034
|
+
length = int(match.group(1))
|
3035
|
+
builder.add_string_column(column_name, length)
|
3036
|
+
else:
|
3037
|
+
raise ValueError(f"Invalid char format: {column_def}")
|
3038
|
+
|
3039
|
+
elif column_def.startswith("decimal("):
|
3040
|
+
# Parse decimal type: decimal(10,2)
|
3041
|
+
import re
|
3042
|
+
|
3043
|
+
match = re.match(r"decimal\((\d+),(\d+)\)", column_def)
|
3044
|
+
if match:
|
3045
|
+
precision = int(match.group(1))
|
3046
|
+
scale = int(match.group(2))
|
3047
|
+
builder.add_numeric_column(column_name, "decimal", precision, scale)
|
3048
|
+
else:
|
3049
|
+
raise ValueError(f"Invalid decimal format: {column_def}")
|
3050
|
+
|
3051
|
+
elif column_def.startswith("float("):
|
3052
|
+
# Parse float type: float(10)
|
3053
|
+
import re
|
3054
|
+
|
3055
|
+
match = re.match(r"float\((\d+)\)", column_def)
|
3056
|
+
if match:
|
3057
|
+
precision = int(match.group(1))
|
3058
|
+
builder.add_numeric_column(column_name, "float", precision)
|
3059
|
+
else:
|
3060
|
+
raise ValueError(f"Invalid float format: {column_def}")
|
3061
|
+
|
3062
|
+
elif column_def in ("int", "integer"):
|
3063
|
+
builder.add_int_column(column_name, primary_key=is_primary)
|
3064
|
+
elif column_def in ("bigint", "bigint unsigned"):
|
3065
|
+
builder.add_bigint_column(column_name, primary_key=is_primary)
|
3066
|
+
elif column_def in ("smallint", "tinyint"):
|
3067
|
+
if column_def == "smallint":
|
3068
|
+
builder.add_smallint_column(column_name, primary_key=is_primary)
|
3069
|
+
else:
|
3070
|
+
builder.add_tinyint_column(column_name, primary_key=is_primary)
|
3071
|
+
elif column_def in ("text", "longtext", "mediumtext", "tinytext"):
|
3072
|
+
builder.add_text_column(column_name)
|
3073
|
+
elif column_def in ("float", "double"):
|
3074
|
+
builder.add_numeric_column(column_name, column_def)
|
3075
|
+
elif column_def in ("date", "datetime", "timestamp", "time"):
|
3076
|
+
builder.add_datetime_column(column_name, column_def)
|
3077
|
+
elif column_def in ("boolean", "bool"):
|
3078
|
+
builder.add_boolean_column(column_name)
|
3079
|
+
elif column_def in ("json", "jsonb"):
|
3080
|
+
builder.add_json_column(column_name)
|
3081
|
+
elif column_def in (
|
3082
|
+
"blob",
|
3083
|
+
"longblob",
|
3084
|
+
"mediumblob",
|
3085
|
+
"tinyblob",
|
3086
|
+
"binary",
|
3087
|
+
"varbinary",
|
3088
|
+
):
|
3089
|
+
builder.add_binary_column(column_name, column_def)
|
3090
|
+
else:
|
3091
|
+
raise ValueError(
|
3092
|
+
f"Unsupported column type '{column_def}' for column '{column_name}'. "
|
3093
|
+
f"Supported types: int, bigint, smallint, tinyint, varchar(n), char(n), "
|
3094
|
+
f"text, float, double, decimal(p,s), date, datetime, timestamp, time, "
|
3095
|
+
f"boolean, json, blob, vecf32(n), vecf64(n)"
|
3096
|
+
)
|
3097
|
+
|
3098
|
+
# Create table using transaction wrapper's execute method
|
3099
|
+
table = builder.build()
|
3100
|
+
create_sql = CreateTable(table)
|
3101
|
+
sql = str(create_sql.compile(dialect=self.client.get_sqlalchemy_engine().dialect))
|
3102
|
+
self.execute(sql)
|
3103
|
+
|
3104
|
+
return self
|
3105
|
+
|
3106
|
+
def drop_table(self, table_name: str) -> "TransactionWrapper":
|
3107
|
+
"""
|
3108
|
+
Drop a table within MatrixOne transaction.
|
3109
|
+
|
3110
|
+
Args::
|
3111
|
+
|
3112
|
+
table_name: Name of the table to drop
|
3113
|
+
|
3114
|
+
Returns::
|
3115
|
+
|
3116
|
+
TransactionWrapper: Self for chaining
|
3117
|
+
"""
|
3118
|
+
sql = f"DROP TABLE IF EXISTS {table_name}"
|
3119
|
+
self.execute(sql)
|
3120
|
+
return self
|
3121
|
+
|
3122
|
+
def create_table_with_index(
|
3123
|
+
self, table_name: str, columns: dict, indexes: list = None, **kwargs
|
3124
|
+
) -> "TransactionWrapper":
|
3125
|
+
"""
|
3126
|
+
Create a table with vector indexes within MatrixOne transaction.
|
3127
|
+
|
3128
|
+
Args::
|
3129
|
+
|
3130
|
+
table_name: Name of the table
|
3131
|
+
columns: Dictionary mapping column names to their types (same format as client.create_table)
|
3132
|
+
indexes: List of index definitions (same format as client.create_table_with_index)
|
3133
|
+
**kwargs: Additional table parameters
|
3134
|
+
|
3135
|
+
Returns::
|
3136
|
+
|
3137
|
+
TransactionWrapper: Self for chaining
|
3138
|
+
"""
|
3139
|
+
from sqlalchemy.schema import CreateTable
|
3140
|
+
|
3141
|
+
from .sqlalchemy_ext import VectorTableBuilder
|
3142
|
+
|
3143
|
+
# Parse primary key from kwargs
|
3144
|
+
primary_key = kwargs.get("primary_key", None)
|
3145
|
+
|
3146
|
+
# Create table using VectorTableBuilder
|
3147
|
+
builder = VectorTableBuilder(table_name)
|
3148
|
+
|
3149
|
+
# Add columns based on simplified format (same logic as client.create_table)
|
3150
|
+
for column_name, column_def in columns.items():
|
3151
|
+
is_primary = primary_key == column_name
|
3152
|
+
|
3153
|
+
if column_def.startswith("vector("):
|
3154
|
+
# Parse vector type: vector(128,f32) or vector(128)
|
3155
|
+
import re
|
3156
|
+
|
3157
|
+
match = re.match(r"vector\((\d+)(?:,(\w+))?\)", column_def)
|
3158
|
+
if match:
|
3159
|
+
dimension = int(match.group(1))
|
3160
|
+
precision = match.group(2) or "f32"
|
3161
|
+
builder.add_vector_column(column_name, dimension, precision)
|
3162
|
+
else:
|
3163
|
+
raise ValueError(f"Invalid vector format: {column_def}")
|
3164
|
+
|
3165
|
+
elif column_def.startswith("varchar("):
|
3166
|
+
# Parse varchar type: varchar(100)
|
3167
|
+
import re
|
3168
|
+
|
3169
|
+
match = re.match(r"varchar\((\d+)\)", column_def)
|
3170
|
+
if match:
|
3171
|
+
length = int(match.group(1))
|
3172
|
+
builder.add_string_column(column_name, length)
|
3173
|
+
else:
|
3174
|
+
raise ValueError(f"Invalid varchar format: {column_def}")
|
3175
|
+
|
3176
|
+
elif column_def.startswith("char("):
|
3177
|
+
# Parse char type: char(10)
|
3178
|
+
import re
|
3179
|
+
|
3180
|
+
match = re.match(r"char\((\d+)\)", column_def)
|
3181
|
+
if match:
|
3182
|
+
length = int(match.group(1))
|
3183
|
+
builder.add_string_column(column_name, length)
|
3184
|
+
else:
|
3185
|
+
raise ValueError(f"Invalid char format: {column_def}")
|
3186
|
+
|
3187
|
+
elif column_def.startswith("decimal("):
|
3188
|
+
# Parse decimal type: decimal(10,2)
|
3189
|
+
import re
|
3190
|
+
|
3191
|
+
match = re.match(r"decimal\((\d+),(\d+)\)", column_def)
|
3192
|
+
if match:
|
3193
|
+
precision = int(match.group(1))
|
3194
|
+
scale = int(match.group(2))
|
3195
|
+
builder.add_numeric_column(column_name, "decimal", precision, scale)
|
3196
|
+
else:
|
3197
|
+
raise ValueError(f"Invalid decimal format: {column_def}")
|
3198
|
+
|
3199
|
+
elif column_def.startswith("float("):
|
3200
|
+
# Parse float type: float(10)
|
3201
|
+
import re
|
3202
|
+
|
3203
|
+
match = re.match(r"float\((\d+)\)", column_def)
|
3204
|
+
if match:
|
3205
|
+
precision = int(match.group(1))
|
3206
|
+
builder.add_numeric_column(column_name, "float", precision)
|
3207
|
+
else:
|
3208
|
+
raise ValueError(f"Invalid float format: {column_def}")
|
3209
|
+
|
3210
|
+
elif column_def in ("int", "integer"):
|
3211
|
+
builder.add_int_column(column_name, primary_key=is_primary)
|
3212
|
+
elif column_def in ("bigint", "bigint unsigned"):
|
3213
|
+
builder.add_bigint_column(column_name, primary_key=is_primary)
|
3214
|
+
elif column_def in ("smallint", "tinyint"):
|
3215
|
+
if column_def == "smallint":
|
3216
|
+
builder.add_smallint_column(column_name, primary_key=is_primary)
|
3217
|
+
else:
|
3218
|
+
builder.add_tinyint_column(column_name, primary_key=is_primary)
|
3219
|
+
elif column_def in ("text", "longtext", "mediumtext", "tinytext"):
|
3220
|
+
builder.add_text_column(column_name)
|
3221
|
+
elif column_def in ("float", "double"):
|
3222
|
+
builder.add_numeric_column(column_name, column_def)
|
3223
|
+
elif column_def in ("date", "datetime", "timestamp", "time"):
|
3224
|
+
builder.add_datetime_column(column_name, column_def)
|
3225
|
+
elif column_def in ("boolean", "bool"):
|
3226
|
+
builder.add_boolean_column(column_name)
|
3227
|
+
elif column_def in ("json", "jsonb"):
|
3228
|
+
builder.add_json_column(column_name)
|
3229
|
+
elif column_def in (
|
3230
|
+
"blob",
|
3231
|
+
"longblob",
|
3232
|
+
"mediumblob",
|
3233
|
+
"tinyblob",
|
3234
|
+
"binary",
|
3235
|
+
"varbinary",
|
3236
|
+
):
|
3237
|
+
builder.add_binary_column(column_name, column_def)
|
3238
|
+
else:
|
3239
|
+
raise ValueError(
|
3240
|
+
f"Unsupported column type '{column_def}' for column '{column_name}'. "
|
3241
|
+
f"Supported types: int, bigint, smallint, tinyint, varchar(n), char(n), "
|
3242
|
+
f"text, float, double, decimal(p,s), date, datetime, timestamp, time, "
|
3243
|
+
f"boolean, json, blob, vecf32(n), vecf64(n)"
|
3244
|
+
)
|
3245
|
+
|
3246
|
+
# Create table using transaction wrapper's execute method
|
3247
|
+
table = builder.build()
|
3248
|
+
create_sql = CreateTable(table)
|
3249
|
+
sql = str(create_sql.compile(dialect=self.client.get_sqlalchemy_engine().dialect))
|
3250
|
+
self.execute(sql)
|
3251
|
+
|
3252
|
+
# Create indexes if specified
|
3253
|
+
if indexes:
|
3254
|
+
for index_def in indexes:
|
3255
|
+
index_name = index_def["name"]
|
3256
|
+
column_name = index_def["column"]
|
3257
|
+
index_type = index_def["type"]
|
3258
|
+
params = index_def.get("params", {})
|
3259
|
+
|
3260
|
+
# Create the index using transaction wrapper's vector_index with separated APIs
|
3261
|
+
if index_type == "ivfflat":
|
3262
|
+
self.vector_index.create_ivf(table_name=table_name, name=index_name, column=column_name, **params)
|
3263
|
+
elif index_type == "hnsw":
|
3264
|
+
self.vector_index.create_hnsw(table_name=table_name, name=index_name, column=column_name, **params)
|
3265
|
+
else:
|
3266
|
+
raise ValueError(f"Unsupported index type: {index_type}")
|
3267
|
+
|
3268
|
+
return self
|
3269
|
+
|
3270
|
+
def create_table_orm(self, table_name: str, *columns, **kwargs) -> "TransactionWrapper":
|
3271
|
+
"""
|
3272
|
+
Create a table using SQLAlchemy ORM-style definitions within MatrixOne transaction.
|
3273
|
+
|
3274
|
+
Args::
|
3275
|
+
|
3276
|
+
table_name: Name of the table
|
3277
|
+
*columns: SQLAlchemy Column objects and Index objects (including VectorIndex)
|
3278
|
+
**kwargs: Additional parameters (like enable_hnsw, enable_ivf)
|
3279
|
+
|
3280
|
+
Returns::
|
3281
|
+
|
3282
|
+
TransactionWrapper: Self for chaining
|
3283
|
+
"""
|
3284
|
+
from sqlalchemy import MetaData, Table
|
3285
|
+
from sqlalchemy.schema import CreateTable
|
3286
|
+
|
3287
|
+
# Create metadata and table
|
3288
|
+
metadata = MetaData()
|
3289
|
+
table = Table(table_name, metadata, *columns)
|
3290
|
+
|
3291
|
+
# Check if we need to enable HNSW or IVF indexing
|
3292
|
+
enable_hnsw = kwargs.get("enable_hnsw", False)
|
3293
|
+
enable_ivf = kwargs.get("enable_ivf", False)
|
3294
|
+
|
3295
|
+
# Check if table has vector indexes that need special handling
|
3296
|
+
has_hnsw_index = False
|
3297
|
+
has_ivf_index = False
|
3298
|
+
|
3299
|
+
for item in table.indexes:
|
3300
|
+
if hasattr(item, "index_type"):
|
3301
|
+
# Check for HNSW index type (string comparison)
|
3302
|
+
if str(item.index_type).lower() == "hnsw":
|
3303
|
+
has_hnsw_index = True
|
3304
|
+
elif str(item.index_type).lower() == "ivfflat":
|
3305
|
+
has_ivf_index = True
|
3306
|
+
|
3307
|
+
# Enable appropriate indexing if needed (within transaction)
|
3308
|
+
if has_hnsw_index or enable_hnsw:
|
3309
|
+
from .sqlalchemy_ext import create_hnsw_config
|
3310
|
+
|
3311
|
+
hnsw_config = create_hnsw_config(self.client._engine)
|
3312
|
+
hnsw_config.enable_hnsw_indexing()
|
3313
|
+
if has_ivf_index or enable_ivf:
|
3314
|
+
from .sqlalchemy_ext import create_ivf_config
|
3315
|
+
|
3316
|
+
ivf_config = create_ivf_config(self.client._engine)
|
3317
|
+
ivf_config.enable_ivf_indexing()
|
3318
|
+
|
3319
|
+
# Create table using transaction wrapper's execute method
|
3320
|
+
create_sql = CreateTable(table)
|
3321
|
+
sql = str(create_sql.compile(dialect=self.client.get_sqlalchemy_engine().dialect))
|
3322
|
+
self.execute(sql)
|
3323
|
+
|
3324
|
+
return self
|
3325
|
+
|
3326
|
+
|
3327
|
+
class TransactionSnapshotManager(SnapshotManager):
|
3328
|
+
"""Snapshot manager that executes operations within a transaction"""
|
3329
|
+
|
3330
|
+
def __init__(self, client, transaction_wrapper):
|
3331
|
+
super().__init__(client)
|
3332
|
+
self.transaction_wrapper = transaction_wrapper
|
3333
|
+
|
3334
|
+
def create(
|
3335
|
+
self,
|
3336
|
+
name: str,
|
3337
|
+
level: Union[str, SnapshotLevel],
|
3338
|
+
database: Optional[str] = None,
|
3339
|
+
table: Optional[str] = None,
|
3340
|
+
description: Optional[str] = None,
|
3341
|
+
) -> Snapshot:
|
3342
|
+
"""Create snapshot within transaction"""
|
3343
|
+
return super().create(name, level, database, table, description, self.transaction_wrapper)
|
3344
|
+
|
3345
|
+
def get(self, name: str) -> Snapshot:
|
3346
|
+
"""Get snapshot within transaction"""
|
3347
|
+
return super().get(name, self.transaction_wrapper)
|
3348
|
+
|
3349
|
+
def delete(self, name: str) -> None:
|
3350
|
+
"""Delete snapshot within transaction"""
|
3351
|
+
return super().delete(name, self.transaction_wrapper)
|
3352
|
+
|
3353
|
+
|
3354
|
+
class TransactionCloneManager(CloneManager):
|
3355
|
+
"""Clone manager that executes operations within a transaction"""
|
3356
|
+
|
3357
|
+
def __init__(self, client, transaction_wrapper):
|
3358
|
+
super().__init__(client)
|
3359
|
+
self.transaction_wrapper = transaction_wrapper
|
3360
|
+
|
3361
|
+
def clone_database(
|
3362
|
+
self,
|
3363
|
+
target_db: str,
|
3364
|
+
source_db: str,
|
3365
|
+
snapshot_name: Optional[str] = None,
|
3366
|
+
if_not_exists: bool = False,
|
3367
|
+
) -> None:
|
3368
|
+
"""Clone database within transaction"""
|
3369
|
+
return super().clone_database(target_db, source_db, snapshot_name, if_not_exists, self.transaction_wrapper)
|
3370
|
+
|
3371
|
+
def clone_table(
|
3372
|
+
self,
|
3373
|
+
target_table: str,
|
3374
|
+
source_table: str,
|
3375
|
+
snapshot_name: Optional[str] = None,
|
3376
|
+
if_not_exists: bool = False,
|
3377
|
+
) -> None:
|
3378
|
+
"""Clone table within transaction"""
|
3379
|
+
return super().clone_table(target_table, source_table, snapshot_name, if_not_exists, self.transaction_wrapper)
|
3380
|
+
|
3381
|
+
def clone_database_with_snapshot(
|
3382
|
+
self, target_db: str, source_db: str, snapshot_name: str, if_not_exists: bool = False
|
3383
|
+
) -> None:
|
3384
|
+
"""Clone database with snapshot within transaction"""
|
3385
|
+
return super().clone_database_with_snapshot(
|
3386
|
+
target_db, source_db, snapshot_name, if_not_exists, self.transaction_wrapper
|
3387
|
+
)
|
3388
|
+
|
3389
|
+
def clone_table_with_snapshot(
|
3390
|
+
self, target_table: str, source_table: str, snapshot_name: str, if_not_exists: bool = False
|
3391
|
+
) -> None:
|
3392
|
+
"""Clone table with snapshot within transaction"""
|
3393
|
+
return super().clone_table_with_snapshot(
|
3394
|
+
target_table, source_table, snapshot_name, if_not_exists, self.transaction_wrapper
|
3395
|
+
)
|
3396
|
+
|
3397
|
+
|
3398
|
+
class VectorManager:
|
3399
|
+
"""
|
3400
|
+
Unified vector manager for MatrixOne vector operations and chain operations.
|
3401
|
+
|
3402
|
+
This class provides comprehensive vector functionality including vector table
|
3403
|
+
creation, vector indexing, vector data operations, and vector similarity search.
|
3404
|
+
It supports both IVF (Inverted File) and HNSW (Hierarchical Navigable Small World)
|
3405
|
+
indexing algorithms for efficient vector similarity search.
|
3406
|
+
|
3407
|
+
Key Features:
|
3408
|
+
|
3409
|
+
- Vector table creation with configurable dimensions and precision
|
3410
|
+
- Vector index creation and management (IVF, HNSW)
|
3411
|
+
- Vector data insertion and batch operations
|
3412
|
+
- Vector similarity search with multiple distance metrics
|
3413
|
+
- Vector range search for distance-based filtering
|
3414
|
+
- Integration with MatrixOne's vector capabilities
|
3415
|
+
- Support for both f32 and f64 vector precision
|
3416
|
+
|
3417
|
+
Supported Index Types:
|
3418
|
+
- IVF (Inverted File): Good for large datasets, requires training
|
3419
|
+
- HNSW: Good for high-dimensional vectors, no training required
|
3420
|
+
|
3421
|
+
Supported Distance Metrics:
|
3422
|
+
- L2 (Euclidean) distance: Standard Euclidean distance
|
3423
|
+
- Cosine similarity: Cosine of the angle between vectors
|
3424
|
+
- Inner product: Dot product of vectors
|
3425
|
+
|
3426
|
+
Supported Operations:
|
3427
|
+
|
3428
|
+
- Vector table creation with various column types
|
3429
|
+
- Vector index creation with configurable parameters
|
3430
|
+
- Vector data insertion and batch operations
|
3431
|
+
- Vector similarity search and distance calculations
|
3432
|
+
- Vector range search for distance-based filtering
|
3433
|
+
- Vector index management and optimization
|
3434
|
+
|
3435
|
+
Usage Examples::
|
3436
|
+
|
3437
|
+
# Initialize vector manager
|
3438
|
+
vector_ops = client.vector_ops
|
3439
|
+
|
3440
|
+
# Create IVF index
|
3441
|
+
vector_ops.create_ivf(
|
3442
|
+
table_name="documents",
|
3443
|
+
name="idx_embedding_ivf",
|
3444
|
+
column="embedding",
|
3445
|
+
lists=100
|
3446
|
+
)
|
3447
|
+
|
3448
|
+
# Create HNSW index
|
3449
|
+
vector_ops.create_hnsw(
|
3450
|
+
table_name="documents",
|
3451
|
+
name="idx_embedding_hnsw",
|
3452
|
+
column="embedding",
|
3453
|
+
m=16,
|
3454
|
+
ef_construction=200
|
3455
|
+
)
|
3456
|
+
|
3457
|
+
# Similarity search
|
3458
|
+
results = vector_ops.similarity_search(
|
3459
|
+
table_name="documents",
|
3460
|
+
vector_column="embedding",
|
3461
|
+
query_vector=[0.1, 0.2, 0.3, ...],
|
3462
|
+
limit=10
|
3463
|
+
)
|
3464
|
+
|
3465
|
+
Note: Vector operations require appropriate vector data and indexing strategies. Vector dimensions
|
3466
|
+
and precision must match your embedding model requirements.
|
3467
|
+
"""
|
3468
|
+
|
3469
|
+
def __init__(self, client):
|
3470
|
+
self.client = client
|
3471
|
+
|
3472
|
+
def _get_ivf_index_table_names(
|
3473
|
+
self,
|
3474
|
+
database: str,
|
3475
|
+
table_name: str,
|
3476
|
+
column_name: str,
|
3477
|
+
connection: Connection,
|
3478
|
+
) -> Dict[str, str]:
|
3479
|
+
"""
|
3480
|
+
Get the table names of the IVF index tables.
|
3481
|
+
"""
|
3482
|
+
sql = (
|
3483
|
+
f"SELECT i.algo_table_type, i.index_table_name "
|
3484
|
+
f"FROM `mo_catalog`.`mo_indexes` AS i "
|
3485
|
+
f"JOIN `mo_catalog`.`mo_tables` AS t ON i.table_id = t.rel_id "
|
3486
|
+
f"AND i.column_name = '{column_name}' "
|
3487
|
+
f"AND t.relname = '{table_name}' "
|
3488
|
+
f"AND t.reldatabase = '{database}' "
|
3489
|
+
f"AND i.algo='ivfflat'"
|
3490
|
+
)
|
3491
|
+
result = self.client._execute_with_logging(connection, sql, context="Get IVF index table names")
|
3492
|
+
# +-----------------+-----------------------------------------------------------+
|
3493
|
+
# | algo_table_type | index_table_name |
|
3494
|
+
# +-----------------+-----------------------------------------------------------+
|
3495
|
+
# | metadata | __mo_index_secondary_01999b6b-414c-71dc-ab7f-8399ab06cb64 |
|
3496
|
+
# | centroids | __mo_index_secondary_01999b6b-414c-7311-98d2-9de6a0f591ee |
|
3497
|
+
# | entries | __mo_index_secondary_01999b6b-414c-7324-90f0-695d791d574f |
|
3498
|
+
# +-----------------+-----------------------------------------------------------+
|
3499
|
+
return {row[0]: row[1] for row in result}
|
3500
|
+
|
3501
|
+
def _get_ivf_buckets_distribution(
|
3502
|
+
self,
|
3503
|
+
database: str,
|
3504
|
+
table_name: str,
|
3505
|
+
connection: Connection,
|
3506
|
+
) -> Dict[str, List[int]]:
|
3507
|
+
"""
|
3508
|
+
Get the buckets distribution of the IVF index tables.
|
3509
|
+
"""
|
3510
|
+
sql = (
|
3511
|
+
f"SELECT "
|
3512
|
+
f" COUNT(*) AS centroid_count, "
|
3513
|
+
f" __mo_index_centroid_fk_id AS centroid_id, "
|
3514
|
+
f" __mo_index_centroid_fk_version AS centroid_version "
|
3515
|
+
f"FROM `{database}`.`{table_name}` "
|
3516
|
+
f"GROUP BY `__mo_index_centroid_fk_id`, `__mo_index_centroid_fk_version`"
|
3517
|
+
)
|
3518
|
+
result = self.client._execute_with_logging(connection, sql, context="Get IVF buckets distribution")
|
3519
|
+
rows = result.fetchall()
|
3520
|
+
# +----------------+-------------+------------------+
|
3521
|
+
# | centroid_count | centroid_id | centroid_version |
|
3522
|
+
# +----------------+-------------+------------------+
|
3523
|
+
# | 51 | 0 | 0 |
|
3524
|
+
# | 32 | 1 | 0 |
|
3525
|
+
# | 62 | 2 | 0 |
|
3526
|
+
# | 40 | 3 | 0 |
|
3527
|
+
# | 60 | 4 | 0 |
|
3528
|
+
# +----------------+-------------+------------------+
|
3529
|
+
|
3530
|
+
# Output:
|
3531
|
+
# {
|
3532
|
+
# "centroid_count": [51, 32, 62, 40, 60],
|
3533
|
+
# "centroid_id": [0, 1, 2, 3, 4],
|
3534
|
+
# "centroid_version": [0, 0, 0, 0, 0]
|
3535
|
+
# }
|
3536
|
+
|
3537
|
+
return {
|
3538
|
+
"centroid_count": [row[0] for row in rows],
|
3539
|
+
"centroid_id": [row[1] for row in rows],
|
3540
|
+
"centroid_version": [row[2] for row in rows],
|
3541
|
+
}
|
3542
|
+
|
3543
|
+
def create_ivf(
|
3544
|
+
self,
|
3545
|
+
table_name_or_model,
|
3546
|
+
name: str,
|
3547
|
+
column: str,
|
3548
|
+
lists: int = 100,
|
3549
|
+
op_type: VectorOpType = None,
|
3550
|
+
) -> "VectorManager":
|
3551
|
+
"""
|
3552
|
+
Create an IVFFLAT vector index using chain operations.
|
3553
|
+
|
3554
|
+
IVFFLAT (Inverted File with Flat Compression) is a vector index type
|
3555
|
+
that provides good performance for similarity search on large datasets.
|
3556
|
+
It supports insert, update, and delete operations.
|
3557
|
+
|
3558
|
+
Args::
|
3559
|
+
|
3560
|
+
table_name_or_model: Either a table name (str) or a SQLAlchemy model class
|
3561
|
+
name: Name of the index
|
3562
|
+
column: Vector column to index
|
3563
|
+
lists: Number of lists for IVFFLAT (default: 100). More lists for larger datasets
|
3564
|
+
op_type: Vector operation type (VectorOpType enum, default: VectorOpType.VECTOR_L2_OPS)
|
3565
|
+
|
3566
|
+
Returns::
|
3567
|
+
|
3568
|
+
VectorManager: Self for chaining
|
3569
|
+
|
3570
|
+
Example
|
3571
|
+
|
3572
|
+
# Create IVF index by table name
|
3573
|
+
client.vector_ops.create_ivf("documents", "idx_embedding", "embedding", lists=50)
|
3574
|
+
|
3575
|
+
# Create IVF index by model class
|
3576
|
+
client.vector_ops.create_ivf(DocumentModel, "idx_embedding", "embedding", lists=100)
|
3577
|
+
"""
|
3578
|
+
from .sqlalchemy_ext import IVFVectorIndex, VectorOpType
|
3579
|
+
|
3580
|
+
# Handle model class input
|
3581
|
+
if hasattr(table_name_or_model, '__tablename__'):
|
3582
|
+
table_name = table_name_or_model.__tablename__
|
3583
|
+
else:
|
3584
|
+
table_name = table_name_or_model
|
3585
|
+
|
3586
|
+
# Use default if not provided
|
3587
|
+
if op_type is None:
|
3588
|
+
op_type = VectorOpType.VECTOR_L2_OPS
|
3589
|
+
|
3590
|
+
success = IVFVectorIndex.create_index(
|
3591
|
+
engine=self.client.get_sqlalchemy_engine(),
|
3592
|
+
table_name=table_name,
|
3593
|
+
name=name,
|
3594
|
+
column=column,
|
3595
|
+
lists=lists,
|
3596
|
+
op_type=op_type,
|
3597
|
+
)
|
3598
|
+
|
3599
|
+
if not success:
|
3600
|
+
raise Exception(f"Failed to create IVFFLAT vector index {name} on table {table_name}")
|
3601
|
+
|
3602
|
+
return self
|
3603
|
+
|
3604
|
+
def create_hnsw(
|
3605
|
+
self,
|
3606
|
+
table_name_or_model,
|
3607
|
+
name: str,
|
3608
|
+
column: str,
|
3609
|
+
m: int = 16,
|
3610
|
+
ef_construction: int = 200,
|
3611
|
+
ef_search: int = 50,
|
3612
|
+
op_type: VectorOpType = None,
|
3613
|
+
) -> "VectorManager":
|
3614
|
+
"""
|
3615
|
+
Create an HNSW vector index using chain operations.
|
3616
|
+
|
3617
|
+
HNSW (Hierarchical Navigable Small World) is a vector index type
|
3618
|
+
that provides excellent search performance but is read-only.
|
3619
|
+
It does not support insert, update, or delete operations.
|
3620
|
+
|
3621
|
+
Args::
|
3622
|
+
|
3623
|
+
table_name_or_model: Either a table name (str) or a SQLAlchemy model class
|
3624
|
+
name: Name of the index
|
3625
|
+
column: Vector column to index
|
3626
|
+
m: Number of bi-directional links for HNSW (default: 16)
|
3627
|
+
ef_construction: Size of dynamic candidate list for HNSW construction (default: 200)
|
3628
|
+
ef_search: Size of dynamic candidate list for HNSW search (default: 50)
|
3629
|
+
op_type: Vector operation type (VectorOpType enum, default: VectorOpType.VECTOR_L2_OPS)
|
3630
|
+
|
3631
|
+
Returns::
|
3632
|
+
|
3633
|
+
VectorManager: Self for chaining
|
3634
|
+
|
3635
|
+
Example
|
3636
|
+
|
3637
|
+
# Create HNSW index by table name
|
3638
|
+
client.vector_ops.create_hnsw("documents", "idx_embedding", "embedding", m=32)
|
3639
|
+
|
3640
|
+
# Create HNSW index by model class
|
3641
|
+
client.vector_ops.create_hnsw(DocumentModel, "idx_embedding", "embedding", m=16)
|
3642
|
+
"""
|
3643
|
+
from .sqlalchemy_ext import HnswVectorIndex, VectorOpType
|
3644
|
+
|
3645
|
+
# Handle model class input
|
3646
|
+
if hasattr(table_name_or_model, '__tablename__'):
|
3647
|
+
table_name = table_name_or_model.__tablename__
|
3648
|
+
else:
|
3649
|
+
table_name = table_name_or_model
|
3650
|
+
|
3651
|
+
# Use default if not provided
|
3652
|
+
if op_type is None:
|
3653
|
+
op_type = VectorOpType.VECTOR_L2_OPS
|
3654
|
+
|
3655
|
+
success = HnswVectorIndex.create_index(
|
3656
|
+
engine=self.client.get_sqlalchemy_engine(),
|
3657
|
+
table_name=table_name,
|
3658
|
+
name=name,
|
3659
|
+
column=column,
|
3660
|
+
m=m,
|
3661
|
+
ef_construction=ef_construction,
|
3662
|
+
ef_search=ef_search,
|
3663
|
+
op_type=op_type,
|
3664
|
+
)
|
3665
|
+
|
3666
|
+
if not success:
|
3667
|
+
raise Exception(f"Failed to create HNSW vector index {name} on table {table_name}")
|
3668
|
+
|
3669
|
+
return self
|
3670
|
+
|
3671
|
+
def create_ivf_in_transaction(
|
3672
|
+
self,
|
3673
|
+
table_name: str,
|
3674
|
+
name: str,
|
3675
|
+
column: str,
|
3676
|
+
connection,
|
3677
|
+
lists: int = 100,
|
3678
|
+
op_type: str = "vector_l2_ops",
|
3679
|
+
) -> "VectorManager":
|
3680
|
+
"""
|
3681
|
+
Create an IVFFLAT vector index within an existing SQLAlchemy transaction.
|
3682
|
+
|
3683
|
+
Args::
|
3684
|
+
|
3685
|
+
table_name: Name of the table
|
3686
|
+
name: Name of the index
|
3687
|
+
column: Vector column to index
|
3688
|
+
connection: SQLAlchemy connection object (required for transaction support)
|
3689
|
+
lists: Number of lists for IVFFLAT (default: 100)
|
3690
|
+
op_type: Vector operation type (default: vector_l2_ops)
|
3691
|
+
|
3692
|
+
Returns::
|
3693
|
+
|
3694
|
+
VectorManager: Self for chaining
|
3695
|
+
|
3696
|
+
Raises::
|
3697
|
+
|
3698
|
+
ValueError: If connection is not provided
|
3699
|
+
"""
|
3700
|
+
if connection is None:
|
3701
|
+
raise ValueError("connection parameter is required for transaction operations")
|
3702
|
+
|
3703
|
+
# Enable IVF indexing if needed
|
3704
|
+
from .sqlalchemy_ext import create_ivf_config
|
3705
|
+
|
3706
|
+
ivf_config = create_ivf_config(self.client._engine)
|
3707
|
+
ivf_config.enable_ivf_indexing()
|
3708
|
+
ivf_config.set_probe_limit(1)
|
3709
|
+
|
3710
|
+
# Build CREATE INDEX statement
|
3711
|
+
sql = f"CREATE INDEX {name} USING ivfflat ON {table_name}({column}) LISTS {lists} op_type '{op_type}'"
|
3712
|
+
|
3713
|
+
# Create the index
|
3714
|
+
if hasattr(connection, 'exec_driver_sql'):
|
3715
|
+
connection.exec_driver_sql(sql)
|
3716
|
+
else:
|
3717
|
+
from sqlalchemy import text
|
3718
|
+
|
3719
|
+
connection.execute(text(sql))
|
3720
|
+
|
3721
|
+
return self
|
3722
|
+
|
3723
|
+
def create_hnsw_in_transaction(
|
3724
|
+
self,
|
3725
|
+
table_name: str,
|
3726
|
+
name: str,
|
3727
|
+
column: str,
|
3728
|
+
connection,
|
3729
|
+
m: int = 16,
|
3730
|
+
ef_construction: int = 200,
|
3731
|
+
ef_search: int = 50,
|
3732
|
+
op_type: str = "vector_l2_ops",
|
3733
|
+
) -> "VectorManager":
|
3734
|
+
"""
|
3735
|
+
Create an HNSW vector index within an existing SQLAlchemy transaction.
|
3736
|
+
|
3737
|
+
Args::
|
3738
|
+
|
3739
|
+
table_name: Name of the table
|
3740
|
+
name: Name of the index
|
3741
|
+
column: Vector column to index
|
3742
|
+
connection: SQLAlchemy connection object (required for transaction support)
|
3743
|
+
m: Number of bi-directional links for HNSW (default: 16)
|
3744
|
+
ef_construction: Size of dynamic candidate list for HNSW construction (default: 200)
|
3745
|
+
ef_search: Size of dynamic candidate list for HNSW search (default: 50)
|
3746
|
+
op_type: Vector operation type (default: vector_l2_ops)
|
3747
|
+
|
3748
|
+
Returns::
|
3749
|
+
|
3750
|
+
VectorManager: Self for chaining
|
3751
|
+
|
3752
|
+
Raises::
|
3753
|
+
|
3754
|
+
ValueError: If connection is not provided
|
3755
|
+
"""
|
3756
|
+
if connection is None:
|
3757
|
+
raise ValueError("connection parameter is required for transaction operations")
|
3758
|
+
|
3759
|
+
# Enable HNSW indexing if needed
|
3760
|
+
from .sqlalchemy_ext import create_hnsw_config
|
3761
|
+
|
3762
|
+
hnsw_config = create_hnsw_config(self.client._engine)
|
3763
|
+
hnsw_config.enable_hnsw_indexing(connection)
|
3764
|
+
|
3765
|
+
# Build CREATE INDEX statement
|
3766
|
+
sql = (
|
3767
|
+
f"CREATE INDEX {name} USING hnsw ON {table_name}({column}) "
|
3768
|
+
f"M {m} EF_CONSTRUCTION {ef_construction} EF_SEARCH {ef_search} "
|
3769
|
+
f"op_type '{op_type}'"
|
3770
|
+
)
|
3771
|
+
|
3772
|
+
# Create the index
|
3773
|
+
if hasattr(connection, 'exec_driver_sql'):
|
3774
|
+
connection.exec_driver_sql(sql)
|
3775
|
+
else:
|
3776
|
+
from sqlalchemy import text
|
3777
|
+
|
3778
|
+
connection.execute(text(sql))
|
3779
|
+
|
3780
|
+
return self
|
3781
|
+
|
3782
|
+
def drop(self, table_name_or_model, name: str) -> "VectorManager":
|
3783
|
+
"""
|
3784
|
+
Drop a vector index using chain operations.
|
3785
|
+
|
3786
|
+
Args::
|
3787
|
+
|
3788
|
+
table_name_or_model: Either a table name (str) or a SQLAlchemy model class
|
3789
|
+
name: Name of the index to drop
|
3790
|
+
|
3791
|
+
Returns::
|
3792
|
+
|
3793
|
+
VectorManager: Self for chaining
|
3794
|
+
"""
|
3795
|
+
from .sqlalchemy_ext import VectorIndex
|
3796
|
+
|
3797
|
+
# Handle model class input
|
3798
|
+
if hasattr(table_name_or_model, '__tablename__'):
|
3799
|
+
table_name = table_name_or_model.__tablename__
|
3800
|
+
else:
|
3801
|
+
table_name = table_name_or_model
|
3802
|
+
|
3803
|
+
success = VectorIndex.drop_index(engine=self.client.get_sqlalchemy_engine(), table_name=table_name, name=name)
|
3804
|
+
|
3805
|
+
if not success:
|
3806
|
+
raise Exception(f"Failed to drop vector index {name} from table {table_name}")
|
3807
|
+
|
3808
|
+
return self
|
3809
|
+
|
3810
|
+
def enable_ivf(self, probe_limit: int = 1) -> "VectorManager":
|
3811
|
+
"""
|
3812
|
+
Enable IVF indexing with chain operations.
|
3813
|
+
|
3814
|
+
Args::
|
3815
|
+
|
3816
|
+
probe_limit: Probe limit for IVF search
|
3817
|
+
|
3818
|
+
Returns::
|
3819
|
+
|
3820
|
+
VectorManager: Self for chaining
|
3821
|
+
"""
|
3822
|
+
from .sqlalchemy_ext import create_ivf_config
|
3823
|
+
|
3824
|
+
ivf_config = create_ivf_config(self.client.get_sqlalchemy_engine())
|
3825
|
+
if not ivf_config.is_ivf_supported():
|
3826
|
+
raise Exception("IVF indexing is not supported in this MatrixOne version")
|
3827
|
+
|
3828
|
+
if not ivf_config.enable_ivf_indexing():
|
3829
|
+
raise Exception("Failed to enable IVF indexing")
|
3830
|
+
|
3831
|
+
if not ivf_config.set_probe_limit(probe_limit):
|
3832
|
+
raise Exception("Failed to set probe limit")
|
3833
|
+
|
3834
|
+
return self
|
3835
|
+
|
3836
|
+
def disable_ivf(self) -> "VectorManager":
|
3837
|
+
"""
|
3838
|
+
Disable IVF indexing with chain operations.
|
3839
|
+
|
3840
|
+
Returns::
|
3841
|
+
|
3842
|
+
VectorManager: Self for chaining
|
3843
|
+
"""
|
3844
|
+
from .sqlalchemy_ext import create_ivf_config
|
3845
|
+
|
3846
|
+
ivf_config = create_ivf_config(self.client.get_sqlalchemy_engine())
|
3847
|
+
if not ivf_config.disable_ivf_indexing():
|
3848
|
+
raise Exception("Failed to disable IVF indexing")
|
3849
|
+
|
3850
|
+
return self
|
3851
|
+
|
3852
|
+
def enable_hnsw(self) -> "VectorManager":
|
3853
|
+
"""
|
3854
|
+
Enable HNSW indexing with chain operations.
|
3855
|
+
|
3856
|
+
Returns::
|
3857
|
+
|
3858
|
+
VectorManager: Self for chaining
|
3859
|
+
"""
|
3860
|
+
from .sqlalchemy_ext import create_hnsw_config
|
3861
|
+
|
3862
|
+
hnsw_config = create_hnsw_config(self.client.get_sqlalchemy_engine())
|
3863
|
+
if not hnsw_config.enable_hnsw_indexing():
|
3864
|
+
raise Exception("Failed to enable HNSW indexing")
|
3865
|
+
|
3866
|
+
return self
|
3867
|
+
|
3868
|
+
def disable_hnsw(self) -> "VectorManager":
|
3869
|
+
"""
|
3870
|
+
Disable HNSW indexing with chain operations.
|
3871
|
+
|
3872
|
+
Returns::
|
3873
|
+
|
3874
|
+
VectorManager: Self for chaining
|
3875
|
+
"""
|
3876
|
+
from .sqlalchemy_ext import create_hnsw_config
|
3877
|
+
|
3878
|
+
hnsw_config = create_hnsw_config(self.client.get_sqlalchemy_engine())
|
3879
|
+
if not hnsw_config.disable_hnsw_indexing():
|
3880
|
+
raise Exception("Failed to disable HNSW indexing")
|
3881
|
+
|
3882
|
+
return self
|
3883
|
+
|
3884
|
+
# Data operations
|
3885
|
+
def insert(self, table_name: str, data: dict) -> "VectorManager":
|
3886
|
+
"""
|
3887
|
+
Insert vector data using chain operations.
|
3888
|
+
|
3889
|
+
Args::
|
3890
|
+
|
3891
|
+
table_name: Name of the table
|
3892
|
+
data: Data to insert (dict with column names as keys)
|
3893
|
+
|
3894
|
+
Returns::
|
3895
|
+
|
3896
|
+
VectorManager: Self for chaining
|
3897
|
+
"""
|
3898
|
+
self.client.insert(table_name, data)
|
3899
|
+
return self
|
3900
|
+
|
3901
|
+
def insert_in_transaction(self, table_name: str, data: dict, connection) -> "VectorManager":
|
3902
|
+
"""
|
3903
|
+
Insert vector data within an existing SQLAlchemy transaction.
|
3904
|
+
|
3905
|
+
Args::
|
3906
|
+
|
3907
|
+
table_name: Name of the table
|
3908
|
+
data: Data to insert (dict with column names as keys)
|
3909
|
+
connection: SQLAlchemy connection object (required for transaction support)
|
3910
|
+
|
3911
|
+
Returns::
|
3912
|
+
|
3913
|
+
VectorManager: Self for chaining
|
3914
|
+
|
3915
|
+
Raises::
|
3916
|
+
|
3917
|
+
ValueError: If connection is not provided
|
3918
|
+
"""
|
3919
|
+
if connection is None:
|
3920
|
+
raise ValueError("connection parameter is required for transaction operations")
|
3921
|
+
|
3922
|
+
# Build INSERT statement
|
3923
|
+
columns = list(data.keys())
|
3924
|
+
values = list(data.values())
|
3925
|
+
|
3926
|
+
# Convert vectors to string format
|
3927
|
+
formatted_values = []
|
3928
|
+
for value in values:
|
3929
|
+
if isinstance(value, list):
|
3930
|
+
formatted_values.append("[" + ",".join(map(str, value)) + "]")
|
3931
|
+
else:
|
3932
|
+
# Escape single quotes
|
3933
|
+
formatted_values.append(str(value).replace("'", "''"))
|
3934
|
+
|
3935
|
+
columns_str = ", ".join(columns)
|
3936
|
+
values_str = ", ".join([f"'{v}'" for v in formatted_values])
|
3937
|
+
|
3938
|
+
sql = f"INSERT INTO {table_name} ({columns_str}) VALUES ({values_str})"
|
3939
|
+
|
3940
|
+
if hasattr(connection, 'exec_driver_sql'):
|
3941
|
+
connection.exec_driver_sql(sql)
|
3942
|
+
else:
|
3943
|
+
from sqlalchemy import text
|
3944
|
+
|
3945
|
+
connection.execute(text(sql))
|
3946
|
+
|
3947
|
+
return self
|
3948
|
+
|
3949
|
+
def batch_insert(self, table_name: str, data_list: list) -> "VectorManager":
|
3950
|
+
"""
|
3951
|
+
Batch insert vector data using chain operations.
|
3952
|
+
|
3953
|
+
Args::
|
3954
|
+
|
3955
|
+
table_name: Name of the table
|
3956
|
+
data_list: List of data dictionaries to insert
|
3957
|
+
|
3958
|
+
Returns::
|
3959
|
+
|
3960
|
+
VectorManager: Self for chaining
|
3961
|
+
"""
|
3962
|
+
self.client.batch_insert(table_name, data_list)
|
3963
|
+
return self
|
3964
|
+
|
3965
|
+
def similarity_search(
|
3966
|
+
self,
|
3967
|
+
table_name_or_model,
|
3968
|
+
vector_column: str,
|
3969
|
+
query_vector: list,
|
3970
|
+
limit: int = 10,
|
3971
|
+
distance_type: str = "l2",
|
3972
|
+
select_columns: list = None,
|
3973
|
+
where_conditions: list = None,
|
3974
|
+
where_params: list = None,
|
3975
|
+
connection=None,
|
3976
|
+
_log_mode: str = None,
|
3977
|
+
) -> list:
|
3978
|
+
"""
|
3979
|
+
Perform similarity search using chain operations.
|
3980
|
+
|
3981
|
+
Args::
|
3982
|
+
|
3983
|
+
table_name_or_model: Either a table name (str) or a SQLAlchemy model class
|
3984
|
+
vector_column: Name of the vector column
|
3985
|
+
query_vector: Query vector as list
|
3986
|
+
limit: Number of results to return
|
3987
|
+
distance_type: Type of distance calculation (l2, cosine, inner_product)
|
3988
|
+
select_columns: List of columns to select (None means all columns)
|
3989
|
+
where_conditions: List of WHERE conditions
|
3990
|
+
where_params: List of parameters for WHERE conditions
|
3991
|
+
connection: Optional existing database connection (for transaction support)
|
3992
|
+
_log_mode: Override SQL logging mode for this operation ('off', 'auto', 'simple', 'full')
|
3993
|
+
|
3994
|
+
Returns::
|
3995
|
+
|
3996
|
+
List of search results
|
3997
|
+
|
3998
|
+
Example::
|
3999
|
+
|
4000
|
+
# Basic similarity search
|
4001
|
+
results = client.vector_ops.similarity_search(
|
4002
|
+
"documents", "embedding", [0.1, 0.2, 0.3], limit=5
|
4003
|
+
)
|
4004
|
+
|
4005
|
+
# Search with filtering
|
4006
|
+
results = client.vector_ops.similarity_search(
|
4007
|
+
"documents", "embedding", [0.1, 0.2, 0.3], limit=5,
|
4008
|
+
where_conditions=["category = ?"], where_params=["AI"]
|
4009
|
+
)
|
4010
|
+
"""
|
4011
|
+
from .sql_builder import DistanceFunction, build_vector_similarity_query
|
4012
|
+
|
4013
|
+
# Handle model class input
|
4014
|
+
if hasattr(table_name_or_model, '__tablename__'):
|
4015
|
+
table_name = table_name_or_model.__tablename__
|
4016
|
+
else:
|
4017
|
+
table_name = table_name_or_model
|
4018
|
+
|
4019
|
+
# Convert distance type to enum
|
4020
|
+
if distance_type == "l2":
|
4021
|
+
distance_func = DistanceFunction.L2_SQ
|
4022
|
+
elif distance_type == "cosine":
|
4023
|
+
distance_func = DistanceFunction.COSINE
|
4024
|
+
elif distance_type == "inner_product":
|
4025
|
+
distance_func = DistanceFunction.INNER_PRODUCT
|
4026
|
+
else:
|
4027
|
+
raise ValueError(f"Unsupported distance type: {distance_type}")
|
4028
|
+
|
4029
|
+
# Build query using unified SQL builder
|
4030
|
+
sql = build_vector_similarity_query(
|
4031
|
+
table_name=table_name,
|
4032
|
+
vector_column=vector_column,
|
4033
|
+
query_vector=query_vector,
|
4034
|
+
distance_func=distance_func,
|
4035
|
+
limit=limit,
|
4036
|
+
select_columns=select_columns,
|
4037
|
+
where_conditions=where_conditions,
|
4038
|
+
where_params=where_params,
|
4039
|
+
)
|
4040
|
+
|
4041
|
+
if connection is not None:
|
4042
|
+
# Use existing connection (for transaction support)
|
4043
|
+
result = self.client._execute_with_logging(
|
4044
|
+
connection, sql, context="Vector similarity search", override_sql_log_mode=_log_mode
|
4045
|
+
)
|
4046
|
+
return result.fetchall()
|
4047
|
+
else:
|
4048
|
+
# Create new connection
|
4049
|
+
with self.client.get_sqlalchemy_engine().begin() as conn:
|
4050
|
+
result = self.client._execute_with_logging(
|
4051
|
+
conn, sql, context="Vector similarity search", override_sql_log_mode=_log_mode
|
4052
|
+
)
|
4053
|
+
return result.fetchall()
|
4054
|
+
|
4055
|
+
def range_search(
|
4056
|
+
self,
|
4057
|
+
table_name: str,
|
4058
|
+
vector_column: str,
|
4059
|
+
query_vector: list,
|
4060
|
+
max_distance: float,
|
4061
|
+
distance_type: str = "l2",
|
4062
|
+
select_columns: list = None,
|
4063
|
+
connection=None,
|
4064
|
+
) -> list:
|
4065
|
+
"""
|
4066
|
+
Perform range search using chain operations.
|
4067
|
+
|
4068
|
+
Args::
|
4069
|
+
|
4070
|
+
table_name: Name of the table
|
4071
|
+
vector_column: Name of the vector column
|
4072
|
+
query_vector: Query vector as list
|
4073
|
+
max_distance: Maximum distance threshold
|
4074
|
+
distance_type: Type of distance calculation
|
4075
|
+
select_columns: List of columns to select (None means all columns)
|
4076
|
+
connection: Optional existing database connection (for transaction support)
|
4077
|
+
|
4078
|
+
Returns::
|
4079
|
+
|
4080
|
+
List of search results within range
|
4081
|
+
"""
|
4082
|
+
# Convert vector to string format
|
4083
|
+
vector_str = "[" + ",".join(map(str, query_vector)) + "]"
|
4084
|
+
|
4085
|
+
# Build distance function based on type
|
4086
|
+
if distance_type == "l2":
|
4087
|
+
distance_func = "l2_distance"
|
4088
|
+
elif distance_type == "cosine":
|
4089
|
+
distance_func = "cosine_distance"
|
4090
|
+
elif distance_type == "inner_product":
|
4091
|
+
distance_func = "inner_product"
|
4092
|
+
else:
|
4093
|
+
raise ValueError(f"Unsupported distance type: {distance_type}")
|
4094
|
+
|
4095
|
+
# Build SELECT clause
|
4096
|
+
if select_columns is None:
|
4097
|
+
select_clause = "*"
|
4098
|
+
else:
|
4099
|
+
# Ensure vector_column is included for distance calculation
|
4100
|
+
columns_to_select = list(select_columns)
|
4101
|
+
if vector_column not in columns_to_select:
|
4102
|
+
columns_to_select.append(vector_column)
|
4103
|
+
select_clause = ", ".join(columns_to_select)
|
4104
|
+
|
4105
|
+
# Build SQL query
|
4106
|
+
sql = f"""
|
4107
|
+
SELECT {select_clause}, {distance_func}({vector_column}, '{vector_str}') as distance
|
4108
|
+
FROM {table_name}
|
4109
|
+
WHERE {distance_func}({vector_column}, '{vector_str}') <= {max_distance}
|
4110
|
+
ORDER BY distance
|
4111
|
+
"""
|
4112
|
+
|
4113
|
+
if connection is not None:
|
4114
|
+
# Use existing connection (for transaction support)
|
4115
|
+
result = self.client._execute_with_logging(connection, sql, context="Vector range search")
|
4116
|
+
return result.fetchall()
|
4117
|
+
else:
|
4118
|
+
# Create new connection
|
4119
|
+
with self.client.get_sqlalchemy_engine().begin() as conn:
|
4120
|
+
result = self.client._execute_with_logging(conn, sql, context="Vector range search")
|
4121
|
+
return result.fetchall()
|
4122
|
+
|
4123
|
+
def get_ivf_stats(self, table_name_or_model, column_name: str = None) -> Dict[str, Any]:
|
4124
|
+
"""
|
4125
|
+
Get IVF index statistics for monitoring and optimization.
|
4126
|
+
|
4127
|
+
This method provides critical insights into IVF index health and performance.
|
4128
|
+
It helps evaluate whether the current IVF index configuration is optimal
|
4129
|
+
and whether the index needs to be rebuilt.
|
4130
|
+
|
4131
|
+
Key Use Cases:
|
4132
|
+
|
4133
|
+
- **Index Health Monitoring**: Check if centroid count matches expected lists parameter
|
4134
|
+
- **Load Balancing Analysis**: Evaluate if vectors are evenly distributed across centroids
|
4135
|
+
- **Performance Optimization**: Identify when to rebuild the index for better performance
|
4136
|
+
- **Capacity Planning**: Understand data distribution patterns
|
4137
|
+
|
4138
|
+
Critical Metrics to Monitor:
|
4139
|
+
|
4140
|
+
- **Centroid Count**: Should match the 'lists' parameter used during index creation
|
4141
|
+
- **Load Distribution**: Each centroid should have roughly equal numbers of vectors
|
4142
|
+
- **Centroid Versions**: Should be consistent (usually all 0 for stable indexes)
|
4143
|
+
|
4144
|
+
When to Rebuild Index:
|
4145
|
+
|
4146
|
+
- Centroid count doesn't match expected lists parameter
|
4147
|
+
- Significant imbalance in centroid load distribution (>2x difference between min/max)
|
4148
|
+
- Performance degradation in similarity search queries
|
4149
|
+
- After major data changes (bulk inserts, updates, deletes)
|
4150
|
+
|
4151
|
+
Args::
|
4152
|
+
|
4153
|
+
table_name_or_model: Either a table name (str) or a SQLAlchemy model class
|
4154
|
+
column_name: Name of the vector column (optional, will be inferred if not provided)
|
4155
|
+
|
4156
|
+
Returns::
|
4157
|
+
|
4158
|
+
Dict containing IVF index statistics including:
|
4159
|
+
- index_tables: Dictionary mapping table types to table names
|
4160
|
+
- distribution: Dictionary containing:
|
4161
|
+
- centroid_count: List of row counts per centroid
|
4162
|
+
- centroid_id: List of centroid identifiers
|
4163
|
+
- centroid_version: List of centroid versions
|
4164
|
+
- database: Database name
|
4165
|
+
- table_name: Table name
|
4166
|
+
- column_name: Vector column name
|
4167
|
+
|
4168
|
+
Raises::
|
4169
|
+
|
4170
|
+
Exception: If IVF index is not found or if there are errors retrieving stats
|
4171
|
+
|
4172
|
+
Examples::
|
4173
|
+
|
4174
|
+
# Monitor index health
|
4175
|
+
stats = client.vector_ops.get_ivf_stats("documents", "embedding")
|
4176
|
+
|
4177
|
+
# Check centroid distribution
|
4178
|
+
centroid_counts = stats['distribution']['centroid_count']
|
4179
|
+
total_centroids = len(centroid_counts)
|
4180
|
+
min_count = min(centroid_counts)
|
4181
|
+
max_count = max(centroid_counts)
|
4182
|
+
|
4183
|
+
print(f"Total centroids: {total_centroids}")
|
4184
|
+
print(f"Min vectors per centroid: {min_count}")
|
4185
|
+
print(f"Max vectors per centroid: {max_count}")
|
4186
|
+
print(f"Load balance ratio: {max_count/min_count:.2f}")
|
4187
|
+
|
4188
|
+
# Check if index needs rebuilding
|
4189
|
+
expected_centroids = 100 # Original lists parameter
|
4190
|
+
if total_centroids != expected_centroids:
|
4191
|
+
print(f"⚠️ Centroid count mismatch! Expected: {expected_centroids}, Actual: {total_centroids}")
|
4192
|
+
|
4193
|
+
if max_count / min_count > 2.0:
|
4194
|
+
print("⚠️ Poor load balance! Consider rebuilding index.")
|
4195
|
+
|
4196
|
+
# Get stats using model class
|
4197
|
+
stats = client.vector_ops.get_ivf_stats(MyModel, "vector_col")
|
4198
|
+
"""
|
4199
|
+
# Handle model class input
|
4200
|
+
if hasattr(table_name_or_model, '__tablename__'):
|
4201
|
+
table_name = table_name_or_model.__tablename__
|
4202
|
+
else:
|
4203
|
+
table_name = table_name_or_model
|
4204
|
+
|
4205
|
+
# Get database name from connection params
|
4206
|
+
database = self.client._connection_params.get('database')
|
4207
|
+
if not database:
|
4208
|
+
raise Exception("No database connection found. Please connect to a database first.")
|
4209
|
+
|
4210
|
+
# If column_name is not provided, try to infer it
|
4211
|
+
if not column_name:
|
4212
|
+
# Query the table schema to find vector columns
|
4213
|
+
with self.client.get_sqlalchemy_engine().begin() as conn:
|
4214
|
+
schema_sql = (
|
4215
|
+
f"SELECT column_name, data_type "
|
4216
|
+
f"FROM information_schema.columns "
|
4217
|
+
f"WHERE table_schema = '{database}' "
|
4218
|
+
f"AND table_name = '{table_name}' "
|
4219
|
+
f"AND (data_type LIKE '%VEC%' OR data_type LIKE '%vec%')"
|
4220
|
+
)
|
4221
|
+
result = self.client._execute_with_logging(conn, schema_sql, context="Auto-detect vector column")
|
4222
|
+
vector_columns = result.fetchall()
|
4223
|
+
|
4224
|
+
if not vector_columns:
|
4225
|
+
raise Exception(f"No vector columns found in table {table_name}")
|
4226
|
+
elif len(vector_columns) == 1:
|
4227
|
+
column_name = vector_columns[0][0]
|
4228
|
+
else:
|
4229
|
+
# Multiple vector columns found, raise error asking user to specify
|
4230
|
+
column_names = [col[0] for col in vector_columns]
|
4231
|
+
raise Exception(
|
4232
|
+
f"Multiple vector columns found in table {table_name}: {column_names}. "
|
4233
|
+
f"Please specify the column_name parameter."
|
4234
|
+
)
|
4235
|
+
|
4236
|
+
# Get IVF index table names
|
4237
|
+
with self.client.get_sqlalchemy_engine().begin() as conn:
|
4238
|
+
index_tables = self._get_ivf_index_table_names(database, table_name, column_name, conn)
|
4239
|
+
|
4240
|
+
if not index_tables:
|
4241
|
+
raise Exception(f"No IVF index found for table {table_name}, column {column_name}")
|
4242
|
+
|
4243
|
+
# Get the entries table name for distribution analysis
|
4244
|
+
entries_table = index_tables.get('entries')
|
4245
|
+
if not entries_table:
|
4246
|
+
raise Exception("No entries table found in IVF index")
|
4247
|
+
|
4248
|
+
# Get bucket distribution
|
4249
|
+
distribution = self._get_ivf_buckets_distribution(database, entries_table, conn)
|
4250
|
+
|
4251
|
+
return {
|
4252
|
+
'index_tables': index_tables,
|
4253
|
+
'distribution': distribution,
|
4254
|
+
'database': database,
|
4255
|
+
'table_name': table_name,
|
4256
|
+
'column_name': column_name,
|
4257
|
+
}
|
4258
|
+
|
4259
|
+
|
4260
|
+
class TransactionVectorIndexManager(VectorManager):
|
4261
|
+
"""Vector index manager that executes operations within a transaction"""
|
4262
|
+
|
4263
|
+
def __init__(self, client, transaction_wrapper):
|
4264
|
+
super().__init__(client)
|
4265
|
+
self.transaction_wrapper = transaction_wrapper
|
4266
|
+
|
4267
|
+
def execute(self, sql: str, params: Optional[Tuple] = None) -> ResultSet:
|
4268
|
+
"""Execute SQL within transaction"""
|
4269
|
+
return self.transaction_wrapper.execute(sql, params)
|
4270
|
+
|
4271
|
+
def create_ivf(
|
4272
|
+
self,
|
4273
|
+
table_name: str,
|
4274
|
+
name: str,
|
4275
|
+
column: str,
|
4276
|
+
lists: int = 100,
|
4277
|
+
op_type: VectorOpType = None,
|
4278
|
+
) -> "TransactionVectorIndexManager":
|
4279
|
+
"""Create an IVFFLAT vector index within transaction"""
|
4280
|
+
from .sqlalchemy_ext import VectorIndex, VectorIndexType, VectorOpType
|
4281
|
+
|
4282
|
+
# Use default if not provided
|
4283
|
+
index_type = VectorIndexType.IVFFLAT
|
4284
|
+
if op_type is None:
|
4285
|
+
op_type = VectorOpType.VECTOR_L2_OPS
|
4286
|
+
|
4287
|
+
# Create index using transaction wrapper's execute method
|
4288
|
+
index = VectorIndex(name, column, index_type, lists, op_type)
|
4289
|
+
|
4290
|
+
try:
|
4291
|
+
# Enable IVF indexing within transaction
|
4292
|
+
self.transaction_wrapper.execute("SET experimental_ivf_index = 1")
|
4293
|
+
self.transaction_wrapper.execute("SET probe_limit = 1")
|
4294
|
+
|
4295
|
+
sql = index.create_sql(table_name)
|
4296
|
+
self.transaction_wrapper.execute(sql)
|
4297
|
+
return self
|
4298
|
+
except Exception as e:
|
4299
|
+
raise Exception(f"Failed to create IVFFLAT vector index {name} on table {table_name} in transaction: {e}")
|
4300
|
+
|
4301
|
+
def create_hnsw(
|
4302
|
+
self,
|
4303
|
+
table_name: str,
|
4304
|
+
name: str,
|
4305
|
+
column: str,
|
4306
|
+
m: int = 16,
|
4307
|
+
ef_construction: int = 200,
|
4308
|
+
ef_search: int = 50,
|
4309
|
+
op_type: VectorOpType = None,
|
4310
|
+
) -> "TransactionVectorIndexManager":
|
4311
|
+
"""Create an HNSW vector index within transaction"""
|
4312
|
+
from .sqlalchemy_ext import VectorIndex, VectorIndexType, VectorOpType
|
4313
|
+
|
4314
|
+
# Use default if not provided
|
4315
|
+
index_type = VectorIndexType.HNSW
|
4316
|
+
if op_type is None:
|
4317
|
+
op_type = VectorOpType.VECTOR_L2_OPS
|
4318
|
+
|
4319
|
+
# Create index using transaction wrapper's execute method
|
4320
|
+
index = VectorIndex(name, column, index_type, None, op_type, m, ef_construction, ef_search)
|
4321
|
+
|
4322
|
+
try:
|
4323
|
+
# Enable HNSW indexing within transaction
|
4324
|
+
self.transaction_wrapper.execute("SET experimental_hnsw_index = 1")
|
4325
|
+
|
4326
|
+
sql = index.create_sql(table_name)
|
4327
|
+
self.transaction_wrapper.execute(sql)
|
4328
|
+
return self
|
4329
|
+
except Exception as e:
|
4330
|
+
raise Exception(f"Failed to create HNSW vector index {name} on table {table_name} in transaction: {e}")
|
4331
|
+
|
4332
|
+
def drop(self, table_name: str, name: str) -> "TransactionVectorIndexManager":
|
4333
|
+
"""Drop a vector index within transaction"""
|
4334
|
+
# Drop index using transaction wrapper's execute method
|
4335
|
+
sql = f"DROP INDEX {name} ON {table_name}"
|
4336
|
+
|
4337
|
+
try:
|
4338
|
+
self.transaction_wrapper.execute(sql)
|
4339
|
+
return self
|
4340
|
+
except Exception as e:
|
4341
|
+
raise Exception(f"Failed to drop vector index {name} from table {table_name} in transaction: {e}")
|
4342
|
+
|
4343
|
+
def get_ivf_stats(self, table_name_or_model, column_name: str = None) -> Dict[str, Any]:
|
4344
|
+
"""
|
4345
|
+
Get IVF index statistics for a table within transaction.
|
4346
|
+
|
4347
|
+
Args::
|
4348
|
+
|
4349
|
+
table_name_or_model: Either a table name (str) or a SQLAlchemy model class
|
4350
|
+
column_name: Name of the vector column (optional, will be inferred if not provided)
|
4351
|
+
|
4352
|
+
Returns::
|
4353
|
+
|
4354
|
+
Dict containing IVF index statistics including:
|
4355
|
+
- index_tables: Dictionary mapping table types to table names
|
4356
|
+
- distribution: Dictionary containing bucket distribution data
|
4357
|
+
- database: Database name
|
4358
|
+
- table_name: Table name
|
4359
|
+
- column_name: Vector column name
|
4360
|
+
|
4361
|
+
Raises::
|
4362
|
+
|
4363
|
+
Exception: If IVF index is not found or if there are errors retrieving stats
|
4364
|
+
|
4365
|
+
Examples
|
4366
|
+
|
4367
|
+
# Get stats for a table with vector column within transaction
|
4368
|
+
with client.transaction() as tx:
|
4369
|
+
stats = tx.vector_ops.get_ivf_stats("my_table", "embedding")
|
4370
|
+
print(f"Index tables: {stats['index_tables']}")
|
4371
|
+
print(f"Distribution: {stats['distribution']}")
|
4372
|
+
"""
|
4373
|
+
from sqlalchemy import text
|
4374
|
+
|
4375
|
+
# Handle model class input
|
4376
|
+
if hasattr(table_name_or_model, '__tablename__'):
|
4377
|
+
table_name = table_name_or_model.__tablename__
|
4378
|
+
else:
|
4379
|
+
table_name = table_name_or_model
|
4380
|
+
|
4381
|
+
# Get database name from connection params
|
4382
|
+
database = self.client._connection_params.get('database')
|
4383
|
+
if not database:
|
4384
|
+
raise Exception("No database connection found. Please connect to a database first.")
|
4385
|
+
|
4386
|
+
# If column_name is not provided, try to infer it
|
4387
|
+
if not column_name:
|
4388
|
+
# Query the table schema to find vector columns using transaction connection
|
4389
|
+
schema_sql = text(
|
4390
|
+
f"""
|
4391
|
+
SELECT column_name, data_type
|
4392
|
+
FROM information_schema.columns
|
4393
|
+
WHERE table_schema = '{database}'
|
4394
|
+
AND table_name = '{table_name}'
|
4395
|
+
AND (data_type LIKE '%VEC%' OR data_type LIKE '%vec%')
|
4396
|
+
"""
|
4397
|
+
)
|
4398
|
+
result = self.transaction_wrapper.execute(schema_sql)
|
4399
|
+
vector_columns = result.fetchall()
|
4400
|
+
|
4401
|
+
if not vector_columns:
|
4402
|
+
raise Exception(f"No vector columns found in table {table_name}")
|
4403
|
+
elif len(vector_columns) == 1:
|
4404
|
+
column_name = vector_columns[0][0]
|
4405
|
+
else:
|
4406
|
+
# Multiple vector columns found, raise error asking user to specify
|
4407
|
+
column_names = [col[0] for col in vector_columns]
|
4408
|
+
raise Exception(
|
4409
|
+
f"Multiple vector columns found in table {table_name}: {column_names}. "
|
4410
|
+
f"Please specify the column_name parameter."
|
4411
|
+
)
|
4412
|
+
|
4413
|
+
# Get connection from transaction wrapper
|
4414
|
+
connection = self.transaction_wrapper.get_connection()
|
4415
|
+
|
4416
|
+
# Get IVF index table names
|
4417
|
+
index_tables = self._get_ivf_index_table_names(database, table_name, column_name, connection)
|
4418
|
+
|
4419
|
+
if not index_tables:
|
4420
|
+
raise Exception(f"No IVF index found for table {table_name}, column {column_name}")
|
4421
|
+
|
4422
|
+
# Get the entries table name for distribution analysis
|
4423
|
+
entries_table = index_tables.get('entries')
|
4424
|
+
if not entries_table:
|
4425
|
+
raise Exception("No entries table found in IVF index")
|
4426
|
+
|
4427
|
+
# Get bucket distribution
|
4428
|
+
distribution = self._get_ivf_buckets_distribution(database, entries_table, connection)
|
4429
|
+
|
4430
|
+
return {
|
4431
|
+
'index_tables': index_tables,
|
4432
|
+
'distribution': distribution,
|
4433
|
+
'database': database,
|
4434
|
+
'table_name': table_name,
|
4435
|
+
'column_name': column_name,
|
4436
|
+
}
|
4437
|
+
|
4438
|
+
|
4439
|
+
class FulltextIndexManager:
|
4440
|
+
"""
|
4441
|
+
Fulltext index manager for MatrixOne fulltext search operations.
|
4442
|
+
|
4443
|
+
This class provides comprehensive fulltext indexing functionality for
|
4444
|
+
enabling fast text search capabilities in MatrixOne databases. It supports
|
4445
|
+
various fulltext algorithms and provides chain operations for efficient
|
4446
|
+
index management.
|
4447
|
+
|
4448
|
+
Key Features:
|
4449
|
+
|
4450
|
+
- Fulltext index creation and management
|
4451
|
+
- Support for multiple fulltext algorithms (TF-IDF, BM25)
|
4452
|
+
- Multi-column fulltext indexing
|
4453
|
+
- Index optimization and maintenance
|
4454
|
+
- Integration with MatrixOne's fulltext search capabilities
|
4455
|
+
- Chain operations for efficient index management
|
4456
|
+
|
4457
|
+
Supported Algorithms:
|
4458
|
+
- TF-IDF: Term Frequency-Inverse Document Frequency (default)
|
4459
|
+
- BM25: Best Matching 25 algorithm for improved relevance scoring
|
4460
|
+
|
4461
|
+
Supported Operations:
|
4462
|
+
|
4463
|
+
- Create fulltext indexes on single or multiple columns
|
4464
|
+
- Drop fulltext indexes
|
4465
|
+
- List and query existing fulltext indexes
|
4466
|
+
- Index optimization and maintenance
|
4467
|
+
- Integration with fulltext search queries
|
4468
|
+
|
4469
|
+
Usage Examples::
|
4470
|
+
|
4471
|
+
# Initialize fulltext index manager
|
4472
|
+
fulltext = client.fulltext_index
|
4473
|
+
|
4474
|
+
# Create fulltext index on single column
|
4475
|
+
fulltext.create(
|
4476
|
+
"documents",
|
4477
|
+
name="idx_content",
|
4478
|
+
columns="content",
|
4479
|
+
algorithm="BM25"
|
4480
|
+
)
|
4481
|
+
|
4482
|
+
# Create fulltext index on multiple columns
|
4483
|
+
fulltext.create(
|
4484
|
+
"articles",
|
4485
|
+
name="idx_title_content",
|
4486
|
+
columns=["title", "content"],
|
4487
|
+
algorithm="TF-IDF"
|
4488
|
+
)
|
4489
|
+
|
4490
|
+
Note: Fulltext indexes improve text search performance but require additional storage space.
|
4491
|
+
"""
|
4492
|
+
|
4493
|
+
def __init__(self, client: "Client"):
|
4494
|
+
"""Initialize fulltext index manager"""
|
4495
|
+
self.client = client
|
4496
|
+
|
4497
|
+
def create(
|
4498
|
+
self, table_name_or_model, name: str, columns: Union[str, List[str]], algorithm: str = "TF-IDF"
|
4499
|
+
) -> "FulltextIndexManager":
|
4500
|
+
"""
|
4501
|
+
Create a fulltext index using chain operations.
|
4502
|
+
|
4503
|
+
Args::
|
4504
|
+
|
4505
|
+
table_name_or_model: Either a table name (str) or a SQLAlchemy model class
|
4506
|
+
name: Index name
|
4507
|
+
columns: Column(s) to index
|
4508
|
+
algorithm: Fulltext algorithm type (TF-IDF or BM25)
|
4509
|
+
|
4510
|
+
Returns::
|
4511
|
+
|
4512
|
+
FulltextIndexManager: Self for chaining
|
4513
|
+
|
4514
|
+
Example
|
4515
|
+
|
4516
|
+
# Create fulltext index by table name
|
4517
|
+
client.fulltext_index.create("articles", "idx_content", ["title", "content"])
|
4518
|
+
|
4519
|
+
# Create fulltext index by model class
|
4520
|
+
client.fulltext_index.create(ArticleModel, "idx_content", ["title", "content"])
|
4521
|
+
|
4522
|
+
# Create with BM25 algorithm
|
4523
|
+
client.fulltext_index.create("articles", "idx_bm25", "content", algorithm="BM25")
|
4524
|
+
"""
|
4525
|
+
from .sqlalchemy_ext import FulltextIndex
|
4526
|
+
|
4527
|
+
# Handle model class input
|
4528
|
+
if hasattr(table_name_or_model, '__tablename__'):
|
4529
|
+
table_name = table_name_or_model.__tablename__
|
4530
|
+
else:
|
4531
|
+
table_name = table_name_or_model
|
4532
|
+
|
4533
|
+
success = FulltextIndex.create_index(
|
4534
|
+
engine=self.client.get_sqlalchemy_engine(),
|
4535
|
+
table_name=table_name,
|
4536
|
+
name=name,
|
4537
|
+
columns=columns,
|
4538
|
+
algorithm=algorithm,
|
4539
|
+
)
|
4540
|
+
|
4541
|
+
if not success:
|
4542
|
+
raise Exception(f"Failed to create fulltext index {name} on table {table_name}")
|
4543
|
+
|
4544
|
+
return self
|
4545
|
+
|
4546
|
+
def create_in_transaction(
|
4547
|
+
self,
|
4548
|
+
transaction_wrapper,
|
4549
|
+
table_name: str,
|
4550
|
+
name: str,
|
4551
|
+
columns: Union[str, List[str]],
|
4552
|
+
algorithm: str = "TF-IDF",
|
4553
|
+
) -> "FulltextIndexManager":
|
4554
|
+
"""
|
4555
|
+
Create a fulltext index within an existing transaction.
|
4556
|
+
|
4557
|
+
Args::
|
4558
|
+
|
4559
|
+
transaction_wrapper: Transaction wrapper
|
4560
|
+
table_name: Target table name
|
4561
|
+
name: Index name
|
4562
|
+
columns: Column(s) to index
|
4563
|
+
algorithm: Fulltext algorithm type
|
4564
|
+
|
4565
|
+
Returns::
|
4566
|
+
|
4567
|
+
FulltextIndexManager: Self for chaining
|
4568
|
+
"""
|
4569
|
+
from .sqlalchemy_ext import FulltextIndex
|
4570
|
+
|
4571
|
+
success = FulltextIndex.create_index_in_transaction(
|
4572
|
+
connection=transaction_wrapper.connection,
|
4573
|
+
table_name=table_name,
|
4574
|
+
name=name,
|
4575
|
+
columns=columns,
|
4576
|
+
algorithm=algorithm,
|
4577
|
+
)
|
4578
|
+
|
4579
|
+
if not success:
|
4580
|
+
raise Exception(f"Failed to create fulltext index {name} on table {table_name} in transaction")
|
4581
|
+
|
4582
|
+
return self
|
4583
|
+
|
4584
|
+
def drop(self, table_name_or_model, name: str) -> "FulltextIndexManager":
|
4585
|
+
"""
|
4586
|
+
Drop a fulltext index using chain operations.
|
4587
|
+
|
4588
|
+
Args::
|
4589
|
+
|
4590
|
+
table_name_or_model: Either a table name (str) or a SQLAlchemy model class
|
4591
|
+
name: Index name
|
4592
|
+
|
4593
|
+
Returns::
|
4594
|
+
|
4595
|
+
FulltextIndexManager: Self for chaining
|
4596
|
+
"""
|
4597
|
+
from .sqlalchemy_ext import FulltextIndex
|
4598
|
+
|
4599
|
+
# Handle model class input
|
4600
|
+
if hasattr(table_name_or_model, '__tablename__'):
|
4601
|
+
table_name = table_name_or_model.__tablename__
|
4602
|
+
else:
|
4603
|
+
table_name = table_name_or_model
|
4604
|
+
|
4605
|
+
success = FulltextIndex.drop_index(engine=self.client.get_sqlalchemy_engine(), table_name=table_name, name=name)
|
4606
|
+
|
4607
|
+
if not success:
|
4608
|
+
raise Exception(f"Failed to drop fulltext index {name} from table {table_name}")
|
4609
|
+
|
4610
|
+
return self
|
4611
|
+
|
4612
|
+
def enable_fulltext(self) -> "FulltextIndexManager":
|
4613
|
+
"""
|
4614
|
+
Enable fulltext indexing with chain operations.
|
4615
|
+
|
4616
|
+
Returns::
|
4617
|
+
|
4618
|
+
FulltextIndexManager: Self for chaining
|
4619
|
+
"""
|
4620
|
+
try:
|
4621
|
+
self.client.execute("SET experimental_fulltext_index = 1")
|
4622
|
+
return self
|
4623
|
+
except Exception as e:
|
4624
|
+
raise Exception(f"Failed to enable fulltext indexing: {e}")
|
4625
|
+
|
4626
|
+
def disable_fulltext(self) -> "FulltextIndexManager":
|
4627
|
+
"""
|
4628
|
+
Disable fulltext indexing with chain operations.
|
4629
|
+
|
4630
|
+
Returns::
|
4631
|
+
|
4632
|
+
FulltextIndexManager: Self for chaining
|
4633
|
+
"""
|
4634
|
+
try:
|
4635
|
+
self.client.execute("SET experimental_fulltext_index = 0")
|
4636
|
+
return self
|
4637
|
+
except Exception as e:
|
4638
|
+
raise Exception(f"Failed to disable fulltext indexing: {e}")
|
4639
|
+
|
4640
|
+
|
4641
|
+
class TransactionFulltextIndexManager(FulltextIndexManager):
|
4642
|
+
"""Fulltext index manager that executes operations within a transaction"""
|
4643
|
+
|
4644
|
+
def __init__(self, client: "Client", transaction_wrapper):
|
4645
|
+
"""Initialize transaction fulltext index manager"""
|
4646
|
+
super().__init__(client)
|
4647
|
+
self.transaction_wrapper = transaction_wrapper
|
4648
|
+
|
4649
|
+
def create(
|
4650
|
+
self, table_name: str, name: str, columns: Union[str, List[str]], algorithm: str = "TF-IDF"
|
4651
|
+
) -> "TransactionFulltextIndexManager":
|
4652
|
+
"""Create a fulltext index within transaction"""
|
4653
|
+
try:
|
4654
|
+
if isinstance(columns, str):
|
4655
|
+
columns = [columns]
|
4656
|
+
|
4657
|
+
columns_str = ", ".join(columns)
|
4658
|
+
sql = f"CREATE FULLTEXT INDEX {name} ON {table_name} ({columns_str})"
|
4659
|
+
|
4660
|
+
self.transaction_wrapper.execute(sql)
|
4661
|
+
return self
|
4662
|
+
except Exception as e:
|
4663
|
+
raise Exception(f"Failed to create fulltext index {name} on table {table_name} in transaction: {e}")
|
4664
|
+
|
4665
|
+
def drop(self, table_name: str, name: str) -> "TransactionFulltextIndexManager":
|
4666
|
+
"""Drop a fulltext index within transaction"""
|
4667
|
+
try:
|
4668
|
+
sql = f"DROP INDEX {name} ON {table_name}"
|
4669
|
+
self.transaction_wrapper.execute(sql)
|
4670
|
+
return self
|
4671
|
+
except Exception as e:
|
4672
|
+
raise Exception(f"Failed to drop fulltext index {name} from table {table_name} in transaction: {e}")
|