matrixone-python-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- matrixone/__init__.py +155 -0
- matrixone/account.py +723 -0
- matrixone/async_client.py +3913 -0
- matrixone/async_metadata_manager.py +311 -0
- matrixone/async_orm.py +123 -0
- matrixone/async_vector_index_manager.py +633 -0
- matrixone/base_client.py +208 -0
- matrixone/client.py +4672 -0
- matrixone/config.py +452 -0
- matrixone/connection_hooks.py +286 -0
- matrixone/exceptions.py +89 -0
- matrixone/logger.py +782 -0
- matrixone/metadata.py +820 -0
- matrixone/moctl.py +219 -0
- matrixone/orm.py +2277 -0
- matrixone/pitr.py +646 -0
- matrixone/pubsub.py +771 -0
- matrixone/restore.py +411 -0
- matrixone/search_vector_index.py +1176 -0
- matrixone/snapshot.py +550 -0
- matrixone/sql_builder.py +844 -0
- matrixone/sqlalchemy_ext/__init__.py +161 -0
- matrixone/sqlalchemy_ext/adapters.py +163 -0
- matrixone/sqlalchemy_ext/dialect.py +534 -0
- matrixone/sqlalchemy_ext/fulltext_index.py +895 -0
- matrixone/sqlalchemy_ext/fulltext_search.py +1686 -0
- matrixone/sqlalchemy_ext/hnsw_config.py +194 -0
- matrixone/sqlalchemy_ext/ivf_config.py +252 -0
- matrixone/sqlalchemy_ext/table_builder.py +351 -0
- matrixone/sqlalchemy_ext/vector_index.py +1721 -0
- matrixone/sqlalchemy_ext/vector_type.py +948 -0
- matrixone/version.py +580 -0
- matrixone_python_sdk-0.1.0.dist-info/METADATA +706 -0
- matrixone_python_sdk-0.1.0.dist-info/RECORD +122 -0
- matrixone_python_sdk-0.1.0.dist-info/WHEEL +5 -0
- matrixone_python_sdk-0.1.0.dist-info/entry_points.txt +5 -0
- matrixone_python_sdk-0.1.0.dist-info/licenses/LICENSE +200 -0
- matrixone_python_sdk-0.1.0.dist-info/top_level.txt +2 -0
- tests/__init__.py +19 -0
- tests/offline/__init__.py +20 -0
- tests/offline/conftest.py +77 -0
- tests/offline/test_account.py +703 -0
- tests/offline/test_async_client_query_comprehensive.py +1218 -0
- tests/offline/test_basic.py +54 -0
- tests/offline/test_case_sensitivity.py +227 -0
- tests/offline/test_connection_hooks_offline.py +287 -0
- tests/offline/test_dialect_schema_handling.py +609 -0
- tests/offline/test_explain_methods.py +346 -0
- tests/offline/test_filter_logical_in.py +237 -0
- tests/offline/test_fulltext_search_comprehensive.py +795 -0
- tests/offline/test_ivf_config.py +249 -0
- tests/offline/test_join_methods.py +281 -0
- tests/offline/test_join_sqlalchemy_compatibility.py +276 -0
- tests/offline/test_logical_in_method.py +237 -0
- tests/offline/test_matrixone_version_parsing.py +264 -0
- tests/offline/test_metadata_offline.py +557 -0
- tests/offline/test_moctl.py +300 -0
- tests/offline/test_moctl_simple.py +251 -0
- tests/offline/test_model_support_offline.py +359 -0
- tests/offline/test_model_support_simple.py +225 -0
- tests/offline/test_pinecone_filter_offline.py +377 -0
- tests/offline/test_pitr.py +585 -0
- tests/offline/test_pubsub.py +712 -0
- tests/offline/test_query_update.py +283 -0
- tests/offline/test_restore.py +445 -0
- tests/offline/test_snapshot_comprehensive.py +384 -0
- tests/offline/test_sql_escaping_edge_cases.py +551 -0
- tests/offline/test_sqlalchemy_integration.py +382 -0
- tests/offline/test_sqlalchemy_vector_integration.py +434 -0
- tests/offline/test_table_builder.py +198 -0
- tests/offline/test_unified_filter.py +398 -0
- tests/offline/test_unified_transaction.py +495 -0
- tests/offline/test_vector_index.py +238 -0
- tests/offline/test_vector_operations.py +688 -0
- tests/offline/test_vector_type.py +174 -0
- tests/offline/test_version_core.py +328 -0
- tests/offline/test_version_management.py +372 -0
- tests/offline/test_version_standalone.py +652 -0
- tests/online/__init__.py +20 -0
- tests/online/conftest.py +216 -0
- tests/online/test_account_management.py +194 -0
- tests/online/test_advanced_features.py +344 -0
- tests/online/test_async_client_interfaces.py +330 -0
- tests/online/test_async_client_online.py +285 -0
- tests/online/test_async_model_insert_online.py +293 -0
- tests/online/test_async_orm_online.py +300 -0
- tests/online/test_async_simple_query_online.py +802 -0
- tests/online/test_async_transaction_simple_query.py +300 -0
- tests/online/test_basic_connection.py +130 -0
- tests/online/test_client_online.py +238 -0
- tests/online/test_config.py +90 -0
- tests/online/test_config_validation.py +123 -0
- tests/online/test_connection_hooks_new_online.py +217 -0
- tests/online/test_dialect_schema_handling_online.py +331 -0
- tests/online/test_filter_logical_in_online.py +374 -0
- tests/online/test_fulltext_comprehensive.py +1773 -0
- tests/online/test_fulltext_label_online.py +433 -0
- tests/online/test_fulltext_search_online.py +842 -0
- tests/online/test_ivf_stats_online.py +506 -0
- tests/online/test_logger_integration.py +311 -0
- tests/online/test_matrixone_query_orm.py +540 -0
- tests/online/test_metadata_online.py +579 -0
- tests/online/test_model_insert_online.py +255 -0
- tests/online/test_mysql_driver_validation.py +213 -0
- tests/online/test_orm_advanced_features.py +2022 -0
- tests/online/test_orm_cte_integration.py +269 -0
- tests/online/test_orm_online.py +270 -0
- tests/online/test_pinecone_filter.py +708 -0
- tests/online/test_pubsub_operations.py +352 -0
- tests/online/test_query_methods.py +225 -0
- tests/online/test_query_update_online.py +433 -0
- tests/online/test_search_vector_index.py +557 -0
- tests/online/test_simple_fulltext_online.py +915 -0
- tests/online/test_snapshot_comprehensive.py +998 -0
- tests/online/test_sqlalchemy_engine_integration.py +336 -0
- tests/online/test_sqlalchemy_integration.py +425 -0
- tests/online/test_transaction_contexts.py +1219 -0
- tests/online/test_transaction_insert_methods.py +356 -0
- tests/online/test_transaction_query_methods.py +288 -0
- tests/online/test_unified_filter_online.py +529 -0
- tests/online/test_vector_comprehensive.py +706 -0
- tests/online/test_version_management.py +291 -0
@@ -0,0 +1,3913 @@
|
|
1
|
+
# Copyright 2021 - 2022 Matrix Origin
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""
|
16
|
+
MatrixOne Async Client - Asynchronous implementation
|
17
|
+
"""
|
18
|
+
|
19
|
+
try:
|
20
|
+
import aiomysql
|
21
|
+
except ImportError:
|
22
|
+
aiomysql = None
|
23
|
+
|
24
|
+
try:
|
25
|
+
from sqlalchemy import text
|
26
|
+
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
|
27
|
+
except ImportError:
|
28
|
+
create_async_engine = None
|
29
|
+
AsyncEngine = None
|
30
|
+
text = None
|
31
|
+
|
32
|
+
from contextlib import asynccontextmanager
|
33
|
+
from datetime import datetime
|
34
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
35
|
+
|
36
|
+
from .account import Account, User
|
37
|
+
from .async_vector_index_manager import AsyncVectorManager
|
38
|
+
from .base_client import BaseMatrixOneClient, BaseMatrixOneExecutor
|
39
|
+
from .connection_hooks import ConnectionHook, ConnectionAction, create_connection_hook
|
40
|
+
from .exceptions import (
|
41
|
+
AccountError,
|
42
|
+
ConnectionError,
|
43
|
+
MoCtlError,
|
44
|
+
PitrError,
|
45
|
+
PubSubError,
|
46
|
+
QueryError,
|
47
|
+
RestoreError,
|
48
|
+
SnapshotError,
|
49
|
+
)
|
50
|
+
from .logger import MatrixOneLogger, create_default_logger
|
51
|
+
from .async_metadata_manager import AsyncMetadataManager
|
52
|
+
from .pitr import Pitr
|
53
|
+
from .pubsub import Publication, Subscription
|
54
|
+
from .snapshot import Snapshot, SnapshotLevel
|
55
|
+
|
56
|
+
|
57
|
+
class AsyncResultSet:
|
58
|
+
"""Async result set wrapper for query results"""
|
59
|
+
|
60
|
+
def __init__(self, columns: List[str], rows: List[Tuple], affected_rows: int = 0):
|
61
|
+
self.columns = columns
|
62
|
+
self.rows = rows
|
63
|
+
self.affected_rows = affected_rows
|
64
|
+
self._cursor = 0 # Track current position in result set
|
65
|
+
|
66
|
+
def fetchall(self) -> List[Tuple]:
|
67
|
+
"""Fetch all remaining rows"""
|
68
|
+
remaining_rows = self.rows[self._cursor :]
|
69
|
+
self._cursor = len(self.rows)
|
70
|
+
return remaining_rows
|
71
|
+
|
72
|
+
def fetchone(self) -> Optional[Tuple]:
|
73
|
+
"""Fetch one row"""
|
74
|
+
if self._cursor < len(self.rows):
|
75
|
+
row = self.rows[self._cursor]
|
76
|
+
self._cursor += 1
|
77
|
+
return row
|
78
|
+
return None
|
79
|
+
|
80
|
+
def fetchmany(self, size: int = 1) -> List[Tuple]:
|
81
|
+
"""Fetch many rows"""
|
82
|
+
start = self._cursor
|
83
|
+
end = min(start + size, len(self.rows))
|
84
|
+
rows = self.rows[start:end]
|
85
|
+
self._cursor = end
|
86
|
+
return rows
|
87
|
+
|
88
|
+
def scalar(self) -> Any:
|
89
|
+
"""Get scalar value (first column of first row)"""
|
90
|
+
if self.rows and self.columns:
|
91
|
+
return self.rows[0][0]
|
92
|
+
return None
|
93
|
+
|
94
|
+
def keys(self):
|
95
|
+
"""Get column names"""
|
96
|
+
return iter(self.columns)
|
97
|
+
|
98
|
+
def __iter__(self):
|
99
|
+
return iter(self.rows)
|
100
|
+
|
101
|
+
def __len__(self):
|
102
|
+
return len(self.rows)
|
103
|
+
|
104
|
+
|
105
|
+
class AsyncSnapshotManager:
|
106
|
+
"""Async snapshot manager"""
|
107
|
+
|
108
|
+
def __init__(self, client):
|
109
|
+
self.client = client
|
110
|
+
|
111
|
+
async def create(
|
112
|
+
self,
|
113
|
+
name: str,
|
114
|
+
level: Union[str, SnapshotLevel],
|
115
|
+
database: Optional[str] = None,
|
116
|
+
table: Optional[str] = None,
|
117
|
+
description: Optional[str] = None,
|
118
|
+
) -> Snapshot:
|
119
|
+
"""Create snapshot asynchronously"""
|
120
|
+
# Convert string level to enum if needed
|
121
|
+
if isinstance(level, str):
|
122
|
+
level = SnapshotLevel(level.lower())
|
123
|
+
|
124
|
+
# Build SQL based on level
|
125
|
+
if level == SnapshotLevel.CLUSTER:
|
126
|
+
sql = f"CREATE SNAPSHOT {name} FOR CLUSTER"
|
127
|
+
elif level == SnapshotLevel.ACCOUNT:
|
128
|
+
sql = f"CREATE SNAPSHOT {name} FOR ACCOUNT"
|
129
|
+
elif level == SnapshotLevel.DATABASE:
|
130
|
+
if not database:
|
131
|
+
raise SnapshotError("Database name is required for database level snapshot")
|
132
|
+
sql = f"CREATE SNAPSHOT {name} FOR DATABASE {database}"
|
133
|
+
elif level == SnapshotLevel.TABLE:
|
134
|
+
if not database or not table:
|
135
|
+
raise SnapshotError("Database and table names are required for table level snapshot")
|
136
|
+
sql = f"CREATE SNAPSHOT {name} FOR TABLE {database} {table}"
|
137
|
+
else:
|
138
|
+
raise SnapshotError(f"Invalid snapshot level: {level}")
|
139
|
+
|
140
|
+
if description:
|
141
|
+
sql += f" COMMENT '{description}'"
|
142
|
+
|
143
|
+
await self.client.execute(sql)
|
144
|
+
|
145
|
+
# Return snapshot object
|
146
|
+
import datetime
|
147
|
+
|
148
|
+
return Snapshot(name, level, datetime.datetime.now(), description, database, table)
|
149
|
+
|
150
|
+
async def get(self, name: str) -> Snapshot:
|
151
|
+
"""Get snapshot asynchronously"""
|
152
|
+
# Use mo_catalog.mo_snapshots table like sync client
|
153
|
+
sql = """
|
154
|
+
SELECT sname, ts, level, account_name, database_name, table_name
|
155
|
+
FROM mo_catalog.mo_snapshots
|
156
|
+
WHERE sname = :name
|
157
|
+
"""
|
158
|
+
result = await self.client.execute(sql, {"name": name})
|
159
|
+
|
160
|
+
if not result.rows:
|
161
|
+
raise SnapshotError(f"Snapshot '{name}' not found")
|
162
|
+
|
163
|
+
row = result.rows[0]
|
164
|
+
# Convert timestamp to datetime
|
165
|
+
from datetime import datetime
|
166
|
+
|
167
|
+
timestamp = datetime.fromtimestamp(row[1] / 1000000000) # Convert nanoseconds to seconds
|
168
|
+
|
169
|
+
# Convert level string to enum
|
170
|
+
level_str = row[2]
|
171
|
+
try:
|
172
|
+
level = SnapshotLevel(level_str.lower())
|
173
|
+
except ValueError:
|
174
|
+
level = level_str # Fallback to string for backward compatibility
|
175
|
+
|
176
|
+
return Snapshot(row[0], level, timestamp, None, row[4], row[5])
|
177
|
+
|
178
|
+
async def list(self) -> List[Snapshot]:
|
179
|
+
"""List all snapshots asynchronously"""
|
180
|
+
# Use mo_catalog.mo_snapshots table like sync client
|
181
|
+
sql = """
|
182
|
+
SELECT sname, ts, level, account_name, database_name, table_name
|
183
|
+
FROM mo_catalog.mo_snapshots
|
184
|
+
ORDER BY ts DESC
|
185
|
+
"""
|
186
|
+
result = await self.client.execute(sql)
|
187
|
+
|
188
|
+
snapshots = []
|
189
|
+
for row in result.rows:
|
190
|
+
# Convert timestamp to datetime
|
191
|
+
timestamp = datetime.fromtimestamp(row[1] / 1000000000) # Convert nanoseconds to seconds
|
192
|
+
|
193
|
+
# Convert level string to enum
|
194
|
+
level_str = row[2]
|
195
|
+
try:
|
196
|
+
level = SnapshotLevel(level_str.lower())
|
197
|
+
except ValueError:
|
198
|
+
level = level_str # Fallback to string for backward compatibility
|
199
|
+
|
200
|
+
snapshots.append(Snapshot(row[0], level, timestamp, None, row[4], row[5]))
|
201
|
+
|
202
|
+
return snapshots
|
203
|
+
|
204
|
+
async def delete(self, name: str) -> None:
|
205
|
+
"""Delete snapshot asynchronously"""
|
206
|
+
sql = f"DROP SNAPSHOT {name}"
|
207
|
+
await self.client.execute(sql)
|
208
|
+
|
209
|
+
async def exists(self, name: str) -> bool:
|
210
|
+
"""Check if snapshot exists asynchronously"""
|
211
|
+
try:
|
212
|
+
await self.get(name)
|
213
|
+
return True
|
214
|
+
except SnapshotError:
|
215
|
+
return False
|
216
|
+
|
217
|
+
|
218
|
+
class AsyncCloneManager:
|
219
|
+
"""Async clone manager"""
|
220
|
+
|
221
|
+
def __init__(self, client):
|
222
|
+
self.client = client
|
223
|
+
|
224
|
+
async def clone_database(
|
225
|
+
self,
|
226
|
+
target_db: str,
|
227
|
+
source_db: str,
|
228
|
+
snapshot_name: Optional[str] = None,
|
229
|
+
if_not_exists: bool = False,
|
230
|
+
) -> None:
|
231
|
+
"""Clone database asynchronously"""
|
232
|
+
if_not_exists_clause = "IF NOT EXISTS" if if_not_exists else ""
|
233
|
+
|
234
|
+
if snapshot_name:
|
235
|
+
sql = f"CREATE DATABASE {target_db} {if_not_exists_clause} CLONE {source_db} FOR SNAPSHOT '{snapshot_name}'"
|
236
|
+
else:
|
237
|
+
sql = f"CREATE DATABASE {target_db} {if_not_exists_clause} CLONE {source_db}"
|
238
|
+
|
239
|
+
await self.client.execute(sql)
|
240
|
+
|
241
|
+
async def clone_table(
|
242
|
+
self,
|
243
|
+
target_table: str,
|
244
|
+
source_table: str,
|
245
|
+
snapshot_name: Optional[str] = None,
|
246
|
+
if_not_exists: bool = False,
|
247
|
+
) -> None:
|
248
|
+
"""Clone table asynchronously"""
|
249
|
+
if_not_exists_clause = "IF NOT EXISTS" if if_not_exists else ""
|
250
|
+
|
251
|
+
if snapshot_name:
|
252
|
+
sql = (
|
253
|
+
f"CREATE TABLE {target_table} {if_not_exists_clause} " f"CLONE {source_table} FOR SNAPSHOT '{snapshot_name}'"
|
254
|
+
)
|
255
|
+
else:
|
256
|
+
sql = f"CREATE TABLE {target_table} {if_not_exists_clause} CLONE {source_table}"
|
257
|
+
|
258
|
+
await self.client.execute(sql)
|
259
|
+
|
260
|
+
async def clone_database_with_snapshot(
|
261
|
+
self, target_db: str, source_db: str, snapshot_name: str, if_not_exists: bool = False
|
262
|
+
) -> None:
|
263
|
+
"""Clone database with snapshot asynchronously"""
|
264
|
+
await self.clone_database(target_db, source_db, snapshot_name, if_not_exists)
|
265
|
+
|
266
|
+
async def clone_table_with_snapshot(
|
267
|
+
self, target_table: str, source_table: str, snapshot_name: str, if_not_exists: bool = False
|
268
|
+
) -> None:
|
269
|
+
"""Clone table with snapshot asynchronously"""
|
270
|
+
await self.clone_table(target_table, source_table, snapshot_name, if_not_exists)
|
271
|
+
|
272
|
+
|
273
|
+
class AsyncPubSubManager:
|
274
|
+
"""Async manager for publish-subscribe operations"""
|
275
|
+
|
276
|
+
def __init__(self, client):
|
277
|
+
self.client = client
|
278
|
+
|
279
|
+
async def create_database_publication(self, name: str, database: str, account: str) -> Publication:
|
280
|
+
"""Create database-level publication asynchronously"""
|
281
|
+
try:
|
282
|
+
sql = (
|
283
|
+
f"CREATE PUBLICATION {self.client._escape_identifier(name)} "
|
284
|
+
f"DATABASE {self.client._escape_identifier(database)} "
|
285
|
+
f"ACCOUNT {self.client._escape_identifier(account)}"
|
286
|
+
)
|
287
|
+
|
288
|
+
result = await self.client.execute(sql)
|
289
|
+
if result is None:
|
290
|
+
raise PubSubError(f"Failed to create database publication '{name}'")
|
291
|
+
|
292
|
+
return await self.get_publication(name)
|
293
|
+
|
294
|
+
except Exception as e:
|
295
|
+
raise PubSubError(f"Failed to create database publication '{name}': {e}")
|
296
|
+
|
297
|
+
async def create_table_publication(self, name: str, database: str, table: str, account: str) -> Publication:
|
298
|
+
"""Create table-level publication asynchronously"""
|
299
|
+
try:
|
300
|
+
sql = (
|
301
|
+
f"CREATE PUBLICATION {self.client._escape_identifier(name)} "
|
302
|
+
f"DATABASE {self.client._escape_identifier(database)} "
|
303
|
+
f"TABLE {self.client._escape_identifier(table)} "
|
304
|
+
f"ACCOUNT {self.client._escape_identifier(account)}"
|
305
|
+
)
|
306
|
+
|
307
|
+
result = await self.client.execute(sql)
|
308
|
+
if result is None:
|
309
|
+
raise PubSubError(f"Failed to create table publication '{name}'")
|
310
|
+
|
311
|
+
return await self.get_publication(name)
|
312
|
+
|
313
|
+
except Exception as e:
|
314
|
+
raise PubSubError(f"Failed to create table publication '{name}': {e}")
|
315
|
+
|
316
|
+
async def get_publication(self, name: str) -> Publication:
|
317
|
+
"""Get publication by name asynchronously"""
|
318
|
+
try:
|
319
|
+
# SHOW PUBLICATIONS doesn't support WHERE clause, so we need to list all and filter
|
320
|
+
sql = "SHOW PUBLICATIONS"
|
321
|
+
result = await self.client.execute(sql)
|
322
|
+
|
323
|
+
if not result or not result.rows:
|
324
|
+
raise PubSubError(f"Publication '{name}' not found")
|
325
|
+
|
326
|
+
# Find publication with matching name
|
327
|
+
for row in result.rows:
|
328
|
+
if row[0] == name: # publication name is in first column
|
329
|
+
return self._row_to_publication(row)
|
330
|
+
|
331
|
+
raise PubSubError(f"Publication '{name}' not found")
|
332
|
+
|
333
|
+
except Exception as e:
|
334
|
+
raise PubSubError(f"Failed to get publication '{name}': {e}")
|
335
|
+
|
336
|
+
async def list_publications(self, account: Optional[str] = None, database: Optional[str] = None) -> List[Publication]:
|
337
|
+
"""List publications with optional filters asynchronously"""
|
338
|
+
try:
|
339
|
+
# SHOW PUBLICATIONS doesn't support WHERE clause, so we need to list all and filter
|
340
|
+
sql = "SHOW PUBLICATIONS"
|
341
|
+
result = await self.client.execute(sql)
|
342
|
+
|
343
|
+
if not result or not result.rows:
|
344
|
+
return []
|
345
|
+
|
346
|
+
publications = []
|
347
|
+
for row in result.rows:
|
348
|
+
pub = self._row_to_publication(row)
|
349
|
+
|
350
|
+
# Apply filters
|
351
|
+
if account and account not in pub.sub_account:
|
352
|
+
continue
|
353
|
+
if database and pub.database != database:
|
354
|
+
continue
|
355
|
+
|
356
|
+
publications.append(pub)
|
357
|
+
|
358
|
+
return publications
|
359
|
+
|
360
|
+
except Exception as e:
|
361
|
+
raise PubSubError(f"Failed to list publications: {e}")
|
362
|
+
|
363
|
+
async def alter_publication(
|
364
|
+
self,
|
365
|
+
name: str,
|
366
|
+
account: Optional[str] = None,
|
367
|
+
database: Optional[str] = None,
|
368
|
+
table: Optional[str] = None,
|
369
|
+
) -> Publication:
|
370
|
+
"""Alter publication asynchronously"""
|
371
|
+
try:
|
372
|
+
# Build ALTER PUBLICATION statement
|
373
|
+
parts = [f"ALTER PUBLICATION {self.client._escape_identifier(name)}"]
|
374
|
+
|
375
|
+
if account:
|
376
|
+
parts.append(f"ACCOUNT {self.client._escape_identifier(account)}")
|
377
|
+
if database:
|
378
|
+
parts.append(f"DATABASE {self.client._escape_identifier(database)}")
|
379
|
+
if table:
|
380
|
+
parts.append(f"TABLE {self.client._escape_identifier(table)}")
|
381
|
+
|
382
|
+
sql = " ".join(parts)
|
383
|
+
result = await self.client.execute(sql)
|
384
|
+
if result is None:
|
385
|
+
raise PubSubError(f"Failed to alter publication '{name}'")
|
386
|
+
|
387
|
+
return await self.get_publication(name)
|
388
|
+
|
389
|
+
except Exception as e:
|
390
|
+
raise PubSubError(f"Failed to alter publication '{name}': {e}")
|
391
|
+
|
392
|
+
async def drop_publication(self, name: str) -> bool:
|
393
|
+
"""Drop publication asynchronously"""
|
394
|
+
try:
|
395
|
+
sql = f"DROP PUBLICATION {self.client._escape_identifier(name)}"
|
396
|
+
result = await self.client.execute(sql)
|
397
|
+
return result is not None
|
398
|
+
|
399
|
+
except Exception as e:
|
400
|
+
raise PubSubError(f"Failed to drop publication '{name}': {e}")
|
401
|
+
|
402
|
+
async def show_create_publication(self, name: str) -> str:
|
403
|
+
"""Show CREATE PUBLICATION statement for a publication asynchronously"""
|
404
|
+
try:
|
405
|
+
sql = f"SHOW CREATE PUBLICATION {self.client._escape_identifier(name)}"
|
406
|
+
result = await self.client.execute(sql)
|
407
|
+
|
408
|
+
if not result or not result.rows:
|
409
|
+
raise PubSubError(f"Publication '{name}' not found")
|
410
|
+
|
411
|
+
# The result should contain the CREATE statement
|
412
|
+
# Assuming the CREATE statement is in the first column
|
413
|
+
return result.rows[0][0]
|
414
|
+
|
415
|
+
except Exception as e:
|
416
|
+
raise PubSubError(f"Failed to show create publication '{name}': {e}")
|
417
|
+
|
418
|
+
async def create_subscription(
|
419
|
+
self, subscription_name: str, publication_name: str, publisher_account: str
|
420
|
+
) -> Subscription:
|
421
|
+
"""Create subscription from publication asynchronously"""
|
422
|
+
try:
|
423
|
+
sql = (
|
424
|
+
f"CREATE DATABASE {self.client._escape_identifier(subscription_name)} "
|
425
|
+
f"FROM {self.client._escape_identifier(publisher_account)} "
|
426
|
+
f"PUBLICATION {self.client._escape_identifier(publication_name)}"
|
427
|
+
)
|
428
|
+
|
429
|
+
result = await self.client.execute(sql)
|
430
|
+
if result is None:
|
431
|
+
raise PubSubError(f"Failed to create subscription '{subscription_name}'")
|
432
|
+
|
433
|
+
return await self.get_subscription(subscription_name)
|
434
|
+
|
435
|
+
except Exception as e:
|
436
|
+
raise PubSubError(f"Failed to create subscription '{subscription_name}': {e}")
|
437
|
+
|
438
|
+
async def get_subscription(self, name: str) -> Subscription:
|
439
|
+
"""Get subscription by name asynchronously"""
|
440
|
+
try:
|
441
|
+
# SHOW SUBSCRIPTIONS doesn't support WHERE clause, so we need to list all and filter
|
442
|
+
sql = "SHOW SUBSCRIPTIONS"
|
443
|
+
result = await self.client.execute(sql)
|
444
|
+
|
445
|
+
if not result or not result.rows:
|
446
|
+
raise PubSubError(f"Subscription '{name}' not found")
|
447
|
+
|
448
|
+
# Find subscription with matching name
|
449
|
+
for row in result.rows:
|
450
|
+
if row[6] == name: # sub_name is in 7th column (index 6)
|
451
|
+
return self._row_to_subscription(row)
|
452
|
+
|
453
|
+
raise PubSubError(f"Subscription '{name}' not found")
|
454
|
+
|
455
|
+
except Exception as e:
|
456
|
+
raise PubSubError(f"Failed to get subscription '{name}': {e}")
|
457
|
+
|
458
|
+
async def list_subscriptions(
|
459
|
+
self, pub_account: Optional[str] = None, pub_database: Optional[str] = None
|
460
|
+
) -> List[Subscription]:
|
461
|
+
"""List subscriptions with optional filters asynchronously"""
|
462
|
+
try:
|
463
|
+
conditions = []
|
464
|
+
|
465
|
+
if pub_account:
|
466
|
+
conditions.append(f"pub_account = {self.client._escape_string(pub_account)}")
|
467
|
+
if pub_database:
|
468
|
+
conditions.append(f"pub_database = {self.client._escape_string(pub_database)}")
|
469
|
+
|
470
|
+
if conditions:
|
471
|
+
where_clause = " WHERE " + " AND ".join(conditions)
|
472
|
+
else:
|
473
|
+
where_clause = ""
|
474
|
+
|
475
|
+
sql = f"SHOW SUBSCRIPTIONS{where_clause}"
|
476
|
+
result = await self.client.execute(sql)
|
477
|
+
|
478
|
+
if not result or not result.rows:
|
479
|
+
return []
|
480
|
+
|
481
|
+
return [self._row_to_subscription(row) for row in result.rows]
|
482
|
+
|
483
|
+
except Exception as e:
|
484
|
+
raise PubSubError(f"Failed to list subscriptions: {e}")
|
485
|
+
|
486
|
+
def _row_to_publication(self, row: tuple) -> Publication:
|
487
|
+
"""Convert database row to Publication object"""
|
488
|
+
# Expected columns: publication, database, tables, sub_account, subscribed_accounts,
|
489
|
+
# create_time, update_time, comments
|
490
|
+
# Based on MatrixOne official documentation:
|
491
|
+
# https://docs.matrixorigin.cn/en/v25.2.2.2/MatrixOne/Reference/SQL-Reference/Other/SHOW-Statements/show-publications/
|
492
|
+
return Publication(
|
493
|
+
name=row[0], # publication
|
494
|
+
database=row[1], # database
|
495
|
+
tables=row[2], # tables
|
496
|
+
sub_account=row[3], # sub_account
|
497
|
+
subscribed_accounts=row[4], # subscribed_accounts
|
498
|
+
created_time=row[5] if len(row) > 5 else None, # create_time
|
499
|
+
update_time=row[6] if len(row) > 6 else None, # update_time
|
500
|
+
comments=row[7] if len(row) > 7 else None, # comments
|
501
|
+
)
|
502
|
+
|
503
|
+
def _row_to_subscription(self, row: tuple) -> Subscription:
|
504
|
+
"""Convert database row to Subscription object"""
|
505
|
+
# Expected columns: pub_name, pub_account, pub_database, pub_tables, pub_comment,
|
506
|
+
# pub_time, sub_name, sub_time, status
|
507
|
+
return Subscription(
|
508
|
+
pub_name=row[0],
|
509
|
+
pub_account=row[1],
|
510
|
+
pub_database=row[2],
|
511
|
+
pub_tables=row[3],
|
512
|
+
pub_comment=row[4] if len(row) > 4 else None,
|
513
|
+
pub_time=row[5] if len(row) > 5 else None,
|
514
|
+
sub_name=row[6] if len(row) > 6 else None,
|
515
|
+
sub_time=row[7] if len(row) > 7 else None,
|
516
|
+
status=row[8] if len(row) > 8 else 0,
|
517
|
+
)
|
518
|
+
|
519
|
+
|
520
|
+
class AsyncPitrManager:
|
521
|
+
"""Async manager for PITR operations"""
|
522
|
+
|
523
|
+
def __init__(self, client):
|
524
|
+
self.client = client
|
525
|
+
|
526
|
+
async def create_cluster_pitr(self, name: str, range_value: int = 1, range_unit: str = "d") -> Pitr:
|
527
|
+
"""Create cluster-level PITR asynchronously"""
|
528
|
+
try:
|
529
|
+
self._validate_range(range_value, range_unit)
|
530
|
+
|
531
|
+
sql = f"CREATE PITR {self.client._escape_identifier(name)} " f"FOR CLUSTER RANGE {range_value} '{range_unit}'"
|
532
|
+
|
533
|
+
result = await self.client.execute(sql)
|
534
|
+
if result is None:
|
535
|
+
raise PitrError(f"Failed to create cluster PITR '{name}'")
|
536
|
+
|
537
|
+
return await self.get(name)
|
538
|
+
|
539
|
+
except Exception as e:
|
540
|
+
raise PitrError(f"Failed to create cluster PITR '{name}': {e}")
|
541
|
+
|
542
|
+
async def create_account_pitr(
|
543
|
+
self,
|
544
|
+
name: str,
|
545
|
+
account_name: Optional[str] = None,
|
546
|
+
range_value: int = 1,
|
547
|
+
range_unit: str = "d",
|
548
|
+
) -> Pitr:
|
549
|
+
"""Create account-level PITR asynchronously"""
|
550
|
+
try:
|
551
|
+
self._validate_range(range_value, range_unit)
|
552
|
+
|
553
|
+
if account_name:
|
554
|
+
sql = (
|
555
|
+
f"CREATE PITR {self.client._escape_identifier(name)} "
|
556
|
+
f"FOR ACCOUNT {self.client._escape_identifier(account_name)} "
|
557
|
+
f"RANGE {range_value} '{range_unit}'"
|
558
|
+
)
|
559
|
+
else:
|
560
|
+
sql = (
|
561
|
+
f"CREATE PITR {self.client._escape_identifier(name)} " f"FOR ACCOUNT RANGE {range_value} '{range_unit}'"
|
562
|
+
)
|
563
|
+
|
564
|
+
result = await self.client.execute(sql)
|
565
|
+
if result is None:
|
566
|
+
raise PitrError(f"Failed to create account PITR '{name}'")
|
567
|
+
|
568
|
+
return await self.get(name)
|
569
|
+
|
570
|
+
except Exception as e:
|
571
|
+
raise PitrError(f"Failed to create account PITR '{name}': {e}")
|
572
|
+
|
573
|
+
async def create_database_pitr(self, name: str, database_name: str, range_value: int = 1, range_unit: str = "d") -> Pitr:
|
574
|
+
"""Create database-level PITR asynchronously"""
|
575
|
+
try:
|
576
|
+
self._validate_range(range_value, range_unit)
|
577
|
+
|
578
|
+
sql = (
|
579
|
+
f"CREATE PITR {self.client._escape_identifier(name)} "
|
580
|
+
f"FOR DATABASE {self.client._escape_identifier(database_name)} "
|
581
|
+
f"RANGE {range_value} '{range_unit}'"
|
582
|
+
)
|
583
|
+
|
584
|
+
result = await self.client.execute(sql)
|
585
|
+
if result is None:
|
586
|
+
raise PitrError(f"Failed to create database PITR '{name}'")
|
587
|
+
|
588
|
+
return await self.get(name)
|
589
|
+
|
590
|
+
except Exception as e:
|
591
|
+
raise PitrError(f"Failed to create database PITR '{name}': {e}")
|
592
|
+
|
593
|
+
async def create_table_pitr(
|
594
|
+
self,
|
595
|
+
name: str,
|
596
|
+
database_name: str,
|
597
|
+
table_name: str,
|
598
|
+
range_value: int = 1,
|
599
|
+
range_unit: str = "d",
|
600
|
+
) -> Pitr:
|
601
|
+
"""Create table-level PITR asynchronously"""
|
602
|
+
try:
|
603
|
+
self._validate_range(range_value, range_unit)
|
604
|
+
|
605
|
+
sql = (
|
606
|
+
f"CREATE PITR {self.client._escape_identifier(name)} "
|
607
|
+
f"FOR TABLE {self.client._escape_identifier(database_name)} "
|
608
|
+
f"{self.client._escape_identifier(table_name)} "
|
609
|
+
f"RANGE {range_value} '{range_unit}'"
|
610
|
+
)
|
611
|
+
|
612
|
+
result = await self.client.execute(sql)
|
613
|
+
if result is None:
|
614
|
+
raise PitrError(f"Failed to create table PITR '{name}'")
|
615
|
+
|
616
|
+
return await self.get(name)
|
617
|
+
|
618
|
+
except Exception as e:
|
619
|
+
raise PitrError(f"Failed to create table PITR '{name}': {e}")
|
620
|
+
|
621
|
+
async def get(self, name: str) -> Pitr:
|
622
|
+
"""Get PITR by name asynchronously"""
|
623
|
+
try:
|
624
|
+
sql = f"SHOW PITR WHERE pitr_name = {self.client._escape_string(name)}"
|
625
|
+
result = await self.client.execute(sql)
|
626
|
+
|
627
|
+
if not result or not result.rows:
|
628
|
+
raise PitrError(f"PITR '{name}' not found")
|
629
|
+
|
630
|
+
row = result.rows[0]
|
631
|
+
return self._row_to_pitr(row)
|
632
|
+
|
633
|
+
except Exception as e:
|
634
|
+
raise PitrError(f"Failed to get PITR '{name}': {e}")
|
635
|
+
|
636
|
+
async def list(
|
637
|
+
self,
|
638
|
+
level: Optional[str] = None,
|
639
|
+
account_name: Optional[str] = None,
|
640
|
+
database_name: Optional[str] = None,
|
641
|
+
table_name: Optional[str] = None,
|
642
|
+
) -> List[Pitr]:
|
643
|
+
"""List PITRs with optional filters asynchronously"""
|
644
|
+
try:
|
645
|
+
conditions = []
|
646
|
+
|
647
|
+
if level:
|
648
|
+
conditions.append(f"pitr_level = {self.client._escape_string(level)}")
|
649
|
+
if account_name:
|
650
|
+
conditions.append(f"account_name = {self.client._escape_string(account_name)}")
|
651
|
+
if database_name:
|
652
|
+
conditions.append(f"database_name = {self.client._escape_string(database_name)}")
|
653
|
+
if table_name:
|
654
|
+
conditions.append(f"table_name = {self.client._escape_string(table_name)}")
|
655
|
+
|
656
|
+
if conditions:
|
657
|
+
where_clause = " WHERE " + " AND ".join(conditions)
|
658
|
+
else:
|
659
|
+
where_clause = ""
|
660
|
+
|
661
|
+
sql = f"SHOW PITR{where_clause}"
|
662
|
+
result = await self.client.execute(sql)
|
663
|
+
|
664
|
+
if not result or not result.rows:
|
665
|
+
return []
|
666
|
+
|
667
|
+
return [self._row_to_pitr(row) for row in result.rows]
|
668
|
+
|
669
|
+
except Exception as e:
|
670
|
+
raise PitrError(f"Failed to list PITRs: {e}")
|
671
|
+
|
672
|
+
async def alter(self, name: str, range_value: int, range_unit: str) -> Pitr:
|
673
|
+
"""Alter PITR range asynchronously"""
|
674
|
+
try:
|
675
|
+
self._validate_range(range_value, range_unit)
|
676
|
+
|
677
|
+
sql = f"ALTER PITR {self.client._escape_identifier(name)} " f"RANGE {range_value} '{range_unit}'"
|
678
|
+
|
679
|
+
result = await self.client.execute(sql)
|
680
|
+
if result is None:
|
681
|
+
raise PitrError(f"Failed to alter PITR '{name}'")
|
682
|
+
|
683
|
+
return await self.get(name)
|
684
|
+
|
685
|
+
except Exception as e:
|
686
|
+
raise PitrError(f"Failed to alter PITR '{name}': {e}")
|
687
|
+
|
688
|
+
async def delete(self, name: str) -> bool:
|
689
|
+
"""Delete PITR asynchronously"""
|
690
|
+
try:
|
691
|
+
sql = f"DROP PITR {self.client._escape_identifier(name)}"
|
692
|
+
result = await self.client.execute(sql)
|
693
|
+
return result is not None
|
694
|
+
|
695
|
+
except Exception as e:
|
696
|
+
raise PitrError(f"Failed to delete PITR '{name}': {e}")
|
697
|
+
|
698
|
+
def _validate_range(self, range_value: int, range_unit: str) -> None:
|
699
|
+
"""Validate PITR range parameters"""
|
700
|
+
if not (1 <= range_value <= 100):
|
701
|
+
raise PitrError("Range value must be between 1 and 100")
|
702
|
+
|
703
|
+
valid_units = ["h", "d", "mo", "y"]
|
704
|
+
if range_unit not in valid_units:
|
705
|
+
raise PitrError(f"Range unit must be one of: {', '.join(valid_units)}")
|
706
|
+
|
707
|
+
def _row_to_pitr(self, row: tuple) -> Pitr:
|
708
|
+
"""Convert database row to Pitr object"""
|
709
|
+
# Expected columns: pitr_name, created_time, modified_time, pitr_level,
|
710
|
+
# account_name, database_name, table_name, pitr_length, pitr_unit
|
711
|
+
return Pitr(
|
712
|
+
name=row[0],
|
713
|
+
created_time=row[1],
|
714
|
+
modified_time=row[2],
|
715
|
+
level=row[3],
|
716
|
+
account_name=row[4] if row[4] != "*" else None,
|
717
|
+
database_name=row[5] if row[5] != "*" else None,
|
718
|
+
table_name=row[6] if row[6] != "*" else None,
|
719
|
+
range_value=row[7],
|
720
|
+
range_unit=row[8],
|
721
|
+
)
|
722
|
+
|
723
|
+
|
724
|
+
class AsyncRestoreManager:
|
725
|
+
"""Async manager for restore operations"""
|
726
|
+
|
727
|
+
def __init__(self, client):
|
728
|
+
self.client = client
|
729
|
+
|
730
|
+
async def restore_cluster(self, snapshot_name: str) -> bool:
|
731
|
+
"""Restore entire cluster from snapshot asynchronously"""
|
732
|
+
try:
|
733
|
+
sql = f"RESTORE CLUSTER FROM SNAPSHOT {self.client._escape_identifier(snapshot_name)}"
|
734
|
+
result = await self.client.execute(sql)
|
735
|
+
return result is not None
|
736
|
+
except Exception as e:
|
737
|
+
raise RestoreError(f"Failed to restore cluster from snapshot '{snapshot_name}': {e}")
|
738
|
+
|
739
|
+
async def restore_tenant(self, snapshot_name: str, account_name: str, to_account: Optional[str] = None) -> bool:
|
740
|
+
"""Restore tenant from snapshot asynchronously"""
|
741
|
+
try:
|
742
|
+
if to_account:
|
743
|
+
sql = (
|
744
|
+
f"RESTORE ACCOUNT {self.client._escape_identifier(account_name)} "
|
745
|
+
f"FROM SNAPSHOT {self.client._escape_identifier(snapshot_name)} "
|
746
|
+
f"TO ACCOUNT {self.client._escape_identifier(to_account)}"
|
747
|
+
)
|
748
|
+
else:
|
749
|
+
sql = (
|
750
|
+
f"RESTORE ACCOUNT {self.client._escape_identifier(account_name)} "
|
751
|
+
f"FROM SNAPSHOT {self.client._escape_identifier(snapshot_name)}"
|
752
|
+
)
|
753
|
+
|
754
|
+
result = await self.client.execute(sql)
|
755
|
+
return result is not None
|
756
|
+
except Exception as e:
|
757
|
+
raise RestoreError(f"Failed to restore tenant '{account_name}' from snapshot '{snapshot_name}': {e}")
|
758
|
+
|
759
|
+
async def restore_database(
|
760
|
+
self,
|
761
|
+
snapshot_name: str,
|
762
|
+
account_name: str,
|
763
|
+
database_name: str,
|
764
|
+
to_account: Optional[str] = None,
|
765
|
+
) -> bool:
|
766
|
+
"""Restore database from snapshot asynchronously"""
|
767
|
+
try:
|
768
|
+
if to_account:
|
769
|
+
sql = (
|
770
|
+
f"RESTORE ACCOUNT {self.client._escape_identifier(account_name)} "
|
771
|
+
f"DATABASE {self.client._escape_identifier(database_name)} "
|
772
|
+
f"FROM SNAPSHOT {self.client._escape_identifier(snapshot_name)} "
|
773
|
+
f"TO ACCOUNT {self.client._escape_identifier(to_account)}"
|
774
|
+
)
|
775
|
+
else:
|
776
|
+
sql = (
|
777
|
+
f"RESTORE ACCOUNT {self.client._escape_identifier(account_name)} "
|
778
|
+
f"DATABASE {self.client._escape_identifier(database_name)} "
|
779
|
+
f"FROM SNAPSHOT {self.client._escape_identifier(snapshot_name)}"
|
780
|
+
)
|
781
|
+
|
782
|
+
result = await self.client.execute(sql)
|
783
|
+
return result is not None
|
784
|
+
except Exception as e:
|
785
|
+
raise RestoreError(f"Failed to restore database '{database_name}' from snapshot '{snapshot_name}': {e}")
|
786
|
+
|
787
|
+
async def restore_table(
|
788
|
+
self,
|
789
|
+
snapshot_name: str,
|
790
|
+
account_name: str,
|
791
|
+
database_name: str,
|
792
|
+
table_name: str,
|
793
|
+
to_account: Optional[str] = None,
|
794
|
+
) -> bool:
|
795
|
+
"""Restore table from snapshot asynchronously"""
|
796
|
+
try:
|
797
|
+
if to_account:
|
798
|
+
sql = (
|
799
|
+
f"RESTORE ACCOUNT {self.client._escape_identifier(account_name)} "
|
800
|
+
f"DATABASE {self.client._escape_identifier(database_name)} "
|
801
|
+
f"TABLE {self.client._escape_identifier(table_name)} "
|
802
|
+
f"FROM SNAPSHOT {self.client._escape_identifier(snapshot_name)} "
|
803
|
+
f"TO ACCOUNT {self.client._escape_identifier(to_account)}"
|
804
|
+
)
|
805
|
+
else:
|
806
|
+
sql = (
|
807
|
+
f"RESTORE ACCOUNT {self.client._escape_identifier(account_name)} "
|
808
|
+
f"DATABASE {self.client._escape_identifier(database_name)} "
|
809
|
+
f"TABLE {self.client._escape_identifier(table_name)} "
|
810
|
+
f"FROM SNAPSHOT {self.client._escape_identifier(snapshot_name)}"
|
811
|
+
)
|
812
|
+
|
813
|
+
result = await self.client.execute(sql)
|
814
|
+
return result is not None
|
815
|
+
except Exception as e:
|
816
|
+
raise RestoreError(f"Failed to restore table '{table_name}' from snapshot '{snapshot_name}': {e}")
|
817
|
+
|
818
|
+
|
819
|
+
class AsyncMoCtlManager:
|
820
|
+
"""Async mo_ctl manager"""
|
821
|
+
|
822
|
+
def __init__(self, client):
|
823
|
+
self.client = client
|
824
|
+
|
825
|
+
async def _execute_moctl(self, method: str, target: str, params: str = "") -> Dict[str, Any]:
|
826
|
+
"""Execute mo_ctl command asynchronously"""
|
827
|
+
import json
|
828
|
+
|
829
|
+
try:
|
830
|
+
# Build mo_ctl SQL command
|
831
|
+
if params:
|
832
|
+
sql = f"SELECT mo_ctl('{method}', '{target}', '{params}')"
|
833
|
+
else:
|
834
|
+
sql = f"SELECT mo_ctl('{method}', '{target}', '')"
|
835
|
+
|
836
|
+
# Execute the command
|
837
|
+
result = await self.client.execute(sql)
|
838
|
+
|
839
|
+
if not result.rows:
|
840
|
+
raise MoCtlError(f"mo_ctl command returned no results: {sql}")
|
841
|
+
|
842
|
+
# Parse the JSON result
|
843
|
+
result_str = result.rows[0][0]
|
844
|
+
parsed_result = json.loads(result_str)
|
845
|
+
|
846
|
+
# Check for errors in the result
|
847
|
+
if "result" in parsed_result and parsed_result["result"]:
|
848
|
+
first_result = parsed_result["result"][0]
|
849
|
+
if "returnStr" in first_result and first_result["returnStr"] != "OK":
|
850
|
+
raise MoCtlError(f"mo_ctl operation failed: {first_result['returnStr']}")
|
851
|
+
|
852
|
+
return parsed_result
|
853
|
+
|
854
|
+
except json.JSONDecodeError as e:
|
855
|
+
raise MoCtlError(f"Failed to parse mo_ctl result: {e}")
|
856
|
+
except Exception as e:
|
857
|
+
raise MoCtlError(f"mo_ctl operation failed: {e}")
|
858
|
+
|
859
|
+
async def flush_table(self, database: str, table: str) -> Dict[str, Any]:
|
860
|
+
"""Force flush table asynchronously"""
|
861
|
+
table_ref = f"{database}.{table}"
|
862
|
+
return await self._execute_moctl("dn", "flush", table_ref)
|
863
|
+
|
864
|
+
async def increment_checkpoint(self) -> Dict[str, Any]:
|
865
|
+
"""Force incremental checkpoint asynchronously"""
|
866
|
+
return await self._execute_moctl("dn", "checkpoint", "")
|
867
|
+
|
868
|
+
async def global_checkpoint(self) -> Dict[str, Any]:
|
869
|
+
"""Force global checkpoint asynchronously"""
|
870
|
+
return await self._execute_moctl("dn", "globalcheckpoint", "")
|
871
|
+
|
872
|
+
|
873
|
+
class AsyncAccountManager:
|
874
|
+
"""Async manager for MatrixOne account operations"""
|
875
|
+
|
876
|
+
def __init__(self, client):
|
877
|
+
self.client = client
|
878
|
+
|
879
|
+
async def create_account(
|
880
|
+
self,
|
881
|
+
account_name: str,
|
882
|
+
admin_name: str,
|
883
|
+
password: str,
|
884
|
+
comment: Optional[str] = None,
|
885
|
+
admin_comment: Optional[str] = None,
|
886
|
+
admin_host: str = "%",
|
887
|
+
admin_identified_by: Optional[str] = None,
|
888
|
+
) -> Account:
|
889
|
+
"""Create a new account asynchronously"""
|
890
|
+
try:
|
891
|
+
sql_parts = [f"CREATE ACCOUNT {self.client._escape_identifier(account_name)}"]
|
892
|
+
sql_parts.append(f"ADMIN_NAME {self.client._escape_string(admin_name)}")
|
893
|
+
sql_parts.append(f"IDENTIFIED BY {self.client._escape_string(password)}")
|
894
|
+
|
895
|
+
if admin_host != "%":
|
896
|
+
sql_parts.append(f"ADMIN_HOST {self.client._escape_string(admin_host)}")
|
897
|
+
|
898
|
+
if comment:
|
899
|
+
sql_parts.append(f"COMMENT {self.client._escape_string(comment)}")
|
900
|
+
|
901
|
+
if admin_comment:
|
902
|
+
sql_parts.append(f"ADMIN_COMMENT {self.client._escape_string(admin_comment)}")
|
903
|
+
|
904
|
+
if admin_identified_by:
|
905
|
+
sql_parts.append(f"ADMIN_IDENTIFIED BY {self.client._escape_string(admin_identified_by)}")
|
906
|
+
|
907
|
+
sql = " ".join(sql_parts)
|
908
|
+
await self.client.execute(sql)
|
909
|
+
|
910
|
+
return await self.get_account(account_name)
|
911
|
+
|
912
|
+
except Exception as e:
|
913
|
+
raise AccountError(f"Failed to create account '{account_name}': {e}")
|
914
|
+
|
915
|
+
async def drop_account(self, account_name: str) -> None:
|
916
|
+
"""Drop an account asynchronously"""
|
917
|
+
try:
|
918
|
+
sql = f"DROP ACCOUNT {self.client._escape_identifier(account_name)}"
|
919
|
+
await self.client.execute(sql)
|
920
|
+
except Exception as e:
|
921
|
+
raise AccountError(f"Failed to drop account '{account_name}': {e}")
|
922
|
+
|
923
|
+
async def alter_account(
|
924
|
+
self,
|
925
|
+
account_name: str,
|
926
|
+
comment: Optional[str] = None,
|
927
|
+
suspend: Optional[bool] = None,
|
928
|
+
suspend_reason: Optional[str] = None,
|
929
|
+
) -> Account:
|
930
|
+
"""Alter an account asynchronously"""
|
931
|
+
try:
|
932
|
+
sql_parts = [f"ALTER ACCOUNT {self.client._escape_identifier(account_name)}"]
|
933
|
+
|
934
|
+
if comment is not None:
|
935
|
+
sql_parts.append(f"COMMENT {self.client._escape_string(comment)}")
|
936
|
+
|
937
|
+
if suspend is not None:
|
938
|
+
if suspend:
|
939
|
+
if suspend_reason:
|
940
|
+
sql_parts.append(f"SUSPEND COMMENT {self.client._escape_string(suspend_reason)}")
|
941
|
+
else:
|
942
|
+
sql_parts.append("SUSPEND")
|
943
|
+
else:
|
944
|
+
sql_parts.append("OPEN")
|
945
|
+
|
946
|
+
sql = " ".join(sql_parts)
|
947
|
+
await self.client.execute(sql)
|
948
|
+
|
949
|
+
return await self.get_account(account_name)
|
950
|
+
|
951
|
+
except Exception as e:
|
952
|
+
raise AccountError(f"Failed to alter account '{account_name}': {e}")
|
953
|
+
|
954
|
+
async def get_account(self, account_name: str) -> Account:
|
955
|
+
"""Get account by name asynchronously"""
|
956
|
+
try:
|
957
|
+
sql = "SHOW ACCOUNTS"
|
958
|
+
result = await self.client.execute(sql)
|
959
|
+
|
960
|
+
if not result or not result.rows:
|
961
|
+
raise AccountError(f"Account '{account_name}' not found")
|
962
|
+
|
963
|
+
for row in result.rows:
|
964
|
+
if row[0] == account_name:
|
965
|
+
return self._row_to_account(row)
|
966
|
+
|
967
|
+
raise AccountError(f"Account '{account_name}' not found")
|
968
|
+
|
969
|
+
except Exception as e:
|
970
|
+
raise AccountError(f"Failed to get account '{account_name}': {e}")
|
971
|
+
|
972
|
+
async def list_accounts(self) -> List[Account]:
|
973
|
+
"""List all accounts asynchronously"""
|
974
|
+
try:
|
975
|
+
sql = "SHOW ACCOUNTS"
|
976
|
+
result = await self.client.execute(sql)
|
977
|
+
|
978
|
+
if not result or not result.rows:
|
979
|
+
return []
|
980
|
+
|
981
|
+
return [self._row_to_account(row) for row in result.rows]
|
982
|
+
|
983
|
+
except Exception as e:
|
984
|
+
raise AccountError(f"Failed to list accounts: {e}")
|
985
|
+
|
986
|
+
async def create_user(self, user_name: str, password: str, comment: Optional[str] = None) -> User:
|
987
|
+
"""
|
988
|
+
Create a new user asynchronously according to MatrixOne CREATE USER syntax:
|
989
|
+
CREATE USER [IF NOT EXISTS] user auth_option [, user auth_option] ...
|
990
|
+
[DEFAULT ROLE rolename] [COMMENT 'comment_string' | ATTRIBUTE 'json_object']
|
991
|
+
|
992
|
+
Args::
|
993
|
+
|
994
|
+
user_name: Name of the user to create
|
995
|
+
password: Password for the user
|
996
|
+
comment: Comment for the user (not supported in MatrixOne)
|
997
|
+
|
998
|
+
Returns::
|
999
|
+
|
1000
|
+
User: Created user object
|
1001
|
+
"""
|
1002
|
+
try:
|
1003
|
+
# Build CREATE USER statement according to MatrixOne syntax
|
1004
|
+
# MatrixOne syntax: CREATE USER user_name IDENTIFIED BY 'password'
|
1005
|
+
sql_parts = [f"CREATE USER {self.client._escape_identifier(user_name)}"]
|
1006
|
+
|
1007
|
+
sql_parts.append(f"IDENTIFIED BY {self.client._escape_string(password)}")
|
1008
|
+
|
1009
|
+
# Note: MatrixOne doesn't support COMMENT or ATTRIBUTE clauses in CREATE USER
|
1010
|
+
# if comment:
|
1011
|
+
# sql_parts.append(f"COMMENT {self.client._escape_string(comment)}")
|
1012
|
+
# if identified_by:
|
1013
|
+
# sql_parts.append(f"IDENTIFIED BY {self.client._escape_string(identified_by)}")
|
1014
|
+
|
1015
|
+
sql = " ".join(sql_parts)
|
1016
|
+
await self.client.execute(sql)
|
1017
|
+
|
1018
|
+
# Return a User object with current account context
|
1019
|
+
return User(
|
1020
|
+
name=user_name,
|
1021
|
+
host="%", # Default host
|
1022
|
+
account="sys", # Default account
|
1023
|
+
created_time=datetime.now(),
|
1024
|
+
status="ACTIVE",
|
1025
|
+
comment=comment,
|
1026
|
+
)
|
1027
|
+
|
1028
|
+
except Exception as e:
|
1029
|
+
raise AccountError(f"Failed to create user '{user_name}': {e}")
|
1030
|
+
|
1031
|
+
async def drop_user(self, user_name: str, if_exists: bool = False) -> None:
|
1032
|
+
"""
|
1033
|
+
Drop a user asynchronously according to MatrixOne DROP USER syntax:
|
1034
|
+
DROP USER [IF EXISTS] user [, user] ...
|
1035
|
+
|
1036
|
+
Args::
|
1037
|
+
|
1038
|
+
user_name: Name of the user to drop
|
1039
|
+
if_exists: If True, add IF EXISTS clause to avoid errors when user doesn't exist
|
1040
|
+
"""
|
1041
|
+
try:
|
1042
|
+
sql_parts = ["DROP USER"]
|
1043
|
+
if if_exists:
|
1044
|
+
sql_parts.append("IF EXISTS")
|
1045
|
+
|
1046
|
+
sql_parts.append(self.client._escape_identifier(user_name))
|
1047
|
+
sql = " ".join(sql_parts)
|
1048
|
+
await self.client.execute(sql)
|
1049
|
+
|
1050
|
+
except Exception as e:
|
1051
|
+
raise AccountError(f"Failed to drop user '{user_name}': {e}")
|
1052
|
+
|
1053
|
+
async def alter_user(
|
1054
|
+
self,
|
1055
|
+
user_name: str,
|
1056
|
+
password: Optional[str] = None,
|
1057
|
+
comment: Optional[str] = None,
|
1058
|
+
lock: Optional[bool] = None,
|
1059
|
+
lock_reason: Optional[str] = None,
|
1060
|
+
) -> User:
|
1061
|
+
"""Alter a user asynchronously"""
|
1062
|
+
try:
|
1063
|
+
sql_parts = [f"ALTER USER {self.client._escape_identifier(user_name)}"]
|
1064
|
+
|
1065
|
+
if password is not None:
|
1066
|
+
sql_parts.append(f"IDENTIFIED BY {self.client._escape_string(password)}")
|
1067
|
+
|
1068
|
+
if comment is not None:
|
1069
|
+
sql_parts.append(f"COMMENT {self.client._escape_string(comment)}")
|
1070
|
+
|
1071
|
+
if lock is not None:
|
1072
|
+
if lock:
|
1073
|
+
if lock_reason:
|
1074
|
+
sql_parts.append(f"ACCOUNT LOCK COMMENT {self.client._escape_string(lock_reason)}")
|
1075
|
+
else:
|
1076
|
+
sql_parts.append("ACCOUNT LOCK")
|
1077
|
+
else:
|
1078
|
+
sql_parts.append("ACCOUNT UNLOCK")
|
1079
|
+
|
1080
|
+
sql = " ".join(sql_parts)
|
1081
|
+
await self.client.execute(sql)
|
1082
|
+
|
1083
|
+
return await self.get_user(user_name)
|
1084
|
+
|
1085
|
+
except Exception as e:
|
1086
|
+
raise AccountError(f"Failed to alter user '{user_name}': {e}")
|
1087
|
+
|
1088
|
+
async def get_user(self, user_name: str) -> User:
|
1089
|
+
"""Get user by name asynchronously"""
|
1090
|
+
try:
|
1091
|
+
sql = "SHOW GRANTS"
|
1092
|
+
result = await self.client.execute(sql)
|
1093
|
+
|
1094
|
+
if not result or not result.rows:
|
1095
|
+
raise AccountError(f"User '{user_name}' not found")
|
1096
|
+
|
1097
|
+
for row in result.rows:
|
1098
|
+
if row[0] == user_name:
|
1099
|
+
return self._row_to_user(row)
|
1100
|
+
|
1101
|
+
raise AccountError(f"User '{user_name}' not found")
|
1102
|
+
|
1103
|
+
except Exception as e:
|
1104
|
+
raise AccountError(f"Failed to get user '{user_name}': {e}")
|
1105
|
+
|
1106
|
+
async def list_users(self, account_name: Optional[str] = None) -> List[User]:
|
1107
|
+
"""List users with optional account filter asynchronously"""
|
1108
|
+
try:
|
1109
|
+
sql = "SHOW GRANTS"
|
1110
|
+
result = await self.client.execute(sql)
|
1111
|
+
|
1112
|
+
if not result or not result.rows:
|
1113
|
+
return []
|
1114
|
+
|
1115
|
+
users = [self._row_to_user(row) for row in result.rows]
|
1116
|
+
|
1117
|
+
if account_name:
|
1118
|
+
users = [user for user in users if user.account == account_name]
|
1119
|
+
|
1120
|
+
return users
|
1121
|
+
|
1122
|
+
except Exception as e:
|
1123
|
+
raise AccountError(f"Failed to list users: {e}")
|
1124
|
+
|
1125
|
+
def _row_to_account(self, row: tuple) -> Account:
|
1126
|
+
"""Convert database row to Account object"""
|
1127
|
+
return Account(
|
1128
|
+
name=row[0],
|
1129
|
+
admin_name=row[1],
|
1130
|
+
created_time=row[2] if len(row) > 2 else None,
|
1131
|
+
status=row[3] if len(row) > 3 else None,
|
1132
|
+
comment=row[4] if len(row) > 4 else None,
|
1133
|
+
suspended_time=row[5] if len(row) > 5 else None,
|
1134
|
+
suspended_reason=row[6] if len(row) > 6 else None,
|
1135
|
+
)
|
1136
|
+
|
1137
|
+
def _row_to_user(self, row: tuple) -> User:
|
1138
|
+
"""Convert database row to User object"""
|
1139
|
+
return User(
|
1140
|
+
name=row[0],
|
1141
|
+
host=row[1],
|
1142
|
+
account=row[2],
|
1143
|
+
created_time=row[3] if len(row) > 3 else None,
|
1144
|
+
status=row[4] if len(row) > 4 else None,
|
1145
|
+
comment=row[5] if len(row) > 5 else None,
|
1146
|
+
locked_time=row[6] if len(row) > 6 else None,
|
1147
|
+
locked_reason=row[7] if len(row) > 7 else None,
|
1148
|
+
)
|
1149
|
+
|
1150
|
+
|
1151
|
+
class AsyncFulltextIndexManager:
|
1152
|
+
"""Async fulltext index manager for client chain operations"""
|
1153
|
+
|
1154
|
+
def __init__(self, client: "AsyncClient"):
|
1155
|
+
"""Initialize async fulltext index manager"""
|
1156
|
+
self.client = client
|
1157
|
+
|
1158
|
+
async def create(
|
1159
|
+
self, table_name_or_model, name: str, columns: Union[str, List[str]], algorithm: str = "TF-IDF"
|
1160
|
+
) -> "AsyncFulltextIndexManager":
|
1161
|
+
"""
|
1162
|
+
Create a fulltext index using chain operations.
|
1163
|
+
|
1164
|
+
Args::
|
1165
|
+
|
1166
|
+
table_name_or_model: Either a table name (str) or a SQLAlchemy model class
|
1167
|
+
name: Index name
|
1168
|
+
columns: Column(s) to index
|
1169
|
+
algorithm: Fulltext algorithm type (TF-IDF or BM25)
|
1170
|
+
|
1171
|
+
Returns::
|
1172
|
+
|
1173
|
+
AsyncFulltextIndexManager: Self for chaining
|
1174
|
+
|
1175
|
+
Example
|
1176
|
+
|
1177
|
+
# Create fulltext index by table name
|
1178
|
+
await client.fulltext_index.create("articles", name="idx_content", columns=["title", "content"])
|
1179
|
+
|
1180
|
+
# Create fulltext index by model class
|
1181
|
+
await client.fulltext_index.create(ArticleModel, name="idx_content", columns=["title", "content"])
|
1182
|
+
"""
|
1183
|
+
from sqlalchemy import text
|
1184
|
+
|
1185
|
+
# Handle model class input
|
1186
|
+
if hasattr(table_name_or_model, '__tablename__'):
|
1187
|
+
table_name = table_name_or_model.__tablename__
|
1188
|
+
else:
|
1189
|
+
table_name = table_name_or_model
|
1190
|
+
|
1191
|
+
try:
|
1192
|
+
if isinstance(columns, str):
|
1193
|
+
columns = [columns]
|
1194
|
+
|
1195
|
+
columns_str = ", ".join(columns)
|
1196
|
+
sql = f"CREATE FULLTEXT INDEX {name} ON {table_name} ({columns_str})"
|
1197
|
+
|
1198
|
+
async with self.client.get_sqlalchemy_engine().begin() as conn:
|
1199
|
+
await conn.execute(text(sql))
|
1200
|
+
|
1201
|
+
return self
|
1202
|
+
except Exception as e:
|
1203
|
+
raise Exception(f"Failed to create fulltext index {name} on table {table_name}: {e}")
|
1204
|
+
|
1205
|
+
async def drop(self, table_name_or_model, name: str) -> "AsyncFulltextIndexManager":
|
1206
|
+
"""
|
1207
|
+
Drop a fulltext index using chain operations.
|
1208
|
+
|
1209
|
+
Args::
|
1210
|
+
|
1211
|
+
table_name_or_model: Either a table name (str) or a SQLAlchemy model class
|
1212
|
+
name: Index name
|
1213
|
+
|
1214
|
+
Returns::
|
1215
|
+
|
1216
|
+
AsyncFulltextIndexManager: Self for chaining
|
1217
|
+
|
1218
|
+
Example
|
1219
|
+
|
1220
|
+
# Drop fulltext index by table name
|
1221
|
+
await client.fulltext_index.drop("articles", "idx_content")
|
1222
|
+
|
1223
|
+
# Drop fulltext index by model class
|
1224
|
+
await client.fulltext_index.drop(ArticleModel, "idx_content")
|
1225
|
+
"""
|
1226
|
+
from sqlalchemy import text
|
1227
|
+
|
1228
|
+
# Handle model class input
|
1229
|
+
if hasattr(table_name_or_model, '__tablename__'):
|
1230
|
+
table_name = table_name_or_model.__tablename__
|
1231
|
+
else:
|
1232
|
+
table_name = table_name_or_model
|
1233
|
+
|
1234
|
+
try:
|
1235
|
+
sql = f"DROP INDEX {name} ON {table_name}"
|
1236
|
+
|
1237
|
+
async with self.client.get_sqlalchemy_engine().begin() as conn:
|
1238
|
+
await conn.execute(text(sql))
|
1239
|
+
|
1240
|
+
return self
|
1241
|
+
except Exception as e:
|
1242
|
+
raise Exception(f"Failed to drop fulltext index {name} from table {table_name}: {e}")
|
1243
|
+
|
1244
|
+
async def enable_fulltext(self) -> "AsyncFulltextIndexManager":
|
1245
|
+
"""
|
1246
|
+
Enable fulltext indexing with chain operations.
|
1247
|
+
|
1248
|
+
Returns::
|
1249
|
+
|
1250
|
+
AsyncFulltextIndexManager: Self for chaining
|
1251
|
+
"""
|
1252
|
+
try:
|
1253
|
+
await self.client.execute("SET experimental_fulltext_index = 1")
|
1254
|
+
return self
|
1255
|
+
except Exception as e:
|
1256
|
+
raise Exception(f"Failed to enable fulltext indexing: {e}")
|
1257
|
+
|
1258
|
+
async def disable_fulltext(self) -> "AsyncFulltextIndexManager":
|
1259
|
+
"""
|
1260
|
+
Disable fulltext indexing with chain operations.
|
1261
|
+
|
1262
|
+
Returns::
|
1263
|
+
|
1264
|
+
AsyncFulltextIndexManager: Self for chaining
|
1265
|
+
"""
|
1266
|
+
try:
|
1267
|
+
await self.client.execute("SET experimental_fulltext_index = 0")
|
1268
|
+
return self
|
1269
|
+
except Exception as e:
|
1270
|
+
raise Exception(f"Failed to disable fulltext indexing: {e}")
|
1271
|
+
|
1272
|
+
|
1273
|
+
class AsyncTransactionFulltextIndexManager(AsyncFulltextIndexManager):
|
1274
|
+
"""Async fulltext index manager that executes operations within a transaction"""
|
1275
|
+
|
1276
|
+
def __init__(self, client: "AsyncClient", transaction_wrapper):
|
1277
|
+
"""Initialize async transaction fulltext index manager"""
|
1278
|
+
super().__init__(client)
|
1279
|
+
self.transaction_wrapper = transaction_wrapper
|
1280
|
+
|
1281
|
+
async def create(
|
1282
|
+
self, table_name: str, name: str, columns: Union[str, List[str]], algorithm: str = "TF-IDF"
|
1283
|
+
) -> "AsyncTransactionFulltextIndexManager":
|
1284
|
+
"""Create a fulltext index within transaction"""
|
1285
|
+
try:
|
1286
|
+
if isinstance(columns, str):
|
1287
|
+
columns = [columns]
|
1288
|
+
|
1289
|
+
columns_str = ", ".join(columns)
|
1290
|
+
sql = f"CREATE FULLTEXT INDEX {name} ON {table_name} ({columns_str})"
|
1291
|
+
|
1292
|
+
await self.transaction_wrapper.execute(sql)
|
1293
|
+
return self
|
1294
|
+
except Exception as e:
|
1295
|
+
raise Exception(f"Failed to create fulltext index {name} on table {table_name} in transaction: {e}")
|
1296
|
+
|
1297
|
+
async def drop(self, table_name: str, name: str) -> "AsyncTransactionFulltextIndexManager":
|
1298
|
+
"""Drop a fulltext index within transaction"""
|
1299
|
+
try:
|
1300
|
+
sql = f"DROP INDEX {name} ON {table_name}"
|
1301
|
+
await self.transaction_wrapper.execute(sql)
|
1302
|
+
return self
|
1303
|
+
except Exception as e:
|
1304
|
+
raise Exception(f"Failed to drop fulltext index {name} from table {table_name} in transaction: {e}")
|
1305
|
+
|
1306
|
+
|
1307
|
+
class AsyncClientExecutor(BaseMatrixOneExecutor):
|
1308
|
+
"""Async client executor that uses AsyncClient's execute method"""
|
1309
|
+
|
1310
|
+
def __init__(self, client):
|
1311
|
+
super().__init__(client)
|
1312
|
+
self.client = client
|
1313
|
+
|
1314
|
+
async def _execute(self, sql: str):
|
1315
|
+
return await self.client.execute(sql)
|
1316
|
+
|
1317
|
+
def _get_empty_result(self):
|
1318
|
+
return AsyncResultSet([], [], affected_rows=0)
|
1319
|
+
|
1320
|
+
async def insert(self, table_name: str, data: dict):
|
1321
|
+
"""Async insert method"""
|
1322
|
+
sql = self.base_client._build_insert_sql(table_name, data)
|
1323
|
+
return await self._execute(sql)
|
1324
|
+
|
1325
|
+
async def batch_insert(self, table_name: str, data_list: list):
|
1326
|
+
"""Async batch insert method"""
|
1327
|
+
if not data_list:
|
1328
|
+
return self._get_empty_result()
|
1329
|
+
|
1330
|
+
sql = self.base_client._build_batch_insert_sql(table_name, data_list)
|
1331
|
+
return await self._execute(sql)
|
1332
|
+
|
1333
|
+
|
1334
|
+
class AsyncClient(BaseMatrixOneClient):
|
1335
|
+
"""
|
1336
|
+
MatrixOne Async Client - Asynchronous interface for MatrixOne database operations.
|
1337
|
+
|
1338
|
+
This class provides a comprehensive asynchronous interface for connecting to and
|
1339
|
+
interacting with MatrixOne databases. It supports modern async/await patterns
|
1340
|
+
including table creation, data insertion, querying, vector operations, and
|
1341
|
+
transaction management.
|
1342
|
+
|
1343
|
+
Key Features:
|
1344
|
+
|
1345
|
+
- Asynchronous connection management with connection pooling
|
1346
|
+
- High-level table operations (create_table, drop_table, insert, batch_insert)
|
1347
|
+
- Query builder interface for complex async queries
|
1348
|
+
- Vector operations (similarity search, range search, indexing)
|
1349
|
+
- Async transaction management with context managers
|
1350
|
+
- Snapshot and restore operations
|
1351
|
+
- Account and user management
|
1352
|
+
- Fulltext search capabilities
|
1353
|
+
- Non-blocking I/O operations
|
1354
|
+
|
1355
|
+
Supported Operations:
|
1356
|
+
|
1357
|
+
- Async connection and disconnection
|
1358
|
+
- Async query execution (SELECT, INSERT, UPDATE, DELETE)
|
1359
|
+
- Async batch operations
|
1360
|
+
- Async transaction management
|
1361
|
+
- Async table creation and management
|
1362
|
+
- Async vector and fulltext operations
|
1363
|
+
- Async snapshot and restore operations
|
1364
|
+
|
1365
|
+
Usage Examples::
|
1366
|
+
|
1367
|
+
Basic async usage::
|
1368
|
+
|
1369
|
+
async def main():
|
1370
|
+
client = AsyncClient()
|
1371
|
+
await client.connect('localhost', 6001, 'root', '111', 'test')
|
1372
|
+
|
1373
|
+
# Create table using high-level API
|
1374
|
+
await client.create_table("users", {
|
1375
|
+
"id": "int primary key",
|
1376
|
+
"name": "varchar(100)",
|
1377
|
+
"email": "varchar(255)"
|
1378
|
+
})
|
1379
|
+
|
1380
|
+
# Insert data
|
1381
|
+
await client.insert("users", {"id": 1, "name": "John", "email": "john@example.com"})
|
1382
|
+
|
1383
|
+
# Query data
|
1384
|
+
result = await client.query("users").where("id = ?", 1).all()
|
1385
|
+
print(result.rows)
|
1386
|
+
|
1387
|
+
await client.disconnect()
|
1388
|
+
|
1389
|
+
Vector operations::
|
1390
|
+
|
1391
|
+
async def vector_example():
|
1392
|
+
client = AsyncClient()
|
1393
|
+
await client.connect('localhost', 6001, 'root', '111', 'test')
|
1394
|
+
|
1395
|
+
# Create vector table
|
1396
|
+
await client.create_table("documents", {
|
1397
|
+
"id": "int primary key",
|
1398
|
+
"content": "text",
|
1399
|
+
"embedding": "vecf32(384)"
|
1400
|
+
})
|
1401
|
+
|
1402
|
+
# Vector similarity search
|
1403
|
+
results = await client.vector_ops.similarity_search(
|
1404
|
+
"documents",
|
1405
|
+
vector_column="embedding",
|
1406
|
+
query_vector=[0.1, 0.2, 0.3, ...], # 384-dimensional vector
|
1407
|
+
limit=10,
|
1408
|
+
distance_type="l2"
|
1409
|
+
)
|
1410
|
+
|
1411
|
+
await client.disconnect()
|
1412
|
+
|
1413
|
+
Async transaction usage::
|
1414
|
+
|
1415
|
+
async def transaction_example():
|
1416
|
+
client = AsyncClient()
|
1417
|
+
await client.connect('localhost', 6001, 'root', '111', 'test')
|
1418
|
+
|
1419
|
+
async with client.transaction() as tx:
|
1420
|
+
await tx.execute("INSERT INTO users (name) VALUES (?)", ("John",))
|
1421
|
+
await tx.execute("INSERT INTO orders (user_id, amount) VALUES (?, ?)", (1, 100.0))
|
1422
|
+
# Transaction commits automatically on success
|
1423
|
+
|
1424
|
+
Note: This class requires asyncio and async database drivers. Use the synchronous Client class
|
1425
|
+
for blocking operations or when async support is not needed.
|
1426
|
+
"""
|
1427
|
+
|
1428
|
+
def __init__(
|
1429
|
+
self,
|
1430
|
+
connection_timeout: int = 30,
|
1431
|
+
query_timeout: int = 300,
|
1432
|
+
auto_commit: bool = True,
|
1433
|
+
charset: str = "utf8mb4",
|
1434
|
+
logger: Optional[MatrixOneLogger] = None,
|
1435
|
+
sql_log_mode: str = "auto",
|
1436
|
+
slow_query_threshold: float = 1.0,
|
1437
|
+
max_sql_display_length: int = 500,
|
1438
|
+
):
|
1439
|
+
"""
|
1440
|
+
Initialize MatrixOne async client
|
1441
|
+
|
1442
|
+
Args::
|
1443
|
+
|
1444
|
+
connection_timeout: Connection timeout in seconds
|
1445
|
+
query_timeout: Query timeout in seconds
|
1446
|
+
auto_commit: Enable auto-commit mode
|
1447
|
+
charset: Character set for connection
|
1448
|
+
logger: Custom logger instance. If None, creates a default logger
|
1449
|
+
sql_log_mode: SQL logging mode ('off', 'auto', 'simple', 'full')
|
1450
|
+
- 'off': No SQL logging
|
1451
|
+
- 'auto': Smart logging - short SQL shown fully, long SQL summarized (default)
|
1452
|
+
- 'simple': Show operation summary only
|
1453
|
+
- 'full': Show complete SQL regardless of length
|
1454
|
+
slow_query_threshold: Threshold in seconds for slow query warnings (default: 1.0)
|
1455
|
+
max_sql_display_length: Maximum SQL length in auto mode before summarizing (default: 500)
|
1456
|
+
"""
|
1457
|
+
self.connection_timeout = connection_timeout
|
1458
|
+
self.query_timeout = query_timeout
|
1459
|
+
self.auto_commit = auto_commit
|
1460
|
+
self.charset = charset
|
1461
|
+
|
1462
|
+
# Initialize logger
|
1463
|
+
if logger is not None:
|
1464
|
+
self.logger = logger
|
1465
|
+
else:
|
1466
|
+
self.logger = create_default_logger(
|
1467
|
+
sql_log_mode=sql_log_mode,
|
1468
|
+
slow_query_threshold=slow_query_threshold,
|
1469
|
+
max_sql_display_length=max_sql_display_length,
|
1470
|
+
)
|
1471
|
+
|
1472
|
+
# Connection management - using SQLAlchemy async engine instead of direct aiomysql connection
|
1473
|
+
self._engine = None
|
1474
|
+
self._connection = None # Keep for backward compatibility, but will be managed by engine
|
1475
|
+
self._connection_params = {}
|
1476
|
+
self._login_info = None
|
1477
|
+
|
1478
|
+
# Initialize managers
|
1479
|
+
self._snapshots = AsyncSnapshotManager(self)
|
1480
|
+
self._clone = AsyncCloneManager(self)
|
1481
|
+
self._moctl = AsyncMoCtlManager(self)
|
1482
|
+
self._restore = AsyncRestoreManager(self)
|
1483
|
+
self._pitr = AsyncPitrManager(self)
|
1484
|
+
self._pubsub = AsyncPubSubManager(self)
|
1485
|
+
self._account = AsyncAccountManager(self)
|
1486
|
+
self._fulltext_index = None
|
1487
|
+
self._metadata = None
|
1488
|
+
|
1489
|
+
async def connect(
|
1490
|
+
self,
|
1491
|
+
host: str,
|
1492
|
+
port: int,
|
1493
|
+
user: str,
|
1494
|
+
password: str,
|
1495
|
+
database: str = None,
|
1496
|
+
account: Optional[str] = None,
|
1497
|
+
role: Optional[str] = None,
|
1498
|
+
charset: str = "utf8mb4",
|
1499
|
+
connection_timeout: int = 30,
|
1500
|
+
auto_commit: bool = True,
|
1501
|
+
on_connect: Optional[Union[ConnectionHook, List[Union[ConnectionAction, str]], Callable]] = None,
|
1502
|
+
):
|
1503
|
+
"""
|
1504
|
+
Connect to MatrixOne database asynchronously
|
1505
|
+
|
1506
|
+
Args::
|
1507
|
+
|
1508
|
+
host: Database host
|
1509
|
+
port: Database port
|
1510
|
+
user: Username or login info in format "user", "account#user", or "account#user#role"
|
1511
|
+
password: Password
|
1512
|
+
database: Database name
|
1513
|
+
account: Optional account name (will be combined with user if user doesn't contain '#')
|
1514
|
+
role: Optional role name (will be combined with user if user doesn't contain '#')
|
1515
|
+
charset: Character set for the connection (default: utf8mb4)
|
1516
|
+
connection_timeout: Connection timeout in seconds (default: 30)
|
1517
|
+
auto_commit: Enable autocommit (default: True)
|
1518
|
+
on_connect: Connection hook to execute after successful connection.
|
1519
|
+
Can be:
|
1520
|
+
- ConnectionHook instance
|
1521
|
+
- List of ConnectionAction or string action names
|
1522
|
+
- Custom callback function (async or sync)
|
1523
|
+
|
1524
|
+
Examples::
|
1525
|
+
|
1526
|
+
# Enable all features after connection
|
1527
|
+
await client.connect(host, port, user, password, database,
|
1528
|
+
on_connect=[ConnectionAction.ENABLE_ALL])
|
1529
|
+
|
1530
|
+
# Enable only vector operations with custom charset
|
1531
|
+
await client.connect(host, port, user, password, database,
|
1532
|
+
charset="utf8mb4",
|
1533
|
+
on_connect=[ConnectionAction.ENABLE_VECTOR])
|
1534
|
+
|
1535
|
+
# Custom async callback
|
1536
|
+
async def my_callback(client):
|
1537
|
+
print(f"Connected to {client._connection_params['host']}")
|
1538
|
+
|
1539
|
+
await client.connect(host, port, user, password, database,
|
1540
|
+
on_connect=my_callback)
|
1541
|
+
"""
|
1542
|
+
try:
|
1543
|
+
# Build final login info based on user parameter and optional account/role
|
1544
|
+
final_user, parsed_info = self._build_login_info(user, account, role)
|
1545
|
+
|
1546
|
+
# Store parsed info for later use
|
1547
|
+
self._login_info = parsed_info
|
1548
|
+
|
1549
|
+
# Store connection parameters for engine creation
|
1550
|
+
self._connection_params = {
|
1551
|
+
"host": host,
|
1552
|
+
"port": port,
|
1553
|
+
"user": final_user,
|
1554
|
+
"password": password,
|
1555
|
+
"database": database,
|
1556
|
+
"charset": charset,
|
1557
|
+
"autocommit": auto_commit,
|
1558
|
+
"connect_timeout": connection_timeout,
|
1559
|
+
}
|
1560
|
+
|
1561
|
+
# Create SQLAlchemy async engine instead of direct aiomysql connection
|
1562
|
+
self._engine = self._create_async_engine()
|
1563
|
+
|
1564
|
+
# Test the connection by executing a simple query
|
1565
|
+
async with self._engine.begin() as conn:
|
1566
|
+
await conn.execute(text("SELECT 1"))
|
1567
|
+
|
1568
|
+
# Initialize vector managers after successful connection
|
1569
|
+
self._initialize_vector_managers()
|
1570
|
+
|
1571
|
+
# Initialize metadata manager after successful connection
|
1572
|
+
self._metadata = AsyncMetadataManager(self)
|
1573
|
+
|
1574
|
+
self.logger.log_connection(host, port, final_user, database or "default", success=True)
|
1575
|
+
|
1576
|
+
# Setup connection hook if provided
|
1577
|
+
if on_connect:
|
1578
|
+
self._setup_connection_hook(on_connect)
|
1579
|
+
# Execute the hook once immediately for the initial connection
|
1580
|
+
await self._execute_connection_hook_immediately(on_connect)
|
1581
|
+
|
1582
|
+
except Exception as e:
|
1583
|
+
self.logger.log_connection(host, port, final_user, database or "default", success=False)
|
1584
|
+
self.logger.log_error(e, context="Async connection")
|
1585
|
+
raise ConnectionError(f"Failed to connect to MatrixOne: {e}")
|
1586
|
+
|
1587
|
+
def _setup_connection_hook(
|
1588
|
+
self, on_connect: Union[ConnectionHook, List[Union[ConnectionAction, str]], Callable]
|
1589
|
+
) -> None:
|
1590
|
+
"""Setup connection hook to be executed on each new connection"""
|
1591
|
+
try:
|
1592
|
+
if isinstance(on_connect, ConnectionHook):
|
1593
|
+
# Direct ConnectionHook instance
|
1594
|
+
hook = on_connect
|
1595
|
+
elif isinstance(on_connect, list):
|
1596
|
+
# List of actions - create a hook
|
1597
|
+
hook = create_connection_hook(actions=on_connect)
|
1598
|
+
elif callable(on_connect):
|
1599
|
+
# Custom callback function
|
1600
|
+
hook = create_connection_hook(custom_hook=on_connect)
|
1601
|
+
else:
|
1602
|
+
self.logger.warning(f"Invalid on_connect parameter type: {type(on_connect)}")
|
1603
|
+
return
|
1604
|
+
|
1605
|
+
# Set the client reference and attach to engine
|
1606
|
+
hook.set_client(self)
|
1607
|
+
hook.attach_to_engine(self._engine)
|
1608
|
+
|
1609
|
+
except Exception as e:
|
1610
|
+
self.logger.warning(f"Connection hook setup failed: {e}")
|
1611
|
+
|
1612
|
+
async def _execute_connection_hook_immediately(
|
1613
|
+
self, on_connect: Union[ConnectionHook, List[Union[ConnectionAction, str]], Callable]
|
1614
|
+
) -> None:
|
1615
|
+
"""Execute connection hook immediately for the initial connection"""
|
1616
|
+
try:
|
1617
|
+
if isinstance(on_connect, ConnectionHook):
|
1618
|
+
# Direct ConnectionHook instance
|
1619
|
+
hook = on_connect
|
1620
|
+
elif isinstance(on_connect, list):
|
1621
|
+
# List of actions - create a hook
|
1622
|
+
hook = create_connection_hook(actions=on_connect)
|
1623
|
+
elif callable(on_connect):
|
1624
|
+
# Custom callback function
|
1625
|
+
hook = create_connection_hook(custom_hook=on_connect)
|
1626
|
+
else:
|
1627
|
+
self.logger.warning(f"Invalid on_connect parameter type: {type(on_connect)}")
|
1628
|
+
return
|
1629
|
+
|
1630
|
+
# Execute the hook immediately
|
1631
|
+
await hook.execute_async(self)
|
1632
|
+
|
1633
|
+
except Exception as e:
|
1634
|
+
self.logger.warning(f"Immediate connection hook execution failed: {e}")
|
1635
|
+
|
1636
|
+
@classmethod
|
1637
|
+
def from_engine(cls, engine: AsyncEngine, **kwargs) -> "AsyncClient":
|
1638
|
+
"""
|
1639
|
+
Create AsyncClient instance from existing SQLAlchemy AsyncEngine
|
1640
|
+
|
1641
|
+
Args::
|
1642
|
+
|
1643
|
+
engine: SQLAlchemy AsyncEngine instance (must use MySQL driver)
|
1644
|
+
**kwargs: Additional client configuration options
|
1645
|
+
|
1646
|
+
Returns::
|
1647
|
+
|
1648
|
+
AsyncClient: Configured async client instance
|
1649
|
+
|
1650
|
+
Raises::
|
1651
|
+
|
1652
|
+
ConnectionError: If engine doesn't use MySQL driver
|
1653
|
+
|
1654
|
+
Examples
|
1655
|
+
|
1656
|
+
Basic usage::
|
1657
|
+
|
1658
|
+
from sqlalchemy.ext.asyncio import create_async_engine
|
1659
|
+
from matrixone import AsyncClient
|
1660
|
+
|
1661
|
+
engine = create_async_engine("mysql+aiomysql://user:pass@host:port/db")
|
1662
|
+
client = AsyncClient.from_engine(engine)
|
1663
|
+
|
1664
|
+
With custom configuration::
|
1665
|
+
|
1666
|
+
engine = create_async_engine("mysql+aiomysql://user:pass@host:port/db")
|
1667
|
+
client = AsyncClient.from_engine(
|
1668
|
+
engine,
|
1669
|
+
sql_log_mode='auto',
|
1670
|
+
slow_query_threshold=0.5
|
1671
|
+
)
|
1672
|
+
"""
|
1673
|
+
# Check if engine uses MySQL driver
|
1674
|
+
if not cls._is_mysql_async_engine(engine):
|
1675
|
+
raise ConnectionError(
|
1676
|
+
"MatrixOne AsyncClient only supports MySQL drivers. "
|
1677
|
+
"Please use mysql+aiomysql:// connection strings. "
|
1678
|
+
f"Current engine uses: {engine.dialect.name}"
|
1679
|
+
)
|
1680
|
+
|
1681
|
+
# Create client instance with default parameters
|
1682
|
+
client = cls(**kwargs)
|
1683
|
+
|
1684
|
+
# Set the provided engine
|
1685
|
+
client._engine = engine
|
1686
|
+
|
1687
|
+
# Initialize vector managers after engine is set
|
1688
|
+
client._initialize_vector_managers()
|
1689
|
+
|
1690
|
+
return client
|
1691
|
+
|
1692
|
+
@staticmethod
|
1693
|
+
def _is_mysql_async_engine(engine: AsyncEngine) -> bool:
|
1694
|
+
"""
|
1695
|
+
Check if the async engine uses a MySQL driver
|
1696
|
+
|
1697
|
+
Args::
|
1698
|
+
|
1699
|
+
engine: SQLAlchemy AsyncEngine instance
|
1700
|
+
|
1701
|
+
Returns::
|
1702
|
+
|
1703
|
+
bool: True if engine uses MySQL driver, False otherwise
|
1704
|
+
"""
|
1705
|
+
# Check dialect name
|
1706
|
+
dialect_name = engine.dialect.name.lower()
|
1707
|
+
|
1708
|
+
# Check if it's a MySQL dialect
|
1709
|
+
if dialect_name == "mysql":
|
1710
|
+
return True
|
1711
|
+
|
1712
|
+
# Check connection string for MySQL async drivers
|
1713
|
+
url = str(engine.url)
|
1714
|
+
mysql_async_drivers = [
|
1715
|
+
"mysql+aiomysql",
|
1716
|
+
"mysql+asyncmy",
|
1717
|
+
"mysql+aiopg", # Note: aiopg is PostgreSQL, but included for completeness
|
1718
|
+
]
|
1719
|
+
|
1720
|
+
return any(driver in url.lower() for driver in mysql_async_drivers)
|
1721
|
+
|
1722
|
+
def _create_async_engine(self) -> AsyncEngine:
|
1723
|
+
"""Create SQLAlchemy async engine with connection pooling"""
|
1724
|
+
if not create_async_engine:
|
1725
|
+
raise ConnectionError("SQLAlchemy async engine not available. Please install sqlalchemy[asyncio]")
|
1726
|
+
|
1727
|
+
# Build connection string for async engine
|
1728
|
+
connection_string = (
|
1729
|
+
f"mysql+aiomysql://{self._connection_params['user']}:"
|
1730
|
+
f"{self._connection_params['password']}@"
|
1731
|
+
f"{self._connection_params['host']}:"
|
1732
|
+
f"{self._connection_params['port']}/"
|
1733
|
+
f"{self._connection_params['database'] or ''}"
|
1734
|
+
f"?charset={self._connection_params['charset']}"
|
1735
|
+
)
|
1736
|
+
|
1737
|
+
# Create async engine with connection pooling
|
1738
|
+
engine = create_async_engine(
|
1739
|
+
connection_string,
|
1740
|
+
pool_size=5, # Smaller pool size for testing
|
1741
|
+
max_overflow=10, # Smaller max overflow
|
1742
|
+
pool_timeout=30, # Default pool timeout
|
1743
|
+
pool_recycle=3600, # Recycle connections after 1 hour
|
1744
|
+
pool_pre_ping=True, # Verify connections before use
|
1745
|
+
pool_reset_on_return="commit", # Reset connections on return
|
1746
|
+
echo=False, # Set to True for SQL logging
|
1747
|
+
)
|
1748
|
+
|
1749
|
+
return engine
|
1750
|
+
|
1751
|
+
def _initialize_vector_managers(self) -> None:
|
1752
|
+
"""Initialize vector managers after successful connection"""
|
1753
|
+
try:
|
1754
|
+
from .async_vector_index_manager import AsyncVectorManager
|
1755
|
+
|
1756
|
+
self._vector = AsyncVectorManager(self)
|
1757
|
+
self._fulltext_index = AsyncFulltextIndexManager(self)
|
1758
|
+
except ImportError:
|
1759
|
+
# Vector managers not available
|
1760
|
+
self._vector = None
|
1761
|
+
self._fulltext_index = None
|
1762
|
+
|
1763
|
+
async def disconnect(self):
|
1764
|
+
"""Disconnect from MatrixOne database asynchronously"""
|
1765
|
+
if self._engine:
|
1766
|
+
try:
|
1767
|
+
# First, try to close any active connections in the pool
|
1768
|
+
if hasattr(self._engine, "_pool") and self._engine._pool is not None:
|
1769
|
+
# Close all connections in the pool
|
1770
|
+
await self._engine._pool.close()
|
1771
|
+
# Wait for all connections to be properly closed
|
1772
|
+
await self._engine._pool.wait_closed()
|
1773
|
+
|
1774
|
+
# Then dispose of the engine - this closes all connections in the pool
|
1775
|
+
await self._engine.dispose()
|
1776
|
+
|
1777
|
+
# Force garbage collection to ensure connections are cleaned up
|
1778
|
+
import gc
|
1779
|
+
|
1780
|
+
gc.collect()
|
1781
|
+
|
1782
|
+
self.logger.log_disconnection(success=True)
|
1783
|
+
except Exception as e:
|
1784
|
+
self.logger.log_disconnection(success=False)
|
1785
|
+
self.logger.log_error(e, context="Async disconnection")
|
1786
|
+
finally:
|
1787
|
+
# Ensure all references are cleared
|
1788
|
+
self._engine = None
|
1789
|
+
self._connection = None
|
1790
|
+
# Clear any cached managers
|
1791
|
+
self._fulltext_index = None
|
1792
|
+
|
1793
|
+
def disconnect_sync(self):
|
1794
|
+
"""Synchronous disconnect for cleanup when event loop is closed"""
|
1795
|
+
if self._engine:
|
1796
|
+
try:
|
1797
|
+
# Try to close the connection pool synchronously
|
1798
|
+
# SQLAlchemy AsyncEngine has a sync_engine property that can be disposed
|
1799
|
+
if hasattr(self._engine, "sync_engine") and self._engine.sync_engine is not None:
|
1800
|
+
# Close all connections in the pool synchronously
|
1801
|
+
self._engine.sync_engine.dispose()
|
1802
|
+
elif hasattr(self._engine, "_pool") and self._engine._pool is not None:
|
1803
|
+
# Direct access to connection pool for cleanup
|
1804
|
+
try:
|
1805
|
+
self._engine._pool.close()
|
1806
|
+
except Exception:
|
1807
|
+
pass
|
1808
|
+
|
1809
|
+
self.logger.log_disconnection(success=True)
|
1810
|
+
except Exception as e:
|
1811
|
+
self.logger.log_disconnection(success=False)
|
1812
|
+
self.logger.log_error(e, context="Sync disconnection")
|
1813
|
+
finally:
|
1814
|
+
# Ensure all references are cleared regardless of success/failure
|
1815
|
+
self._engine = None
|
1816
|
+
self._connection = None
|
1817
|
+
# Clear any cached managers
|
1818
|
+
self._fulltext_index = None
|
1819
|
+
|
1820
|
+
def __del__(self):
|
1821
|
+
"""Cleanup when object is garbage collected"""
|
1822
|
+
# Don't try to cleanup in __del__ as it can cause issues with event loops
|
1823
|
+
# The fixture should handle proper cleanup
|
1824
|
+
pass
|
1825
|
+
|
1826
|
+
def get_sqlalchemy_engine(self) -> AsyncEngine:
|
1827
|
+
"""
|
1828
|
+
Get SQLAlchemy async engine
|
1829
|
+
|
1830
|
+
Returns::
|
1831
|
+
|
1832
|
+
SQLAlchemy AsyncEngine
|
1833
|
+
"""
|
1834
|
+
if not self._engine:
|
1835
|
+
raise ConnectionError("Not connected to database")
|
1836
|
+
return self._engine
|
1837
|
+
|
1838
|
+
async def create_all(self, base_class=None):
|
1839
|
+
"""
|
1840
|
+
Create all tables defined in the given base class or default Base.
|
1841
|
+
|
1842
|
+
Args::
|
1843
|
+
|
1844
|
+
base_class: SQLAlchemy declarative base class. If None, uses the default Base.
|
1845
|
+
"""
|
1846
|
+
if base_class is None:
|
1847
|
+
from matrixone.orm import declarative_base
|
1848
|
+
|
1849
|
+
base_class = declarative_base()
|
1850
|
+
|
1851
|
+
async with self._engine.begin() as conn:
|
1852
|
+
await conn.run_sync(base_class.metadata.create_all)
|
1853
|
+
return self
|
1854
|
+
|
1855
|
+
async def drop_all(self, base_class=None):
|
1856
|
+
"""
|
1857
|
+
Drop all tables defined in the given base class or default Base.
|
1858
|
+
|
1859
|
+
Args::
|
1860
|
+
|
1861
|
+
base_class: SQLAlchemy declarative base class. If None, uses the default Base.
|
1862
|
+
"""
|
1863
|
+
if base_class is None:
|
1864
|
+
from matrixone.orm import declarative_base
|
1865
|
+
|
1866
|
+
base_class = declarative_base()
|
1867
|
+
|
1868
|
+
# Get all table names from the metadata
|
1869
|
+
table_names = list(base_class.metadata.tables.keys())
|
1870
|
+
|
1871
|
+
# Drop each table individually using direct SQL for better compatibility
|
1872
|
+
for table_name in table_names:
|
1873
|
+
try:
|
1874
|
+
await self.execute(f"DROP TABLE IF EXISTS {table_name}")
|
1875
|
+
except Exception as e:
|
1876
|
+
# Log the error but continue with other tables
|
1877
|
+
print(f"Warning: Failed to drop table {table_name}: {e}")
|
1878
|
+
|
1879
|
+
return self
|
1880
|
+
|
1881
|
+
async def _execute_with_logging(
|
1882
|
+
self, connection, sql: str, context: str = "Async SQL execution", override_sql_log_mode: str = None
|
1883
|
+
):
|
1884
|
+
"""
|
1885
|
+
Execute SQL asynchronously with proper logging through the client's logger.
|
1886
|
+
|
1887
|
+
This is an internal helper method used by all SDK components to ensure
|
1888
|
+
consistent SQL logging across async vector operations, transactions, and other features.
|
1889
|
+
|
1890
|
+
Args::
|
1891
|
+
|
1892
|
+
connection: SQLAlchemy async connection object
|
1893
|
+
sql: SQL query string
|
1894
|
+
context: Context description for error logging (default: "Async SQL execution")
|
1895
|
+
override_sql_log_mode: Temporarily override sql_log_mode for this query only
|
1896
|
+
|
1897
|
+
Returns::
|
1898
|
+
|
1899
|
+
SQLAlchemy result object
|
1900
|
+
|
1901
|
+
Note:
|
1902
|
+
|
1903
|
+
This method is used internally by AsyncVectorManager, AsyncTransactionWrapper,
|
1904
|
+
and other SDK components. External users should use execute() instead.
|
1905
|
+
"""
|
1906
|
+
import time
|
1907
|
+
from sqlalchemy import text
|
1908
|
+
|
1909
|
+
start_time = time.time()
|
1910
|
+
try:
|
1911
|
+
result = await connection.execute(text(sql))
|
1912
|
+
execution_time = time.time() - start_time
|
1913
|
+
|
1914
|
+
# Try to get row count if available
|
1915
|
+
try:
|
1916
|
+
if result.returns_rows:
|
1917
|
+
# For SELECT queries, we can't consume the result to count rows
|
1918
|
+
# So we just log without row count
|
1919
|
+
self.logger.log_query(
|
1920
|
+
sql, execution_time, None, success=True, override_sql_log_mode=override_sql_log_mode
|
1921
|
+
)
|
1922
|
+
else:
|
1923
|
+
# For DML queries (INSERT/UPDATE/DELETE), we can get rowcount
|
1924
|
+
self.logger.log_query(
|
1925
|
+
sql, execution_time, result.rowcount, success=True, override_sql_log_mode=override_sql_log_mode
|
1926
|
+
)
|
1927
|
+
except Exception:
|
1928
|
+
# Fallback: just log the query without row count
|
1929
|
+
self.logger.log_query(sql, execution_time, None, success=True, override_sql_log_mode=override_sql_log_mode)
|
1930
|
+
|
1931
|
+
return result
|
1932
|
+
except Exception as e:
|
1933
|
+
execution_time = time.time() - start_time
|
1934
|
+
self.logger.log_query(sql, execution_time, success=False, override_sql_log_mode=override_sql_log_mode)
|
1935
|
+
self.logger.log_error(e, context=context)
|
1936
|
+
raise
|
1937
|
+
|
1938
|
+
async def execute(self, sql: str, params: Optional[Tuple] = None) -> AsyncResultSet:
|
1939
|
+
"""
|
1940
|
+
Execute SQL query asynchronously using SQLAlchemy async engine
|
1941
|
+
|
1942
|
+
Args::
|
1943
|
+
|
1944
|
+
sql: SQL query string
|
1945
|
+
params: Query parameters
|
1946
|
+
|
1947
|
+
Returns::
|
1948
|
+
|
1949
|
+
AsyncResultSet with query results
|
1950
|
+
"""
|
1951
|
+
if not self._engine:
|
1952
|
+
raise ConnectionError("Not connected to database")
|
1953
|
+
|
1954
|
+
import time
|
1955
|
+
|
1956
|
+
start_time = time.time()
|
1957
|
+
|
1958
|
+
try:
|
1959
|
+
# Handle parameter substitution for MatrixOne compatibility
|
1960
|
+
final_sql = self._substitute_parameters(sql, params)
|
1961
|
+
|
1962
|
+
async with self._engine.begin() as conn:
|
1963
|
+
# Use exec_driver_sql() to bypass SQLAlchemy's bind parameter parsing
|
1964
|
+
# This prevents JSON strings like {"a":1} from being parsed as :1 bind params
|
1965
|
+
if hasattr(conn, 'exec_driver_sql'):
|
1966
|
+
# Escape % to %% for pymysql's format string handling
|
1967
|
+
escaped_sql = final_sql.replace('%', '%%')
|
1968
|
+
result = await conn.exec_driver_sql(escaped_sql)
|
1969
|
+
else:
|
1970
|
+
# Fallback for testing or older SQLAlchemy versions
|
1971
|
+
from sqlalchemy import text
|
1972
|
+
|
1973
|
+
result = await conn.execute(text(final_sql))
|
1974
|
+
|
1975
|
+
execution_time = time.time() - start_time
|
1976
|
+
|
1977
|
+
if result.returns_rows:
|
1978
|
+
rows = result.fetchall()
|
1979
|
+
columns = list(result.keys()) if hasattr(result, "keys") else []
|
1980
|
+
async_result = AsyncResultSet(columns, rows)
|
1981
|
+
self.logger.log_query(final_sql, execution_time, len(rows), success=True)
|
1982
|
+
return async_result
|
1983
|
+
else:
|
1984
|
+
async_result = AsyncResultSet([], [], affected_rows=result.rowcount)
|
1985
|
+
self.logger.log_query(final_sql, execution_time, result.rowcount, success=True)
|
1986
|
+
return async_result
|
1987
|
+
|
1988
|
+
except Exception as e:
|
1989
|
+
execution_time = time.time() - start_time
|
1990
|
+
|
1991
|
+
# Log error FIRST, before any error processing
|
1992
|
+
# Wrap in try-except to ensure logging failure doesn't hide the real error
|
1993
|
+
try:
|
1994
|
+
self.logger.log_query(final_sql, execution_time, success=False)
|
1995
|
+
self.logger.log_error(e, context="Async query execution")
|
1996
|
+
except Exception as log_err:
|
1997
|
+
# If logging fails, print to stderr as fallback but continue with error handling
|
1998
|
+
import sys
|
1999
|
+
|
2000
|
+
print(f"Warning: Error logging failed: {log_err}", file=sys.stderr)
|
2001
|
+
|
2002
|
+
# Extract user-friendly error message
|
2003
|
+
error_msg = str(e)
|
2004
|
+
|
2005
|
+
# Handle common database errors with helpful messages
|
2006
|
+
# Check for "does not exist" first before "syntax error"
|
2007
|
+
if (
|
2008
|
+
'does not exist' in error_msg.lower()
|
2009
|
+
or 'no such table' in error_msg.lower()
|
2010
|
+
or 'doesn\'t exist' in error_msg.lower()
|
2011
|
+
):
|
2012
|
+
# Table doesn't exist
|
2013
|
+
import re
|
2014
|
+
|
2015
|
+
match = re.search(r"(?:table|database)\s+[\"']?(\w+)[\"']?\s+does not exist", error_msg, re.IGNORECASE)
|
2016
|
+
if match:
|
2017
|
+
obj_name = match.group(1)
|
2018
|
+
raise QueryError(
|
2019
|
+
f"Table or database '{obj_name}' does not exist. "
|
2020
|
+
f"Create it first using client.create_table() or CREATE TABLE/DATABASE statement."
|
2021
|
+
) from None
|
2022
|
+
else:
|
2023
|
+
raise QueryError(f"Object not found: {error_msg}") from None
|
2024
|
+
|
2025
|
+
elif 'already exists' in error_msg.lower() and '1050' in error_msg:
|
2026
|
+
# Table already exists
|
2027
|
+
import re
|
2028
|
+
|
2029
|
+
match = re.search(r"table\s+(\w+)\s+already\s+exists", error_msg, re.IGNORECASE)
|
2030
|
+
if match:
|
2031
|
+
table_name = match.group(1)
|
2032
|
+
raise QueryError(
|
2033
|
+
f"Table '{table_name}' already exists. "
|
2034
|
+
f"Use DROP TABLE {table_name} or client.drop_table() to remove it first."
|
2035
|
+
) from None
|
2036
|
+
else:
|
2037
|
+
raise QueryError(f"Object already exists: {error_msg}") from None
|
2038
|
+
|
2039
|
+
elif 'duplicate' in error_msg.lower() and ('1062' in error_msg or '1061' in error_msg):
|
2040
|
+
# Duplicate key/entry
|
2041
|
+
raise QueryError(
|
2042
|
+
f"Duplicate entry error: {error_msg}. "
|
2043
|
+
f"Check for duplicate primary key or unique constraint violations."
|
2044
|
+
) from None
|
2045
|
+
|
2046
|
+
elif 'syntax error' in error_msg.lower() or '1064' in error_msg:
|
2047
|
+
# SQL syntax error
|
2048
|
+
sql_preview = final_sql[:200] + '...' if len(final_sql) > 200 else final_sql
|
2049
|
+
raise QueryError(f"SQL syntax error: {error_msg}\n" f"Query: {sql_preview}") from None
|
2050
|
+
|
2051
|
+
elif 'column' in error_msg.lower() and ('unknown' in error_msg.lower() or 'not found' in error_msg.lower()):
|
2052
|
+
# Column doesn't exist
|
2053
|
+
raise QueryError(f"Column not found: {error_msg}. " f"Check your column names and table schema.") from None
|
2054
|
+
|
2055
|
+
elif 'cannot be null' in error_msg.lower() or '1048' in error_msg:
|
2056
|
+
# NULL constraint violation
|
2057
|
+
raise QueryError(
|
2058
|
+
f"NULL constraint violation: {error_msg}. " f"Some columns require non-NULL values."
|
2059
|
+
) from None
|
2060
|
+
|
2061
|
+
elif 'not supported' in error_msg.lower() and '20105' in error_msg:
|
2062
|
+
# MatrixOne-specific: feature not supported
|
2063
|
+
raise QueryError(
|
2064
|
+
f"MatrixOne feature limitation: {error_msg}. "
|
2065
|
+
f"This feature may require additional configuration or is not yet supported."
|
2066
|
+
) from None
|
2067
|
+
|
2068
|
+
elif 'bind parameter' in error_msg.lower() or 'InvalidRequestError' in error_msg:
|
2069
|
+
# SQLAlchemy bind parameter error
|
2070
|
+
raise QueryError(
|
2071
|
+
f"Parameter binding error: {error_msg}. "
|
2072
|
+
f"This might be caused by special characters in your data (colons in JSON, etc.)"
|
2073
|
+
) from None
|
2074
|
+
|
2075
|
+
else:
|
2076
|
+
# Generic error - cleaner message without full SQLAlchemy stack
|
2077
|
+
raise QueryError(f"Query execution failed: {error_msg}") from None
|
2078
|
+
|
2079
|
+
def _substitute_parameters(self, sql: str, params: Optional[Tuple] = None) -> str:
|
2080
|
+
"""
|
2081
|
+
Substitute ? placeholders with actual values since MatrixOne doesn't support prepared statements
|
2082
|
+
|
2083
|
+
Args::
|
2084
|
+
|
2085
|
+
sql: SQL query string with ? placeholders
|
2086
|
+
params: Tuple of parameter values
|
2087
|
+
|
2088
|
+
Returns::
|
2089
|
+
|
2090
|
+
SQL string with parameters substituted
|
2091
|
+
"""
|
2092
|
+
if not params:
|
2093
|
+
return sql
|
2094
|
+
|
2095
|
+
final_sql = sql
|
2096
|
+
for param in params:
|
2097
|
+
if isinstance(param, str):
|
2098
|
+
# Escape single quotes in string values
|
2099
|
+
escaped_param = param.replace("'", "''")
|
2100
|
+
final_sql = final_sql.replace("?", f"'{escaped_param}'", 1)
|
2101
|
+
elif param is None:
|
2102
|
+
final_sql = final_sql.replace("?", "NULL", 1)
|
2103
|
+
else:
|
2104
|
+
final_sql = final_sql.replace("?", str(param), 1)
|
2105
|
+
|
2106
|
+
return final_sql
|
2107
|
+
|
2108
|
+
def _build_login_info(self, user: str, account: Optional[str] = None, role: Optional[str] = None) -> tuple[str, dict]:
|
2109
|
+
"""
|
2110
|
+
Build final login info based on user parameter and optional account/role
|
2111
|
+
|
2112
|
+
Args::
|
2113
|
+
|
2114
|
+
user: Username or login info in format "user", "account#user", or "account#user#role"
|
2115
|
+
account: Optional account name
|
2116
|
+
role: Optional role name
|
2117
|
+
|
2118
|
+
Returns::
|
2119
|
+
|
2120
|
+
tuple: (final_user_string, parsed_info_dict)
|
2121
|
+
|
2122
|
+
Rules:
|
2123
|
+
1. If user contains '#', it's already in format "account#user" or "account#user#role"
|
2124
|
+
- If account or role is also provided, raise error (conflict)
|
2125
|
+
2. If user doesn't contain '#', combine with optional account/role:
|
2126
|
+
- No account/role: use user as-is
|
2127
|
+
- Only role: use "sys#user#role"
|
2128
|
+
- Only account: use "account#user"
|
2129
|
+
- Both: use "account#user#role"
|
2130
|
+
"""
|
2131
|
+
# Check if user already contains login format
|
2132
|
+
if "#" in user:
|
2133
|
+
# User is already in format "account#user" or "account#user#role"
|
2134
|
+
if account is not None or role is not None:
|
2135
|
+
raise ValueError(
|
2136
|
+
f"Conflict: user parameter '{user}' already contains account/role info, "
|
2137
|
+
f"but account='{account}' and role='{role}' are also provided. "
|
2138
|
+
f"Use either user format or separate account/role parameters, not both."
|
2139
|
+
)
|
2140
|
+
|
2141
|
+
# Parse the existing format
|
2142
|
+
parts = user.split("#")
|
2143
|
+
if len(parts) == 2:
|
2144
|
+
# "account#user" format
|
2145
|
+
final_account, final_user, final_role = parts[0], parts[1], None
|
2146
|
+
elif len(parts) == 3:
|
2147
|
+
# "account#user#role" format
|
2148
|
+
final_account, final_user, final_role = parts[0], parts[1], parts[2]
|
2149
|
+
else:
|
2150
|
+
raise ValueError(f"Invalid user format: '{user}'. Expected 'user', 'account#user', or 'account#user#role'")
|
2151
|
+
|
2152
|
+
final_user_string = user
|
2153
|
+
|
2154
|
+
else:
|
2155
|
+
# User is just a username, combine with optional account/role
|
2156
|
+
if account is None and role is None:
|
2157
|
+
# No account/role provided, use user as-is
|
2158
|
+
final_account, final_user, final_role = "sys", user, None
|
2159
|
+
final_user_string = user
|
2160
|
+
elif account is None and role is not None:
|
2161
|
+
# Only role provided, use sys account
|
2162
|
+
final_account, final_user, final_role = "sys", user, role
|
2163
|
+
final_user_string = f"sys#{user}#{role}"
|
2164
|
+
elif account is not None and role is None:
|
2165
|
+
# Only account provided, no role
|
2166
|
+
final_account, final_user, final_role = account, user, None
|
2167
|
+
final_user_string = f"{account}#{user}"
|
2168
|
+
else:
|
2169
|
+
# Both account and role provided
|
2170
|
+
final_account, final_user, final_role = account, user, role
|
2171
|
+
final_user_string = f"{account}#{user}#{role}"
|
2172
|
+
|
2173
|
+
parsed_info = {"account": final_account, "user": final_user, "role": final_role}
|
2174
|
+
|
2175
|
+
return final_user_string, parsed_info
|
2176
|
+
|
2177
|
+
def get_login_info(self) -> Optional[dict]:
|
2178
|
+
"""Get parsed login information"""
|
2179
|
+
return self._login_info
|
2180
|
+
|
2181
|
+
def _escape_identifier(self, identifier: str) -> str:
|
2182
|
+
"""Escapes an identifier to prevent SQL injection."""
|
2183
|
+
return f"`{identifier}`"
|
2184
|
+
|
2185
|
+
def _escape_string(self, value: str) -> str:
|
2186
|
+
"""Escapes a string value for SQL queries."""
|
2187
|
+
return f"'{value}'"
|
2188
|
+
|
2189
|
+
def query(self, *columns, snapshot: str = None):
|
2190
|
+
"""Get async MatrixOne query builder - SQLAlchemy style
|
2191
|
+
|
2192
|
+
Args::
|
2193
|
+
|
2194
|
+
*columns: Can be:
|
2195
|
+
- Single model class: query(Article) - returns all columns from model
|
2196
|
+
- Multiple columns: query(Article.id, Article.title) - returns specific columns
|
2197
|
+
- Mixed: query(Article, Article.id, some_expression.label('alias')) - model + additional columns
|
2198
|
+
snapshot: Optional snapshot name for snapshot queries
|
2199
|
+
|
2200
|
+
Examples
|
2201
|
+
|
2202
|
+
# Traditional model query (all columns)
|
2203
|
+
await client.query(Article).filter(...).all()
|
2204
|
+
|
2205
|
+
# Column-specific query
|
2206
|
+
await client.query(Article.id, Article.title).filter(...).all()
|
2207
|
+
|
2208
|
+
# With fulltext score
|
2209
|
+
await client.query(Article.id, boolean_match("title", "content").must("python").label("score"))
|
2210
|
+
|
2211
|
+
# Snapshot query
|
2212
|
+
await client.query(Article, snapshot="my_snapshot").filter(...).all()
|
2213
|
+
|
2214
|
+
Returns::
|
2215
|
+
|
2216
|
+
AsyncMatrixOneQuery instance configured for the specified columns
|
2217
|
+
"""
|
2218
|
+
from .async_orm import AsyncMatrixOneQuery
|
2219
|
+
|
2220
|
+
if len(columns) == 1:
|
2221
|
+
# Traditional single model class usage
|
2222
|
+
column = columns[0]
|
2223
|
+
if isinstance(column, str):
|
2224
|
+
# String table name
|
2225
|
+
return AsyncMatrixOneQuery(column, self, None, None, snapshot)
|
2226
|
+
elif hasattr(column, '__tablename__'):
|
2227
|
+
# This is a model class
|
2228
|
+
return AsyncMatrixOneQuery(column, self, None, None, snapshot)
|
2229
|
+
elif hasattr(column, 'name') and hasattr(column, 'as_sql'):
|
2230
|
+
# This is a CTE object
|
2231
|
+
from .orm import CTE
|
2232
|
+
|
2233
|
+
if isinstance(column, CTE):
|
2234
|
+
query = AsyncMatrixOneQuery(None, self, None, None, snapshot)
|
2235
|
+
query._table_name = column.name
|
2236
|
+
query._select_columns = ["*"] # Default to select all from CTE
|
2237
|
+
query._ctes = [column] # Add the CTE to the query
|
2238
|
+
return query
|
2239
|
+
else:
|
2240
|
+
# This is a single column/expression - need to handle specially
|
2241
|
+
# For now, we'll create a query that can handle column selections
|
2242
|
+
query = AsyncMatrixOneQuery(None, self, None, None, snapshot)
|
2243
|
+
query._select_columns = [column]
|
2244
|
+
# Try to infer table name from column
|
2245
|
+
if hasattr(column, 'table') and hasattr(column.table, 'name'):
|
2246
|
+
query._table_name = column.table.name
|
2247
|
+
return query
|
2248
|
+
else:
|
2249
|
+
# Multiple columns/expressions
|
2250
|
+
model_class = None
|
2251
|
+
select_columns = []
|
2252
|
+
|
2253
|
+
for column in columns:
|
2254
|
+
if hasattr(column, '__tablename__'):
|
2255
|
+
# This is a model class - use its table
|
2256
|
+
model_class = column
|
2257
|
+
else:
|
2258
|
+
# This is a column or expression
|
2259
|
+
select_columns.append(column)
|
2260
|
+
|
2261
|
+
if model_class:
|
2262
|
+
query = AsyncMatrixOneQuery(model_class, self, None, None, snapshot)
|
2263
|
+
if select_columns:
|
2264
|
+
# Add additional columns to the model's default columns
|
2265
|
+
query._select_columns = select_columns
|
2266
|
+
return query
|
2267
|
+
else:
|
2268
|
+
# No model class provided, need to infer table from columns
|
2269
|
+
query = AsyncMatrixOneQuery(None, self, None, None, snapshot)
|
2270
|
+
query._select_columns = select_columns
|
2271
|
+
|
2272
|
+
# Try to infer table name from first column that has table info
|
2273
|
+
for col in select_columns:
|
2274
|
+
if hasattr(col, 'table') and hasattr(col.table, 'name'):
|
2275
|
+
query._table_name = col.table.name
|
2276
|
+
break
|
2277
|
+
elif isinstance(col, str) and '.' in col:
|
2278
|
+
# String column like "table.column" - extract table name
|
2279
|
+
parts = col.split('.')
|
2280
|
+
if len(parts) >= 2:
|
2281
|
+
# For "db.table.column" format, use "db.table"
|
2282
|
+
# For "table.column" format, use "table"
|
2283
|
+
table_name = '.'.join(parts[:-1])
|
2284
|
+
query._table_name = table_name
|
2285
|
+
break
|
2286
|
+
|
2287
|
+
return query
|
2288
|
+
|
2289
|
+
@asynccontextmanager
|
2290
|
+
async def snapshot(self, snapshot_name: str):
|
2291
|
+
"""
|
2292
|
+
Snapshot context manager
|
2293
|
+
|
2294
|
+
Usage
|
2295
|
+
|
2296
|
+
async with client.snapshot("daily_backup") as snapshot_client:
|
2297
|
+
result = await snapshot_client.execute("SELECT * FROM users")
|
2298
|
+
"""
|
2299
|
+
if not self._engine:
|
2300
|
+
raise ConnectionError("Not connected to database")
|
2301
|
+
|
2302
|
+
# Create a snapshot client wrapper
|
2303
|
+
from .client import SnapshotClient
|
2304
|
+
|
2305
|
+
snapshot_client = SnapshotClient(self, snapshot_name)
|
2306
|
+
yield snapshot_client
|
2307
|
+
|
2308
|
+
async def insert(self, table_name_or_model, data: dict) -> "AsyncResultSet":
|
2309
|
+
"""
|
2310
|
+
Insert data into a table asynchronously.
|
2311
|
+
|
2312
|
+
Args::
|
2313
|
+
|
2314
|
+
table_name_or_model: Either a table name (str) or a SQLAlchemy model class
|
2315
|
+
data: Data to insert (dict with column names as keys)
|
2316
|
+
|
2317
|
+
Returns::
|
2318
|
+
|
2319
|
+
AsyncResultSet object
|
2320
|
+
"""
|
2321
|
+
# Handle model class input
|
2322
|
+
if hasattr(table_name_or_model, '__tablename__'):
|
2323
|
+
# It's a model class
|
2324
|
+
table_name = table_name_or_model.__tablename__
|
2325
|
+
else:
|
2326
|
+
# It's a table name string
|
2327
|
+
table_name = table_name_or_model
|
2328
|
+
|
2329
|
+
executor = AsyncClientExecutor(self)
|
2330
|
+
return await executor.insert(table_name, data)
|
2331
|
+
|
2332
|
+
async def batch_insert(self, table_name_or_model, data_list: list) -> "AsyncResultSet":
|
2333
|
+
"""
|
2334
|
+
Batch insert data into a table asynchronously.
|
2335
|
+
|
2336
|
+
Args::
|
2337
|
+
|
2338
|
+
table_name_or_model: Either a table name (str) or a SQLAlchemy model class
|
2339
|
+
data_list: List of data dictionaries to insert
|
2340
|
+
|
2341
|
+
Returns::
|
2342
|
+
|
2343
|
+
AsyncResultSet object
|
2344
|
+
"""
|
2345
|
+
# Handle model class input
|
2346
|
+
if hasattr(table_name_or_model, '__tablename__'):
|
2347
|
+
# It's a model class
|
2348
|
+
table_name = table_name_or_model.__tablename__
|
2349
|
+
else:
|
2350
|
+
# It's a table name string
|
2351
|
+
table_name = table_name_or_model
|
2352
|
+
|
2353
|
+
executor = AsyncClientExecutor(self)
|
2354
|
+
return await executor.batch_insert(table_name, data_list)
|
2355
|
+
|
2356
|
+
@asynccontextmanager
|
2357
|
+
async def transaction(self):
|
2358
|
+
"""
|
2359
|
+
Async transaction context manager
|
2360
|
+
|
2361
|
+
Usage
|
2362
|
+
|
2363
|
+
async with client.transaction() as tx:
|
2364
|
+
await tx.execute("INSERT INTO users ...")
|
2365
|
+
await tx.execute("UPDATE users ...")
|
2366
|
+
# Snapshot and clone operations within transaction
|
2367
|
+
await tx.snapshots.create("snap1", "table", database="db1", table="t1")
|
2368
|
+
await tx.clone.clone_database("target_db", "source_db")
|
2369
|
+
"""
|
2370
|
+
if not self._engine:
|
2371
|
+
raise ConnectionError("Not connected to database")
|
2372
|
+
|
2373
|
+
tx_wrapper = None
|
2374
|
+
try:
|
2375
|
+
# Use SQLAlchemy async engine for transaction
|
2376
|
+
async with self._engine.begin() as conn:
|
2377
|
+
tx_wrapper = AsyncTransactionWrapper(conn, self)
|
2378
|
+
yield tx_wrapper
|
2379
|
+
|
2380
|
+
except Exception as e:
|
2381
|
+
# Transaction will be automatically rolled back by SQLAlchemy
|
2382
|
+
raise e
|
2383
|
+
finally:
|
2384
|
+
# Clean up transaction wrapper
|
2385
|
+
if tx_wrapper:
|
2386
|
+
await tx_wrapper.close_sqlalchemy()
|
2387
|
+
|
2388
|
+
@property
|
2389
|
+
def snapshots(self) -> AsyncSnapshotManager:
|
2390
|
+
"""Get async snapshot manager"""
|
2391
|
+
return self._snapshots
|
2392
|
+
|
2393
|
+
@property
|
2394
|
+
def clone(self) -> AsyncCloneManager:
|
2395
|
+
"""Get async clone manager"""
|
2396
|
+
return self._clone
|
2397
|
+
|
2398
|
+
@property
|
2399
|
+
def moctl(self) -> AsyncMoCtlManager:
|
2400
|
+
"""Get async mo_ctl manager"""
|
2401
|
+
return self._moctl
|
2402
|
+
|
2403
|
+
@property
|
2404
|
+
def restore(self) -> AsyncRestoreManager:
|
2405
|
+
"""Get async restore manager"""
|
2406
|
+
return self._restore
|
2407
|
+
|
2408
|
+
@property
|
2409
|
+
def pitr(self) -> AsyncPitrManager:
|
2410
|
+
"""Get async PITR manager"""
|
2411
|
+
return self._pitr
|
2412
|
+
|
2413
|
+
@property
|
2414
|
+
def pubsub(self) -> AsyncPubSubManager:
|
2415
|
+
"""Get async publish-subscribe manager"""
|
2416
|
+
return self._pubsub
|
2417
|
+
|
2418
|
+
@property
|
2419
|
+
def account(self) -> AsyncAccountManager:
|
2420
|
+
"""Get async account manager"""
|
2421
|
+
return self._account
|
2422
|
+
|
2423
|
+
def connected(self) -> bool:
|
2424
|
+
"""Check if client is connected to database"""
|
2425
|
+
return self._engine is not None
|
2426
|
+
|
2427
|
+
@property
|
2428
|
+
def vector_ops(self):
|
2429
|
+
"""Get unified vector operations manager for vector operations (index and data)"""
|
2430
|
+
return self._vector
|
2431
|
+
|
2432
|
+
def get_pinecone_index(self, table_name_or_model, vector_column: str):
|
2433
|
+
"""
|
2434
|
+
Get a PineconeCompatibleIndex object for vector search operations.
|
2435
|
+
|
2436
|
+
This method creates a Pinecone-compatible vector search interface
|
2437
|
+
that automatically parses the table schema and vector index configuration.
|
2438
|
+
The primary key column is automatically detected, and all other columns
|
2439
|
+
except the vector column will be included as metadata.
|
2440
|
+
|
2441
|
+
Args::
|
2442
|
+
|
2443
|
+
table_name_or_model: Either a table name (str) or a SQLAlchemy model class
|
2444
|
+
vector_column: Name of the vector column
|
2445
|
+
|
2446
|
+
Returns::
|
2447
|
+
|
2448
|
+
PineconeCompatibleIndex object with Pinecone-compatible API
|
2449
|
+
|
2450
|
+
Example::
|
2451
|
+
|
2452
|
+
index = await client.get_pinecone_index("documents", "embedding")
|
2453
|
+
results = await index.query_async([0.1, 0.2, 0.3], top_k=5)
|
2454
|
+
for match in results.matches:
|
2455
|
+
print(f"ID: {match.id}, Score: {match.score}")
|
2456
|
+
"""
|
2457
|
+
from .search_vector_index import PineconeCompatibleIndex
|
2458
|
+
|
2459
|
+
# Handle model class input
|
2460
|
+
if hasattr(table_name_or_model, '__tablename__'):
|
2461
|
+
table_name = table_name_or_model.__tablename__
|
2462
|
+
else:
|
2463
|
+
table_name = table_name_or_model
|
2464
|
+
|
2465
|
+
return PineconeCompatibleIndex(
|
2466
|
+
client=self,
|
2467
|
+
table_name=table_name,
|
2468
|
+
vector_column=vector_column,
|
2469
|
+
)
|
2470
|
+
|
2471
|
+
@property
|
2472
|
+
def fulltext_index(self):
|
2473
|
+
"""Get fulltext index manager for fulltext index operations"""
|
2474
|
+
return self._fulltext_index
|
2475
|
+
|
2476
|
+
async def version(self) -> str:
|
2477
|
+
"""
|
2478
|
+
Get MatrixOne server version asynchronously
|
2479
|
+
|
2480
|
+
Returns::
|
2481
|
+
|
2482
|
+
str: MatrixOne server version string
|
2483
|
+
|
2484
|
+
Raises::
|
2485
|
+
|
2486
|
+
ConnectionError: If not connected to MatrixOne
|
2487
|
+
QueryError: If version query fails
|
2488
|
+
|
2489
|
+
Example
|
2490
|
+
|
2491
|
+
>>> client = AsyncClient()
|
2492
|
+
>>> await client.connect('localhost', 6001, 'root', '111', 'test')
|
2493
|
+
>>> version = await client.version()
|
2494
|
+
>>> print(f"MatrixOne version: {version}")
|
2495
|
+
"""
|
2496
|
+
if not self.connected():
|
2497
|
+
raise ConnectionError("Not connected to MatrixOne")
|
2498
|
+
|
2499
|
+
try:
|
2500
|
+
result = await self.execute("SELECT VERSION()")
|
2501
|
+
if result.rows:
|
2502
|
+
return result.rows[0][0]
|
2503
|
+
else:
|
2504
|
+
raise QueryError("Failed to get version information")
|
2505
|
+
except Exception as e:
|
2506
|
+
raise QueryError(f"Failed to get version: {e}")
|
2507
|
+
|
2508
|
+
async def git_version(self) -> str:
|
2509
|
+
"""
|
2510
|
+
Get MatrixOne git version information asynchronously
|
2511
|
+
|
2512
|
+
Returns::
|
2513
|
+
|
2514
|
+
str: MatrixOne git version string
|
2515
|
+
|
2516
|
+
Raises::
|
2517
|
+
|
2518
|
+
ConnectionError: If not connected to MatrixOne
|
2519
|
+
QueryError: If git version query fails
|
2520
|
+
|
2521
|
+
Example
|
2522
|
+
|
2523
|
+
>>> client = AsyncClient()
|
2524
|
+
>>> await client.connect('localhost', 6001, 'root', '111', 'test')
|
2525
|
+
>>> git_version = await client.git_version()
|
2526
|
+
>>> print(f"MatrixOne git version: {git_version}")
|
2527
|
+
"""
|
2528
|
+
if not self.connected():
|
2529
|
+
raise ConnectionError("Not connected to MatrixOne")
|
2530
|
+
|
2531
|
+
try:
|
2532
|
+
# Use MatrixOne's built-in git_version() function
|
2533
|
+
result = await self.execute("SELECT git_version()")
|
2534
|
+
if result.rows:
|
2535
|
+
return result.rows[0][0]
|
2536
|
+
else:
|
2537
|
+
raise QueryError("Failed to get git version information")
|
2538
|
+
except Exception as e:
|
2539
|
+
raise QueryError(f"Failed to get git version: {e}")
|
2540
|
+
|
2541
|
+
async def __aenter__(self):
|
2542
|
+
return self
|
2543
|
+
|
2544
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
2545
|
+
# Only disconnect if we're actually connected
|
2546
|
+
if self.connected():
|
2547
|
+
await self.disconnect()
|
2548
|
+
|
2549
|
+
async def create_table(self, table_name_or_model, columns: dict = None, **kwargs) -> "AsyncClient":
|
2550
|
+
"""
|
2551
|
+
Create a table asynchronously
|
2552
|
+
|
2553
|
+
Args::
|
2554
|
+
|
2555
|
+
table_name_or_model: Either a table name (str) or a SQLAlchemy model class
|
2556
|
+
columns: Dictionary mapping column names to their definitions (required if table_name_or_model is str)
|
2557
|
+
**kwargs: Additional table creation options
|
2558
|
+
|
2559
|
+
Returns::
|
2560
|
+
|
2561
|
+
AsyncClient: Self for chaining
|
2562
|
+
|
2563
|
+
Example
|
2564
|
+
|
2565
|
+
>>> await client.create_table("users", {
|
2566
|
+
... "id": "int primary key",
|
2567
|
+
... "name": "varchar(100)",
|
2568
|
+
... "email": "varchar(255)"
|
2569
|
+
... })
|
2570
|
+
"""
|
2571
|
+
if not self._engine:
|
2572
|
+
raise ConnectionError("Not connected to database")
|
2573
|
+
|
2574
|
+
# Handle model class input
|
2575
|
+
if hasattr(table_name_or_model, '__tablename__'):
|
2576
|
+
# It's a model class
|
2577
|
+
model_class = table_name_or_model
|
2578
|
+
table_name = model_class.__tablename__
|
2579
|
+
# Use SQLAlchemy metadata to create table
|
2580
|
+
async with self._engine.begin() as conn:
|
2581
|
+
await conn.run_sync(model_class.__table__.create)
|
2582
|
+
return self
|
2583
|
+
|
2584
|
+
# It's a table name string
|
2585
|
+
table_name = table_name_or_model
|
2586
|
+
if columns is None:
|
2587
|
+
raise ValueError("columns parameter is required when table_name_or_model is a string")
|
2588
|
+
|
2589
|
+
# Build CREATE TABLE SQL
|
2590
|
+
column_definitions = []
|
2591
|
+
for column_name, column_def in columns.items():
|
2592
|
+
# Ensure vecf32/vecf64 format is lowercase for consistency
|
2593
|
+
if column_def.lower().startswith("vecf32(") or column_def.lower().startswith("vecf64("):
|
2594
|
+
column_def = column_def.lower()
|
2595
|
+
|
2596
|
+
column_definitions.append(f"{column_name} {column_def}")
|
2597
|
+
|
2598
|
+
sql = f"CREATE TABLE {table_name} ({', '.join(column_definitions)})"
|
2599
|
+
|
2600
|
+
await self.execute(sql)
|
2601
|
+
return self
|
2602
|
+
|
2603
|
+
async def drop_table(self, table_name_or_model) -> "AsyncClient":
|
2604
|
+
"""
|
2605
|
+
Drop a table asynchronously
|
2606
|
+
|
2607
|
+
Args::
|
2608
|
+
|
2609
|
+
table_name_or_model: Either a table name (str) or a SQLAlchemy model class
|
2610
|
+
|
2611
|
+
Returns::
|
2612
|
+
|
2613
|
+
AsyncClient: Self for chaining
|
2614
|
+
|
2615
|
+
Example
|
2616
|
+
|
2617
|
+
>>> await client.drop_table("users")
|
2618
|
+
"""
|
2619
|
+
if not self._engine:
|
2620
|
+
raise ConnectionError("Not connected to database")
|
2621
|
+
|
2622
|
+
# Handle model class input
|
2623
|
+
if hasattr(table_name_or_model, '__tablename__'):
|
2624
|
+
# It's a model class
|
2625
|
+
table_name = table_name_or_model.__tablename__
|
2626
|
+
else:
|
2627
|
+
# It's a table name string
|
2628
|
+
table_name = table_name_or_model
|
2629
|
+
|
2630
|
+
sql = f"DROP TABLE IF EXISTS {table_name}"
|
2631
|
+
await self.execute(sql)
|
2632
|
+
return self
|
2633
|
+
|
2634
|
+
async def create_table_with_index(self, table_name: str, columns: dict, indexes: list = None, **kwargs) -> "AsyncClient":
|
2635
|
+
"""
|
2636
|
+
Create a table with indexes asynchronously
|
2637
|
+
|
2638
|
+
Args::
|
2639
|
+
|
2640
|
+
table_name: Name of the table to create
|
2641
|
+
columns: Dictionary mapping column names to their definitions
|
2642
|
+
indexes: List of index definitions
|
2643
|
+
**kwargs: Additional table creation options
|
2644
|
+
|
2645
|
+
Returns::
|
2646
|
+
|
2647
|
+
AsyncClient: Self for chaining
|
2648
|
+
|
2649
|
+
Example
|
2650
|
+
|
2651
|
+
>>> await client.create_table_with_index("users", {
|
2652
|
+
... "id": "int primary key",
|
2653
|
+
... "name": "varchar(100)",
|
2654
|
+
... "email": "varchar(255)"
|
2655
|
+
... }, [
|
2656
|
+
... {"name": "idx_name", "columns": ["name"]},
|
2657
|
+
... {"name": "idx_email", "columns": ["email"], "unique": True}
|
2658
|
+
... ])
|
2659
|
+
"""
|
2660
|
+
if not self._engine:
|
2661
|
+
raise ConnectionError("Not connected to database")
|
2662
|
+
|
2663
|
+
# Build CREATE TABLE SQL
|
2664
|
+
column_definitions = []
|
2665
|
+
for column_name, column_def in columns.items():
|
2666
|
+
column_definitions.append(f"{column_name} {column_def}")
|
2667
|
+
|
2668
|
+
sql = f"CREATE TABLE {table_name} ({', '.join(column_definitions)})"
|
2669
|
+
await self.execute(sql)
|
2670
|
+
|
2671
|
+
# Create indexes if provided
|
2672
|
+
if indexes:
|
2673
|
+
for index_def in indexes:
|
2674
|
+
index_name = index_def["name"]
|
2675
|
+
index_columns = ", ".join(index_def["columns"])
|
2676
|
+
unique = "UNIQUE " if index_def.get("unique", False) else ""
|
2677
|
+
index_sql = f"CREATE {unique}INDEX {index_name} ON {table_name} ({index_columns})"
|
2678
|
+
await self.execute(index_sql)
|
2679
|
+
|
2680
|
+
return self
|
2681
|
+
|
2682
|
+
async def create_table_orm(self, table_name: str, *columns, **kwargs) -> "AsyncClient":
|
2683
|
+
"""
|
2684
|
+
Create a table using SQLAlchemy ORM asynchronously
|
2685
|
+
|
2686
|
+
Args::
|
2687
|
+
|
2688
|
+
table_name: Name of the table to create
|
2689
|
+
*columns: SQLAlchemy column definitions
|
2690
|
+
**kwargs: Additional table creation options
|
2691
|
+
|
2692
|
+
Returns::
|
2693
|
+
|
2694
|
+
AsyncClient: Self for chaining
|
2695
|
+
|
2696
|
+
Example
|
2697
|
+
|
2698
|
+
>>> from sqlalchemy import Column, Integer, String
|
2699
|
+
>>> await client.create_table_orm("users",
|
2700
|
+
... Column("id", Integer, primary_key=True),
|
2701
|
+
... Column("name", String(100)),
|
2702
|
+
... Column("email", String(255))
|
2703
|
+
... )
|
2704
|
+
"""
|
2705
|
+
if not self._engine:
|
2706
|
+
raise ConnectionError("Not connected to database")
|
2707
|
+
|
2708
|
+
from sqlalchemy import MetaData, Table
|
2709
|
+
|
2710
|
+
metadata = MetaData()
|
2711
|
+
Table(table_name, metadata, *columns)
|
2712
|
+
|
2713
|
+
# Create the table
|
2714
|
+
async with self._engine.begin() as conn:
|
2715
|
+
await conn.run_sync(metadata.create_all)
|
2716
|
+
|
2717
|
+
return self
|
2718
|
+
|
2719
|
+
@property
|
2720
|
+
def metadata(self) -> Optional["AsyncMetadataManager"]:
|
2721
|
+
"""Get metadata manager for table metadata operations"""
|
2722
|
+
return self._metadata
|
2723
|
+
|
2724
|
+
|
2725
|
+
class AsyncTransactionVectorIndexManager(AsyncVectorManager):
|
2726
|
+
"""Async transaction-aware vector index manager"""
|
2727
|
+
|
2728
|
+
def __init__(self, client, transaction_wrapper):
|
2729
|
+
super().__init__(client)
|
2730
|
+
self.transaction_wrapper = transaction_wrapper
|
2731
|
+
|
2732
|
+
async def execute(self, sql: str, params: Optional[Tuple] = None) -> AsyncResultSet:
|
2733
|
+
"""Execute SQL within transaction"""
|
2734
|
+
return await self.transaction_wrapper.execute(sql, params)
|
2735
|
+
|
2736
|
+
async def get_ivf_stats(self, table_name_or_model, column_name: str = None) -> Dict[str, Any]:
|
2737
|
+
"""
|
2738
|
+
Get IVF index statistics for a table within transaction.
|
2739
|
+
|
2740
|
+
Args::
|
2741
|
+
|
2742
|
+
table_name_or_model: Either a table name (str) or a SQLAlchemy model class
|
2743
|
+
column_name: Name of the vector column (optional, will be inferred if not provided)
|
2744
|
+
|
2745
|
+
Returns::
|
2746
|
+
|
2747
|
+
Dict containing IVF index statistics including:
|
2748
|
+
- index_tables: Dictionary mapping table types to table names
|
2749
|
+
- distribution: Dictionary containing bucket distribution data
|
2750
|
+
- database: Database name
|
2751
|
+
- table_name: Table name
|
2752
|
+
- column_name: Vector column name
|
2753
|
+
|
2754
|
+
Raises::
|
2755
|
+
|
2756
|
+
Exception: If IVF index is not found or if there are errors retrieving stats
|
2757
|
+
|
2758
|
+
Examples
|
2759
|
+
|
2760
|
+
# Get stats for a table with vector column within transaction
|
2761
|
+
async with client.transaction() as tx:
|
2762
|
+
stats = await tx.vector_ops.get_ivf_stats("my_table", "embedding")
|
2763
|
+
print(f"Index tables: {stats['index_tables']}")
|
2764
|
+
print(f"Distribution: {stats['distribution']}")
|
2765
|
+
"""
|
2766
|
+
from sqlalchemy import text
|
2767
|
+
|
2768
|
+
# Handle model class input
|
2769
|
+
if hasattr(table_name_or_model, '__tablename__'):
|
2770
|
+
table_name = table_name_or_model.__tablename__
|
2771
|
+
else:
|
2772
|
+
table_name = table_name_or_model
|
2773
|
+
|
2774
|
+
# Get database name from connection params
|
2775
|
+
database = self.client._connection_params.get('database')
|
2776
|
+
if not database:
|
2777
|
+
raise Exception("No database connection found. Please connect to a database first.")
|
2778
|
+
|
2779
|
+
# If column_name is not provided, try to infer it
|
2780
|
+
if not column_name:
|
2781
|
+
# Query the table schema to find vector columns using transaction connection
|
2782
|
+
schema_sql = text(
|
2783
|
+
f"""
|
2784
|
+
SELECT column_name, data_type
|
2785
|
+
FROM information_schema.columns
|
2786
|
+
WHERE table_schema = '{database}'
|
2787
|
+
AND table_name = '{table_name}'
|
2788
|
+
AND (data_type LIKE '%VEC%' OR data_type LIKE '%vec%')
|
2789
|
+
"""
|
2790
|
+
)
|
2791
|
+
result = await self.transaction_wrapper.execute(schema_sql)
|
2792
|
+
vector_columns = result.fetchall()
|
2793
|
+
|
2794
|
+
if not vector_columns:
|
2795
|
+
raise Exception(f"No vector columns found in table {table_name}")
|
2796
|
+
elif len(vector_columns) == 1:
|
2797
|
+
column_name = vector_columns[0][0]
|
2798
|
+
else:
|
2799
|
+
# Multiple vector columns found, raise error asking user to specify
|
2800
|
+
column_names = [col[0] for col in vector_columns]
|
2801
|
+
raise Exception(
|
2802
|
+
f"Multiple vector columns found in table {table_name}: {column_names}. "
|
2803
|
+
f"Please specify the column_name parameter."
|
2804
|
+
)
|
2805
|
+
|
2806
|
+
# Get connection from transaction wrapper
|
2807
|
+
connection = self.transaction_wrapper.connection
|
2808
|
+
|
2809
|
+
# Get IVF index table names
|
2810
|
+
index_tables = await self._get_ivf_index_table_names(database, table_name, column_name, connection)
|
2811
|
+
|
2812
|
+
if not index_tables:
|
2813
|
+
raise Exception(f"No IVF index found for table {table_name}, column {column_name}")
|
2814
|
+
|
2815
|
+
# Get the entries table name for distribution analysis
|
2816
|
+
entries_table = index_tables.get('entries')
|
2817
|
+
if not entries_table:
|
2818
|
+
raise Exception("No entries table found in IVF index")
|
2819
|
+
|
2820
|
+
# Get bucket distribution
|
2821
|
+
distribution = await self._get_ivf_buckets_distribution(database, entries_table, connection)
|
2822
|
+
|
2823
|
+
return {
|
2824
|
+
'index_tables': index_tables,
|
2825
|
+
'distribution': distribution,
|
2826
|
+
'database': database,
|
2827
|
+
'table_name': table_name,
|
2828
|
+
'column_name': column_name,
|
2829
|
+
}
|
2830
|
+
|
2831
|
+
|
2832
|
+
class AsyncTransactionWrapper:
|
2833
|
+
"""Async transaction wrapper for executing queries within a transaction"""
|
2834
|
+
|
2835
|
+
def __init__(self, connection, client):
|
2836
|
+
self.connection = connection
|
2837
|
+
self.client = client
|
2838
|
+
# Create snapshot, clone, restore, PITR, pubsub, account, and fulltext managers that use this transaction
|
2839
|
+
self.snapshots = AsyncTransactionSnapshotManager(client, self)
|
2840
|
+
self.clone = AsyncTransactionCloneManager(client, self)
|
2841
|
+
self.restore = AsyncTransactionRestoreManager(client, self)
|
2842
|
+
self.pitr = AsyncTransactionPitrManager(client, self)
|
2843
|
+
self.pubsub = AsyncTransactionPubSubManager(client, self)
|
2844
|
+
self.account = AsyncTransactionAccountManager(self)
|
2845
|
+
self.vector_ops = AsyncTransactionVectorIndexManager(client, self)
|
2846
|
+
self.fulltext_index = AsyncTransactionFulltextIndexManager(client, self)
|
2847
|
+
# SQLAlchemy integration
|
2848
|
+
self._sqlalchemy_session = None
|
2849
|
+
self._sqlalchemy_engine = None
|
2850
|
+
|
2851
|
+
async def execute(self, sql: str, params: Optional[Tuple] = None) -> AsyncResultSet:
|
2852
|
+
"""Execute SQL within transaction asynchronously"""
|
2853
|
+
import time
|
2854
|
+
|
2855
|
+
start_time = time.time()
|
2856
|
+
|
2857
|
+
try:
|
2858
|
+
# Handle parameter substitution for MatrixOne compatibility
|
2859
|
+
final_sql = self.client._substitute_parameters(sql, params)
|
2860
|
+
# Use exec_driver_sql() to bypass SQLAlchemy's bind parameter parsing
|
2861
|
+
# This prevents JSON strings like {"a":1} from being parsed as :1 bind params
|
2862
|
+
if hasattr(self.connection, 'exec_driver_sql'):
|
2863
|
+
# Escape % to %% for pymysql's format string handling
|
2864
|
+
escaped_sql = final_sql.replace('%', '%%')
|
2865
|
+
result = await self.connection.exec_driver_sql(escaped_sql)
|
2866
|
+
else:
|
2867
|
+
# Fallback for testing or older SQLAlchemy versions
|
2868
|
+
from sqlalchemy import text
|
2869
|
+
|
2870
|
+
result = await self.connection.execute(text(final_sql))
|
2871
|
+
execution_time = time.time() - start_time
|
2872
|
+
|
2873
|
+
if result.returns_rows:
|
2874
|
+
rows = result.fetchall()
|
2875
|
+
columns = list(result.keys()) if hasattr(result, "keys") else []
|
2876
|
+
self.client.logger.log_query(sql, execution_time, len(rows), success=True)
|
2877
|
+
return AsyncResultSet(columns, rows)
|
2878
|
+
else:
|
2879
|
+
self.client.logger.log_query(sql, execution_time, result.rowcount, success=True)
|
2880
|
+
return AsyncResultSet([], [], affected_rows=result.rowcount)
|
2881
|
+
|
2882
|
+
except Exception as e:
|
2883
|
+
execution_time = time.time() - start_time
|
2884
|
+
self.client.logger.log_query(sql, execution_time, success=False)
|
2885
|
+
self.client.logger.log_error(e, context="Async transaction query execution")
|
2886
|
+
raise QueryError(f"Transaction query execution failed: {e}")
|
2887
|
+
|
2888
|
+
def get_connection(self):
|
2889
|
+
"""
|
2890
|
+
Get the underlying SQLAlchemy async connection for direct use
|
2891
|
+
|
2892
|
+
Returns::
|
2893
|
+
|
2894
|
+
SQLAlchemy AsyncConnection instance bound to this transaction
|
2895
|
+
"""
|
2896
|
+
return self.connection
|
2897
|
+
|
2898
|
+
async def get_sqlalchemy_session(self):
|
2899
|
+
"""
|
2900
|
+
Get async SQLAlchemy session that uses the same transaction asynchronously
|
2901
|
+
|
2902
|
+
Returns::
|
2903
|
+
|
2904
|
+
Async SQLAlchemy Session instance bound to this transaction
|
2905
|
+
"""
|
2906
|
+
if self._sqlalchemy_session is None:
|
2907
|
+
try:
|
2908
|
+
from sqlalchemy.ext.asyncio import (
|
2909
|
+
AsyncSession,
|
2910
|
+
async_sessionmaker,
|
2911
|
+
create_async_engine,
|
2912
|
+
)
|
2913
|
+
except ImportError:
|
2914
|
+
# Fallback for older SQLAlchemy versions
|
2915
|
+
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
2916
|
+
from sqlalchemy.orm import sessionmaker as sync_sessionmaker
|
2917
|
+
|
2918
|
+
# Create a simple async sessionmaker equivalent
|
2919
|
+
def async_sessionmaker(bind=None, **kwargs):
|
2920
|
+
# Remove class_ from kwargs if present to avoid conflicts
|
2921
|
+
kwargs.pop('class_', None)
|
2922
|
+
return sync_sessionmaker(bind=bind, class_=AsyncSession, **kwargs)
|
2923
|
+
|
2924
|
+
# Create engine using the same connection parameters
|
2925
|
+
if not self.client._connection_params:
|
2926
|
+
raise ConnectionError("Not connected to database")
|
2927
|
+
|
2928
|
+
connection_string = (
|
2929
|
+
f"mysql+aiomysql://{self.client._connection_params['user']}:"
|
2930
|
+
f"{self.client._connection_params['password']}@"
|
2931
|
+
f"{self.client._connection_params['host']}:"
|
2932
|
+
f"{self.client._connection_params['port']}/"
|
2933
|
+
f"{self.client._connection_params['database']}"
|
2934
|
+
)
|
2935
|
+
|
2936
|
+
# Create async engine that will use the same connection
|
2937
|
+
self._sqlalchemy_engine = create_async_engine(connection_string, pool_pre_ping=True, pool_recycle=300)
|
2938
|
+
|
2939
|
+
# Create async session factory
|
2940
|
+
AsyncSessionLocal = async_sessionmaker(bind=self._sqlalchemy_engine, class_=AsyncSession, expire_on_commit=False)
|
2941
|
+
self._sqlalchemy_session = AsyncSessionLocal()
|
2942
|
+
|
2943
|
+
# Begin async SQLAlchemy transaction
|
2944
|
+
await self._sqlalchemy_session.begin()
|
2945
|
+
|
2946
|
+
return self._sqlalchemy_session
|
2947
|
+
|
2948
|
+
async def commit_sqlalchemy(self):
|
2949
|
+
"""Commit async SQLAlchemy session asynchronously"""
|
2950
|
+
if self._sqlalchemy_session:
|
2951
|
+
await self._sqlalchemy_session.commit()
|
2952
|
+
|
2953
|
+
async def rollback_sqlalchemy(self):
|
2954
|
+
"""Rollback async SQLAlchemy session asynchronously"""
|
2955
|
+
if self._sqlalchemy_session:
|
2956
|
+
await self._sqlalchemy_session.rollback()
|
2957
|
+
|
2958
|
+
async def close_sqlalchemy(self):
|
2959
|
+
"""Close async SQLAlchemy session asynchronously"""
|
2960
|
+
if self._sqlalchemy_session:
|
2961
|
+
await self._sqlalchemy_session.close()
|
2962
|
+
self._sqlalchemy_session = None
|
2963
|
+
if self._sqlalchemy_engine:
|
2964
|
+
await self._sqlalchemy_engine.dispose()
|
2965
|
+
self._sqlalchemy_engine = None
|
2966
|
+
|
2967
|
+
async def insert(self, table_name_or_model, data: dict) -> AsyncResultSet:
|
2968
|
+
"""
|
2969
|
+
Insert data into a table within transaction asynchronously.
|
2970
|
+
|
2971
|
+
Args::
|
2972
|
+
|
2973
|
+
table_name_or_model: Either a table name (str) or a SQLAlchemy model class
|
2974
|
+
data: Data to insert (dict with column names as keys)
|
2975
|
+
|
2976
|
+
Returns::
|
2977
|
+
|
2978
|
+
AsyncResultSet object
|
2979
|
+
"""
|
2980
|
+
# Handle model class input
|
2981
|
+
if hasattr(table_name_or_model, '__tablename__'):
|
2982
|
+
# It's a model class
|
2983
|
+
table_name = table_name_or_model.__tablename__
|
2984
|
+
else:
|
2985
|
+
# It's a table name string
|
2986
|
+
table_name = table_name_or_model
|
2987
|
+
|
2988
|
+
sql = self.client._build_insert_sql(table_name, data)
|
2989
|
+
return await self.execute(sql)
|
2990
|
+
|
2991
|
+
async def batch_insert(self, table_name_or_model, data_list: list) -> AsyncResultSet:
|
2992
|
+
"""
|
2993
|
+
Batch insert data into a table within transaction asynchronously.
|
2994
|
+
|
2995
|
+
Args::
|
2996
|
+
|
2997
|
+
table_name_or_model: Either a table name (str) or a SQLAlchemy model class
|
2998
|
+
data_list: List of data dictionaries to insert
|
2999
|
+
|
3000
|
+
Returns::
|
3001
|
+
|
3002
|
+
AsyncResultSet object
|
3003
|
+
"""
|
3004
|
+
if not data_list:
|
3005
|
+
return AsyncResultSet([], [], affected_rows=0)
|
3006
|
+
|
3007
|
+
# Handle model class input
|
3008
|
+
if hasattr(table_name_or_model, '__tablename__'):
|
3009
|
+
# It's a model class
|
3010
|
+
table_name = table_name_or_model.__tablename__
|
3011
|
+
else:
|
3012
|
+
# It's a table name string
|
3013
|
+
table_name = table_name_or_model
|
3014
|
+
|
3015
|
+
sql = self.client._build_batch_insert_sql(table_name, data_list)
|
3016
|
+
return await self.execute(sql)
|
3017
|
+
|
3018
|
+
def query(self, *columns, snapshot: str = None):
|
3019
|
+
"""Get async MatrixOne query builder within transaction - SQLAlchemy style
|
3020
|
+
|
3021
|
+
Args::
|
3022
|
+
|
3023
|
+
*columns: Can be:
|
3024
|
+
- Single model class: query(Article) - returns all columns from model
|
3025
|
+
- Multiple columns: query(Article.id, Article.title) - returns specific columns
|
3026
|
+
- Mixed: query(Article, Article.id, some_expression.label('alias')) - model + additional columns
|
3027
|
+
snapshot: Optional snapshot name for snapshot queries
|
3028
|
+
|
3029
|
+
Returns::
|
3030
|
+
|
3031
|
+
AsyncMatrixOneQuery instance configured for the specified columns within transaction
|
3032
|
+
"""
|
3033
|
+
from .async_orm import AsyncMatrixOneQuery
|
3034
|
+
|
3035
|
+
if len(columns) == 1:
|
3036
|
+
# Traditional single model class usage
|
3037
|
+
column = columns[0]
|
3038
|
+
if isinstance(column, str):
|
3039
|
+
# String table name
|
3040
|
+
return AsyncMatrixOneQuery(column, self.client, None, transaction_wrapper=self, snapshot=snapshot)
|
3041
|
+
elif hasattr(column, '__tablename__'):
|
3042
|
+
# This is a model class
|
3043
|
+
return AsyncMatrixOneQuery(column, self.client, None, transaction_wrapper=self, snapshot=snapshot)
|
3044
|
+
elif hasattr(column, 'name') and hasattr(column, 'as_sql'):
|
3045
|
+
# This is a CTE object
|
3046
|
+
from .orm import CTE
|
3047
|
+
|
3048
|
+
if isinstance(column, CTE):
|
3049
|
+
query = AsyncMatrixOneQuery(None, self.client, None, transaction_wrapper=self, snapshot=snapshot)
|
3050
|
+
query._table_name = column.name
|
3051
|
+
query._select_columns = ["*"] # Default to select all from CTE
|
3052
|
+
query._ctes = [column] # Add the CTE to the query
|
3053
|
+
return query
|
3054
|
+
else:
|
3055
|
+
# This is a single column/expression - need to handle specially
|
3056
|
+
# For now, we'll create a query that can handle column selections
|
3057
|
+
query = AsyncMatrixOneQuery(None, self.client, None, transaction_wrapper=self, snapshot=snapshot)
|
3058
|
+
query._select_columns = [column]
|
3059
|
+
# Try to infer table name from column
|
3060
|
+
if hasattr(column, 'table') and hasattr(column.table, 'name'):
|
3061
|
+
query._table_name = column.table.name
|
3062
|
+
return query
|
3063
|
+
else:
|
3064
|
+
# Multiple columns/expressions
|
3065
|
+
model_class = None
|
3066
|
+
select_columns = []
|
3067
|
+
|
3068
|
+
for column in columns:
|
3069
|
+
if hasattr(column, '__tablename__'):
|
3070
|
+
# This is a model class - use its table
|
3071
|
+
model_class = column
|
3072
|
+
else:
|
3073
|
+
# This is a column or expression
|
3074
|
+
select_columns.append(column)
|
3075
|
+
|
3076
|
+
if model_class:
|
3077
|
+
query = AsyncMatrixOneQuery(model_class, self.client, None, transaction_wrapper=self, snapshot=snapshot)
|
3078
|
+
if select_columns:
|
3079
|
+
# Add additional columns to the model's default columns
|
3080
|
+
query._select_columns = select_columns
|
3081
|
+
return query
|
3082
|
+
else:
|
3083
|
+
# No model class provided, need to infer table from columns
|
3084
|
+
query = AsyncMatrixOneQuery(None, self.client, None, transaction_wrapper=self, snapshot=snapshot)
|
3085
|
+
query._select_columns = select_columns
|
3086
|
+
|
3087
|
+
# Try to infer table name from first column that has table info
|
3088
|
+
for col in select_columns:
|
3089
|
+
if hasattr(col, 'table') and hasattr(col.table, 'name'):
|
3090
|
+
query._table_name = col.table.name
|
3091
|
+
break
|
3092
|
+
elif isinstance(col, str) and '.' in col:
|
3093
|
+
# String column like "table.column" - extract table name
|
3094
|
+
parts = col.split('.')
|
3095
|
+
if len(parts) >= 2:
|
3096
|
+
# For "db.table.column" format, use "db.table"
|
3097
|
+
# For "table.column" format, use "table"
|
3098
|
+
table_name = '.'.join(parts[:-1])
|
3099
|
+
query._table_name = table_name
|
3100
|
+
break
|
3101
|
+
|
3102
|
+
return query
|
3103
|
+
|
3104
|
+
|
3105
|
+
class AsyncTransactionSnapshotManager(AsyncSnapshotManager):
|
3106
|
+
"""Async snapshot manager that executes operations within a transaction"""
|
3107
|
+
|
3108
|
+
def __init__(self, client, transaction_wrapper):
|
3109
|
+
super().__init__(client)
|
3110
|
+
self.transaction_wrapper = transaction_wrapper
|
3111
|
+
|
3112
|
+
async def create(
|
3113
|
+
self,
|
3114
|
+
name: str,
|
3115
|
+
level: Union[str, SnapshotLevel],
|
3116
|
+
database: Optional[str] = None,
|
3117
|
+
table: Optional[str] = None,
|
3118
|
+
description: Optional[str] = None,
|
3119
|
+
) -> Snapshot:
|
3120
|
+
"""Create snapshot within transaction asynchronously"""
|
3121
|
+
return await super().create(name, level, database, table, description)
|
3122
|
+
|
3123
|
+
async def get(self, name: str) -> Snapshot:
|
3124
|
+
"""Get snapshot within transaction asynchronously"""
|
3125
|
+
return await super().get(name)
|
3126
|
+
|
3127
|
+
async def delete(self, name: str) -> None:
|
3128
|
+
"""Delete snapshot within transaction asynchronously"""
|
3129
|
+
return await super().delete(name)
|
3130
|
+
|
3131
|
+
|
3132
|
+
class AsyncTransactionPubSubManager(AsyncPubSubManager):
|
3133
|
+
"""Async publish-subscribe manager for use within transactions"""
|
3134
|
+
|
3135
|
+
def __init__(self, client, transaction_wrapper):
|
3136
|
+
super().__init__(client)
|
3137
|
+
self.transaction_wrapper = transaction_wrapper
|
3138
|
+
|
3139
|
+
async def create_database_publication(self, name: str, database: str, account: str) -> Publication:
|
3140
|
+
"""Create database publication within transaction asynchronously"""
|
3141
|
+
try:
|
3142
|
+
sql = (
|
3143
|
+
f"CREATE PUBLICATION {self.client._escape_identifier(name)} "
|
3144
|
+
f"DATABASE {self.client._escape_identifier(database)} "
|
3145
|
+
f"ACCOUNT {self.client._escape_identifier(account)}"
|
3146
|
+
)
|
3147
|
+
|
3148
|
+
result = await self.transaction_wrapper.execute(sql)
|
3149
|
+
if result is None:
|
3150
|
+
raise PubSubError(f"Failed to create database publication '{name}'")
|
3151
|
+
|
3152
|
+
return await self.get_publication(name)
|
3153
|
+
|
3154
|
+
except Exception as e:
|
3155
|
+
raise PubSubError(f"Failed to create database publication '{name}': {e}")
|
3156
|
+
|
3157
|
+
async def create_table_publication(self, name: str, database: str, table: str, account: str) -> Publication:
|
3158
|
+
"""Create table publication within transaction asynchronously"""
|
3159
|
+
try:
|
3160
|
+
sql = (
|
3161
|
+
f"CREATE PUBLICATION {self.client._escape_identifier(name)} "
|
3162
|
+
f"DATABASE {self.client._escape_identifier(database)} "
|
3163
|
+
f"TABLE {self.client._escape_identifier(table)} "
|
3164
|
+
f"ACCOUNT {self.client._escape_identifier(account)}"
|
3165
|
+
)
|
3166
|
+
|
3167
|
+
result = await self.transaction_wrapper.execute(sql)
|
3168
|
+
if result is None:
|
3169
|
+
raise PubSubError(f"Failed to create table publication '{name}'")
|
3170
|
+
|
3171
|
+
return await self.get_publication(name)
|
3172
|
+
|
3173
|
+
except Exception as e:
|
3174
|
+
raise PubSubError(f"Failed to create table publication '{name}': {e}")
|
3175
|
+
|
3176
|
+
async def get_publication(self, name: str) -> Publication:
|
3177
|
+
"""Get publication within transaction asynchronously"""
|
3178
|
+
try:
|
3179
|
+
sql = f"SHOW PUBLICATIONS WHERE pub_name = {self.client._escape_string(name)}"
|
3180
|
+
result = await self.transaction_wrapper.execute(sql)
|
3181
|
+
|
3182
|
+
if not result or not result.rows:
|
3183
|
+
raise PubSubError(f"Publication '{name}' not found")
|
3184
|
+
|
3185
|
+
row = result.rows[0]
|
3186
|
+
return self._row_to_publication(row)
|
3187
|
+
|
3188
|
+
except Exception as e:
|
3189
|
+
raise PubSubError(f"Failed to get publication '{name}': {e}")
|
3190
|
+
|
3191
|
+
async def list_publications(self, account: Optional[str] = None, database: Optional[str] = None) -> List[Publication]:
|
3192
|
+
"""List publications within transaction asynchronously"""
|
3193
|
+
try:
|
3194
|
+
# SHOW PUBLICATIONS doesn't support WHERE clause, so we need to list all and filter
|
3195
|
+
sql = "SHOW PUBLICATIONS"
|
3196
|
+
result = await self.transaction_wrapper.execute(sql)
|
3197
|
+
|
3198
|
+
if not result or not result.rows:
|
3199
|
+
return []
|
3200
|
+
|
3201
|
+
publications = []
|
3202
|
+
for row in result.rows:
|
3203
|
+
pub = self._row_to_publication(row)
|
3204
|
+
|
3205
|
+
# Apply filters
|
3206
|
+
if account and account not in pub.sub_account:
|
3207
|
+
continue
|
3208
|
+
if database and pub.database != database:
|
3209
|
+
continue
|
3210
|
+
|
3211
|
+
publications.append(pub)
|
3212
|
+
|
3213
|
+
return publications
|
3214
|
+
|
3215
|
+
except Exception as e:
|
3216
|
+
raise PubSubError(f"Failed to list publications: {e}")
|
3217
|
+
|
3218
|
+
async def alter_publication(
|
3219
|
+
self,
|
3220
|
+
name: str,
|
3221
|
+
account: Optional[str] = None,
|
3222
|
+
database: Optional[str] = None,
|
3223
|
+
table: Optional[str] = None,
|
3224
|
+
) -> Publication:
|
3225
|
+
"""Alter publication within transaction asynchronously"""
|
3226
|
+
try:
|
3227
|
+
# Build ALTER PUBLICATION statement
|
3228
|
+
parts = [f"ALTER PUBLICATION {self.client._escape_identifier(name)}"]
|
3229
|
+
|
3230
|
+
if account:
|
3231
|
+
parts.append(f"ACCOUNT {self.client._escape_identifier(account)}")
|
3232
|
+
if database:
|
3233
|
+
parts.append(f"DATABASE {self.client._escape_identifier(database)}")
|
3234
|
+
if table:
|
3235
|
+
parts.append(f"TABLE {self.client._escape_identifier(table)}")
|
3236
|
+
|
3237
|
+
sql = " ".join(parts)
|
3238
|
+
result = await self.transaction_wrapper.execute(sql)
|
3239
|
+
if result is None:
|
3240
|
+
raise PubSubError(f"Failed to alter publication '{name}'")
|
3241
|
+
|
3242
|
+
return await self.get_publication(name)
|
3243
|
+
|
3244
|
+
except Exception as e:
|
3245
|
+
raise PubSubError(f"Failed to alter publication '{name}': {e}")
|
3246
|
+
|
3247
|
+
async def drop_publication(self, name: str) -> bool:
|
3248
|
+
"""Drop publication within transaction asynchronously"""
|
3249
|
+
try:
|
3250
|
+
sql = f"DROP PUBLICATION {self.client._escape_identifier(name)}"
|
3251
|
+
result = await self.transaction_wrapper.execute(sql)
|
3252
|
+
return result is not None
|
3253
|
+
|
3254
|
+
except Exception as e:
|
3255
|
+
raise PubSubError(f"Failed to drop publication '{name}': {e}")
|
3256
|
+
|
3257
|
+
async def create_subscription(
|
3258
|
+
self, subscription_name: str, publication_name: str, publisher_account: str
|
3259
|
+
) -> Subscription:
|
3260
|
+
"""Create subscription within transaction asynchronously"""
|
3261
|
+
try:
|
3262
|
+
sql = (
|
3263
|
+
f"CREATE DATABASE {self.client._escape_identifier(subscription_name)} "
|
3264
|
+
f"FROM {self.client._escape_identifier(publisher_account)} "
|
3265
|
+
f"PUBLICATION {self.client._escape_identifier(publication_name)}"
|
3266
|
+
)
|
3267
|
+
|
3268
|
+
result = await self.transaction_wrapper.execute(sql)
|
3269
|
+
if result is None:
|
3270
|
+
raise PubSubError(f"Failed to create subscription '{subscription_name}'")
|
3271
|
+
|
3272
|
+
return await self.get_subscription(subscription_name)
|
3273
|
+
|
3274
|
+
except Exception as e:
|
3275
|
+
raise PubSubError(f"Failed to create subscription '{subscription_name}': {e}")
|
3276
|
+
|
3277
|
+
async def get_subscription(self, name: str) -> Subscription:
|
3278
|
+
"""Get subscription within transaction asynchronously"""
|
3279
|
+
try:
|
3280
|
+
# SHOW SUBSCRIPTIONS doesn't support WHERE clause, so we need to list all and filter
|
3281
|
+
sql = "SHOW SUBSCRIPTIONS"
|
3282
|
+
result = await self.transaction_wrapper.execute(sql)
|
3283
|
+
|
3284
|
+
if not result or not result.rows:
|
3285
|
+
raise PubSubError(f"Subscription '{name}' not found")
|
3286
|
+
|
3287
|
+
# Find subscription with matching name
|
3288
|
+
for row in result.rows:
|
3289
|
+
if row[6] == name: # sub_name is in 7th column (index 6)
|
3290
|
+
return self._row_to_subscription(row)
|
3291
|
+
|
3292
|
+
raise PubSubError(f"Subscription '{name}' not found")
|
3293
|
+
|
3294
|
+
except Exception as e:
|
3295
|
+
raise PubSubError(f"Failed to get subscription '{name}': {e}")
|
3296
|
+
|
3297
|
+
async def list_subscriptions(
|
3298
|
+
self, pub_account: Optional[str] = None, pub_database: Optional[str] = None
|
3299
|
+
) -> List[Subscription]:
|
3300
|
+
"""List subscriptions within transaction asynchronously"""
|
3301
|
+
try:
|
3302
|
+
conditions = []
|
3303
|
+
|
3304
|
+
if pub_account:
|
3305
|
+
conditions.append(f"pub_account = {self.client._escape_string(pub_account)}")
|
3306
|
+
if pub_database:
|
3307
|
+
conditions.append(f"pub_database = {self.client._escape_string(pub_database)}")
|
3308
|
+
|
3309
|
+
if conditions:
|
3310
|
+
where_clause = " WHERE " + " AND ".join(conditions)
|
3311
|
+
else:
|
3312
|
+
where_clause = ""
|
3313
|
+
|
3314
|
+
sql = f"SHOW SUBSCRIPTIONS{where_clause}"
|
3315
|
+
result = await self.transaction_wrapper.execute(sql)
|
3316
|
+
|
3317
|
+
if not result or not result.rows:
|
3318
|
+
return []
|
3319
|
+
|
3320
|
+
return [self._row_to_subscription(row) for row in result.rows]
|
3321
|
+
|
3322
|
+
except Exception as e:
|
3323
|
+
raise PubSubError(f"Failed to list subscriptions: {e}")
|
3324
|
+
|
3325
|
+
|
3326
|
+
class AsyncTransactionPitrManager(AsyncPitrManager):
|
3327
|
+
"""Async PITR manager for use within transactions"""
|
3328
|
+
|
3329
|
+
def __init__(self, client, transaction_wrapper):
|
3330
|
+
super().__init__(client)
|
3331
|
+
self.transaction_wrapper = transaction_wrapper
|
3332
|
+
|
3333
|
+
async def create_cluster_pitr(self, name: str, range_value: int = 1, range_unit: str = "d") -> Pitr:
|
3334
|
+
"""Create cluster PITR within transaction asynchronously"""
|
3335
|
+
try:
|
3336
|
+
self._validate_range(range_value, range_unit)
|
3337
|
+
|
3338
|
+
sql = f"CREATE PITR {self.client._escape_identifier(name)} " f"FOR CLUSTER RANGE {range_value} '{range_unit}'"
|
3339
|
+
|
3340
|
+
result = await self.transaction_wrapper.execute(sql)
|
3341
|
+
if result is None:
|
3342
|
+
raise PitrError(f"Failed to create cluster PITR '{name}'")
|
3343
|
+
|
3344
|
+
return await self.get(name)
|
3345
|
+
|
3346
|
+
except Exception as e:
|
3347
|
+
raise PitrError(f"Failed to create cluster PITR '{name}': {e}")
|
3348
|
+
|
3349
|
+
async def create_account_pitr(
|
3350
|
+
self,
|
3351
|
+
name: str,
|
3352
|
+
account_name: Optional[str] = None,
|
3353
|
+
range_value: int = 1,
|
3354
|
+
range_unit: str = "d",
|
3355
|
+
) -> Pitr:
|
3356
|
+
"""Create account PITR within transaction asynchronously"""
|
3357
|
+
try:
|
3358
|
+
self._validate_range(range_value, range_unit)
|
3359
|
+
|
3360
|
+
if account_name:
|
3361
|
+
sql = (
|
3362
|
+
f"CREATE PITR {self.client._escape_identifier(name)} "
|
3363
|
+
f"FOR ACCOUNT {self.client._escape_identifier(account_name)} "
|
3364
|
+
f"RANGE {range_value} '{range_unit}'"
|
3365
|
+
)
|
3366
|
+
else:
|
3367
|
+
sql = (
|
3368
|
+
f"CREATE PITR {self.client._escape_identifier(name)} " f"FOR ACCOUNT RANGE {range_value} '{range_unit}'"
|
3369
|
+
)
|
3370
|
+
|
3371
|
+
result = await self.transaction_wrapper.execute(sql)
|
3372
|
+
if result is None:
|
3373
|
+
raise PitrError(f"Failed to create account PITR '{name}'")
|
3374
|
+
|
3375
|
+
return await self.get(name)
|
3376
|
+
|
3377
|
+
except Exception as e:
|
3378
|
+
raise PitrError(f"Failed to create account PITR '{name}': {e}")
|
3379
|
+
|
3380
|
+
async def create_database_pitr(self, name: str, database_name: str, range_value: int = 1, range_unit: str = "d") -> Pitr:
|
3381
|
+
"""Create database PITR within transaction asynchronously"""
|
3382
|
+
try:
|
3383
|
+
self._validate_range(range_value, range_unit)
|
3384
|
+
|
3385
|
+
sql = (
|
3386
|
+
f"CREATE PITR {self.client._escape_identifier(name)} "
|
3387
|
+
f"FOR DATABASE {self.client._escape_identifier(database_name)} "
|
3388
|
+
f"RANGE {range_value} '{range_unit}'"
|
3389
|
+
)
|
3390
|
+
|
3391
|
+
result = await self.transaction_wrapper.execute(sql)
|
3392
|
+
if result is None:
|
3393
|
+
raise PitrError(f"Failed to create database PITR '{name}'")
|
3394
|
+
|
3395
|
+
return await self.get(name)
|
3396
|
+
|
3397
|
+
except Exception as e:
|
3398
|
+
raise PitrError(f"Failed to create database PITR '{name}': {e}")
|
3399
|
+
|
3400
|
+
async def create_table_pitr(
|
3401
|
+
self,
|
3402
|
+
name: str,
|
3403
|
+
database_name: str,
|
3404
|
+
table_name: str,
|
3405
|
+
range_value: int = 1,
|
3406
|
+
range_unit: str = "d",
|
3407
|
+
) -> Pitr:
|
3408
|
+
"""Create table PITR within transaction asynchronously"""
|
3409
|
+
try:
|
3410
|
+
self._validate_range(range_value, range_unit)
|
3411
|
+
|
3412
|
+
sql = (
|
3413
|
+
f"CREATE PITR {self.client._escape_identifier(name)} "
|
3414
|
+
f"FOR TABLE {self.client._escape_identifier(database_name)} "
|
3415
|
+
f"{self.client._escape_identifier(table_name)} "
|
3416
|
+
f"RANGE {range_value} '{range_unit}'"
|
3417
|
+
)
|
3418
|
+
|
3419
|
+
result = await self.transaction_wrapper.execute(sql)
|
3420
|
+
if result is None:
|
3421
|
+
raise PitrError(f"Failed to create table PITR '{name}'")
|
3422
|
+
|
3423
|
+
return await self.get(name)
|
3424
|
+
|
3425
|
+
except Exception as e:
|
3426
|
+
raise PitrError(f"Failed to create table PITR '{name}': {e}")
|
3427
|
+
|
3428
|
+
async def get(self, name: str) -> Pitr:
|
3429
|
+
"""Get PITR within transaction asynchronously"""
|
3430
|
+
try:
|
3431
|
+
sql = f"SHOW PITR WHERE pitr_name = {self.client._escape_string(name)}"
|
3432
|
+
result = await self.transaction_wrapper.execute(sql)
|
3433
|
+
|
3434
|
+
if not result or not result.rows:
|
3435
|
+
raise PitrError(f"PITR '{name}' not found")
|
3436
|
+
|
3437
|
+
row = result.rows[0]
|
3438
|
+
return self._row_to_pitr(row)
|
3439
|
+
|
3440
|
+
except Exception as e:
|
3441
|
+
raise PitrError(f"Failed to get PITR '{name}': {e}")
|
3442
|
+
|
3443
|
+
async def list(
|
3444
|
+
self,
|
3445
|
+
level: Optional[str] = None,
|
3446
|
+
account_name: Optional[str] = None,
|
3447
|
+
database_name: Optional[str] = None,
|
3448
|
+
table_name: Optional[str] = None,
|
3449
|
+
) -> List[Pitr]:
|
3450
|
+
"""List PITRs within transaction asynchronously"""
|
3451
|
+
try:
|
3452
|
+
conditions = []
|
3453
|
+
|
3454
|
+
if level:
|
3455
|
+
conditions.append(f"pitr_level = {self.client._escape_string(level)}")
|
3456
|
+
if account_name:
|
3457
|
+
conditions.append(f"account_name = {self.client._escape_string(account_name)}")
|
3458
|
+
if database_name:
|
3459
|
+
conditions.append(f"database_name = {self.client._escape_string(database_name)}")
|
3460
|
+
if table_name:
|
3461
|
+
conditions.append(f"table_name = {self.client._escape_string(table_name)}")
|
3462
|
+
|
3463
|
+
if conditions:
|
3464
|
+
where_clause = " WHERE " + " AND ".join(conditions)
|
3465
|
+
else:
|
3466
|
+
where_clause = ""
|
3467
|
+
|
3468
|
+
sql = f"SHOW PITR{where_clause}"
|
3469
|
+
result = await self.transaction_wrapper.execute(sql)
|
3470
|
+
|
3471
|
+
if not result or not result.rows:
|
3472
|
+
return []
|
3473
|
+
|
3474
|
+
return [self._row_to_pitr(row) for row in result.rows]
|
3475
|
+
|
3476
|
+
except Exception as e:
|
3477
|
+
raise PitrError(f"Failed to list PITRs: {e}")
|
3478
|
+
|
3479
|
+
async def alter(self, name: str, range_value: int, range_unit: str) -> Pitr:
|
3480
|
+
"""Alter PITR within transaction asynchronously"""
|
3481
|
+
try:
|
3482
|
+
self._validate_range(range_value, range_unit)
|
3483
|
+
|
3484
|
+
sql = f"ALTER PITR {self.client._escape_identifier(name)} " f"RANGE {range_value} '{range_unit}'"
|
3485
|
+
|
3486
|
+
result = await self.transaction_wrapper.execute(sql)
|
3487
|
+
if result is None:
|
3488
|
+
raise PitrError(f"Failed to alter PITR '{name}'")
|
3489
|
+
|
3490
|
+
return await self.get(name)
|
3491
|
+
|
3492
|
+
except Exception as e:
|
3493
|
+
raise PitrError(f"Failed to alter PITR '{name}': {e}")
|
3494
|
+
|
3495
|
+
async def delete(self, name: str) -> bool:
|
3496
|
+
"""Delete PITR within transaction asynchronously"""
|
3497
|
+
try:
|
3498
|
+
sql = f"DROP PITR {self.client._escape_identifier(name)}"
|
3499
|
+
result = await self.transaction_wrapper.execute(sql)
|
3500
|
+
return result is not None
|
3501
|
+
|
3502
|
+
except Exception as e:
|
3503
|
+
raise PitrError(f"Failed to delete PITR '{name}': {e}")
|
3504
|
+
|
3505
|
+
|
3506
|
+
class AsyncTransactionRestoreManager(AsyncRestoreManager):
|
3507
|
+
"""Async restore manager for use within transactions"""
|
3508
|
+
|
3509
|
+
def __init__(self, client, transaction_wrapper):
|
3510
|
+
super().__init__(client)
|
3511
|
+
self.transaction_wrapper = transaction_wrapper
|
3512
|
+
|
3513
|
+
async def restore_cluster(self, snapshot_name: str) -> bool:
|
3514
|
+
"""Restore cluster within transaction asynchronously"""
|
3515
|
+
try:
|
3516
|
+
sql = f"RESTORE CLUSTER FROM SNAPSHOT {self.client._escape_identifier(snapshot_name)}"
|
3517
|
+
result = await self.transaction_wrapper.execute(sql)
|
3518
|
+
return result is not None
|
3519
|
+
except Exception as e:
|
3520
|
+
raise RestoreError(f"Failed to restore cluster from snapshot '{snapshot_name}': {e}")
|
3521
|
+
|
3522
|
+
async def restore_tenant(self, snapshot_name: str, account_name: str, to_account: Optional[str] = None) -> bool:
|
3523
|
+
"""Restore tenant within transaction asynchronously"""
|
3524
|
+
try:
|
3525
|
+
if to_account:
|
3526
|
+
sql = (
|
3527
|
+
f"RESTORE ACCOUNT {self.client._escape_identifier(account_name)} "
|
3528
|
+
f"FROM SNAPSHOT {self.client._escape_identifier(snapshot_name)} "
|
3529
|
+
f"TO ACCOUNT {self.client._escape_identifier(to_account)}"
|
3530
|
+
)
|
3531
|
+
else:
|
3532
|
+
sql = (
|
3533
|
+
f"RESTORE ACCOUNT {self.client._escape_identifier(account_name)} "
|
3534
|
+
f"FROM SNAPSHOT {self.client._escape_identifier(snapshot_name)}"
|
3535
|
+
)
|
3536
|
+
|
3537
|
+
result = await self.transaction_wrapper.execute(sql)
|
3538
|
+
return result is not None
|
3539
|
+
except Exception as e:
|
3540
|
+
raise RestoreError(f"Failed to restore tenant '{account_name}' from snapshot '{snapshot_name}': {e}")
|
3541
|
+
|
3542
|
+
async def restore_database(
|
3543
|
+
self,
|
3544
|
+
snapshot_name: str,
|
3545
|
+
account_name: str,
|
3546
|
+
database_name: str,
|
3547
|
+
to_account: Optional[str] = None,
|
3548
|
+
) -> bool:
|
3549
|
+
"""Restore database within transaction asynchronously"""
|
3550
|
+
try:
|
3551
|
+
if to_account:
|
3552
|
+
sql = (
|
3553
|
+
f"RESTORE ACCOUNT {self.client._escape_identifier(account_name)} "
|
3554
|
+
f"DATABASE {self.client._escape_identifier(database_name)} "
|
3555
|
+
f"FROM SNAPSHOT {self.client._escape_identifier(snapshot_name)} "
|
3556
|
+
f"TO ACCOUNT {self.client._escape_identifier(to_account)}"
|
3557
|
+
)
|
3558
|
+
else:
|
3559
|
+
sql = (
|
3560
|
+
f"RESTORE ACCOUNT {self.client._escape_identifier(account_name)} "
|
3561
|
+
f"DATABASE {self.client._escape_identifier(database_name)} "
|
3562
|
+
f"FROM SNAPSHOT {self.client._escape_identifier(snapshot_name)}"
|
3563
|
+
)
|
3564
|
+
|
3565
|
+
result = await self.transaction_wrapper.execute(sql)
|
3566
|
+
return result is not None
|
3567
|
+
except Exception as e:
|
3568
|
+
raise RestoreError(f"Failed to restore database '{database_name}' from snapshot '{snapshot_name}': {e}")
|
3569
|
+
|
3570
|
+
async def restore_table(
|
3571
|
+
self,
|
3572
|
+
snapshot_name: str,
|
3573
|
+
account_name: str,
|
3574
|
+
database_name: str,
|
3575
|
+
table_name: str,
|
3576
|
+
to_account: Optional[str] = None,
|
3577
|
+
) -> bool:
|
3578
|
+
"""Restore table within transaction asynchronously"""
|
3579
|
+
try:
|
3580
|
+
if to_account:
|
3581
|
+
sql = (
|
3582
|
+
f"RESTORE ACCOUNT {self.client._escape_identifier(account_name)} "
|
3583
|
+
f"DATABASE {self.client._escape_identifier(database_name)} "
|
3584
|
+
f"TABLE {self.client._escape_identifier(table_name)} "
|
3585
|
+
f"FROM SNAPSHOT {self.client._escape_identifier(snapshot_name)} "
|
3586
|
+
f"TO ACCOUNT {self.client._escape_identifier(to_account)}"
|
3587
|
+
)
|
3588
|
+
else:
|
3589
|
+
sql = (
|
3590
|
+
f"RESTORE ACCOUNT {self.client._escape_identifier(account_name)} "
|
3591
|
+
f"DATABASE {self.client._escape_identifier(database_name)} "
|
3592
|
+
f"TABLE {self.client._escape_identifier(table_name)} "
|
3593
|
+
f"FROM SNAPSHOT {self.client._escape_identifier(snapshot_name)}"
|
3594
|
+
)
|
3595
|
+
|
3596
|
+
result = await self.transaction_wrapper.execute(sql)
|
3597
|
+
return result is not None
|
3598
|
+
except Exception as e:
|
3599
|
+
raise RestoreError(f"Failed to restore table '{table_name}' from snapshot '{snapshot_name}': {e}")
|
3600
|
+
|
3601
|
+
|
3602
|
+
class AsyncTransactionCloneManager(AsyncCloneManager):
|
3603
|
+
"""Async clone manager that executes operations within a transaction"""
|
3604
|
+
|
3605
|
+
def __init__(self, client, transaction_wrapper):
|
3606
|
+
super().__init__(client)
|
3607
|
+
self.transaction_wrapper = transaction_wrapper
|
3608
|
+
|
3609
|
+
async def clone_database(
|
3610
|
+
self,
|
3611
|
+
target_db: str,
|
3612
|
+
source_db: str,
|
3613
|
+
snapshot_name: Optional[str] = None,
|
3614
|
+
if_not_exists: bool = False,
|
3615
|
+
) -> None:
|
3616
|
+
"""Clone database within transaction asynchronously"""
|
3617
|
+
return await super().clone_database(target_db, source_db, snapshot_name, if_not_exists)
|
3618
|
+
|
3619
|
+
async def clone_table(
|
3620
|
+
self,
|
3621
|
+
target_table: str,
|
3622
|
+
source_table: str,
|
3623
|
+
snapshot_name: Optional[str] = None,
|
3624
|
+
if_not_exists: bool = False,
|
3625
|
+
) -> None:
|
3626
|
+
"""Clone table within transaction asynchronously"""
|
3627
|
+
return await super().clone_table(target_table, source_table, snapshot_name, if_not_exists)
|
3628
|
+
|
3629
|
+
async def clone_table_with_snapshot(
|
3630
|
+
self, target_table: str, source_table: str, snapshot_name: str, if_not_exists: bool = False
|
3631
|
+
) -> None:
|
3632
|
+
"""Clone table with snapshot within transaction asynchronously"""
|
3633
|
+
return await super().clone_table_with_snapshot(target_table, source_table, snapshot_name, if_not_exists)
|
3634
|
+
|
3635
|
+
|
3636
|
+
class AsyncTransactionAccountManager:
|
3637
|
+
"""Async transaction-scoped account manager"""
|
3638
|
+
|
3639
|
+
def __init__(self, transaction):
|
3640
|
+
self.transaction = transaction
|
3641
|
+
self.client = transaction.client
|
3642
|
+
|
3643
|
+
async def create_account(
|
3644
|
+
self,
|
3645
|
+
account_name: str,
|
3646
|
+
admin_name: str,
|
3647
|
+
password: str,
|
3648
|
+
comment: Optional[str] = None,
|
3649
|
+
admin_comment: Optional[str] = None,
|
3650
|
+
admin_host: str = "%",
|
3651
|
+
admin_identified_by: Optional[str] = None,
|
3652
|
+
) -> Account:
|
3653
|
+
"""Create account within async transaction"""
|
3654
|
+
try:
|
3655
|
+
sql_parts = [f"CREATE ACCOUNT {self.client._escape_identifier(account_name)}"]
|
3656
|
+
sql_parts.append(f"ADMIN_NAME {self.client._escape_string(admin_name)}")
|
3657
|
+
sql_parts.append(f"IDENTIFIED BY {self.client._escape_string(password)}")
|
3658
|
+
|
3659
|
+
if admin_host != "%":
|
3660
|
+
sql_parts.append(f"ADMIN_HOST {self.client._escape_string(admin_host)}")
|
3661
|
+
|
3662
|
+
if comment:
|
3663
|
+
sql_parts.append(f"COMMENT {self.client._escape_string(comment)}")
|
3664
|
+
|
3665
|
+
if admin_comment:
|
3666
|
+
sql_parts.append(f"ADMIN_COMMENT {self.client._escape_string(admin_comment)}")
|
3667
|
+
|
3668
|
+
if admin_identified_by:
|
3669
|
+
sql_parts.append(f"ADMIN_IDENTIFIED BY {self.client._escape_string(admin_identified_by)}")
|
3670
|
+
|
3671
|
+
sql = " ".join(sql_parts)
|
3672
|
+
await self.transaction.execute(sql)
|
3673
|
+
|
3674
|
+
return await self.get_account(account_name)
|
3675
|
+
|
3676
|
+
except Exception as e:
|
3677
|
+
raise AccountError(f"Failed to create account '{account_name}': {e}")
|
3678
|
+
|
3679
|
+
async def drop_account(self, account_name: str) -> None:
|
3680
|
+
"""Drop account within async transaction"""
|
3681
|
+
try:
|
3682
|
+
sql = f"DROP ACCOUNT {self.client._escape_identifier(account_name)}"
|
3683
|
+
await self.transaction.execute(sql)
|
3684
|
+
except Exception as e:
|
3685
|
+
raise AccountError(f"Failed to drop account '{account_name}': {e}")
|
3686
|
+
|
3687
|
+
async def alter_account(
|
3688
|
+
self,
|
3689
|
+
account_name: str,
|
3690
|
+
comment: Optional[str] = None,
|
3691
|
+
suspend: Optional[bool] = None,
|
3692
|
+
suspend_reason: Optional[str] = None,
|
3693
|
+
) -> Account:
|
3694
|
+
"""Alter account within async transaction"""
|
3695
|
+
try:
|
3696
|
+
sql_parts = [f"ALTER ACCOUNT {self.client._escape_identifier(account_name)}"]
|
3697
|
+
|
3698
|
+
if comment is not None:
|
3699
|
+
sql_parts.append(f"COMMENT {self.client._escape_string(comment)}")
|
3700
|
+
|
3701
|
+
if suspend is not None:
|
3702
|
+
if suspend:
|
3703
|
+
if suspend_reason:
|
3704
|
+
sql_parts.append(f"SUSPEND COMMENT {self.client._escape_string(suspend_reason)}")
|
3705
|
+
else:
|
3706
|
+
sql_parts.append("SUSPEND")
|
3707
|
+
else:
|
3708
|
+
sql_parts.append("OPEN")
|
3709
|
+
|
3710
|
+
sql = " ".join(sql_parts)
|
3711
|
+
await self.transaction.execute(sql)
|
3712
|
+
|
3713
|
+
return await self.get_account(account_name)
|
3714
|
+
|
3715
|
+
except Exception as e:
|
3716
|
+
raise AccountError(f"Failed to alter account '{account_name}': {e}")
|
3717
|
+
|
3718
|
+
async def get_account(self, account_name: str) -> Account:
|
3719
|
+
"""Get account within async transaction"""
|
3720
|
+
try:
|
3721
|
+
sql = "SHOW ACCOUNTS"
|
3722
|
+
result = await self.transaction.execute(sql)
|
3723
|
+
|
3724
|
+
if not result or not result.rows:
|
3725
|
+
raise AccountError(f"Account '{account_name}' not found")
|
3726
|
+
|
3727
|
+
for row in result.rows:
|
3728
|
+
if row[0] == account_name:
|
3729
|
+
return self._row_to_account(row)
|
3730
|
+
|
3731
|
+
raise AccountError(f"Account '{account_name}' not found")
|
3732
|
+
|
3733
|
+
except Exception as e:
|
3734
|
+
raise AccountError(f"Failed to get account '{account_name}': {e}")
|
3735
|
+
|
3736
|
+
async def list_accounts(self) -> List[Account]:
|
3737
|
+
"""List accounts within async transaction"""
|
3738
|
+
try:
|
3739
|
+
sql = "SHOW ACCOUNTS"
|
3740
|
+
result = await self.transaction.execute(sql)
|
3741
|
+
|
3742
|
+
if not result or not result.rows:
|
3743
|
+
return []
|
3744
|
+
|
3745
|
+
return [self._row_to_account(row) for row in result.rows]
|
3746
|
+
|
3747
|
+
except Exception as e:
|
3748
|
+
raise AccountError(f"Failed to list accounts: {e}")
|
3749
|
+
|
3750
|
+
async def create_user(self, user_name: str, password: str, comment: Optional[str] = None) -> User:
|
3751
|
+
"""
|
3752
|
+
Create user within async transaction according to MatrixOne CREATE USER syntax:
|
3753
|
+
CREATE USER [IF NOT EXISTS] user auth_option [, user auth_option] ...
|
3754
|
+
[DEFAULT ROLE rolename] [COMMENT 'comment_string' | ATTRIBUTE 'json_object']
|
3755
|
+
|
3756
|
+
Args::
|
3757
|
+
|
3758
|
+
user_name: Name of the user to create
|
3759
|
+
password: Password for the user
|
3760
|
+
comment: Comment for the user (not supported in MatrixOne)
|
3761
|
+
|
3762
|
+
Returns::
|
3763
|
+
|
3764
|
+
User: Created user object
|
3765
|
+
"""
|
3766
|
+
try:
|
3767
|
+
# Build CREATE USER statement according to MatrixOne syntax
|
3768
|
+
# MatrixOne syntax: CREATE USER user_name IDENTIFIED BY 'password'
|
3769
|
+
sql_parts = [f"CREATE USER {self.client._escape_identifier(user_name)}"]
|
3770
|
+
|
3771
|
+
sql_parts.append(f"IDENTIFIED BY {self.client._escape_string(password)}")
|
3772
|
+
|
3773
|
+
# Note: MatrixOne doesn't support COMMENT or ATTRIBUTE clauses in CREATE USER
|
3774
|
+
# sql_parts.append(f"ACCOUNT {self.client._escape_identifier(account_name)}")
|
3775
|
+
# if comment:
|
3776
|
+
# sql_parts.append(f"COMMENT {self.client._escape_string(comment)}")
|
3777
|
+
# if identified_by:
|
3778
|
+
# sql_parts.append(f"IDENTIFIED BY {self.client._escape_string(identified_by)}")
|
3779
|
+
|
3780
|
+
sql = " ".join(sql_parts)
|
3781
|
+
await self.transaction.execute(sql)
|
3782
|
+
|
3783
|
+
# Return a User object with current account context
|
3784
|
+
return User(
|
3785
|
+
name=user_name,
|
3786
|
+
host="%", # Default host
|
3787
|
+
account="sys", # Default account
|
3788
|
+
created_time=datetime.now(),
|
3789
|
+
status="ACTIVE",
|
3790
|
+
comment=comment,
|
3791
|
+
)
|
3792
|
+
|
3793
|
+
except Exception as e:
|
3794
|
+
raise AccountError(f"Failed to create user '{user_name}': {e}")
|
3795
|
+
|
3796
|
+
async def drop_user(self, user_name: str, if_exists: bool = False) -> None:
|
3797
|
+
"""
|
3798
|
+
Drop user within async transaction according to MatrixOne DROP USER syntax:
|
3799
|
+
DROP USER [IF EXISTS] user [, user] ...
|
3800
|
+
|
3801
|
+
Args::
|
3802
|
+
|
3803
|
+
user_name: Name of the user to drop
|
3804
|
+
if_exists: If True, add IF EXISTS clause to avoid errors when user doesn't exist
|
3805
|
+
"""
|
3806
|
+
try:
|
3807
|
+
sql_parts = ["DROP USER"]
|
3808
|
+
if if_exists:
|
3809
|
+
sql_parts.append("IF EXISTS")
|
3810
|
+
|
3811
|
+
sql_parts.append(self.client._escape_identifier(user_name))
|
3812
|
+
sql = " ".join(sql_parts)
|
3813
|
+
await self.transaction.execute(sql)
|
3814
|
+
|
3815
|
+
except Exception as e:
|
3816
|
+
raise AccountError(f"Failed to drop user '{user_name}': {e}")
|
3817
|
+
|
3818
|
+
async def alter_user(
|
3819
|
+
self,
|
3820
|
+
user_name: str,
|
3821
|
+
password: Optional[str] = None,
|
3822
|
+
comment: Optional[str] = None,
|
3823
|
+
lock: Optional[bool] = None,
|
3824
|
+
lock_reason: Optional[str] = None,
|
3825
|
+
) -> User:
|
3826
|
+
"""Alter user within async transaction"""
|
3827
|
+
try:
|
3828
|
+
sql_parts = [f"ALTER USER {self.client._escape_identifier(user_name)}"]
|
3829
|
+
|
3830
|
+
if password is not None:
|
3831
|
+
sql_parts.append(f"IDENTIFIED BY {self.client._escape_string(password)}")
|
3832
|
+
|
3833
|
+
if comment is not None:
|
3834
|
+
sql_parts.append(f"COMMENT {self.client._escape_string(comment)}")
|
3835
|
+
|
3836
|
+
if lock is not None:
|
3837
|
+
if lock:
|
3838
|
+
if lock_reason:
|
3839
|
+
sql_parts.append(f"ACCOUNT LOCK COMMENT {self.client._escape_string(lock_reason)}")
|
3840
|
+
else:
|
3841
|
+
sql_parts.append("ACCOUNT LOCK")
|
3842
|
+
else:
|
3843
|
+
sql_parts.append("ACCOUNT UNLOCK")
|
3844
|
+
|
3845
|
+
sql = " ".join(sql_parts)
|
3846
|
+
await self.transaction.execute(sql)
|
3847
|
+
|
3848
|
+
return await self.get_user(user_name)
|
3849
|
+
|
3850
|
+
except Exception as e:
|
3851
|
+
raise AccountError(f"Failed to alter user '{user_name}': {e}")
|
3852
|
+
|
3853
|
+
async def get_user(self, user_name: str) -> User:
|
3854
|
+
"""Get user within async transaction"""
|
3855
|
+
try:
|
3856
|
+
sql = "SHOW GRANTS"
|
3857
|
+
result = await self.transaction.execute(sql)
|
3858
|
+
|
3859
|
+
if not result or not result.rows:
|
3860
|
+
raise AccountError(f"User '{user_name}' not found")
|
3861
|
+
|
3862
|
+
for row in result.rows:
|
3863
|
+
if row[0] == user_name:
|
3864
|
+
return self._row_to_user(row)
|
3865
|
+
|
3866
|
+
raise AccountError(f"User '{user_name}' not found")
|
3867
|
+
|
3868
|
+
except Exception as e:
|
3869
|
+
raise AccountError(f"Failed to get user '{user_name}': {e}")
|
3870
|
+
|
3871
|
+
async def list_users(self, account_name: Optional[str] = None) -> List[User]:
|
3872
|
+
"""List users within async transaction"""
|
3873
|
+
try:
|
3874
|
+
sql = "SHOW GRANTS"
|
3875
|
+
result = await self.transaction.execute(sql)
|
3876
|
+
|
3877
|
+
if not result or not result.rows:
|
3878
|
+
return []
|
3879
|
+
|
3880
|
+
users = [self._row_to_user(row) for row in result.rows]
|
3881
|
+
|
3882
|
+
if account_name:
|
3883
|
+
users = [user for user in users if user.account == account_name]
|
3884
|
+
|
3885
|
+
return users
|
3886
|
+
|
3887
|
+
except Exception as e:
|
3888
|
+
raise AccountError(f"Failed to list users: {e}")
|
3889
|
+
|
3890
|
+
def _row_to_account(self, row: tuple) -> Account:
|
3891
|
+
"""Convert database row to Account object"""
|
3892
|
+
return Account(
|
3893
|
+
name=row[0],
|
3894
|
+
admin_name=row[1],
|
3895
|
+
created_time=row[2] if len(row) > 2 else None,
|
3896
|
+
status=row[3] if len(row) > 3 else None,
|
3897
|
+
comment=row[4] if len(row) > 4 else None,
|
3898
|
+
suspended_time=row[5] if len(row) > 5 else None,
|
3899
|
+
suspended_reason=row[6] if len(row) > 6 else None,
|
3900
|
+
)
|
3901
|
+
|
3902
|
+
def _row_to_user(self, row: tuple) -> User:
|
3903
|
+
"""Convert database row to User object"""
|
3904
|
+
return User(
|
3905
|
+
name=row[0],
|
3906
|
+
host=row[1],
|
3907
|
+
account=row[2],
|
3908
|
+
created_time=row[3] if len(row) > 3 else None,
|
3909
|
+
status=row[4] if len(row) > 4 else None,
|
3910
|
+
comment=row[5] if len(row) > 5 else None,
|
3911
|
+
locked_time=row[6] if len(row) > 6 else None,
|
3912
|
+
locked_reason=row[7] if len(row) > 7 else None,
|
3913
|
+
)
|