matrixone-python-sdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. matrixone/__init__.py +155 -0
  2. matrixone/account.py +723 -0
  3. matrixone/async_client.py +3913 -0
  4. matrixone/async_metadata_manager.py +311 -0
  5. matrixone/async_orm.py +123 -0
  6. matrixone/async_vector_index_manager.py +633 -0
  7. matrixone/base_client.py +208 -0
  8. matrixone/client.py +4672 -0
  9. matrixone/config.py +452 -0
  10. matrixone/connection_hooks.py +286 -0
  11. matrixone/exceptions.py +89 -0
  12. matrixone/logger.py +782 -0
  13. matrixone/metadata.py +820 -0
  14. matrixone/moctl.py +219 -0
  15. matrixone/orm.py +2277 -0
  16. matrixone/pitr.py +646 -0
  17. matrixone/pubsub.py +771 -0
  18. matrixone/restore.py +411 -0
  19. matrixone/search_vector_index.py +1176 -0
  20. matrixone/snapshot.py +550 -0
  21. matrixone/sql_builder.py +844 -0
  22. matrixone/sqlalchemy_ext/__init__.py +161 -0
  23. matrixone/sqlalchemy_ext/adapters.py +163 -0
  24. matrixone/sqlalchemy_ext/dialect.py +534 -0
  25. matrixone/sqlalchemy_ext/fulltext_index.py +895 -0
  26. matrixone/sqlalchemy_ext/fulltext_search.py +1686 -0
  27. matrixone/sqlalchemy_ext/hnsw_config.py +194 -0
  28. matrixone/sqlalchemy_ext/ivf_config.py +252 -0
  29. matrixone/sqlalchemy_ext/table_builder.py +351 -0
  30. matrixone/sqlalchemy_ext/vector_index.py +1721 -0
  31. matrixone/sqlalchemy_ext/vector_type.py +948 -0
  32. matrixone/version.py +580 -0
  33. matrixone_python_sdk-0.1.0.dist-info/METADATA +706 -0
  34. matrixone_python_sdk-0.1.0.dist-info/RECORD +122 -0
  35. matrixone_python_sdk-0.1.0.dist-info/WHEEL +5 -0
  36. matrixone_python_sdk-0.1.0.dist-info/entry_points.txt +5 -0
  37. matrixone_python_sdk-0.1.0.dist-info/licenses/LICENSE +200 -0
  38. matrixone_python_sdk-0.1.0.dist-info/top_level.txt +2 -0
  39. tests/__init__.py +19 -0
  40. tests/offline/__init__.py +20 -0
  41. tests/offline/conftest.py +77 -0
  42. tests/offline/test_account.py +703 -0
  43. tests/offline/test_async_client_query_comprehensive.py +1218 -0
  44. tests/offline/test_basic.py +54 -0
  45. tests/offline/test_case_sensitivity.py +227 -0
  46. tests/offline/test_connection_hooks_offline.py +287 -0
  47. tests/offline/test_dialect_schema_handling.py +609 -0
  48. tests/offline/test_explain_methods.py +346 -0
  49. tests/offline/test_filter_logical_in.py +237 -0
  50. tests/offline/test_fulltext_search_comprehensive.py +795 -0
  51. tests/offline/test_ivf_config.py +249 -0
  52. tests/offline/test_join_methods.py +281 -0
  53. tests/offline/test_join_sqlalchemy_compatibility.py +276 -0
  54. tests/offline/test_logical_in_method.py +237 -0
  55. tests/offline/test_matrixone_version_parsing.py +264 -0
  56. tests/offline/test_metadata_offline.py +557 -0
  57. tests/offline/test_moctl.py +300 -0
  58. tests/offline/test_moctl_simple.py +251 -0
  59. tests/offline/test_model_support_offline.py +359 -0
  60. tests/offline/test_model_support_simple.py +225 -0
  61. tests/offline/test_pinecone_filter_offline.py +377 -0
  62. tests/offline/test_pitr.py +585 -0
  63. tests/offline/test_pubsub.py +712 -0
  64. tests/offline/test_query_update.py +283 -0
  65. tests/offline/test_restore.py +445 -0
  66. tests/offline/test_snapshot_comprehensive.py +384 -0
  67. tests/offline/test_sql_escaping_edge_cases.py +551 -0
  68. tests/offline/test_sqlalchemy_integration.py +382 -0
  69. tests/offline/test_sqlalchemy_vector_integration.py +434 -0
  70. tests/offline/test_table_builder.py +198 -0
  71. tests/offline/test_unified_filter.py +398 -0
  72. tests/offline/test_unified_transaction.py +495 -0
  73. tests/offline/test_vector_index.py +238 -0
  74. tests/offline/test_vector_operations.py +688 -0
  75. tests/offline/test_vector_type.py +174 -0
  76. tests/offline/test_version_core.py +328 -0
  77. tests/offline/test_version_management.py +372 -0
  78. tests/offline/test_version_standalone.py +652 -0
  79. tests/online/__init__.py +20 -0
  80. tests/online/conftest.py +216 -0
  81. tests/online/test_account_management.py +194 -0
  82. tests/online/test_advanced_features.py +344 -0
  83. tests/online/test_async_client_interfaces.py +330 -0
  84. tests/online/test_async_client_online.py +285 -0
  85. tests/online/test_async_model_insert_online.py +293 -0
  86. tests/online/test_async_orm_online.py +300 -0
  87. tests/online/test_async_simple_query_online.py +802 -0
  88. tests/online/test_async_transaction_simple_query.py +300 -0
  89. tests/online/test_basic_connection.py +130 -0
  90. tests/online/test_client_online.py +238 -0
  91. tests/online/test_config.py +90 -0
  92. tests/online/test_config_validation.py +123 -0
  93. tests/online/test_connection_hooks_new_online.py +217 -0
  94. tests/online/test_dialect_schema_handling_online.py +331 -0
  95. tests/online/test_filter_logical_in_online.py +374 -0
  96. tests/online/test_fulltext_comprehensive.py +1773 -0
  97. tests/online/test_fulltext_label_online.py +433 -0
  98. tests/online/test_fulltext_search_online.py +842 -0
  99. tests/online/test_ivf_stats_online.py +506 -0
  100. tests/online/test_logger_integration.py +311 -0
  101. tests/online/test_matrixone_query_orm.py +540 -0
  102. tests/online/test_metadata_online.py +579 -0
  103. tests/online/test_model_insert_online.py +255 -0
  104. tests/online/test_mysql_driver_validation.py +213 -0
  105. tests/online/test_orm_advanced_features.py +2022 -0
  106. tests/online/test_orm_cte_integration.py +269 -0
  107. tests/online/test_orm_online.py +270 -0
  108. tests/online/test_pinecone_filter.py +708 -0
  109. tests/online/test_pubsub_operations.py +352 -0
  110. tests/online/test_query_methods.py +225 -0
  111. tests/online/test_query_update_online.py +433 -0
  112. tests/online/test_search_vector_index.py +557 -0
  113. tests/online/test_simple_fulltext_online.py +915 -0
  114. tests/online/test_snapshot_comprehensive.py +998 -0
  115. tests/online/test_sqlalchemy_engine_integration.py +336 -0
  116. tests/online/test_sqlalchemy_integration.py +425 -0
  117. tests/online/test_transaction_contexts.py +1219 -0
  118. tests/online/test_transaction_insert_methods.py +356 -0
  119. tests/online/test_transaction_query_methods.py +288 -0
  120. tests/online/test_unified_filter_online.py +529 -0
  121. tests/online/test_vector_comprehensive.py +706 -0
  122. tests/online/test_version_management.py +291 -0
@@ -0,0 +1,3913 @@
1
+ # Copyright 2021 - 2022 Matrix Origin
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ MatrixOne Async Client - Asynchronous implementation
17
+ """
18
+
19
+ try:
20
+ import aiomysql
21
+ except ImportError:
22
+ aiomysql = None
23
+
24
+ try:
25
+ from sqlalchemy import text
26
+ from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
27
+ except ImportError:
28
+ create_async_engine = None
29
+ AsyncEngine = None
30
+ text = None
31
+
32
+ from contextlib import asynccontextmanager
33
+ from datetime import datetime
34
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
35
+
36
+ from .account import Account, User
37
+ from .async_vector_index_manager import AsyncVectorManager
38
+ from .base_client import BaseMatrixOneClient, BaseMatrixOneExecutor
39
+ from .connection_hooks import ConnectionHook, ConnectionAction, create_connection_hook
40
+ from .exceptions import (
41
+ AccountError,
42
+ ConnectionError,
43
+ MoCtlError,
44
+ PitrError,
45
+ PubSubError,
46
+ QueryError,
47
+ RestoreError,
48
+ SnapshotError,
49
+ )
50
+ from .logger import MatrixOneLogger, create_default_logger
51
+ from .async_metadata_manager import AsyncMetadataManager
52
+ from .pitr import Pitr
53
+ from .pubsub import Publication, Subscription
54
+ from .snapshot import Snapshot, SnapshotLevel
55
+
56
+
57
+ class AsyncResultSet:
58
+ """Async result set wrapper for query results"""
59
+
60
+ def __init__(self, columns: List[str], rows: List[Tuple], affected_rows: int = 0):
61
+ self.columns = columns
62
+ self.rows = rows
63
+ self.affected_rows = affected_rows
64
+ self._cursor = 0 # Track current position in result set
65
+
66
+ def fetchall(self) -> List[Tuple]:
67
+ """Fetch all remaining rows"""
68
+ remaining_rows = self.rows[self._cursor :]
69
+ self._cursor = len(self.rows)
70
+ return remaining_rows
71
+
72
+ def fetchone(self) -> Optional[Tuple]:
73
+ """Fetch one row"""
74
+ if self._cursor < len(self.rows):
75
+ row = self.rows[self._cursor]
76
+ self._cursor += 1
77
+ return row
78
+ return None
79
+
80
+ def fetchmany(self, size: int = 1) -> List[Tuple]:
81
+ """Fetch many rows"""
82
+ start = self._cursor
83
+ end = min(start + size, len(self.rows))
84
+ rows = self.rows[start:end]
85
+ self._cursor = end
86
+ return rows
87
+
88
+ def scalar(self) -> Any:
89
+ """Get scalar value (first column of first row)"""
90
+ if self.rows and self.columns:
91
+ return self.rows[0][0]
92
+ return None
93
+
94
+ def keys(self):
95
+ """Get column names"""
96
+ return iter(self.columns)
97
+
98
+ def __iter__(self):
99
+ return iter(self.rows)
100
+
101
+ def __len__(self):
102
+ return len(self.rows)
103
+
104
+
105
+ class AsyncSnapshotManager:
106
+ """Async snapshot manager"""
107
+
108
+ def __init__(self, client):
109
+ self.client = client
110
+
111
+ async def create(
112
+ self,
113
+ name: str,
114
+ level: Union[str, SnapshotLevel],
115
+ database: Optional[str] = None,
116
+ table: Optional[str] = None,
117
+ description: Optional[str] = None,
118
+ ) -> Snapshot:
119
+ """Create snapshot asynchronously"""
120
+ # Convert string level to enum if needed
121
+ if isinstance(level, str):
122
+ level = SnapshotLevel(level.lower())
123
+
124
+ # Build SQL based on level
125
+ if level == SnapshotLevel.CLUSTER:
126
+ sql = f"CREATE SNAPSHOT {name} FOR CLUSTER"
127
+ elif level == SnapshotLevel.ACCOUNT:
128
+ sql = f"CREATE SNAPSHOT {name} FOR ACCOUNT"
129
+ elif level == SnapshotLevel.DATABASE:
130
+ if not database:
131
+ raise SnapshotError("Database name is required for database level snapshot")
132
+ sql = f"CREATE SNAPSHOT {name} FOR DATABASE {database}"
133
+ elif level == SnapshotLevel.TABLE:
134
+ if not database or not table:
135
+ raise SnapshotError("Database and table names are required for table level snapshot")
136
+ sql = f"CREATE SNAPSHOT {name} FOR TABLE {database} {table}"
137
+ else:
138
+ raise SnapshotError(f"Invalid snapshot level: {level}")
139
+
140
+ if description:
141
+ sql += f" COMMENT '{description}'"
142
+
143
+ await self.client.execute(sql)
144
+
145
+ # Return snapshot object
146
+ import datetime
147
+
148
+ return Snapshot(name, level, datetime.datetime.now(), description, database, table)
149
+
150
+ async def get(self, name: str) -> Snapshot:
151
+ """Get snapshot asynchronously"""
152
+ # Use mo_catalog.mo_snapshots table like sync client
153
+ sql = """
154
+ SELECT sname, ts, level, account_name, database_name, table_name
155
+ FROM mo_catalog.mo_snapshots
156
+ WHERE sname = :name
157
+ """
158
+ result = await self.client.execute(sql, {"name": name})
159
+
160
+ if not result.rows:
161
+ raise SnapshotError(f"Snapshot '{name}' not found")
162
+
163
+ row = result.rows[0]
164
+ # Convert timestamp to datetime
165
+ from datetime import datetime
166
+
167
+ timestamp = datetime.fromtimestamp(row[1] / 1000000000) # Convert nanoseconds to seconds
168
+
169
+ # Convert level string to enum
170
+ level_str = row[2]
171
+ try:
172
+ level = SnapshotLevel(level_str.lower())
173
+ except ValueError:
174
+ level = level_str # Fallback to string for backward compatibility
175
+
176
+ return Snapshot(row[0], level, timestamp, None, row[4], row[5])
177
+
178
+ async def list(self) -> List[Snapshot]:
179
+ """List all snapshots asynchronously"""
180
+ # Use mo_catalog.mo_snapshots table like sync client
181
+ sql = """
182
+ SELECT sname, ts, level, account_name, database_name, table_name
183
+ FROM mo_catalog.mo_snapshots
184
+ ORDER BY ts DESC
185
+ """
186
+ result = await self.client.execute(sql)
187
+
188
+ snapshots = []
189
+ for row in result.rows:
190
+ # Convert timestamp to datetime
191
+ timestamp = datetime.fromtimestamp(row[1] / 1000000000) # Convert nanoseconds to seconds
192
+
193
+ # Convert level string to enum
194
+ level_str = row[2]
195
+ try:
196
+ level = SnapshotLevel(level_str.lower())
197
+ except ValueError:
198
+ level = level_str # Fallback to string for backward compatibility
199
+
200
+ snapshots.append(Snapshot(row[0], level, timestamp, None, row[4], row[5]))
201
+
202
+ return snapshots
203
+
204
+ async def delete(self, name: str) -> None:
205
+ """Delete snapshot asynchronously"""
206
+ sql = f"DROP SNAPSHOT {name}"
207
+ await self.client.execute(sql)
208
+
209
+ async def exists(self, name: str) -> bool:
210
+ """Check if snapshot exists asynchronously"""
211
+ try:
212
+ await self.get(name)
213
+ return True
214
+ except SnapshotError:
215
+ return False
216
+
217
+
218
+ class AsyncCloneManager:
219
+ """Async clone manager"""
220
+
221
+ def __init__(self, client):
222
+ self.client = client
223
+
224
+ async def clone_database(
225
+ self,
226
+ target_db: str,
227
+ source_db: str,
228
+ snapshot_name: Optional[str] = None,
229
+ if_not_exists: bool = False,
230
+ ) -> None:
231
+ """Clone database asynchronously"""
232
+ if_not_exists_clause = "IF NOT EXISTS" if if_not_exists else ""
233
+
234
+ if snapshot_name:
235
+ sql = f"CREATE DATABASE {target_db} {if_not_exists_clause} CLONE {source_db} FOR SNAPSHOT '{snapshot_name}'"
236
+ else:
237
+ sql = f"CREATE DATABASE {target_db} {if_not_exists_clause} CLONE {source_db}"
238
+
239
+ await self.client.execute(sql)
240
+
241
+ async def clone_table(
242
+ self,
243
+ target_table: str,
244
+ source_table: str,
245
+ snapshot_name: Optional[str] = None,
246
+ if_not_exists: bool = False,
247
+ ) -> None:
248
+ """Clone table asynchronously"""
249
+ if_not_exists_clause = "IF NOT EXISTS" if if_not_exists else ""
250
+
251
+ if snapshot_name:
252
+ sql = (
253
+ f"CREATE TABLE {target_table} {if_not_exists_clause} " f"CLONE {source_table} FOR SNAPSHOT '{snapshot_name}'"
254
+ )
255
+ else:
256
+ sql = f"CREATE TABLE {target_table} {if_not_exists_clause} CLONE {source_table}"
257
+
258
+ await self.client.execute(sql)
259
+
260
+ async def clone_database_with_snapshot(
261
+ self, target_db: str, source_db: str, snapshot_name: str, if_not_exists: bool = False
262
+ ) -> None:
263
+ """Clone database with snapshot asynchronously"""
264
+ await self.clone_database(target_db, source_db, snapshot_name, if_not_exists)
265
+
266
+ async def clone_table_with_snapshot(
267
+ self, target_table: str, source_table: str, snapshot_name: str, if_not_exists: bool = False
268
+ ) -> None:
269
+ """Clone table with snapshot asynchronously"""
270
+ await self.clone_table(target_table, source_table, snapshot_name, if_not_exists)
271
+
272
+
273
+ class AsyncPubSubManager:
274
+ """Async manager for publish-subscribe operations"""
275
+
276
+ def __init__(self, client):
277
+ self.client = client
278
+
279
+ async def create_database_publication(self, name: str, database: str, account: str) -> Publication:
280
+ """Create database-level publication asynchronously"""
281
+ try:
282
+ sql = (
283
+ f"CREATE PUBLICATION {self.client._escape_identifier(name)} "
284
+ f"DATABASE {self.client._escape_identifier(database)} "
285
+ f"ACCOUNT {self.client._escape_identifier(account)}"
286
+ )
287
+
288
+ result = await self.client.execute(sql)
289
+ if result is None:
290
+ raise PubSubError(f"Failed to create database publication '{name}'")
291
+
292
+ return await self.get_publication(name)
293
+
294
+ except Exception as e:
295
+ raise PubSubError(f"Failed to create database publication '{name}': {e}")
296
+
297
+ async def create_table_publication(self, name: str, database: str, table: str, account: str) -> Publication:
298
+ """Create table-level publication asynchronously"""
299
+ try:
300
+ sql = (
301
+ f"CREATE PUBLICATION {self.client._escape_identifier(name)} "
302
+ f"DATABASE {self.client._escape_identifier(database)} "
303
+ f"TABLE {self.client._escape_identifier(table)} "
304
+ f"ACCOUNT {self.client._escape_identifier(account)}"
305
+ )
306
+
307
+ result = await self.client.execute(sql)
308
+ if result is None:
309
+ raise PubSubError(f"Failed to create table publication '{name}'")
310
+
311
+ return await self.get_publication(name)
312
+
313
+ except Exception as e:
314
+ raise PubSubError(f"Failed to create table publication '{name}': {e}")
315
+
316
+ async def get_publication(self, name: str) -> Publication:
317
+ """Get publication by name asynchronously"""
318
+ try:
319
+ # SHOW PUBLICATIONS doesn't support WHERE clause, so we need to list all and filter
320
+ sql = "SHOW PUBLICATIONS"
321
+ result = await self.client.execute(sql)
322
+
323
+ if not result or not result.rows:
324
+ raise PubSubError(f"Publication '{name}' not found")
325
+
326
+ # Find publication with matching name
327
+ for row in result.rows:
328
+ if row[0] == name: # publication name is in first column
329
+ return self._row_to_publication(row)
330
+
331
+ raise PubSubError(f"Publication '{name}' not found")
332
+
333
+ except Exception as e:
334
+ raise PubSubError(f"Failed to get publication '{name}': {e}")
335
+
336
+ async def list_publications(self, account: Optional[str] = None, database: Optional[str] = None) -> List[Publication]:
337
+ """List publications with optional filters asynchronously"""
338
+ try:
339
+ # SHOW PUBLICATIONS doesn't support WHERE clause, so we need to list all and filter
340
+ sql = "SHOW PUBLICATIONS"
341
+ result = await self.client.execute(sql)
342
+
343
+ if not result or not result.rows:
344
+ return []
345
+
346
+ publications = []
347
+ for row in result.rows:
348
+ pub = self._row_to_publication(row)
349
+
350
+ # Apply filters
351
+ if account and account not in pub.sub_account:
352
+ continue
353
+ if database and pub.database != database:
354
+ continue
355
+
356
+ publications.append(pub)
357
+
358
+ return publications
359
+
360
+ except Exception as e:
361
+ raise PubSubError(f"Failed to list publications: {e}")
362
+
363
+ async def alter_publication(
364
+ self,
365
+ name: str,
366
+ account: Optional[str] = None,
367
+ database: Optional[str] = None,
368
+ table: Optional[str] = None,
369
+ ) -> Publication:
370
+ """Alter publication asynchronously"""
371
+ try:
372
+ # Build ALTER PUBLICATION statement
373
+ parts = [f"ALTER PUBLICATION {self.client._escape_identifier(name)}"]
374
+
375
+ if account:
376
+ parts.append(f"ACCOUNT {self.client._escape_identifier(account)}")
377
+ if database:
378
+ parts.append(f"DATABASE {self.client._escape_identifier(database)}")
379
+ if table:
380
+ parts.append(f"TABLE {self.client._escape_identifier(table)}")
381
+
382
+ sql = " ".join(parts)
383
+ result = await self.client.execute(sql)
384
+ if result is None:
385
+ raise PubSubError(f"Failed to alter publication '{name}'")
386
+
387
+ return await self.get_publication(name)
388
+
389
+ except Exception as e:
390
+ raise PubSubError(f"Failed to alter publication '{name}': {e}")
391
+
392
+ async def drop_publication(self, name: str) -> bool:
393
+ """Drop publication asynchronously"""
394
+ try:
395
+ sql = f"DROP PUBLICATION {self.client._escape_identifier(name)}"
396
+ result = await self.client.execute(sql)
397
+ return result is not None
398
+
399
+ except Exception as e:
400
+ raise PubSubError(f"Failed to drop publication '{name}': {e}")
401
+
402
+ async def show_create_publication(self, name: str) -> str:
403
+ """Show CREATE PUBLICATION statement for a publication asynchronously"""
404
+ try:
405
+ sql = f"SHOW CREATE PUBLICATION {self.client._escape_identifier(name)}"
406
+ result = await self.client.execute(sql)
407
+
408
+ if not result or not result.rows:
409
+ raise PubSubError(f"Publication '{name}' not found")
410
+
411
+ # The result should contain the CREATE statement
412
+ # Assuming the CREATE statement is in the first column
413
+ return result.rows[0][0]
414
+
415
+ except Exception as e:
416
+ raise PubSubError(f"Failed to show create publication '{name}': {e}")
417
+
418
+ async def create_subscription(
419
+ self, subscription_name: str, publication_name: str, publisher_account: str
420
+ ) -> Subscription:
421
+ """Create subscription from publication asynchronously"""
422
+ try:
423
+ sql = (
424
+ f"CREATE DATABASE {self.client._escape_identifier(subscription_name)} "
425
+ f"FROM {self.client._escape_identifier(publisher_account)} "
426
+ f"PUBLICATION {self.client._escape_identifier(publication_name)}"
427
+ )
428
+
429
+ result = await self.client.execute(sql)
430
+ if result is None:
431
+ raise PubSubError(f"Failed to create subscription '{subscription_name}'")
432
+
433
+ return await self.get_subscription(subscription_name)
434
+
435
+ except Exception as e:
436
+ raise PubSubError(f"Failed to create subscription '{subscription_name}': {e}")
437
+
438
+ async def get_subscription(self, name: str) -> Subscription:
439
+ """Get subscription by name asynchronously"""
440
+ try:
441
+ # SHOW SUBSCRIPTIONS doesn't support WHERE clause, so we need to list all and filter
442
+ sql = "SHOW SUBSCRIPTIONS"
443
+ result = await self.client.execute(sql)
444
+
445
+ if not result or not result.rows:
446
+ raise PubSubError(f"Subscription '{name}' not found")
447
+
448
+ # Find subscription with matching name
449
+ for row in result.rows:
450
+ if row[6] == name: # sub_name is in 7th column (index 6)
451
+ return self._row_to_subscription(row)
452
+
453
+ raise PubSubError(f"Subscription '{name}' not found")
454
+
455
+ except Exception as e:
456
+ raise PubSubError(f"Failed to get subscription '{name}': {e}")
457
+
458
+ async def list_subscriptions(
459
+ self, pub_account: Optional[str] = None, pub_database: Optional[str] = None
460
+ ) -> List[Subscription]:
461
+ """List subscriptions with optional filters asynchronously"""
462
+ try:
463
+ conditions = []
464
+
465
+ if pub_account:
466
+ conditions.append(f"pub_account = {self.client._escape_string(pub_account)}")
467
+ if pub_database:
468
+ conditions.append(f"pub_database = {self.client._escape_string(pub_database)}")
469
+
470
+ if conditions:
471
+ where_clause = " WHERE " + " AND ".join(conditions)
472
+ else:
473
+ where_clause = ""
474
+
475
+ sql = f"SHOW SUBSCRIPTIONS{where_clause}"
476
+ result = await self.client.execute(sql)
477
+
478
+ if not result or not result.rows:
479
+ return []
480
+
481
+ return [self._row_to_subscription(row) for row in result.rows]
482
+
483
+ except Exception as e:
484
+ raise PubSubError(f"Failed to list subscriptions: {e}")
485
+
486
+ def _row_to_publication(self, row: tuple) -> Publication:
487
+ """Convert database row to Publication object"""
488
+ # Expected columns: publication, database, tables, sub_account, subscribed_accounts,
489
+ # create_time, update_time, comments
490
+ # Based on MatrixOne official documentation:
491
+ # https://docs.matrixorigin.cn/en/v25.2.2.2/MatrixOne/Reference/SQL-Reference/Other/SHOW-Statements/show-publications/
492
+ return Publication(
493
+ name=row[0], # publication
494
+ database=row[1], # database
495
+ tables=row[2], # tables
496
+ sub_account=row[3], # sub_account
497
+ subscribed_accounts=row[4], # subscribed_accounts
498
+ created_time=row[5] if len(row) > 5 else None, # create_time
499
+ update_time=row[6] if len(row) > 6 else None, # update_time
500
+ comments=row[7] if len(row) > 7 else None, # comments
501
+ )
502
+
503
+ def _row_to_subscription(self, row: tuple) -> Subscription:
504
+ """Convert database row to Subscription object"""
505
+ # Expected columns: pub_name, pub_account, pub_database, pub_tables, pub_comment,
506
+ # pub_time, sub_name, sub_time, status
507
+ return Subscription(
508
+ pub_name=row[0],
509
+ pub_account=row[1],
510
+ pub_database=row[2],
511
+ pub_tables=row[3],
512
+ pub_comment=row[4] if len(row) > 4 else None,
513
+ pub_time=row[5] if len(row) > 5 else None,
514
+ sub_name=row[6] if len(row) > 6 else None,
515
+ sub_time=row[7] if len(row) > 7 else None,
516
+ status=row[8] if len(row) > 8 else 0,
517
+ )
518
+
519
+
520
+ class AsyncPitrManager:
521
+ """Async manager for PITR operations"""
522
+
523
+ def __init__(self, client):
524
+ self.client = client
525
+
526
+ async def create_cluster_pitr(self, name: str, range_value: int = 1, range_unit: str = "d") -> Pitr:
527
+ """Create cluster-level PITR asynchronously"""
528
+ try:
529
+ self._validate_range(range_value, range_unit)
530
+
531
+ sql = f"CREATE PITR {self.client._escape_identifier(name)} " f"FOR CLUSTER RANGE {range_value} '{range_unit}'"
532
+
533
+ result = await self.client.execute(sql)
534
+ if result is None:
535
+ raise PitrError(f"Failed to create cluster PITR '{name}'")
536
+
537
+ return await self.get(name)
538
+
539
+ except Exception as e:
540
+ raise PitrError(f"Failed to create cluster PITR '{name}': {e}")
541
+
542
+ async def create_account_pitr(
543
+ self,
544
+ name: str,
545
+ account_name: Optional[str] = None,
546
+ range_value: int = 1,
547
+ range_unit: str = "d",
548
+ ) -> Pitr:
549
+ """Create account-level PITR asynchronously"""
550
+ try:
551
+ self._validate_range(range_value, range_unit)
552
+
553
+ if account_name:
554
+ sql = (
555
+ f"CREATE PITR {self.client._escape_identifier(name)} "
556
+ f"FOR ACCOUNT {self.client._escape_identifier(account_name)} "
557
+ f"RANGE {range_value} '{range_unit}'"
558
+ )
559
+ else:
560
+ sql = (
561
+ f"CREATE PITR {self.client._escape_identifier(name)} " f"FOR ACCOUNT RANGE {range_value} '{range_unit}'"
562
+ )
563
+
564
+ result = await self.client.execute(sql)
565
+ if result is None:
566
+ raise PitrError(f"Failed to create account PITR '{name}'")
567
+
568
+ return await self.get(name)
569
+
570
+ except Exception as e:
571
+ raise PitrError(f"Failed to create account PITR '{name}': {e}")
572
+
573
+ async def create_database_pitr(self, name: str, database_name: str, range_value: int = 1, range_unit: str = "d") -> Pitr:
574
+ """Create database-level PITR asynchronously"""
575
+ try:
576
+ self._validate_range(range_value, range_unit)
577
+
578
+ sql = (
579
+ f"CREATE PITR {self.client._escape_identifier(name)} "
580
+ f"FOR DATABASE {self.client._escape_identifier(database_name)} "
581
+ f"RANGE {range_value} '{range_unit}'"
582
+ )
583
+
584
+ result = await self.client.execute(sql)
585
+ if result is None:
586
+ raise PitrError(f"Failed to create database PITR '{name}'")
587
+
588
+ return await self.get(name)
589
+
590
+ except Exception as e:
591
+ raise PitrError(f"Failed to create database PITR '{name}': {e}")
592
+
593
+ async def create_table_pitr(
594
+ self,
595
+ name: str,
596
+ database_name: str,
597
+ table_name: str,
598
+ range_value: int = 1,
599
+ range_unit: str = "d",
600
+ ) -> Pitr:
601
+ """Create table-level PITR asynchronously"""
602
+ try:
603
+ self._validate_range(range_value, range_unit)
604
+
605
+ sql = (
606
+ f"CREATE PITR {self.client._escape_identifier(name)} "
607
+ f"FOR TABLE {self.client._escape_identifier(database_name)} "
608
+ f"{self.client._escape_identifier(table_name)} "
609
+ f"RANGE {range_value} '{range_unit}'"
610
+ )
611
+
612
+ result = await self.client.execute(sql)
613
+ if result is None:
614
+ raise PitrError(f"Failed to create table PITR '{name}'")
615
+
616
+ return await self.get(name)
617
+
618
+ except Exception as e:
619
+ raise PitrError(f"Failed to create table PITR '{name}': {e}")
620
+
621
+ async def get(self, name: str) -> Pitr:
622
+ """Get PITR by name asynchronously"""
623
+ try:
624
+ sql = f"SHOW PITR WHERE pitr_name = {self.client._escape_string(name)}"
625
+ result = await self.client.execute(sql)
626
+
627
+ if not result or not result.rows:
628
+ raise PitrError(f"PITR '{name}' not found")
629
+
630
+ row = result.rows[0]
631
+ return self._row_to_pitr(row)
632
+
633
+ except Exception as e:
634
+ raise PitrError(f"Failed to get PITR '{name}': {e}")
635
+
636
+ async def list(
637
+ self,
638
+ level: Optional[str] = None,
639
+ account_name: Optional[str] = None,
640
+ database_name: Optional[str] = None,
641
+ table_name: Optional[str] = None,
642
+ ) -> List[Pitr]:
643
+ """List PITRs with optional filters asynchronously"""
644
+ try:
645
+ conditions = []
646
+
647
+ if level:
648
+ conditions.append(f"pitr_level = {self.client._escape_string(level)}")
649
+ if account_name:
650
+ conditions.append(f"account_name = {self.client._escape_string(account_name)}")
651
+ if database_name:
652
+ conditions.append(f"database_name = {self.client._escape_string(database_name)}")
653
+ if table_name:
654
+ conditions.append(f"table_name = {self.client._escape_string(table_name)}")
655
+
656
+ if conditions:
657
+ where_clause = " WHERE " + " AND ".join(conditions)
658
+ else:
659
+ where_clause = ""
660
+
661
+ sql = f"SHOW PITR{where_clause}"
662
+ result = await self.client.execute(sql)
663
+
664
+ if not result or not result.rows:
665
+ return []
666
+
667
+ return [self._row_to_pitr(row) for row in result.rows]
668
+
669
+ except Exception as e:
670
+ raise PitrError(f"Failed to list PITRs: {e}")
671
+
672
+ async def alter(self, name: str, range_value: int, range_unit: str) -> Pitr:
673
+ """Alter PITR range asynchronously"""
674
+ try:
675
+ self._validate_range(range_value, range_unit)
676
+
677
+ sql = f"ALTER PITR {self.client._escape_identifier(name)} " f"RANGE {range_value} '{range_unit}'"
678
+
679
+ result = await self.client.execute(sql)
680
+ if result is None:
681
+ raise PitrError(f"Failed to alter PITR '{name}'")
682
+
683
+ return await self.get(name)
684
+
685
+ except Exception as e:
686
+ raise PitrError(f"Failed to alter PITR '{name}': {e}")
687
+
688
+ async def delete(self, name: str) -> bool:
689
+ """Delete PITR asynchronously"""
690
+ try:
691
+ sql = f"DROP PITR {self.client._escape_identifier(name)}"
692
+ result = await self.client.execute(sql)
693
+ return result is not None
694
+
695
+ except Exception as e:
696
+ raise PitrError(f"Failed to delete PITR '{name}': {e}")
697
+
698
+ def _validate_range(self, range_value: int, range_unit: str) -> None:
699
+ """Validate PITR range parameters"""
700
+ if not (1 <= range_value <= 100):
701
+ raise PitrError("Range value must be between 1 and 100")
702
+
703
+ valid_units = ["h", "d", "mo", "y"]
704
+ if range_unit not in valid_units:
705
+ raise PitrError(f"Range unit must be one of: {', '.join(valid_units)}")
706
+
707
+ def _row_to_pitr(self, row: tuple) -> Pitr:
708
+ """Convert database row to Pitr object"""
709
+ # Expected columns: pitr_name, created_time, modified_time, pitr_level,
710
+ # account_name, database_name, table_name, pitr_length, pitr_unit
711
+ return Pitr(
712
+ name=row[0],
713
+ created_time=row[1],
714
+ modified_time=row[2],
715
+ level=row[3],
716
+ account_name=row[4] if row[4] != "*" else None,
717
+ database_name=row[5] if row[5] != "*" else None,
718
+ table_name=row[6] if row[6] != "*" else None,
719
+ range_value=row[7],
720
+ range_unit=row[8],
721
+ )
722
+
723
+
724
+ class AsyncRestoreManager:
725
+ """Async manager for restore operations"""
726
+
727
+ def __init__(self, client):
728
+ self.client = client
729
+
730
+ async def restore_cluster(self, snapshot_name: str) -> bool:
731
+ """Restore entire cluster from snapshot asynchronously"""
732
+ try:
733
+ sql = f"RESTORE CLUSTER FROM SNAPSHOT {self.client._escape_identifier(snapshot_name)}"
734
+ result = await self.client.execute(sql)
735
+ return result is not None
736
+ except Exception as e:
737
+ raise RestoreError(f"Failed to restore cluster from snapshot '{snapshot_name}': {e}")
738
+
739
+ async def restore_tenant(self, snapshot_name: str, account_name: str, to_account: Optional[str] = None) -> bool:
740
+ """Restore tenant from snapshot asynchronously"""
741
+ try:
742
+ if to_account:
743
+ sql = (
744
+ f"RESTORE ACCOUNT {self.client._escape_identifier(account_name)} "
745
+ f"FROM SNAPSHOT {self.client._escape_identifier(snapshot_name)} "
746
+ f"TO ACCOUNT {self.client._escape_identifier(to_account)}"
747
+ )
748
+ else:
749
+ sql = (
750
+ f"RESTORE ACCOUNT {self.client._escape_identifier(account_name)} "
751
+ f"FROM SNAPSHOT {self.client._escape_identifier(snapshot_name)}"
752
+ )
753
+
754
+ result = await self.client.execute(sql)
755
+ return result is not None
756
+ except Exception as e:
757
+ raise RestoreError(f"Failed to restore tenant '{account_name}' from snapshot '{snapshot_name}': {e}")
758
+
759
+ async def restore_database(
760
+ self,
761
+ snapshot_name: str,
762
+ account_name: str,
763
+ database_name: str,
764
+ to_account: Optional[str] = None,
765
+ ) -> bool:
766
+ """Restore database from snapshot asynchronously"""
767
+ try:
768
+ if to_account:
769
+ sql = (
770
+ f"RESTORE ACCOUNT {self.client._escape_identifier(account_name)} "
771
+ f"DATABASE {self.client._escape_identifier(database_name)} "
772
+ f"FROM SNAPSHOT {self.client._escape_identifier(snapshot_name)} "
773
+ f"TO ACCOUNT {self.client._escape_identifier(to_account)}"
774
+ )
775
+ else:
776
+ sql = (
777
+ f"RESTORE ACCOUNT {self.client._escape_identifier(account_name)} "
778
+ f"DATABASE {self.client._escape_identifier(database_name)} "
779
+ f"FROM SNAPSHOT {self.client._escape_identifier(snapshot_name)}"
780
+ )
781
+
782
+ result = await self.client.execute(sql)
783
+ return result is not None
784
+ except Exception as e:
785
+ raise RestoreError(f"Failed to restore database '{database_name}' from snapshot '{snapshot_name}': {e}")
786
+
787
+ async def restore_table(
788
+ self,
789
+ snapshot_name: str,
790
+ account_name: str,
791
+ database_name: str,
792
+ table_name: str,
793
+ to_account: Optional[str] = None,
794
+ ) -> bool:
795
+ """Restore table from snapshot asynchronously"""
796
+ try:
797
+ if to_account:
798
+ sql = (
799
+ f"RESTORE ACCOUNT {self.client._escape_identifier(account_name)} "
800
+ f"DATABASE {self.client._escape_identifier(database_name)} "
801
+ f"TABLE {self.client._escape_identifier(table_name)} "
802
+ f"FROM SNAPSHOT {self.client._escape_identifier(snapshot_name)} "
803
+ f"TO ACCOUNT {self.client._escape_identifier(to_account)}"
804
+ )
805
+ else:
806
+ sql = (
807
+ f"RESTORE ACCOUNT {self.client._escape_identifier(account_name)} "
808
+ f"DATABASE {self.client._escape_identifier(database_name)} "
809
+ f"TABLE {self.client._escape_identifier(table_name)} "
810
+ f"FROM SNAPSHOT {self.client._escape_identifier(snapshot_name)}"
811
+ )
812
+
813
+ result = await self.client.execute(sql)
814
+ return result is not None
815
+ except Exception as e:
816
+ raise RestoreError(f"Failed to restore table '{table_name}' from snapshot '{snapshot_name}': {e}")
817
+
818
+
819
+ class AsyncMoCtlManager:
820
+ """Async mo_ctl manager"""
821
+
822
+ def __init__(self, client):
823
+ self.client = client
824
+
825
+ async def _execute_moctl(self, method: str, target: str, params: str = "") -> Dict[str, Any]:
826
+ """Execute mo_ctl command asynchronously"""
827
+ import json
828
+
829
+ try:
830
+ # Build mo_ctl SQL command
831
+ if params:
832
+ sql = f"SELECT mo_ctl('{method}', '{target}', '{params}')"
833
+ else:
834
+ sql = f"SELECT mo_ctl('{method}', '{target}', '')"
835
+
836
+ # Execute the command
837
+ result = await self.client.execute(sql)
838
+
839
+ if not result.rows:
840
+ raise MoCtlError(f"mo_ctl command returned no results: {sql}")
841
+
842
+ # Parse the JSON result
843
+ result_str = result.rows[0][0]
844
+ parsed_result = json.loads(result_str)
845
+
846
+ # Check for errors in the result
847
+ if "result" in parsed_result and parsed_result["result"]:
848
+ first_result = parsed_result["result"][0]
849
+ if "returnStr" in first_result and first_result["returnStr"] != "OK":
850
+ raise MoCtlError(f"mo_ctl operation failed: {first_result['returnStr']}")
851
+
852
+ return parsed_result
853
+
854
+ except json.JSONDecodeError as e:
855
+ raise MoCtlError(f"Failed to parse mo_ctl result: {e}")
856
+ except Exception as e:
857
+ raise MoCtlError(f"mo_ctl operation failed: {e}")
858
+
859
+ async def flush_table(self, database: str, table: str) -> Dict[str, Any]:
860
+ """Force flush table asynchronously"""
861
+ table_ref = f"{database}.{table}"
862
+ return await self._execute_moctl("dn", "flush", table_ref)
863
+
864
+ async def increment_checkpoint(self) -> Dict[str, Any]:
865
+ """Force incremental checkpoint asynchronously"""
866
+ return await self._execute_moctl("dn", "checkpoint", "")
867
+
868
+ async def global_checkpoint(self) -> Dict[str, Any]:
869
+ """Force global checkpoint asynchronously"""
870
+ return await self._execute_moctl("dn", "globalcheckpoint", "")
871
+
872
+
873
+ class AsyncAccountManager:
874
+ """Async manager for MatrixOne account operations"""
875
+
876
+ def __init__(self, client):
877
+ self.client = client
878
+
879
+ async def create_account(
880
+ self,
881
+ account_name: str,
882
+ admin_name: str,
883
+ password: str,
884
+ comment: Optional[str] = None,
885
+ admin_comment: Optional[str] = None,
886
+ admin_host: str = "%",
887
+ admin_identified_by: Optional[str] = None,
888
+ ) -> Account:
889
+ """Create a new account asynchronously"""
890
+ try:
891
+ sql_parts = [f"CREATE ACCOUNT {self.client._escape_identifier(account_name)}"]
892
+ sql_parts.append(f"ADMIN_NAME {self.client._escape_string(admin_name)}")
893
+ sql_parts.append(f"IDENTIFIED BY {self.client._escape_string(password)}")
894
+
895
+ if admin_host != "%":
896
+ sql_parts.append(f"ADMIN_HOST {self.client._escape_string(admin_host)}")
897
+
898
+ if comment:
899
+ sql_parts.append(f"COMMENT {self.client._escape_string(comment)}")
900
+
901
+ if admin_comment:
902
+ sql_parts.append(f"ADMIN_COMMENT {self.client._escape_string(admin_comment)}")
903
+
904
+ if admin_identified_by:
905
+ sql_parts.append(f"ADMIN_IDENTIFIED BY {self.client._escape_string(admin_identified_by)}")
906
+
907
+ sql = " ".join(sql_parts)
908
+ await self.client.execute(sql)
909
+
910
+ return await self.get_account(account_name)
911
+
912
+ except Exception as e:
913
+ raise AccountError(f"Failed to create account '{account_name}': {e}")
914
+
915
+ async def drop_account(self, account_name: str) -> None:
916
+ """Drop an account asynchronously"""
917
+ try:
918
+ sql = f"DROP ACCOUNT {self.client._escape_identifier(account_name)}"
919
+ await self.client.execute(sql)
920
+ except Exception as e:
921
+ raise AccountError(f"Failed to drop account '{account_name}': {e}")
922
+
923
+ async def alter_account(
924
+ self,
925
+ account_name: str,
926
+ comment: Optional[str] = None,
927
+ suspend: Optional[bool] = None,
928
+ suspend_reason: Optional[str] = None,
929
+ ) -> Account:
930
+ """Alter an account asynchronously"""
931
+ try:
932
+ sql_parts = [f"ALTER ACCOUNT {self.client._escape_identifier(account_name)}"]
933
+
934
+ if comment is not None:
935
+ sql_parts.append(f"COMMENT {self.client._escape_string(comment)}")
936
+
937
+ if suspend is not None:
938
+ if suspend:
939
+ if suspend_reason:
940
+ sql_parts.append(f"SUSPEND COMMENT {self.client._escape_string(suspend_reason)}")
941
+ else:
942
+ sql_parts.append("SUSPEND")
943
+ else:
944
+ sql_parts.append("OPEN")
945
+
946
+ sql = " ".join(sql_parts)
947
+ await self.client.execute(sql)
948
+
949
+ return await self.get_account(account_name)
950
+
951
+ except Exception as e:
952
+ raise AccountError(f"Failed to alter account '{account_name}': {e}")
953
+
954
+ async def get_account(self, account_name: str) -> Account:
955
+ """Get account by name asynchronously"""
956
+ try:
957
+ sql = "SHOW ACCOUNTS"
958
+ result = await self.client.execute(sql)
959
+
960
+ if not result or not result.rows:
961
+ raise AccountError(f"Account '{account_name}' not found")
962
+
963
+ for row in result.rows:
964
+ if row[0] == account_name:
965
+ return self._row_to_account(row)
966
+
967
+ raise AccountError(f"Account '{account_name}' not found")
968
+
969
+ except Exception as e:
970
+ raise AccountError(f"Failed to get account '{account_name}': {e}")
971
+
972
+ async def list_accounts(self) -> List[Account]:
973
+ """List all accounts asynchronously"""
974
+ try:
975
+ sql = "SHOW ACCOUNTS"
976
+ result = await self.client.execute(sql)
977
+
978
+ if not result or not result.rows:
979
+ return []
980
+
981
+ return [self._row_to_account(row) for row in result.rows]
982
+
983
+ except Exception as e:
984
+ raise AccountError(f"Failed to list accounts: {e}")
985
+
986
+ async def create_user(self, user_name: str, password: str, comment: Optional[str] = None) -> User:
987
+ """
988
+ Create a new user asynchronously according to MatrixOne CREATE USER syntax:
989
+ CREATE USER [IF NOT EXISTS] user auth_option [, user auth_option] ...
990
+ [DEFAULT ROLE rolename] [COMMENT 'comment_string' | ATTRIBUTE 'json_object']
991
+
992
+ Args::
993
+
994
+ user_name: Name of the user to create
995
+ password: Password for the user
996
+ comment: Comment for the user (not supported in MatrixOne)
997
+
998
+ Returns::
999
+
1000
+ User: Created user object
1001
+ """
1002
+ try:
1003
+ # Build CREATE USER statement according to MatrixOne syntax
1004
+ # MatrixOne syntax: CREATE USER user_name IDENTIFIED BY 'password'
1005
+ sql_parts = [f"CREATE USER {self.client._escape_identifier(user_name)}"]
1006
+
1007
+ sql_parts.append(f"IDENTIFIED BY {self.client._escape_string(password)}")
1008
+
1009
+ # Note: MatrixOne doesn't support COMMENT or ATTRIBUTE clauses in CREATE USER
1010
+ # if comment:
1011
+ # sql_parts.append(f"COMMENT {self.client._escape_string(comment)}")
1012
+ # if identified_by:
1013
+ # sql_parts.append(f"IDENTIFIED BY {self.client._escape_string(identified_by)}")
1014
+
1015
+ sql = " ".join(sql_parts)
1016
+ await self.client.execute(sql)
1017
+
1018
+ # Return a User object with current account context
1019
+ return User(
1020
+ name=user_name,
1021
+ host="%", # Default host
1022
+ account="sys", # Default account
1023
+ created_time=datetime.now(),
1024
+ status="ACTIVE",
1025
+ comment=comment,
1026
+ )
1027
+
1028
+ except Exception as e:
1029
+ raise AccountError(f"Failed to create user '{user_name}': {e}")
1030
+
1031
+ async def drop_user(self, user_name: str, if_exists: bool = False) -> None:
1032
+ """
1033
+ Drop a user asynchronously according to MatrixOne DROP USER syntax:
1034
+ DROP USER [IF EXISTS] user [, user] ...
1035
+
1036
+ Args::
1037
+
1038
+ user_name: Name of the user to drop
1039
+ if_exists: If True, add IF EXISTS clause to avoid errors when user doesn't exist
1040
+ """
1041
+ try:
1042
+ sql_parts = ["DROP USER"]
1043
+ if if_exists:
1044
+ sql_parts.append("IF EXISTS")
1045
+
1046
+ sql_parts.append(self.client._escape_identifier(user_name))
1047
+ sql = " ".join(sql_parts)
1048
+ await self.client.execute(sql)
1049
+
1050
+ except Exception as e:
1051
+ raise AccountError(f"Failed to drop user '{user_name}': {e}")
1052
+
1053
+ async def alter_user(
1054
+ self,
1055
+ user_name: str,
1056
+ password: Optional[str] = None,
1057
+ comment: Optional[str] = None,
1058
+ lock: Optional[bool] = None,
1059
+ lock_reason: Optional[str] = None,
1060
+ ) -> User:
1061
+ """Alter a user asynchronously"""
1062
+ try:
1063
+ sql_parts = [f"ALTER USER {self.client._escape_identifier(user_name)}"]
1064
+
1065
+ if password is not None:
1066
+ sql_parts.append(f"IDENTIFIED BY {self.client._escape_string(password)}")
1067
+
1068
+ if comment is not None:
1069
+ sql_parts.append(f"COMMENT {self.client._escape_string(comment)}")
1070
+
1071
+ if lock is not None:
1072
+ if lock:
1073
+ if lock_reason:
1074
+ sql_parts.append(f"ACCOUNT LOCK COMMENT {self.client._escape_string(lock_reason)}")
1075
+ else:
1076
+ sql_parts.append("ACCOUNT LOCK")
1077
+ else:
1078
+ sql_parts.append("ACCOUNT UNLOCK")
1079
+
1080
+ sql = " ".join(sql_parts)
1081
+ await self.client.execute(sql)
1082
+
1083
+ return await self.get_user(user_name)
1084
+
1085
+ except Exception as e:
1086
+ raise AccountError(f"Failed to alter user '{user_name}': {e}")
1087
+
1088
+ async def get_user(self, user_name: str) -> User:
1089
+ """Get user by name asynchronously"""
1090
+ try:
1091
+ sql = "SHOW GRANTS"
1092
+ result = await self.client.execute(sql)
1093
+
1094
+ if not result or not result.rows:
1095
+ raise AccountError(f"User '{user_name}' not found")
1096
+
1097
+ for row in result.rows:
1098
+ if row[0] == user_name:
1099
+ return self._row_to_user(row)
1100
+
1101
+ raise AccountError(f"User '{user_name}' not found")
1102
+
1103
+ except Exception as e:
1104
+ raise AccountError(f"Failed to get user '{user_name}': {e}")
1105
+
1106
+ async def list_users(self, account_name: Optional[str] = None) -> List[User]:
1107
+ """List users with optional account filter asynchronously"""
1108
+ try:
1109
+ sql = "SHOW GRANTS"
1110
+ result = await self.client.execute(sql)
1111
+
1112
+ if not result or not result.rows:
1113
+ return []
1114
+
1115
+ users = [self._row_to_user(row) for row in result.rows]
1116
+
1117
+ if account_name:
1118
+ users = [user for user in users if user.account == account_name]
1119
+
1120
+ return users
1121
+
1122
+ except Exception as e:
1123
+ raise AccountError(f"Failed to list users: {e}")
1124
+
1125
+ def _row_to_account(self, row: tuple) -> Account:
1126
+ """Convert database row to Account object"""
1127
+ return Account(
1128
+ name=row[0],
1129
+ admin_name=row[1],
1130
+ created_time=row[2] if len(row) > 2 else None,
1131
+ status=row[3] if len(row) > 3 else None,
1132
+ comment=row[4] if len(row) > 4 else None,
1133
+ suspended_time=row[5] if len(row) > 5 else None,
1134
+ suspended_reason=row[6] if len(row) > 6 else None,
1135
+ )
1136
+
1137
+ def _row_to_user(self, row: tuple) -> User:
1138
+ """Convert database row to User object"""
1139
+ return User(
1140
+ name=row[0],
1141
+ host=row[1],
1142
+ account=row[2],
1143
+ created_time=row[3] if len(row) > 3 else None,
1144
+ status=row[4] if len(row) > 4 else None,
1145
+ comment=row[5] if len(row) > 5 else None,
1146
+ locked_time=row[6] if len(row) > 6 else None,
1147
+ locked_reason=row[7] if len(row) > 7 else None,
1148
+ )
1149
+
1150
+
1151
+ class AsyncFulltextIndexManager:
1152
+ """Async fulltext index manager for client chain operations"""
1153
+
1154
+ def __init__(self, client: "AsyncClient"):
1155
+ """Initialize async fulltext index manager"""
1156
+ self.client = client
1157
+
1158
+ async def create(
1159
+ self, table_name_or_model, name: str, columns: Union[str, List[str]], algorithm: str = "TF-IDF"
1160
+ ) -> "AsyncFulltextIndexManager":
1161
+ """
1162
+ Create a fulltext index using chain operations.
1163
+
1164
+ Args::
1165
+
1166
+ table_name_or_model: Either a table name (str) or a SQLAlchemy model class
1167
+ name: Index name
1168
+ columns: Column(s) to index
1169
+ algorithm: Fulltext algorithm type (TF-IDF or BM25)
1170
+
1171
+ Returns::
1172
+
1173
+ AsyncFulltextIndexManager: Self for chaining
1174
+
1175
+ Example
1176
+
1177
+ # Create fulltext index by table name
1178
+ await client.fulltext_index.create("articles", name="idx_content", columns=["title", "content"])
1179
+
1180
+ # Create fulltext index by model class
1181
+ await client.fulltext_index.create(ArticleModel, name="idx_content", columns=["title", "content"])
1182
+ """
1183
+ from sqlalchemy import text
1184
+
1185
+ # Handle model class input
1186
+ if hasattr(table_name_or_model, '__tablename__'):
1187
+ table_name = table_name_or_model.__tablename__
1188
+ else:
1189
+ table_name = table_name_or_model
1190
+
1191
+ try:
1192
+ if isinstance(columns, str):
1193
+ columns = [columns]
1194
+
1195
+ columns_str = ", ".join(columns)
1196
+ sql = f"CREATE FULLTEXT INDEX {name} ON {table_name} ({columns_str})"
1197
+
1198
+ async with self.client.get_sqlalchemy_engine().begin() as conn:
1199
+ await conn.execute(text(sql))
1200
+
1201
+ return self
1202
+ except Exception as e:
1203
+ raise Exception(f"Failed to create fulltext index {name} on table {table_name}: {e}")
1204
+
1205
+ async def drop(self, table_name_or_model, name: str) -> "AsyncFulltextIndexManager":
1206
+ """
1207
+ Drop a fulltext index using chain operations.
1208
+
1209
+ Args::
1210
+
1211
+ table_name_or_model: Either a table name (str) or a SQLAlchemy model class
1212
+ name: Index name
1213
+
1214
+ Returns::
1215
+
1216
+ AsyncFulltextIndexManager: Self for chaining
1217
+
1218
+ Example
1219
+
1220
+ # Drop fulltext index by table name
1221
+ await client.fulltext_index.drop("articles", "idx_content")
1222
+
1223
+ # Drop fulltext index by model class
1224
+ await client.fulltext_index.drop(ArticleModel, "idx_content")
1225
+ """
1226
+ from sqlalchemy import text
1227
+
1228
+ # Handle model class input
1229
+ if hasattr(table_name_or_model, '__tablename__'):
1230
+ table_name = table_name_or_model.__tablename__
1231
+ else:
1232
+ table_name = table_name_or_model
1233
+
1234
+ try:
1235
+ sql = f"DROP INDEX {name} ON {table_name}"
1236
+
1237
+ async with self.client.get_sqlalchemy_engine().begin() as conn:
1238
+ await conn.execute(text(sql))
1239
+
1240
+ return self
1241
+ except Exception as e:
1242
+ raise Exception(f"Failed to drop fulltext index {name} from table {table_name}: {e}")
1243
+
1244
+ async def enable_fulltext(self) -> "AsyncFulltextIndexManager":
1245
+ """
1246
+ Enable fulltext indexing with chain operations.
1247
+
1248
+ Returns::
1249
+
1250
+ AsyncFulltextIndexManager: Self for chaining
1251
+ """
1252
+ try:
1253
+ await self.client.execute("SET experimental_fulltext_index = 1")
1254
+ return self
1255
+ except Exception as e:
1256
+ raise Exception(f"Failed to enable fulltext indexing: {e}")
1257
+
1258
+ async def disable_fulltext(self) -> "AsyncFulltextIndexManager":
1259
+ """
1260
+ Disable fulltext indexing with chain operations.
1261
+
1262
+ Returns::
1263
+
1264
+ AsyncFulltextIndexManager: Self for chaining
1265
+ """
1266
+ try:
1267
+ await self.client.execute("SET experimental_fulltext_index = 0")
1268
+ return self
1269
+ except Exception as e:
1270
+ raise Exception(f"Failed to disable fulltext indexing: {e}")
1271
+
1272
+
1273
+ class AsyncTransactionFulltextIndexManager(AsyncFulltextIndexManager):
1274
+ """Async fulltext index manager that executes operations within a transaction"""
1275
+
1276
+ def __init__(self, client: "AsyncClient", transaction_wrapper):
1277
+ """Initialize async transaction fulltext index manager"""
1278
+ super().__init__(client)
1279
+ self.transaction_wrapper = transaction_wrapper
1280
+
1281
+ async def create(
1282
+ self, table_name: str, name: str, columns: Union[str, List[str]], algorithm: str = "TF-IDF"
1283
+ ) -> "AsyncTransactionFulltextIndexManager":
1284
+ """Create a fulltext index within transaction"""
1285
+ try:
1286
+ if isinstance(columns, str):
1287
+ columns = [columns]
1288
+
1289
+ columns_str = ", ".join(columns)
1290
+ sql = f"CREATE FULLTEXT INDEX {name} ON {table_name} ({columns_str})"
1291
+
1292
+ await self.transaction_wrapper.execute(sql)
1293
+ return self
1294
+ except Exception as e:
1295
+ raise Exception(f"Failed to create fulltext index {name} on table {table_name} in transaction: {e}")
1296
+
1297
+ async def drop(self, table_name: str, name: str) -> "AsyncTransactionFulltextIndexManager":
1298
+ """Drop a fulltext index within transaction"""
1299
+ try:
1300
+ sql = f"DROP INDEX {name} ON {table_name}"
1301
+ await self.transaction_wrapper.execute(sql)
1302
+ return self
1303
+ except Exception as e:
1304
+ raise Exception(f"Failed to drop fulltext index {name} from table {table_name} in transaction: {e}")
1305
+
1306
+
1307
+ class AsyncClientExecutor(BaseMatrixOneExecutor):
1308
+ """Async client executor that uses AsyncClient's execute method"""
1309
+
1310
+ def __init__(self, client):
1311
+ super().__init__(client)
1312
+ self.client = client
1313
+
1314
+ async def _execute(self, sql: str):
1315
+ return await self.client.execute(sql)
1316
+
1317
+ def _get_empty_result(self):
1318
+ return AsyncResultSet([], [], affected_rows=0)
1319
+
1320
+ async def insert(self, table_name: str, data: dict):
1321
+ """Async insert method"""
1322
+ sql = self.base_client._build_insert_sql(table_name, data)
1323
+ return await self._execute(sql)
1324
+
1325
+ async def batch_insert(self, table_name: str, data_list: list):
1326
+ """Async batch insert method"""
1327
+ if not data_list:
1328
+ return self._get_empty_result()
1329
+
1330
+ sql = self.base_client._build_batch_insert_sql(table_name, data_list)
1331
+ return await self._execute(sql)
1332
+
1333
+
1334
+ class AsyncClient(BaseMatrixOneClient):
1335
+ """
1336
+ MatrixOne Async Client - Asynchronous interface for MatrixOne database operations.
1337
+
1338
+ This class provides a comprehensive asynchronous interface for connecting to and
1339
+ interacting with MatrixOne databases. It supports modern async/await patterns
1340
+ including table creation, data insertion, querying, vector operations, and
1341
+ transaction management.
1342
+
1343
+ Key Features:
1344
+
1345
+ - Asynchronous connection management with connection pooling
1346
+ - High-level table operations (create_table, drop_table, insert, batch_insert)
1347
+ - Query builder interface for complex async queries
1348
+ - Vector operations (similarity search, range search, indexing)
1349
+ - Async transaction management with context managers
1350
+ - Snapshot and restore operations
1351
+ - Account and user management
1352
+ - Fulltext search capabilities
1353
+ - Non-blocking I/O operations
1354
+
1355
+ Supported Operations:
1356
+
1357
+ - Async connection and disconnection
1358
+ - Async query execution (SELECT, INSERT, UPDATE, DELETE)
1359
+ - Async batch operations
1360
+ - Async transaction management
1361
+ - Async table creation and management
1362
+ - Async vector and fulltext operations
1363
+ - Async snapshot and restore operations
1364
+
1365
+ Usage Examples::
1366
+
1367
+ Basic async usage::
1368
+
1369
+ async def main():
1370
+ client = AsyncClient()
1371
+ await client.connect('localhost', 6001, 'root', '111', 'test')
1372
+
1373
+ # Create table using high-level API
1374
+ await client.create_table("users", {
1375
+ "id": "int primary key",
1376
+ "name": "varchar(100)",
1377
+ "email": "varchar(255)"
1378
+ })
1379
+
1380
+ # Insert data
1381
+ await client.insert("users", {"id": 1, "name": "John", "email": "john@example.com"})
1382
+
1383
+ # Query data
1384
+ result = await client.query("users").where("id = ?", 1).all()
1385
+ print(result.rows)
1386
+
1387
+ await client.disconnect()
1388
+
1389
+ Vector operations::
1390
+
1391
+ async def vector_example():
1392
+ client = AsyncClient()
1393
+ await client.connect('localhost', 6001, 'root', '111', 'test')
1394
+
1395
+ # Create vector table
1396
+ await client.create_table("documents", {
1397
+ "id": "int primary key",
1398
+ "content": "text",
1399
+ "embedding": "vecf32(384)"
1400
+ })
1401
+
1402
+ # Vector similarity search
1403
+ results = await client.vector_ops.similarity_search(
1404
+ "documents",
1405
+ vector_column="embedding",
1406
+ query_vector=[0.1, 0.2, 0.3, ...], # 384-dimensional vector
1407
+ limit=10,
1408
+ distance_type="l2"
1409
+ )
1410
+
1411
+ await client.disconnect()
1412
+
1413
+ Async transaction usage::
1414
+
1415
+ async def transaction_example():
1416
+ client = AsyncClient()
1417
+ await client.connect('localhost', 6001, 'root', '111', 'test')
1418
+
1419
+ async with client.transaction() as tx:
1420
+ await tx.execute("INSERT INTO users (name) VALUES (?)", ("John",))
1421
+ await tx.execute("INSERT INTO orders (user_id, amount) VALUES (?, ?)", (1, 100.0))
1422
+ # Transaction commits automatically on success
1423
+
1424
+ Note: This class requires asyncio and async database drivers. Use the synchronous Client class
1425
+ for blocking operations or when async support is not needed.
1426
+ """
1427
+
1428
+ def __init__(
1429
+ self,
1430
+ connection_timeout: int = 30,
1431
+ query_timeout: int = 300,
1432
+ auto_commit: bool = True,
1433
+ charset: str = "utf8mb4",
1434
+ logger: Optional[MatrixOneLogger] = None,
1435
+ sql_log_mode: str = "auto",
1436
+ slow_query_threshold: float = 1.0,
1437
+ max_sql_display_length: int = 500,
1438
+ ):
1439
+ """
1440
+ Initialize MatrixOne async client
1441
+
1442
+ Args::
1443
+
1444
+ connection_timeout: Connection timeout in seconds
1445
+ query_timeout: Query timeout in seconds
1446
+ auto_commit: Enable auto-commit mode
1447
+ charset: Character set for connection
1448
+ logger: Custom logger instance. If None, creates a default logger
1449
+ sql_log_mode: SQL logging mode ('off', 'auto', 'simple', 'full')
1450
+ - 'off': No SQL logging
1451
+ - 'auto': Smart logging - short SQL shown fully, long SQL summarized (default)
1452
+ - 'simple': Show operation summary only
1453
+ - 'full': Show complete SQL regardless of length
1454
+ slow_query_threshold: Threshold in seconds for slow query warnings (default: 1.0)
1455
+ max_sql_display_length: Maximum SQL length in auto mode before summarizing (default: 500)
1456
+ """
1457
+ self.connection_timeout = connection_timeout
1458
+ self.query_timeout = query_timeout
1459
+ self.auto_commit = auto_commit
1460
+ self.charset = charset
1461
+
1462
+ # Initialize logger
1463
+ if logger is not None:
1464
+ self.logger = logger
1465
+ else:
1466
+ self.logger = create_default_logger(
1467
+ sql_log_mode=sql_log_mode,
1468
+ slow_query_threshold=slow_query_threshold,
1469
+ max_sql_display_length=max_sql_display_length,
1470
+ )
1471
+
1472
+ # Connection management - using SQLAlchemy async engine instead of direct aiomysql connection
1473
+ self._engine = None
1474
+ self._connection = None # Keep for backward compatibility, but will be managed by engine
1475
+ self._connection_params = {}
1476
+ self._login_info = None
1477
+
1478
+ # Initialize managers
1479
+ self._snapshots = AsyncSnapshotManager(self)
1480
+ self._clone = AsyncCloneManager(self)
1481
+ self._moctl = AsyncMoCtlManager(self)
1482
+ self._restore = AsyncRestoreManager(self)
1483
+ self._pitr = AsyncPitrManager(self)
1484
+ self._pubsub = AsyncPubSubManager(self)
1485
+ self._account = AsyncAccountManager(self)
1486
+ self._fulltext_index = None
1487
+ self._metadata = None
1488
+
1489
+ async def connect(
1490
+ self,
1491
+ host: str,
1492
+ port: int,
1493
+ user: str,
1494
+ password: str,
1495
+ database: str = None,
1496
+ account: Optional[str] = None,
1497
+ role: Optional[str] = None,
1498
+ charset: str = "utf8mb4",
1499
+ connection_timeout: int = 30,
1500
+ auto_commit: bool = True,
1501
+ on_connect: Optional[Union[ConnectionHook, List[Union[ConnectionAction, str]], Callable]] = None,
1502
+ ):
1503
+ """
1504
+ Connect to MatrixOne database asynchronously
1505
+
1506
+ Args::
1507
+
1508
+ host: Database host
1509
+ port: Database port
1510
+ user: Username or login info in format "user", "account#user", or "account#user#role"
1511
+ password: Password
1512
+ database: Database name
1513
+ account: Optional account name (will be combined with user if user doesn't contain '#')
1514
+ role: Optional role name (will be combined with user if user doesn't contain '#')
1515
+ charset: Character set for the connection (default: utf8mb4)
1516
+ connection_timeout: Connection timeout in seconds (default: 30)
1517
+ auto_commit: Enable autocommit (default: True)
1518
+ on_connect: Connection hook to execute after successful connection.
1519
+ Can be:
1520
+ - ConnectionHook instance
1521
+ - List of ConnectionAction or string action names
1522
+ - Custom callback function (async or sync)
1523
+
1524
+ Examples::
1525
+
1526
+ # Enable all features after connection
1527
+ await client.connect(host, port, user, password, database,
1528
+ on_connect=[ConnectionAction.ENABLE_ALL])
1529
+
1530
+ # Enable only vector operations with custom charset
1531
+ await client.connect(host, port, user, password, database,
1532
+ charset="utf8mb4",
1533
+ on_connect=[ConnectionAction.ENABLE_VECTOR])
1534
+
1535
+ # Custom async callback
1536
+ async def my_callback(client):
1537
+ print(f"Connected to {client._connection_params['host']}")
1538
+
1539
+ await client.connect(host, port, user, password, database,
1540
+ on_connect=my_callback)
1541
+ """
1542
+ try:
1543
+ # Build final login info based on user parameter and optional account/role
1544
+ final_user, parsed_info = self._build_login_info(user, account, role)
1545
+
1546
+ # Store parsed info for later use
1547
+ self._login_info = parsed_info
1548
+
1549
+ # Store connection parameters for engine creation
1550
+ self._connection_params = {
1551
+ "host": host,
1552
+ "port": port,
1553
+ "user": final_user,
1554
+ "password": password,
1555
+ "database": database,
1556
+ "charset": charset,
1557
+ "autocommit": auto_commit,
1558
+ "connect_timeout": connection_timeout,
1559
+ }
1560
+
1561
+ # Create SQLAlchemy async engine instead of direct aiomysql connection
1562
+ self._engine = self._create_async_engine()
1563
+
1564
+ # Test the connection by executing a simple query
1565
+ async with self._engine.begin() as conn:
1566
+ await conn.execute(text("SELECT 1"))
1567
+
1568
+ # Initialize vector managers after successful connection
1569
+ self._initialize_vector_managers()
1570
+
1571
+ # Initialize metadata manager after successful connection
1572
+ self._metadata = AsyncMetadataManager(self)
1573
+
1574
+ self.logger.log_connection(host, port, final_user, database or "default", success=True)
1575
+
1576
+ # Setup connection hook if provided
1577
+ if on_connect:
1578
+ self._setup_connection_hook(on_connect)
1579
+ # Execute the hook once immediately for the initial connection
1580
+ await self._execute_connection_hook_immediately(on_connect)
1581
+
1582
+ except Exception as e:
1583
+ self.logger.log_connection(host, port, final_user, database or "default", success=False)
1584
+ self.logger.log_error(e, context="Async connection")
1585
+ raise ConnectionError(f"Failed to connect to MatrixOne: {e}")
1586
+
1587
+ def _setup_connection_hook(
1588
+ self, on_connect: Union[ConnectionHook, List[Union[ConnectionAction, str]], Callable]
1589
+ ) -> None:
1590
+ """Setup connection hook to be executed on each new connection"""
1591
+ try:
1592
+ if isinstance(on_connect, ConnectionHook):
1593
+ # Direct ConnectionHook instance
1594
+ hook = on_connect
1595
+ elif isinstance(on_connect, list):
1596
+ # List of actions - create a hook
1597
+ hook = create_connection_hook(actions=on_connect)
1598
+ elif callable(on_connect):
1599
+ # Custom callback function
1600
+ hook = create_connection_hook(custom_hook=on_connect)
1601
+ else:
1602
+ self.logger.warning(f"Invalid on_connect parameter type: {type(on_connect)}")
1603
+ return
1604
+
1605
+ # Set the client reference and attach to engine
1606
+ hook.set_client(self)
1607
+ hook.attach_to_engine(self._engine)
1608
+
1609
+ except Exception as e:
1610
+ self.logger.warning(f"Connection hook setup failed: {e}")
1611
+
1612
+ async def _execute_connection_hook_immediately(
1613
+ self, on_connect: Union[ConnectionHook, List[Union[ConnectionAction, str]], Callable]
1614
+ ) -> None:
1615
+ """Execute connection hook immediately for the initial connection"""
1616
+ try:
1617
+ if isinstance(on_connect, ConnectionHook):
1618
+ # Direct ConnectionHook instance
1619
+ hook = on_connect
1620
+ elif isinstance(on_connect, list):
1621
+ # List of actions - create a hook
1622
+ hook = create_connection_hook(actions=on_connect)
1623
+ elif callable(on_connect):
1624
+ # Custom callback function
1625
+ hook = create_connection_hook(custom_hook=on_connect)
1626
+ else:
1627
+ self.logger.warning(f"Invalid on_connect parameter type: {type(on_connect)}")
1628
+ return
1629
+
1630
+ # Execute the hook immediately
1631
+ await hook.execute_async(self)
1632
+
1633
+ except Exception as e:
1634
+ self.logger.warning(f"Immediate connection hook execution failed: {e}")
1635
+
1636
+ @classmethod
1637
+ def from_engine(cls, engine: AsyncEngine, **kwargs) -> "AsyncClient":
1638
+ """
1639
+ Create AsyncClient instance from existing SQLAlchemy AsyncEngine
1640
+
1641
+ Args::
1642
+
1643
+ engine: SQLAlchemy AsyncEngine instance (must use MySQL driver)
1644
+ **kwargs: Additional client configuration options
1645
+
1646
+ Returns::
1647
+
1648
+ AsyncClient: Configured async client instance
1649
+
1650
+ Raises::
1651
+
1652
+ ConnectionError: If engine doesn't use MySQL driver
1653
+
1654
+ Examples
1655
+
1656
+ Basic usage::
1657
+
1658
+ from sqlalchemy.ext.asyncio import create_async_engine
1659
+ from matrixone import AsyncClient
1660
+
1661
+ engine = create_async_engine("mysql+aiomysql://user:pass@host:port/db")
1662
+ client = AsyncClient.from_engine(engine)
1663
+
1664
+ With custom configuration::
1665
+
1666
+ engine = create_async_engine("mysql+aiomysql://user:pass@host:port/db")
1667
+ client = AsyncClient.from_engine(
1668
+ engine,
1669
+ sql_log_mode='auto',
1670
+ slow_query_threshold=0.5
1671
+ )
1672
+ """
1673
+ # Check if engine uses MySQL driver
1674
+ if not cls._is_mysql_async_engine(engine):
1675
+ raise ConnectionError(
1676
+ "MatrixOne AsyncClient only supports MySQL drivers. "
1677
+ "Please use mysql+aiomysql:// connection strings. "
1678
+ f"Current engine uses: {engine.dialect.name}"
1679
+ )
1680
+
1681
+ # Create client instance with default parameters
1682
+ client = cls(**kwargs)
1683
+
1684
+ # Set the provided engine
1685
+ client._engine = engine
1686
+
1687
+ # Initialize vector managers after engine is set
1688
+ client._initialize_vector_managers()
1689
+
1690
+ return client
1691
+
1692
+ @staticmethod
1693
+ def _is_mysql_async_engine(engine: AsyncEngine) -> bool:
1694
+ """
1695
+ Check if the async engine uses a MySQL driver
1696
+
1697
+ Args::
1698
+
1699
+ engine: SQLAlchemy AsyncEngine instance
1700
+
1701
+ Returns::
1702
+
1703
+ bool: True if engine uses MySQL driver, False otherwise
1704
+ """
1705
+ # Check dialect name
1706
+ dialect_name = engine.dialect.name.lower()
1707
+
1708
+ # Check if it's a MySQL dialect
1709
+ if dialect_name == "mysql":
1710
+ return True
1711
+
1712
+ # Check connection string for MySQL async drivers
1713
+ url = str(engine.url)
1714
+ mysql_async_drivers = [
1715
+ "mysql+aiomysql",
1716
+ "mysql+asyncmy",
1717
+ "mysql+aiopg", # Note: aiopg is PostgreSQL, but included for completeness
1718
+ ]
1719
+
1720
+ return any(driver in url.lower() for driver in mysql_async_drivers)
1721
+
1722
+ def _create_async_engine(self) -> AsyncEngine:
1723
+ """Create SQLAlchemy async engine with connection pooling"""
1724
+ if not create_async_engine:
1725
+ raise ConnectionError("SQLAlchemy async engine not available. Please install sqlalchemy[asyncio]")
1726
+
1727
+ # Build connection string for async engine
1728
+ connection_string = (
1729
+ f"mysql+aiomysql://{self._connection_params['user']}:"
1730
+ f"{self._connection_params['password']}@"
1731
+ f"{self._connection_params['host']}:"
1732
+ f"{self._connection_params['port']}/"
1733
+ f"{self._connection_params['database'] or ''}"
1734
+ f"?charset={self._connection_params['charset']}"
1735
+ )
1736
+
1737
+ # Create async engine with connection pooling
1738
+ engine = create_async_engine(
1739
+ connection_string,
1740
+ pool_size=5, # Smaller pool size for testing
1741
+ max_overflow=10, # Smaller max overflow
1742
+ pool_timeout=30, # Default pool timeout
1743
+ pool_recycle=3600, # Recycle connections after 1 hour
1744
+ pool_pre_ping=True, # Verify connections before use
1745
+ pool_reset_on_return="commit", # Reset connections on return
1746
+ echo=False, # Set to True for SQL logging
1747
+ )
1748
+
1749
+ return engine
1750
+
1751
+ def _initialize_vector_managers(self) -> None:
1752
+ """Initialize vector managers after successful connection"""
1753
+ try:
1754
+ from .async_vector_index_manager import AsyncVectorManager
1755
+
1756
+ self._vector = AsyncVectorManager(self)
1757
+ self._fulltext_index = AsyncFulltextIndexManager(self)
1758
+ except ImportError:
1759
+ # Vector managers not available
1760
+ self._vector = None
1761
+ self._fulltext_index = None
1762
+
1763
+ async def disconnect(self):
1764
+ """Disconnect from MatrixOne database asynchronously"""
1765
+ if self._engine:
1766
+ try:
1767
+ # First, try to close any active connections in the pool
1768
+ if hasattr(self._engine, "_pool") and self._engine._pool is not None:
1769
+ # Close all connections in the pool
1770
+ await self._engine._pool.close()
1771
+ # Wait for all connections to be properly closed
1772
+ await self._engine._pool.wait_closed()
1773
+
1774
+ # Then dispose of the engine - this closes all connections in the pool
1775
+ await self._engine.dispose()
1776
+
1777
+ # Force garbage collection to ensure connections are cleaned up
1778
+ import gc
1779
+
1780
+ gc.collect()
1781
+
1782
+ self.logger.log_disconnection(success=True)
1783
+ except Exception as e:
1784
+ self.logger.log_disconnection(success=False)
1785
+ self.logger.log_error(e, context="Async disconnection")
1786
+ finally:
1787
+ # Ensure all references are cleared
1788
+ self._engine = None
1789
+ self._connection = None
1790
+ # Clear any cached managers
1791
+ self._fulltext_index = None
1792
+
1793
+ def disconnect_sync(self):
1794
+ """Synchronous disconnect for cleanup when event loop is closed"""
1795
+ if self._engine:
1796
+ try:
1797
+ # Try to close the connection pool synchronously
1798
+ # SQLAlchemy AsyncEngine has a sync_engine property that can be disposed
1799
+ if hasattr(self._engine, "sync_engine") and self._engine.sync_engine is not None:
1800
+ # Close all connections in the pool synchronously
1801
+ self._engine.sync_engine.dispose()
1802
+ elif hasattr(self._engine, "_pool") and self._engine._pool is not None:
1803
+ # Direct access to connection pool for cleanup
1804
+ try:
1805
+ self._engine._pool.close()
1806
+ except Exception:
1807
+ pass
1808
+
1809
+ self.logger.log_disconnection(success=True)
1810
+ except Exception as e:
1811
+ self.logger.log_disconnection(success=False)
1812
+ self.logger.log_error(e, context="Sync disconnection")
1813
+ finally:
1814
+ # Ensure all references are cleared regardless of success/failure
1815
+ self._engine = None
1816
+ self._connection = None
1817
+ # Clear any cached managers
1818
+ self._fulltext_index = None
1819
+
1820
+ def __del__(self):
1821
+ """Cleanup when object is garbage collected"""
1822
+ # Don't try to cleanup in __del__ as it can cause issues with event loops
1823
+ # The fixture should handle proper cleanup
1824
+ pass
1825
+
1826
+ def get_sqlalchemy_engine(self) -> AsyncEngine:
1827
+ """
1828
+ Get SQLAlchemy async engine
1829
+
1830
+ Returns::
1831
+
1832
+ SQLAlchemy AsyncEngine
1833
+ """
1834
+ if not self._engine:
1835
+ raise ConnectionError("Not connected to database")
1836
+ return self._engine
1837
+
1838
+ async def create_all(self, base_class=None):
1839
+ """
1840
+ Create all tables defined in the given base class or default Base.
1841
+
1842
+ Args::
1843
+
1844
+ base_class: SQLAlchemy declarative base class. If None, uses the default Base.
1845
+ """
1846
+ if base_class is None:
1847
+ from matrixone.orm import declarative_base
1848
+
1849
+ base_class = declarative_base()
1850
+
1851
+ async with self._engine.begin() as conn:
1852
+ await conn.run_sync(base_class.metadata.create_all)
1853
+ return self
1854
+
1855
+ async def drop_all(self, base_class=None):
1856
+ """
1857
+ Drop all tables defined in the given base class or default Base.
1858
+
1859
+ Args::
1860
+
1861
+ base_class: SQLAlchemy declarative base class. If None, uses the default Base.
1862
+ """
1863
+ if base_class is None:
1864
+ from matrixone.orm import declarative_base
1865
+
1866
+ base_class = declarative_base()
1867
+
1868
+ # Get all table names from the metadata
1869
+ table_names = list(base_class.metadata.tables.keys())
1870
+
1871
+ # Drop each table individually using direct SQL for better compatibility
1872
+ for table_name in table_names:
1873
+ try:
1874
+ await self.execute(f"DROP TABLE IF EXISTS {table_name}")
1875
+ except Exception as e:
1876
+ # Log the error but continue with other tables
1877
+ print(f"Warning: Failed to drop table {table_name}: {e}")
1878
+
1879
+ return self
1880
+
1881
+ async def _execute_with_logging(
1882
+ self, connection, sql: str, context: str = "Async SQL execution", override_sql_log_mode: str = None
1883
+ ):
1884
+ """
1885
+ Execute SQL asynchronously with proper logging through the client's logger.
1886
+
1887
+ This is an internal helper method used by all SDK components to ensure
1888
+ consistent SQL logging across async vector operations, transactions, and other features.
1889
+
1890
+ Args::
1891
+
1892
+ connection: SQLAlchemy async connection object
1893
+ sql: SQL query string
1894
+ context: Context description for error logging (default: "Async SQL execution")
1895
+ override_sql_log_mode: Temporarily override sql_log_mode for this query only
1896
+
1897
+ Returns::
1898
+
1899
+ SQLAlchemy result object
1900
+
1901
+ Note:
1902
+
1903
+ This method is used internally by AsyncVectorManager, AsyncTransactionWrapper,
1904
+ and other SDK components. External users should use execute() instead.
1905
+ """
1906
+ import time
1907
+ from sqlalchemy import text
1908
+
1909
+ start_time = time.time()
1910
+ try:
1911
+ result = await connection.execute(text(sql))
1912
+ execution_time = time.time() - start_time
1913
+
1914
+ # Try to get row count if available
1915
+ try:
1916
+ if result.returns_rows:
1917
+ # For SELECT queries, we can't consume the result to count rows
1918
+ # So we just log without row count
1919
+ self.logger.log_query(
1920
+ sql, execution_time, None, success=True, override_sql_log_mode=override_sql_log_mode
1921
+ )
1922
+ else:
1923
+ # For DML queries (INSERT/UPDATE/DELETE), we can get rowcount
1924
+ self.logger.log_query(
1925
+ sql, execution_time, result.rowcount, success=True, override_sql_log_mode=override_sql_log_mode
1926
+ )
1927
+ except Exception:
1928
+ # Fallback: just log the query without row count
1929
+ self.logger.log_query(sql, execution_time, None, success=True, override_sql_log_mode=override_sql_log_mode)
1930
+
1931
+ return result
1932
+ except Exception as e:
1933
+ execution_time = time.time() - start_time
1934
+ self.logger.log_query(sql, execution_time, success=False, override_sql_log_mode=override_sql_log_mode)
1935
+ self.logger.log_error(e, context=context)
1936
+ raise
1937
+
1938
+ async def execute(self, sql: str, params: Optional[Tuple] = None) -> AsyncResultSet:
1939
+ """
1940
+ Execute SQL query asynchronously using SQLAlchemy async engine
1941
+
1942
+ Args::
1943
+
1944
+ sql: SQL query string
1945
+ params: Query parameters
1946
+
1947
+ Returns::
1948
+
1949
+ AsyncResultSet with query results
1950
+ """
1951
+ if not self._engine:
1952
+ raise ConnectionError("Not connected to database")
1953
+
1954
+ import time
1955
+
1956
+ start_time = time.time()
1957
+
1958
+ try:
1959
+ # Handle parameter substitution for MatrixOne compatibility
1960
+ final_sql = self._substitute_parameters(sql, params)
1961
+
1962
+ async with self._engine.begin() as conn:
1963
+ # Use exec_driver_sql() to bypass SQLAlchemy's bind parameter parsing
1964
+ # This prevents JSON strings like {"a":1} from being parsed as :1 bind params
1965
+ if hasattr(conn, 'exec_driver_sql'):
1966
+ # Escape % to %% for pymysql's format string handling
1967
+ escaped_sql = final_sql.replace('%', '%%')
1968
+ result = await conn.exec_driver_sql(escaped_sql)
1969
+ else:
1970
+ # Fallback for testing or older SQLAlchemy versions
1971
+ from sqlalchemy import text
1972
+
1973
+ result = await conn.execute(text(final_sql))
1974
+
1975
+ execution_time = time.time() - start_time
1976
+
1977
+ if result.returns_rows:
1978
+ rows = result.fetchall()
1979
+ columns = list(result.keys()) if hasattr(result, "keys") else []
1980
+ async_result = AsyncResultSet(columns, rows)
1981
+ self.logger.log_query(final_sql, execution_time, len(rows), success=True)
1982
+ return async_result
1983
+ else:
1984
+ async_result = AsyncResultSet([], [], affected_rows=result.rowcount)
1985
+ self.logger.log_query(final_sql, execution_time, result.rowcount, success=True)
1986
+ return async_result
1987
+
1988
+ except Exception as e:
1989
+ execution_time = time.time() - start_time
1990
+
1991
+ # Log error FIRST, before any error processing
1992
+ # Wrap in try-except to ensure logging failure doesn't hide the real error
1993
+ try:
1994
+ self.logger.log_query(final_sql, execution_time, success=False)
1995
+ self.logger.log_error(e, context="Async query execution")
1996
+ except Exception as log_err:
1997
+ # If logging fails, print to stderr as fallback but continue with error handling
1998
+ import sys
1999
+
2000
+ print(f"Warning: Error logging failed: {log_err}", file=sys.stderr)
2001
+
2002
+ # Extract user-friendly error message
2003
+ error_msg = str(e)
2004
+
2005
+ # Handle common database errors with helpful messages
2006
+ # Check for "does not exist" first before "syntax error"
2007
+ if (
2008
+ 'does not exist' in error_msg.lower()
2009
+ or 'no such table' in error_msg.lower()
2010
+ or 'doesn\'t exist' in error_msg.lower()
2011
+ ):
2012
+ # Table doesn't exist
2013
+ import re
2014
+
2015
+ match = re.search(r"(?:table|database)\s+[\"']?(\w+)[\"']?\s+does not exist", error_msg, re.IGNORECASE)
2016
+ if match:
2017
+ obj_name = match.group(1)
2018
+ raise QueryError(
2019
+ f"Table or database '{obj_name}' does not exist. "
2020
+ f"Create it first using client.create_table() or CREATE TABLE/DATABASE statement."
2021
+ ) from None
2022
+ else:
2023
+ raise QueryError(f"Object not found: {error_msg}") from None
2024
+
2025
+ elif 'already exists' in error_msg.lower() and '1050' in error_msg:
2026
+ # Table already exists
2027
+ import re
2028
+
2029
+ match = re.search(r"table\s+(\w+)\s+already\s+exists", error_msg, re.IGNORECASE)
2030
+ if match:
2031
+ table_name = match.group(1)
2032
+ raise QueryError(
2033
+ f"Table '{table_name}' already exists. "
2034
+ f"Use DROP TABLE {table_name} or client.drop_table() to remove it first."
2035
+ ) from None
2036
+ else:
2037
+ raise QueryError(f"Object already exists: {error_msg}") from None
2038
+
2039
+ elif 'duplicate' in error_msg.lower() and ('1062' in error_msg or '1061' in error_msg):
2040
+ # Duplicate key/entry
2041
+ raise QueryError(
2042
+ f"Duplicate entry error: {error_msg}. "
2043
+ f"Check for duplicate primary key or unique constraint violations."
2044
+ ) from None
2045
+
2046
+ elif 'syntax error' in error_msg.lower() or '1064' in error_msg:
2047
+ # SQL syntax error
2048
+ sql_preview = final_sql[:200] + '...' if len(final_sql) > 200 else final_sql
2049
+ raise QueryError(f"SQL syntax error: {error_msg}\n" f"Query: {sql_preview}") from None
2050
+
2051
+ elif 'column' in error_msg.lower() and ('unknown' in error_msg.lower() or 'not found' in error_msg.lower()):
2052
+ # Column doesn't exist
2053
+ raise QueryError(f"Column not found: {error_msg}. " f"Check your column names and table schema.") from None
2054
+
2055
+ elif 'cannot be null' in error_msg.lower() or '1048' in error_msg:
2056
+ # NULL constraint violation
2057
+ raise QueryError(
2058
+ f"NULL constraint violation: {error_msg}. " f"Some columns require non-NULL values."
2059
+ ) from None
2060
+
2061
+ elif 'not supported' in error_msg.lower() and '20105' in error_msg:
2062
+ # MatrixOne-specific: feature not supported
2063
+ raise QueryError(
2064
+ f"MatrixOne feature limitation: {error_msg}. "
2065
+ f"This feature may require additional configuration or is not yet supported."
2066
+ ) from None
2067
+
2068
+ elif 'bind parameter' in error_msg.lower() or 'InvalidRequestError' in error_msg:
2069
+ # SQLAlchemy bind parameter error
2070
+ raise QueryError(
2071
+ f"Parameter binding error: {error_msg}. "
2072
+ f"This might be caused by special characters in your data (colons in JSON, etc.)"
2073
+ ) from None
2074
+
2075
+ else:
2076
+ # Generic error - cleaner message without full SQLAlchemy stack
2077
+ raise QueryError(f"Query execution failed: {error_msg}") from None
2078
+
2079
+ def _substitute_parameters(self, sql: str, params: Optional[Tuple] = None) -> str:
2080
+ """
2081
+ Substitute ? placeholders with actual values since MatrixOne doesn't support prepared statements
2082
+
2083
+ Args::
2084
+
2085
+ sql: SQL query string with ? placeholders
2086
+ params: Tuple of parameter values
2087
+
2088
+ Returns::
2089
+
2090
+ SQL string with parameters substituted
2091
+ """
2092
+ if not params:
2093
+ return sql
2094
+
2095
+ final_sql = sql
2096
+ for param in params:
2097
+ if isinstance(param, str):
2098
+ # Escape single quotes in string values
2099
+ escaped_param = param.replace("'", "''")
2100
+ final_sql = final_sql.replace("?", f"'{escaped_param}'", 1)
2101
+ elif param is None:
2102
+ final_sql = final_sql.replace("?", "NULL", 1)
2103
+ else:
2104
+ final_sql = final_sql.replace("?", str(param), 1)
2105
+
2106
+ return final_sql
2107
+
2108
+ def _build_login_info(self, user: str, account: Optional[str] = None, role: Optional[str] = None) -> tuple[str, dict]:
2109
+ """
2110
+ Build final login info based on user parameter and optional account/role
2111
+
2112
+ Args::
2113
+
2114
+ user: Username or login info in format "user", "account#user", or "account#user#role"
2115
+ account: Optional account name
2116
+ role: Optional role name
2117
+
2118
+ Returns::
2119
+
2120
+ tuple: (final_user_string, parsed_info_dict)
2121
+
2122
+ Rules:
2123
+ 1. If user contains '#', it's already in format "account#user" or "account#user#role"
2124
+ - If account or role is also provided, raise error (conflict)
2125
+ 2. If user doesn't contain '#', combine with optional account/role:
2126
+ - No account/role: use user as-is
2127
+ - Only role: use "sys#user#role"
2128
+ - Only account: use "account#user"
2129
+ - Both: use "account#user#role"
2130
+ """
2131
+ # Check if user already contains login format
2132
+ if "#" in user:
2133
+ # User is already in format "account#user" or "account#user#role"
2134
+ if account is not None or role is not None:
2135
+ raise ValueError(
2136
+ f"Conflict: user parameter '{user}' already contains account/role info, "
2137
+ f"but account='{account}' and role='{role}' are also provided. "
2138
+ f"Use either user format or separate account/role parameters, not both."
2139
+ )
2140
+
2141
+ # Parse the existing format
2142
+ parts = user.split("#")
2143
+ if len(parts) == 2:
2144
+ # "account#user" format
2145
+ final_account, final_user, final_role = parts[0], parts[1], None
2146
+ elif len(parts) == 3:
2147
+ # "account#user#role" format
2148
+ final_account, final_user, final_role = parts[0], parts[1], parts[2]
2149
+ else:
2150
+ raise ValueError(f"Invalid user format: '{user}'. Expected 'user', 'account#user', or 'account#user#role'")
2151
+
2152
+ final_user_string = user
2153
+
2154
+ else:
2155
+ # User is just a username, combine with optional account/role
2156
+ if account is None and role is None:
2157
+ # No account/role provided, use user as-is
2158
+ final_account, final_user, final_role = "sys", user, None
2159
+ final_user_string = user
2160
+ elif account is None and role is not None:
2161
+ # Only role provided, use sys account
2162
+ final_account, final_user, final_role = "sys", user, role
2163
+ final_user_string = f"sys#{user}#{role}"
2164
+ elif account is not None and role is None:
2165
+ # Only account provided, no role
2166
+ final_account, final_user, final_role = account, user, None
2167
+ final_user_string = f"{account}#{user}"
2168
+ else:
2169
+ # Both account and role provided
2170
+ final_account, final_user, final_role = account, user, role
2171
+ final_user_string = f"{account}#{user}#{role}"
2172
+
2173
+ parsed_info = {"account": final_account, "user": final_user, "role": final_role}
2174
+
2175
+ return final_user_string, parsed_info
2176
+
2177
+ def get_login_info(self) -> Optional[dict]:
2178
+ """Get parsed login information"""
2179
+ return self._login_info
2180
+
2181
+ def _escape_identifier(self, identifier: str) -> str:
2182
+ """Escapes an identifier to prevent SQL injection."""
2183
+ return f"`{identifier}`"
2184
+
2185
+ def _escape_string(self, value: str) -> str:
2186
+ """Escapes a string value for SQL queries."""
2187
+ return f"'{value}'"
2188
+
2189
+ def query(self, *columns, snapshot: str = None):
2190
+ """Get async MatrixOne query builder - SQLAlchemy style
2191
+
2192
+ Args::
2193
+
2194
+ *columns: Can be:
2195
+ - Single model class: query(Article) - returns all columns from model
2196
+ - Multiple columns: query(Article.id, Article.title) - returns specific columns
2197
+ - Mixed: query(Article, Article.id, some_expression.label('alias')) - model + additional columns
2198
+ snapshot: Optional snapshot name for snapshot queries
2199
+
2200
+ Examples
2201
+
2202
+ # Traditional model query (all columns)
2203
+ await client.query(Article).filter(...).all()
2204
+
2205
+ # Column-specific query
2206
+ await client.query(Article.id, Article.title).filter(...).all()
2207
+
2208
+ # With fulltext score
2209
+ await client.query(Article.id, boolean_match("title", "content").must("python").label("score"))
2210
+
2211
+ # Snapshot query
2212
+ await client.query(Article, snapshot="my_snapshot").filter(...).all()
2213
+
2214
+ Returns::
2215
+
2216
+ AsyncMatrixOneQuery instance configured for the specified columns
2217
+ """
2218
+ from .async_orm import AsyncMatrixOneQuery
2219
+
2220
+ if len(columns) == 1:
2221
+ # Traditional single model class usage
2222
+ column = columns[0]
2223
+ if isinstance(column, str):
2224
+ # String table name
2225
+ return AsyncMatrixOneQuery(column, self, None, None, snapshot)
2226
+ elif hasattr(column, '__tablename__'):
2227
+ # This is a model class
2228
+ return AsyncMatrixOneQuery(column, self, None, None, snapshot)
2229
+ elif hasattr(column, 'name') and hasattr(column, 'as_sql'):
2230
+ # This is a CTE object
2231
+ from .orm import CTE
2232
+
2233
+ if isinstance(column, CTE):
2234
+ query = AsyncMatrixOneQuery(None, self, None, None, snapshot)
2235
+ query._table_name = column.name
2236
+ query._select_columns = ["*"] # Default to select all from CTE
2237
+ query._ctes = [column] # Add the CTE to the query
2238
+ return query
2239
+ else:
2240
+ # This is a single column/expression - need to handle specially
2241
+ # For now, we'll create a query that can handle column selections
2242
+ query = AsyncMatrixOneQuery(None, self, None, None, snapshot)
2243
+ query._select_columns = [column]
2244
+ # Try to infer table name from column
2245
+ if hasattr(column, 'table') and hasattr(column.table, 'name'):
2246
+ query._table_name = column.table.name
2247
+ return query
2248
+ else:
2249
+ # Multiple columns/expressions
2250
+ model_class = None
2251
+ select_columns = []
2252
+
2253
+ for column in columns:
2254
+ if hasattr(column, '__tablename__'):
2255
+ # This is a model class - use its table
2256
+ model_class = column
2257
+ else:
2258
+ # This is a column or expression
2259
+ select_columns.append(column)
2260
+
2261
+ if model_class:
2262
+ query = AsyncMatrixOneQuery(model_class, self, None, None, snapshot)
2263
+ if select_columns:
2264
+ # Add additional columns to the model's default columns
2265
+ query._select_columns = select_columns
2266
+ return query
2267
+ else:
2268
+ # No model class provided, need to infer table from columns
2269
+ query = AsyncMatrixOneQuery(None, self, None, None, snapshot)
2270
+ query._select_columns = select_columns
2271
+
2272
+ # Try to infer table name from first column that has table info
2273
+ for col in select_columns:
2274
+ if hasattr(col, 'table') and hasattr(col.table, 'name'):
2275
+ query._table_name = col.table.name
2276
+ break
2277
+ elif isinstance(col, str) and '.' in col:
2278
+ # String column like "table.column" - extract table name
2279
+ parts = col.split('.')
2280
+ if len(parts) >= 2:
2281
+ # For "db.table.column" format, use "db.table"
2282
+ # For "table.column" format, use "table"
2283
+ table_name = '.'.join(parts[:-1])
2284
+ query._table_name = table_name
2285
+ break
2286
+
2287
+ return query
2288
+
2289
+ @asynccontextmanager
2290
+ async def snapshot(self, snapshot_name: str):
2291
+ """
2292
+ Snapshot context manager
2293
+
2294
+ Usage
2295
+
2296
+ async with client.snapshot("daily_backup") as snapshot_client:
2297
+ result = await snapshot_client.execute("SELECT * FROM users")
2298
+ """
2299
+ if not self._engine:
2300
+ raise ConnectionError("Not connected to database")
2301
+
2302
+ # Create a snapshot client wrapper
2303
+ from .client import SnapshotClient
2304
+
2305
+ snapshot_client = SnapshotClient(self, snapshot_name)
2306
+ yield snapshot_client
2307
+
2308
+ async def insert(self, table_name_or_model, data: dict) -> "AsyncResultSet":
2309
+ """
2310
+ Insert data into a table asynchronously.
2311
+
2312
+ Args::
2313
+
2314
+ table_name_or_model: Either a table name (str) or a SQLAlchemy model class
2315
+ data: Data to insert (dict with column names as keys)
2316
+
2317
+ Returns::
2318
+
2319
+ AsyncResultSet object
2320
+ """
2321
+ # Handle model class input
2322
+ if hasattr(table_name_or_model, '__tablename__'):
2323
+ # It's a model class
2324
+ table_name = table_name_or_model.__tablename__
2325
+ else:
2326
+ # It's a table name string
2327
+ table_name = table_name_or_model
2328
+
2329
+ executor = AsyncClientExecutor(self)
2330
+ return await executor.insert(table_name, data)
2331
+
2332
+ async def batch_insert(self, table_name_or_model, data_list: list) -> "AsyncResultSet":
2333
+ """
2334
+ Batch insert data into a table asynchronously.
2335
+
2336
+ Args::
2337
+
2338
+ table_name_or_model: Either a table name (str) or a SQLAlchemy model class
2339
+ data_list: List of data dictionaries to insert
2340
+
2341
+ Returns::
2342
+
2343
+ AsyncResultSet object
2344
+ """
2345
+ # Handle model class input
2346
+ if hasattr(table_name_or_model, '__tablename__'):
2347
+ # It's a model class
2348
+ table_name = table_name_or_model.__tablename__
2349
+ else:
2350
+ # It's a table name string
2351
+ table_name = table_name_or_model
2352
+
2353
+ executor = AsyncClientExecutor(self)
2354
+ return await executor.batch_insert(table_name, data_list)
2355
+
2356
+ @asynccontextmanager
2357
+ async def transaction(self):
2358
+ """
2359
+ Async transaction context manager
2360
+
2361
+ Usage
2362
+
2363
+ async with client.transaction() as tx:
2364
+ await tx.execute("INSERT INTO users ...")
2365
+ await tx.execute("UPDATE users ...")
2366
+ # Snapshot and clone operations within transaction
2367
+ await tx.snapshots.create("snap1", "table", database="db1", table="t1")
2368
+ await tx.clone.clone_database("target_db", "source_db")
2369
+ """
2370
+ if not self._engine:
2371
+ raise ConnectionError("Not connected to database")
2372
+
2373
+ tx_wrapper = None
2374
+ try:
2375
+ # Use SQLAlchemy async engine for transaction
2376
+ async with self._engine.begin() as conn:
2377
+ tx_wrapper = AsyncTransactionWrapper(conn, self)
2378
+ yield tx_wrapper
2379
+
2380
+ except Exception as e:
2381
+ # Transaction will be automatically rolled back by SQLAlchemy
2382
+ raise e
2383
+ finally:
2384
+ # Clean up transaction wrapper
2385
+ if tx_wrapper:
2386
+ await tx_wrapper.close_sqlalchemy()
2387
+
2388
+ @property
2389
+ def snapshots(self) -> AsyncSnapshotManager:
2390
+ """Get async snapshot manager"""
2391
+ return self._snapshots
2392
+
2393
+ @property
2394
+ def clone(self) -> AsyncCloneManager:
2395
+ """Get async clone manager"""
2396
+ return self._clone
2397
+
2398
+ @property
2399
+ def moctl(self) -> AsyncMoCtlManager:
2400
+ """Get async mo_ctl manager"""
2401
+ return self._moctl
2402
+
2403
+ @property
2404
+ def restore(self) -> AsyncRestoreManager:
2405
+ """Get async restore manager"""
2406
+ return self._restore
2407
+
2408
+ @property
2409
+ def pitr(self) -> AsyncPitrManager:
2410
+ """Get async PITR manager"""
2411
+ return self._pitr
2412
+
2413
+ @property
2414
+ def pubsub(self) -> AsyncPubSubManager:
2415
+ """Get async publish-subscribe manager"""
2416
+ return self._pubsub
2417
+
2418
+ @property
2419
+ def account(self) -> AsyncAccountManager:
2420
+ """Get async account manager"""
2421
+ return self._account
2422
+
2423
+ def connected(self) -> bool:
2424
+ """Check if client is connected to database"""
2425
+ return self._engine is not None
2426
+
2427
+ @property
2428
+ def vector_ops(self):
2429
+ """Get unified vector operations manager for vector operations (index and data)"""
2430
+ return self._vector
2431
+
2432
+ def get_pinecone_index(self, table_name_or_model, vector_column: str):
2433
+ """
2434
+ Get a PineconeCompatibleIndex object for vector search operations.
2435
+
2436
+ This method creates a Pinecone-compatible vector search interface
2437
+ that automatically parses the table schema and vector index configuration.
2438
+ The primary key column is automatically detected, and all other columns
2439
+ except the vector column will be included as metadata.
2440
+
2441
+ Args::
2442
+
2443
+ table_name_or_model: Either a table name (str) or a SQLAlchemy model class
2444
+ vector_column: Name of the vector column
2445
+
2446
+ Returns::
2447
+
2448
+ PineconeCompatibleIndex object with Pinecone-compatible API
2449
+
2450
+ Example::
2451
+
2452
+ index = await client.get_pinecone_index("documents", "embedding")
2453
+ results = await index.query_async([0.1, 0.2, 0.3], top_k=5)
2454
+ for match in results.matches:
2455
+ print(f"ID: {match.id}, Score: {match.score}")
2456
+ """
2457
+ from .search_vector_index import PineconeCompatibleIndex
2458
+
2459
+ # Handle model class input
2460
+ if hasattr(table_name_or_model, '__tablename__'):
2461
+ table_name = table_name_or_model.__tablename__
2462
+ else:
2463
+ table_name = table_name_or_model
2464
+
2465
+ return PineconeCompatibleIndex(
2466
+ client=self,
2467
+ table_name=table_name,
2468
+ vector_column=vector_column,
2469
+ )
2470
+
2471
+ @property
2472
+ def fulltext_index(self):
2473
+ """Get fulltext index manager for fulltext index operations"""
2474
+ return self._fulltext_index
2475
+
2476
+ async def version(self) -> str:
2477
+ """
2478
+ Get MatrixOne server version asynchronously
2479
+
2480
+ Returns::
2481
+
2482
+ str: MatrixOne server version string
2483
+
2484
+ Raises::
2485
+
2486
+ ConnectionError: If not connected to MatrixOne
2487
+ QueryError: If version query fails
2488
+
2489
+ Example
2490
+
2491
+ >>> client = AsyncClient()
2492
+ >>> await client.connect('localhost', 6001, 'root', '111', 'test')
2493
+ >>> version = await client.version()
2494
+ >>> print(f"MatrixOne version: {version}")
2495
+ """
2496
+ if not self.connected():
2497
+ raise ConnectionError("Not connected to MatrixOne")
2498
+
2499
+ try:
2500
+ result = await self.execute("SELECT VERSION()")
2501
+ if result.rows:
2502
+ return result.rows[0][0]
2503
+ else:
2504
+ raise QueryError("Failed to get version information")
2505
+ except Exception as e:
2506
+ raise QueryError(f"Failed to get version: {e}")
2507
+
2508
+ async def git_version(self) -> str:
2509
+ """
2510
+ Get MatrixOne git version information asynchronously
2511
+
2512
+ Returns::
2513
+
2514
+ str: MatrixOne git version string
2515
+
2516
+ Raises::
2517
+
2518
+ ConnectionError: If not connected to MatrixOne
2519
+ QueryError: If git version query fails
2520
+
2521
+ Example
2522
+
2523
+ >>> client = AsyncClient()
2524
+ >>> await client.connect('localhost', 6001, 'root', '111', 'test')
2525
+ >>> git_version = await client.git_version()
2526
+ >>> print(f"MatrixOne git version: {git_version}")
2527
+ """
2528
+ if not self.connected():
2529
+ raise ConnectionError("Not connected to MatrixOne")
2530
+
2531
+ try:
2532
+ # Use MatrixOne's built-in git_version() function
2533
+ result = await self.execute("SELECT git_version()")
2534
+ if result.rows:
2535
+ return result.rows[0][0]
2536
+ else:
2537
+ raise QueryError("Failed to get git version information")
2538
+ except Exception as e:
2539
+ raise QueryError(f"Failed to get git version: {e}")
2540
+
2541
+ async def __aenter__(self):
2542
+ return self
2543
+
2544
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
2545
+ # Only disconnect if we're actually connected
2546
+ if self.connected():
2547
+ await self.disconnect()
2548
+
2549
+ async def create_table(self, table_name_or_model, columns: dict = None, **kwargs) -> "AsyncClient":
2550
+ """
2551
+ Create a table asynchronously
2552
+
2553
+ Args::
2554
+
2555
+ table_name_or_model: Either a table name (str) or a SQLAlchemy model class
2556
+ columns: Dictionary mapping column names to their definitions (required if table_name_or_model is str)
2557
+ **kwargs: Additional table creation options
2558
+
2559
+ Returns::
2560
+
2561
+ AsyncClient: Self for chaining
2562
+
2563
+ Example
2564
+
2565
+ >>> await client.create_table("users", {
2566
+ ... "id": "int primary key",
2567
+ ... "name": "varchar(100)",
2568
+ ... "email": "varchar(255)"
2569
+ ... })
2570
+ """
2571
+ if not self._engine:
2572
+ raise ConnectionError("Not connected to database")
2573
+
2574
+ # Handle model class input
2575
+ if hasattr(table_name_or_model, '__tablename__'):
2576
+ # It's a model class
2577
+ model_class = table_name_or_model
2578
+ table_name = model_class.__tablename__
2579
+ # Use SQLAlchemy metadata to create table
2580
+ async with self._engine.begin() as conn:
2581
+ await conn.run_sync(model_class.__table__.create)
2582
+ return self
2583
+
2584
+ # It's a table name string
2585
+ table_name = table_name_or_model
2586
+ if columns is None:
2587
+ raise ValueError("columns parameter is required when table_name_or_model is a string")
2588
+
2589
+ # Build CREATE TABLE SQL
2590
+ column_definitions = []
2591
+ for column_name, column_def in columns.items():
2592
+ # Ensure vecf32/vecf64 format is lowercase for consistency
2593
+ if column_def.lower().startswith("vecf32(") or column_def.lower().startswith("vecf64("):
2594
+ column_def = column_def.lower()
2595
+
2596
+ column_definitions.append(f"{column_name} {column_def}")
2597
+
2598
+ sql = f"CREATE TABLE {table_name} ({', '.join(column_definitions)})"
2599
+
2600
+ await self.execute(sql)
2601
+ return self
2602
+
2603
+ async def drop_table(self, table_name_or_model) -> "AsyncClient":
2604
+ """
2605
+ Drop a table asynchronously
2606
+
2607
+ Args::
2608
+
2609
+ table_name_or_model: Either a table name (str) or a SQLAlchemy model class
2610
+
2611
+ Returns::
2612
+
2613
+ AsyncClient: Self for chaining
2614
+
2615
+ Example
2616
+
2617
+ >>> await client.drop_table("users")
2618
+ """
2619
+ if not self._engine:
2620
+ raise ConnectionError("Not connected to database")
2621
+
2622
+ # Handle model class input
2623
+ if hasattr(table_name_or_model, '__tablename__'):
2624
+ # It's a model class
2625
+ table_name = table_name_or_model.__tablename__
2626
+ else:
2627
+ # It's a table name string
2628
+ table_name = table_name_or_model
2629
+
2630
+ sql = f"DROP TABLE IF EXISTS {table_name}"
2631
+ await self.execute(sql)
2632
+ return self
2633
+
2634
+ async def create_table_with_index(self, table_name: str, columns: dict, indexes: list = None, **kwargs) -> "AsyncClient":
2635
+ """
2636
+ Create a table with indexes asynchronously
2637
+
2638
+ Args::
2639
+
2640
+ table_name: Name of the table to create
2641
+ columns: Dictionary mapping column names to their definitions
2642
+ indexes: List of index definitions
2643
+ **kwargs: Additional table creation options
2644
+
2645
+ Returns::
2646
+
2647
+ AsyncClient: Self for chaining
2648
+
2649
+ Example
2650
+
2651
+ >>> await client.create_table_with_index("users", {
2652
+ ... "id": "int primary key",
2653
+ ... "name": "varchar(100)",
2654
+ ... "email": "varchar(255)"
2655
+ ... }, [
2656
+ ... {"name": "idx_name", "columns": ["name"]},
2657
+ ... {"name": "idx_email", "columns": ["email"], "unique": True}
2658
+ ... ])
2659
+ """
2660
+ if not self._engine:
2661
+ raise ConnectionError("Not connected to database")
2662
+
2663
+ # Build CREATE TABLE SQL
2664
+ column_definitions = []
2665
+ for column_name, column_def in columns.items():
2666
+ column_definitions.append(f"{column_name} {column_def}")
2667
+
2668
+ sql = f"CREATE TABLE {table_name} ({', '.join(column_definitions)})"
2669
+ await self.execute(sql)
2670
+
2671
+ # Create indexes if provided
2672
+ if indexes:
2673
+ for index_def in indexes:
2674
+ index_name = index_def["name"]
2675
+ index_columns = ", ".join(index_def["columns"])
2676
+ unique = "UNIQUE " if index_def.get("unique", False) else ""
2677
+ index_sql = f"CREATE {unique}INDEX {index_name} ON {table_name} ({index_columns})"
2678
+ await self.execute(index_sql)
2679
+
2680
+ return self
2681
+
2682
+ async def create_table_orm(self, table_name: str, *columns, **kwargs) -> "AsyncClient":
2683
+ """
2684
+ Create a table using SQLAlchemy ORM asynchronously
2685
+
2686
+ Args::
2687
+
2688
+ table_name: Name of the table to create
2689
+ *columns: SQLAlchemy column definitions
2690
+ **kwargs: Additional table creation options
2691
+
2692
+ Returns::
2693
+
2694
+ AsyncClient: Self for chaining
2695
+
2696
+ Example
2697
+
2698
+ >>> from sqlalchemy import Column, Integer, String
2699
+ >>> await client.create_table_orm("users",
2700
+ ... Column("id", Integer, primary_key=True),
2701
+ ... Column("name", String(100)),
2702
+ ... Column("email", String(255))
2703
+ ... )
2704
+ """
2705
+ if not self._engine:
2706
+ raise ConnectionError("Not connected to database")
2707
+
2708
+ from sqlalchemy import MetaData, Table
2709
+
2710
+ metadata = MetaData()
2711
+ Table(table_name, metadata, *columns)
2712
+
2713
+ # Create the table
2714
+ async with self._engine.begin() as conn:
2715
+ await conn.run_sync(metadata.create_all)
2716
+
2717
+ return self
2718
+
2719
+ @property
2720
+ def metadata(self) -> Optional["AsyncMetadataManager"]:
2721
+ """Get metadata manager for table metadata operations"""
2722
+ return self._metadata
2723
+
2724
+
2725
+ class AsyncTransactionVectorIndexManager(AsyncVectorManager):
2726
+ """Async transaction-aware vector index manager"""
2727
+
2728
+ def __init__(self, client, transaction_wrapper):
2729
+ super().__init__(client)
2730
+ self.transaction_wrapper = transaction_wrapper
2731
+
2732
+ async def execute(self, sql: str, params: Optional[Tuple] = None) -> AsyncResultSet:
2733
+ """Execute SQL within transaction"""
2734
+ return await self.transaction_wrapper.execute(sql, params)
2735
+
2736
+ async def get_ivf_stats(self, table_name_or_model, column_name: str = None) -> Dict[str, Any]:
2737
+ """
2738
+ Get IVF index statistics for a table within transaction.
2739
+
2740
+ Args::
2741
+
2742
+ table_name_or_model: Either a table name (str) or a SQLAlchemy model class
2743
+ column_name: Name of the vector column (optional, will be inferred if not provided)
2744
+
2745
+ Returns::
2746
+
2747
+ Dict containing IVF index statistics including:
2748
+ - index_tables: Dictionary mapping table types to table names
2749
+ - distribution: Dictionary containing bucket distribution data
2750
+ - database: Database name
2751
+ - table_name: Table name
2752
+ - column_name: Vector column name
2753
+
2754
+ Raises::
2755
+
2756
+ Exception: If IVF index is not found or if there are errors retrieving stats
2757
+
2758
+ Examples
2759
+
2760
+ # Get stats for a table with vector column within transaction
2761
+ async with client.transaction() as tx:
2762
+ stats = await tx.vector_ops.get_ivf_stats("my_table", "embedding")
2763
+ print(f"Index tables: {stats['index_tables']}")
2764
+ print(f"Distribution: {stats['distribution']}")
2765
+ """
2766
+ from sqlalchemy import text
2767
+
2768
+ # Handle model class input
2769
+ if hasattr(table_name_or_model, '__tablename__'):
2770
+ table_name = table_name_or_model.__tablename__
2771
+ else:
2772
+ table_name = table_name_or_model
2773
+
2774
+ # Get database name from connection params
2775
+ database = self.client._connection_params.get('database')
2776
+ if not database:
2777
+ raise Exception("No database connection found. Please connect to a database first.")
2778
+
2779
+ # If column_name is not provided, try to infer it
2780
+ if not column_name:
2781
+ # Query the table schema to find vector columns using transaction connection
2782
+ schema_sql = text(
2783
+ f"""
2784
+ SELECT column_name, data_type
2785
+ FROM information_schema.columns
2786
+ WHERE table_schema = '{database}'
2787
+ AND table_name = '{table_name}'
2788
+ AND (data_type LIKE '%VEC%' OR data_type LIKE '%vec%')
2789
+ """
2790
+ )
2791
+ result = await self.transaction_wrapper.execute(schema_sql)
2792
+ vector_columns = result.fetchall()
2793
+
2794
+ if not vector_columns:
2795
+ raise Exception(f"No vector columns found in table {table_name}")
2796
+ elif len(vector_columns) == 1:
2797
+ column_name = vector_columns[0][0]
2798
+ else:
2799
+ # Multiple vector columns found, raise error asking user to specify
2800
+ column_names = [col[0] for col in vector_columns]
2801
+ raise Exception(
2802
+ f"Multiple vector columns found in table {table_name}: {column_names}. "
2803
+ f"Please specify the column_name parameter."
2804
+ )
2805
+
2806
+ # Get connection from transaction wrapper
2807
+ connection = self.transaction_wrapper.connection
2808
+
2809
+ # Get IVF index table names
2810
+ index_tables = await self._get_ivf_index_table_names(database, table_name, column_name, connection)
2811
+
2812
+ if not index_tables:
2813
+ raise Exception(f"No IVF index found for table {table_name}, column {column_name}")
2814
+
2815
+ # Get the entries table name for distribution analysis
2816
+ entries_table = index_tables.get('entries')
2817
+ if not entries_table:
2818
+ raise Exception("No entries table found in IVF index")
2819
+
2820
+ # Get bucket distribution
2821
+ distribution = await self._get_ivf_buckets_distribution(database, entries_table, connection)
2822
+
2823
+ return {
2824
+ 'index_tables': index_tables,
2825
+ 'distribution': distribution,
2826
+ 'database': database,
2827
+ 'table_name': table_name,
2828
+ 'column_name': column_name,
2829
+ }
2830
+
2831
+
2832
+ class AsyncTransactionWrapper:
2833
+ """Async transaction wrapper for executing queries within a transaction"""
2834
+
2835
+ def __init__(self, connection, client):
2836
+ self.connection = connection
2837
+ self.client = client
2838
+ # Create snapshot, clone, restore, PITR, pubsub, account, and fulltext managers that use this transaction
2839
+ self.snapshots = AsyncTransactionSnapshotManager(client, self)
2840
+ self.clone = AsyncTransactionCloneManager(client, self)
2841
+ self.restore = AsyncTransactionRestoreManager(client, self)
2842
+ self.pitr = AsyncTransactionPitrManager(client, self)
2843
+ self.pubsub = AsyncTransactionPubSubManager(client, self)
2844
+ self.account = AsyncTransactionAccountManager(self)
2845
+ self.vector_ops = AsyncTransactionVectorIndexManager(client, self)
2846
+ self.fulltext_index = AsyncTransactionFulltextIndexManager(client, self)
2847
+ # SQLAlchemy integration
2848
+ self._sqlalchemy_session = None
2849
+ self._sqlalchemy_engine = None
2850
+
2851
+ async def execute(self, sql: str, params: Optional[Tuple] = None) -> AsyncResultSet:
2852
+ """Execute SQL within transaction asynchronously"""
2853
+ import time
2854
+
2855
+ start_time = time.time()
2856
+
2857
+ try:
2858
+ # Handle parameter substitution for MatrixOne compatibility
2859
+ final_sql = self.client._substitute_parameters(sql, params)
2860
+ # Use exec_driver_sql() to bypass SQLAlchemy's bind parameter parsing
2861
+ # This prevents JSON strings like {"a":1} from being parsed as :1 bind params
2862
+ if hasattr(self.connection, 'exec_driver_sql'):
2863
+ # Escape % to %% for pymysql's format string handling
2864
+ escaped_sql = final_sql.replace('%', '%%')
2865
+ result = await self.connection.exec_driver_sql(escaped_sql)
2866
+ else:
2867
+ # Fallback for testing or older SQLAlchemy versions
2868
+ from sqlalchemy import text
2869
+
2870
+ result = await self.connection.execute(text(final_sql))
2871
+ execution_time = time.time() - start_time
2872
+
2873
+ if result.returns_rows:
2874
+ rows = result.fetchall()
2875
+ columns = list(result.keys()) if hasattr(result, "keys") else []
2876
+ self.client.logger.log_query(sql, execution_time, len(rows), success=True)
2877
+ return AsyncResultSet(columns, rows)
2878
+ else:
2879
+ self.client.logger.log_query(sql, execution_time, result.rowcount, success=True)
2880
+ return AsyncResultSet([], [], affected_rows=result.rowcount)
2881
+
2882
+ except Exception as e:
2883
+ execution_time = time.time() - start_time
2884
+ self.client.logger.log_query(sql, execution_time, success=False)
2885
+ self.client.logger.log_error(e, context="Async transaction query execution")
2886
+ raise QueryError(f"Transaction query execution failed: {e}")
2887
+
2888
+ def get_connection(self):
2889
+ """
2890
+ Get the underlying SQLAlchemy async connection for direct use
2891
+
2892
+ Returns::
2893
+
2894
+ SQLAlchemy AsyncConnection instance bound to this transaction
2895
+ """
2896
+ return self.connection
2897
+
2898
+ async def get_sqlalchemy_session(self):
2899
+ """
2900
+ Get async SQLAlchemy session that uses the same transaction asynchronously
2901
+
2902
+ Returns::
2903
+
2904
+ Async SQLAlchemy Session instance bound to this transaction
2905
+ """
2906
+ if self._sqlalchemy_session is None:
2907
+ try:
2908
+ from sqlalchemy.ext.asyncio import (
2909
+ AsyncSession,
2910
+ async_sessionmaker,
2911
+ create_async_engine,
2912
+ )
2913
+ except ImportError:
2914
+ # Fallback for older SQLAlchemy versions
2915
+ from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
2916
+ from sqlalchemy.orm import sessionmaker as sync_sessionmaker
2917
+
2918
+ # Create a simple async sessionmaker equivalent
2919
+ def async_sessionmaker(bind=None, **kwargs):
2920
+ # Remove class_ from kwargs if present to avoid conflicts
2921
+ kwargs.pop('class_', None)
2922
+ return sync_sessionmaker(bind=bind, class_=AsyncSession, **kwargs)
2923
+
2924
+ # Create engine using the same connection parameters
2925
+ if not self.client._connection_params:
2926
+ raise ConnectionError("Not connected to database")
2927
+
2928
+ connection_string = (
2929
+ f"mysql+aiomysql://{self.client._connection_params['user']}:"
2930
+ f"{self.client._connection_params['password']}@"
2931
+ f"{self.client._connection_params['host']}:"
2932
+ f"{self.client._connection_params['port']}/"
2933
+ f"{self.client._connection_params['database']}"
2934
+ )
2935
+
2936
+ # Create async engine that will use the same connection
2937
+ self._sqlalchemy_engine = create_async_engine(connection_string, pool_pre_ping=True, pool_recycle=300)
2938
+
2939
+ # Create async session factory
2940
+ AsyncSessionLocal = async_sessionmaker(bind=self._sqlalchemy_engine, class_=AsyncSession, expire_on_commit=False)
2941
+ self._sqlalchemy_session = AsyncSessionLocal()
2942
+
2943
+ # Begin async SQLAlchemy transaction
2944
+ await self._sqlalchemy_session.begin()
2945
+
2946
+ return self._sqlalchemy_session
2947
+
2948
+ async def commit_sqlalchemy(self):
2949
+ """Commit async SQLAlchemy session asynchronously"""
2950
+ if self._sqlalchemy_session:
2951
+ await self._sqlalchemy_session.commit()
2952
+
2953
+ async def rollback_sqlalchemy(self):
2954
+ """Rollback async SQLAlchemy session asynchronously"""
2955
+ if self._sqlalchemy_session:
2956
+ await self._sqlalchemy_session.rollback()
2957
+
2958
+ async def close_sqlalchemy(self):
2959
+ """Close async SQLAlchemy session asynchronously"""
2960
+ if self._sqlalchemy_session:
2961
+ await self._sqlalchemy_session.close()
2962
+ self._sqlalchemy_session = None
2963
+ if self._sqlalchemy_engine:
2964
+ await self._sqlalchemy_engine.dispose()
2965
+ self._sqlalchemy_engine = None
2966
+
2967
+ async def insert(self, table_name_or_model, data: dict) -> AsyncResultSet:
2968
+ """
2969
+ Insert data into a table within transaction asynchronously.
2970
+
2971
+ Args::
2972
+
2973
+ table_name_or_model: Either a table name (str) or a SQLAlchemy model class
2974
+ data: Data to insert (dict with column names as keys)
2975
+
2976
+ Returns::
2977
+
2978
+ AsyncResultSet object
2979
+ """
2980
+ # Handle model class input
2981
+ if hasattr(table_name_or_model, '__tablename__'):
2982
+ # It's a model class
2983
+ table_name = table_name_or_model.__tablename__
2984
+ else:
2985
+ # It's a table name string
2986
+ table_name = table_name_or_model
2987
+
2988
+ sql = self.client._build_insert_sql(table_name, data)
2989
+ return await self.execute(sql)
2990
+
2991
+ async def batch_insert(self, table_name_or_model, data_list: list) -> AsyncResultSet:
2992
+ """
2993
+ Batch insert data into a table within transaction asynchronously.
2994
+
2995
+ Args::
2996
+
2997
+ table_name_or_model: Either a table name (str) or a SQLAlchemy model class
2998
+ data_list: List of data dictionaries to insert
2999
+
3000
+ Returns::
3001
+
3002
+ AsyncResultSet object
3003
+ """
3004
+ if not data_list:
3005
+ return AsyncResultSet([], [], affected_rows=0)
3006
+
3007
+ # Handle model class input
3008
+ if hasattr(table_name_or_model, '__tablename__'):
3009
+ # It's a model class
3010
+ table_name = table_name_or_model.__tablename__
3011
+ else:
3012
+ # It's a table name string
3013
+ table_name = table_name_or_model
3014
+
3015
+ sql = self.client._build_batch_insert_sql(table_name, data_list)
3016
+ return await self.execute(sql)
3017
+
3018
+ def query(self, *columns, snapshot: str = None):
3019
+ """Get async MatrixOne query builder within transaction - SQLAlchemy style
3020
+
3021
+ Args::
3022
+
3023
+ *columns: Can be:
3024
+ - Single model class: query(Article) - returns all columns from model
3025
+ - Multiple columns: query(Article.id, Article.title) - returns specific columns
3026
+ - Mixed: query(Article, Article.id, some_expression.label('alias')) - model + additional columns
3027
+ snapshot: Optional snapshot name for snapshot queries
3028
+
3029
+ Returns::
3030
+
3031
+ AsyncMatrixOneQuery instance configured for the specified columns within transaction
3032
+ """
3033
+ from .async_orm import AsyncMatrixOneQuery
3034
+
3035
+ if len(columns) == 1:
3036
+ # Traditional single model class usage
3037
+ column = columns[0]
3038
+ if isinstance(column, str):
3039
+ # String table name
3040
+ return AsyncMatrixOneQuery(column, self.client, None, transaction_wrapper=self, snapshot=snapshot)
3041
+ elif hasattr(column, '__tablename__'):
3042
+ # This is a model class
3043
+ return AsyncMatrixOneQuery(column, self.client, None, transaction_wrapper=self, snapshot=snapshot)
3044
+ elif hasattr(column, 'name') and hasattr(column, 'as_sql'):
3045
+ # This is a CTE object
3046
+ from .orm import CTE
3047
+
3048
+ if isinstance(column, CTE):
3049
+ query = AsyncMatrixOneQuery(None, self.client, None, transaction_wrapper=self, snapshot=snapshot)
3050
+ query._table_name = column.name
3051
+ query._select_columns = ["*"] # Default to select all from CTE
3052
+ query._ctes = [column] # Add the CTE to the query
3053
+ return query
3054
+ else:
3055
+ # This is a single column/expression - need to handle specially
3056
+ # For now, we'll create a query that can handle column selections
3057
+ query = AsyncMatrixOneQuery(None, self.client, None, transaction_wrapper=self, snapshot=snapshot)
3058
+ query._select_columns = [column]
3059
+ # Try to infer table name from column
3060
+ if hasattr(column, 'table') and hasattr(column.table, 'name'):
3061
+ query._table_name = column.table.name
3062
+ return query
3063
+ else:
3064
+ # Multiple columns/expressions
3065
+ model_class = None
3066
+ select_columns = []
3067
+
3068
+ for column in columns:
3069
+ if hasattr(column, '__tablename__'):
3070
+ # This is a model class - use its table
3071
+ model_class = column
3072
+ else:
3073
+ # This is a column or expression
3074
+ select_columns.append(column)
3075
+
3076
+ if model_class:
3077
+ query = AsyncMatrixOneQuery(model_class, self.client, None, transaction_wrapper=self, snapshot=snapshot)
3078
+ if select_columns:
3079
+ # Add additional columns to the model's default columns
3080
+ query._select_columns = select_columns
3081
+ return query
3082
+ else:
3083
+ # No model class provided, need to infer table from columns
3084
+ query = AsyncMatrixOneQuery(None, self.client, None, transaction_wrapper=self, snapshot=snapshot)
3085
+ query._select_columns = select_columns
3086
+
3087
+ # Try to infer table name from first column that has table info
3088
+ for col in select_columns:
3089
+ if hasattr(col, 'table') and hasattr(col.table, 'name'):
3090
+ query._table_name = col.table.name
3091
+ break
3092
+ elif isinstance(col, str) and '.' in col:
3093
+ # String column like "table.column" - extract table name
3094
+ parts = col.split('.')
3095
+ if len(parts) >= 2:
3096
+ # For "db.table.column" format, use "db.table"
3097
+ # For "table.column" format, use "table"
3098
+ table_name = '.'.join(parts[:-1])
3099
+ query._table_name = table_name
3100
+ break
3101
+
3102
+ return query
3103
+
3104
+
3105
+ class AsyncTransactionSnapshotManager(AsyncSnapshotManager):
3106
+ """Async snapshot manager that executes operations within a transaction"""
3107
+
3108
+ def __init__(self, client, transaction_wrapper):
3109
+ super().__init__(client)
3110
+ self.transaction_wrapper = transaction_wrapper
3111
+
3112
+ async def create(
3113
+ self,
3114
+ name: str,
3115
+ level: Union[str, SnapshotLevel],
3116
+ database: Optional[str] = None,
3117
+ table: Optional[str] = None,
3118
+ description: Optional[str] = None,
3119
+ ) -> Snapshot:
3120
+ """Create snapshot within transaction asynchronously"""
3121
+ return await super().create(name, level, database, table, description)
3122
+
3123
+ async def get(self, name: str) -> Snapshot:
3124
+ """Get snapshot within transaction asynchronously"""
3125
+ return await super().get(name)
3126
+
3127
+ async def delete(self, name: str) -> None:
3128
+ """Delete snapshot within transaction asynchronously"""
3129
+ return await super().delete(name)
3130
+
3131
+
3132
+ class AsyncTransactionPubSubManager(AsyncPubSubManager):
3133
+ """Async publish-subscribe manager for use within transactions"""
3134
+
3135
+ def __init__(self, client, transaction_wrapper):
3136
+ super().__init__(client)
3137
+ self.transaction_wrapper = transaction_wrapper
3138
+
3139
+ async def create_database_publication(self, name: str, database: str, account: str) -> Publication:
3140
+ """Create database publication within transaction asynchronously"""
3141
+ try:
3142
+ sql = (
3143
+ f"CREATE PUBLICATION {self.client._escape_identifier(name)} "
3144
+ f"DATABASE {self.client._escape_identifier(database)} "
3145
+ f"ACCOUNT {self.client._escape_identifier(account)}"
3146
+ )
3147
+
3148
+ result = await self.transaction_wrapper.execute(sql)
3149
+ if result is None:
3150
+ raise PubSubError(f"Failed to create database publication '{name}'")
3151
+
3152
+ return await self.get_publication(name)
3153
+
3154
+ except Exception as e:
3155
+ raise PubSubError(f"Failed to create database publication '{name}': {e}")
3156
+
3157
+ async def create_table_publication(self, name: str, database: str, table: str, account: str) -> Publication:
3158
+ """Create table publication within transaction asynchronously"""
3159
+ try:
3160
+ sql = (
3161
+ f"CREATE PUBLICATION {self.client._escape_identifier(name)} "
3162
+ f"DATABASE {self.client._escape_identifier(database)} "
3163
+ f"TABLE {self.client._escape_identifier(table)} "
3164
+ f"ACCOUNT {self.client._escape_identifier(account)}"
3165
+ )
3166
+
3167
+ result = await self.transaction_wrapper.execute(sql)
3168
+ if result is None:
3169
+ raise PubSubError(f"Failed to create table publication '{name}'")
3170
+
3171
+ return await self.get_publication(name)
3172
+
3173
+ except Exception as e:
3174
+ raise PubSubError(f"Failed to create table publication '{name}': {e}")
3175
+
3176
+ async def get_publication(self, name: str) -> Publication:
3177
+ """Get publication within transaction asynchronously"""
3178
+ try:
3179
+ sql = f"SHOW PUBLICATIONS WHERE pub_name = {self.client._escape_string(name)}"
3180
+ result = await self.transaction_wrapper.execute(sql)
3181
+
3182
+ if not result or not result.rows:
3183
+ raise PubSubError(f"Publication '{name}' not found")
3184
+
3185
+ row = result.rows[0]
3186
+ return self._row_to_publication(row)
3187
+
3188
+ except Exception as e:
3189
+ raise PubSubError(f"Failed to get publication '{name}': {e}")
3190
+
3191
+ async def list_publications(self, account: Optional[str] = None, database: Optional[str] = None) -> List[Publication]:
3192
+ """List publications within transaction asynchronously"""
3193
+ try:
3194
+ # SHOW PUBLICATIONS doesn't support WHERE clause, so we need to list all and filter
3195
+ sql = "SHOW PUBLICATIONS"
3196
+ result = await self.transaction_wrapper.execute(sql)
3197
+
3198
+ if not result or not result.rows:
3199
+ return []
3200
+
3201
+ publications = []
3202
+ for row in result.rows:
3203
+ pub = self._row_to_publication(row)
3204
+
3205
+ # Apply filters
3206
+ if account and account not in pub.sub_account:
3207
+ continue
3208
+ if database and pub.database != database:
3209
+ continue
3210
+
3211
+ publications.append(pub)
3212
+
3213
+ return publications
3214
+
3215
+ except Exception as e:
3216
+ raise PubSubError(f"Failed to list publications: {e}")
3217
+
3218
+ async def alter_publication(
3219
+ self,
3220
+ name: str,
3221
+ account: Optional[str] = None,
3222
+ database: Optional[str] = None,
3223
+ table: Optional[str] = None,
3224
+ ) -> Publication:
3225
+ """Alter publication within transaction asynchronously"""
3226
+ try:
3227
+ # Build ALTER PUBLICATION statement
3228
+ parts = [f"ALTER PUBLICATION {self.client._escape_identifier(name)}"]
3229
+
3230
+ if account:
3231
+ parts.append(f"ACCOUNT {self.client._escape_identifier(account)}")
3232
+ if database:
3233
+ parts.append(f"DATABASE {self.client._escape_identifier(database)}")
3234
+ if table:
3235
+ parts.append(f"TABLE {self.client._escape_identifier(table)}")
3236
+
3237
+ sql = " ".join(parts)
3238
+ result = await self.transaction_wrapper.execute(sql)
3239
+ if result is None:
3240
+ raise PubSubError(f"Failed to alter publication '{name}'")
3241
+
3242
+ return await self.get_publication(name)
3243
+
3244
+ except Exception as e:
3245
+ raise PubSubError(f"Failed to alter publication '{name}': {e}")
3246
+
3247
+ async def drop_publication(self, name: str) -> bool:
3248
+ """Drop publication within transaction asynchronously"""
3249
+ try:
3250
+ sql = f"DROP PUBLICATION {self.client._escape_identifier(name)}"
3251
+ result = await self.transaction_wrapper.execute(sql)
3252
+ return result is not None
3253
+
3254
+ except Exception as e:
3255
+ raise PubSubError(f"Failed to drop publication '{name}': {e}")
3256
+
3257
+ async def create_subscription(
3258
+ self, subscription_name: str, publication_name: str, publisher_account: str
3259
+ ) -> Subscription:
3260
+ """Create subscription within transaction asynchronously"""
3261
+ try:
3262
+ sql = (
3263
+ f"CREATE DATABASE {self.client._escape_identifier(subscription_name)} "
3264
+ f"FROM {self.client._escape_identifier(publisher_account)} "
3265
+ f"PUBLICATION {self.client._escape_identifier(publication_name)}"
3266
+ )
3267
+
3268
+ result = await self.transaction_wrapper.execute(sql)
3269
+ if result is None:
3270
+ raise PubSubError(f"Failed to create subscription '{subscription_name}'")
3271
+
3272
+ return await self.get_subscription(subscription_name)
3273
+
3274
+ except Exception as e:
3275
+ raise PubSubError(f"Failed to create subscription '{subscription_name}': {e}")
3276
+
3277
+ async def get_subscription(self, name: str) -> Subscription:
3278
+ """Get subscription within transaction asynchronously"""
3279
+ try:
3280
+ # SHOW SUBSCRIPTIONS doesn't support WHERE clause, so we need to list all and filter
3281
+ sql = "SHOW SUBSCRIPTIONS"
3282
+ result = await self.transaction_wrapper.execute(sql)
3283
+
3284
+ if not result or not result.rows:
3285
+ raise PubSubError(f"Subscription '{name}' not found")
3286
+
3287
+ # Find subscription with matching name
3288
+ for row in result.rows:
3289
+ if row[6] == name: # sub_name is in 7th column (index 6)
3290
+ return self._row_to_subscription(row)
3291
+
3292
+ raise PubSubError(f"Subscription '{name}' not found")
3293
+
3294
+ except Exception as e:
3295
+ raise PubSubError(f"Failed to get subscription '{name}': {e}")
3296
+
3297
+ async def list_subscriptions(
3298
+ self, pub_account: Optional[str] = None, pub_database: Optional[str] = None
3299
+ ) -> List[Subscription]:
3300
+ """List subscriptions within transaction asynchronously"""
3301
+ try:
3302
+ conditions = []
3303
+
3304
+ if pub_account:
3305
+ conditions.append(f"pub_account = {self.client._escape_string(pub_account)}")
3306
+ if pub_database:
3307
+ conditions.append(f"pub_database = {self.client._escape_string(pub_database)}")
3308
+
3309
+ if conditions:
3310
+ where_clause = " WHERE " + " AND ".join(conditions)
3311
+ else:
3312
+ where_clause = ""
3313
+
3314
+ sql = f"SHOW SUBSCRIPTIONS{where_clause}"
3315
+ result = await self.transaction_wrapper.execute(sql)
3316
+
3317
+ if not result or not result.rows:
3318
+ return []
3319
+
3320
+ return [self._row_to_subscription(row) for row in result.rows]
3321
+
3322
+ except Exception as e:
3323
+ raise PubSubError(f"Failed to list subscriptions: {e}")
3324
+
3325
+
3326
+ class AsyncTransactionPitrManager(AsyncPitrManager):
3327
+ """Async PITR manager for use within transactions"""
3328
+
3329
+ def __init__(self, client, transaction_wrapper):
3330
+ super().__init__(client)
3331
+ self.transaction_wrapper = transaction_wrapper
3332
+
3333
+ async def create_cluster_pitr(self, name: str, range_value: int = 1, range_unit: str = "d") -> Pitr:
3334
+ """Create cluster PITR within transaction asynchronously"""
3335
+ try:
3336
+ self._validate_range(range_value, range_unit)
3337
+
3338
+ sql = f"CREATE PITR {self.client._escape_identifier(name)} " f"FOR CLUSTER RANGE {range_value} '{range_unit}'"
3339
+
3340
+ result = await self.transaction_wrapper.execute(sql)
3341
+ if result is None:
3342
+ raise PitrError(f"Failed to create cluster PITR '{name}'")
3343
+
3344
+ return await self.get(name)
3345
+
3346
+ except Exception as e:
3347
+ raise PitrError(f"Failed to create cluster PITR '{name}': {e}")
3348
+
3349
+ async def create_account_pitr(
3350
+ self,
3351
+ name: str,
3352
+ account_name: Optional[str] = None,
3353
+ range_value: int = 1,
3354
+ range_unit: str = "d",
3355
+ ) -> Pitr:
3356
+ """Create account PITR within transaction asynchronously"""
3357
+ try:
3358
+ self._validate_range(range_value, range_unit)
3359
+
3360
+ if account_name:
3361
+ sql = (
3362
+ f"CREATE PITR {self.client._escape_identifier(name)} "
3363
+ f"FOR ACCOUNT {self.client._escape_identifier(account_name)} "
3364
+ f"RANGE {range_value} '{range_unit}'"
3365
+ )
3366
+ else:
3367
+ sql = (
3368
+ f"CREATE PITR {self.client._escape_identifier(name)} " f"FOR ACCOUNT RANGE {range_value} '{range_unit}'"
3369
+ )
3370
+
3371
+ result = await self.transaction_wrapper.execute(sql)
3372
+ if result is None:
3373
+ raise PitrError(f"Failed to create account PITR '{name}'")
3374
+
3375
+ return await self.get(name)
3376
+
3377
+ except Exception as e:
3378
+ raise PitrError(f"Failed to create account PITR '{name}': {e}")
3379
+
3380
+ async def create_database_pitr(self, name: str, database_name: str, range_value: int = 1, range_unit: str = "d") -> Pitr:
3381
+ """Create database PITR within transaction asynchronously"""
3382
+ try:
3383
+ self._validate_range(range_value, range_unit)
3384
+
3385
+ sql = (
3386
+ f"CREATE PITR {self.client._escape_identifier(name)} "
3387
+ f"FOR DATABASE {self.client._escape_identifier(database_name)} "
3388
+ f"RANGE {range_value} '{range_unit}'"
3389
+ )
3390
+
3391
+ result = await self.transaction_wrapper.execute(sql)
3392
+ if result is None:
3393
+ raise PitrError(f"Failed to create database PITR '{name}'")
3394
+
3395
+ return await self.get(name)
3396
+
3397
+ except Exception as e:
3398
+ raise PitrError(f"Failed to create database PITR '{name}': {e}")
3399
+
3400
+ async def create_table_pitr(
3401
+ self,
3402
+ name: str,
3403
+ database_name: str,
3404
+ table_name: str,
3405
+ range_value: int = 1,
3406
+ range_unit: str = "d",
3407
+ ) -> Pitr:
3408
+ """Create table PITR within transaction asynchronously"""
3409
+ try:
3410
+ self._validate_range(range_value, range_unit)
3411
+
3412
+ sql = (
3413
+ f"CREATE PITR {self.client._escape_identifier(name)} "
3414
+ f"FOR TABLE {self.client._escape_identifier(database_name)} "
3415
+ f"{self.client._escape_identifier(table_name)} "
3416
+ f"RANGE {range_value} '{range_unit}'"
3417
+ )
3418
+
3419
+ result = await self.transaction_wrapper.execute(sql)
3420
+ if result is None:
3421
+ raise PitrError(f"Failed to create table PITR '{name}'")
3422
+
3423
+ return await self.get(name)
3424
+
3425
+ except Exception as e:
3426
+ raise PitrError(f"Failed to create table PITR '{name}': {e}")
3427
+
3428
+ async def get(self, name: str) -> Pitr:
3429
+ """Get PITR within transaction asynchronously"""
3430
+ try:
3431
+ sql = f"SHOW PITR WHERE pitr_name = {self.client._escape_string(name)}"
3432
+ result = await self.transaction_wrapper.execute(sql)
3433
+
3434
+ if not result or not result.rows:
3435
+ raise PitrError(f"PITR '{name}' not found")
3436
+
3437
+ row = result.rows[0]
3438
+ return self._row_to_pitr(row)
3439
+
3440
+ except Exception as e:
3441
+ raise PitrError(f"Failed to get PITR '{name}': {e}")
3442
+
3443
+ async def list(
3444
+ self,
3445
+ level: Optional[str] = None,
3446
+ account_name: Optional[str] = None,
3447
+ database_name: Optional[str] = None,
3448
+ table_name: Optional[str] = None,
3449
+ ) -> List[Pitr]:
3450
+ """List PITRs within transaction asynchronously"""
3451
+ try:
3452
+ conditions = []
3453
+
3454
+ if level:
3455
+ conditions.append(f"pitr_level = {self.client._escape_string(level)}")
3456
+ if account_name:
3457
+ conditions.append(f"account_name = {self.client._escape_string(account_name)}")
3458
+ if database_name:
3459
+ conditions.append(f"database_name = {self.client._escape_string(database_name)}")
3460
+ if table_name:
3461
+ conditions.append(f"table_name = {self.client._escape_string(table_name)}")
3462
+
3463
+ if conditions:
3464
+ where_clause = " WHERE " + " AND ".join(conditions)
3465
+ else:
3466
+ where_clause = ""
3467
+
3468
+ sql = f"SHOW PITR{where_clause}"
3469
+ result = await self.transaction_wrapper.execute(sql)
3470
+
3471
+ if not result or not result.rows:
3472
+ return []
3473
+
3474
+ return [self._row_to_pitr(row) for row in result.rows]
3475
+
3476
+ except Exception as e:
3477
+ raise PitrError(f"Failed to list PITRs: {e}")
3478
+
3479
+ async def alter(self, name: str, range_value: int, range_unit: str) -> Pitr:
3480
+ """Alter PITR within transaction asynchronously"""
3481
+ try:
3482
+ self._validate_range(range_value, range_unit)
3483
+
3484
+ sql = f"ALTER PITR {self.client._escape_identifier(name)} " f"RANGE {range_value} '{range_unit}'"
3485
+
3486
+ result = await self.transaction_wrapper.execute(sql)
3487
+ if result is None:
3488
+ raise PitrError(f"Failed to alter PITR '{name}'")
3489
+
3490
+ return await self.get(name)
3491
+
3492
+ except Exception as e:
3493
+ raise PitrError(f"Failed to alter PITR '{name}': {e}")
3494
+
3495
+ async def delete(self, name: str) -> bool:
3496
+ """Delete PITR within transaction asynchronously"""
3497
+ try:
3498
+ sql = f"DROP PITR {self.client._escape_identifier(name)}"
3499
+ result = await self.transaction_wrapper.execute(sql)
3500
+ return result is not None
3501
+
3502
+ except Exception as e:
3503
+ raise PitrError(f"Failed to delete PITR '{name}': {e}")
3504
+
3505
+
3506
+ class AsyncTransactionRestoreManager(AsyncRestoreManager):
3507
+ """Async restore manager for use within transactions"""
3508
+
3509
+ def __init__(self, client, transaction_wrapper):
3510
+ super().__init__(client)
3511
+ self.transaction_wrapper = transaction_wrapper
3512
+
3513
+ async def restore_cluster(self, snapshot_name: str) -> bool:
3514
+ """Restore cluster within transaction asynchronously"""
3515
+ try:
3516
+ sql = f"RESTORE CLUSTER FROM SNAPSHOT {self.client._escape_identifier(snapshot_name)}"
3517
+ result = await self.transaction_wrapper.execute(sql)
3518
+ return result is not None
3519
+ except Exception as e:
3520
+ raise RestoreError(f"Failed to restore cluster from snapshot '{snapshot_name}': {e}")
3521
+
3522
+ async def restore_tenant(self, snapshot_name: str, account_name: str, to_account: Optional[str] = None) -> bool:
3523
+ """Restore tenant within transaction asynchronously"""
3524
+ try:
3525
+ if to_account:
3526
+ sql = (
3527
+ f"RESTORE ACCOUNT {self.client._escape_identifier(account_name)} "
3528
+ f"FROM SNAPSHOT {self.client._escape_identifier(snapshot_name)} "
3529
+ f"TO ACCOUNT {self.client._escape_identifier(to_account)}"
3530
+ )
3531
+ else:
3532
+ sql = (
3533
+ f"RESTORE ACCOUNT {self.client._escape_identifier(account_name)} "
3534
+ f"FROM SNAPSHOT {self.client._escape_identifier(snapshot_name)}"
3535
+ )
3536
+
3537
+ result = await self.transaction_wrapper.execute(sql)
3538
+ return result is not None
3539
+ except Exception as e:
3540
+ raise RestoreError(f"Failed to restore tenant '{account_name}' from snapshot '{snapshot_name}': {e}")
3541
+
3542
+ async def restore_database(
3543
+ self,
3544
+ snapshot_name: str,
3545
+ account_name: str,
3546
+ database_name: str,
3547
+ to_account: Optional[str] = None,
3548
+ ) -> bool:
3549
+ """Restore database within transaction asynchronously"""
3550
+ try:
3551
+ if to_account:
3552
+ sql = (
3553
+ f"RESTORE ACCOUNT {self.client._escape_identifier(account_name)} "
3554
+ f"DATABASE {self.client._escape_identifier(database_name)} "
3555
+ f"FROM SNAPSHOT {self.client._escape_identifier(snapshot_name)} "
3556
+ f"TO ACCOUNT {self.client._escape_identifier(to_account)}"
3557
+ )
3558
+ else:
3559
+ sql = (
3560
+ f"RESTORE ACCOUNT {self.client._escape_identifier(account_name)} "
3561
+ f"DATABASE {self.client._escape_identifier(database_name)} "
3562
+ f"FROM SNAPSHOT {self.client._escape_identifier(snapshot_name)}"
3563
+ )
3564
+
3565
+ result = await self.transaction_wrapper.execute(sql)
3566
+ return result is not None
3567
+ except Exception as e:
3568
+ raise RestoreError(f"Failed to restore database '{database_name}' from snapshot '{snapshot_name}': {e}")
3569
+
3570
+ async def restore_table(
3571
+ self,
3572
+ snapshot_name: str,
3573
+ account_name: str,
3574
+ database_name: str,
3575
+ table_name: str,
3576
+ to_account: Optional[str] = None,
3577
+ ) -> bool:
3578
+ """Restore table within transaction asynchronously"""
3579
+ try:
3580
+ if to_account:
3581
+ sql = (
3582
+ f"RESTORE ACCOUNT {self.client._escape_identifier(account_name)} "
3583
+ f"DATABASE {self.client._escape_identifier(database_name)} "
3584
+ f"TABLE {self.client._escape_identifier(table_name)} "
3585
+ f"FROM SNAPSHOT {self.client._escape_identifier(snapshot_name)} "
3586
+ f"TO ACCOUNT {self.client._escape_identifier(to_account)}"
3587
+ )
3588
+ else:
3589
+ sql = (
3590
+ f"RESTORE ACCOUNT {self.client._escape_identifier(account_name)} "
3591
+ f"DATABASE {self.client._escape_identifier(database_name)} "
3592
+ f"TABLE {self.client._escape_identifier(table_name)} "
3593
+ f"FROM SNAPSHOT {self.client._escape_identifier(snapshot_name)}"
3594
+ )
3595
+
3596
+ result = await self.transaction_wrapper.execute(sql)
3597
+ return result is not None
3598
+ except Exception as e:
3599
+ raise RestoreError(f"Failed to restore table '{table_name}' from snapshot '{snapshot_name}': {e}")
3600
+
3601
+
3602
+ class AsyncTransactionCloneManager(AsyncCloneManager):
3603
+ """Async clone manager that executes operations within a transaction"""
3604
+
3605
+ def __init__(self, client, transaction_wrapper):
3606
+ super().__init__(client)
3607
+ self.transaction_wrapper = transaction_wrapper
3608
+
3609
+ async def clone_database(
3610
+ self,
3611
+ target_db: str,
3612
+ source_db: str,
3613
+ snapshot_name: Optional[str] = None,
3614
+ if_not_exists: bool = False,
3615
+ ) -> None:
3616
+ """Clone database within transaction asynchronously"""
3617
+ return await super().clone_database(target_db, source_db, snapshot_name, if_not_exists)
3618
+
3619
+ async def clone_table(
3620
+ self,
3621
+ target_table: str,
3622
+ source_table: str,
3623
+ snapshot_name: Optional[str] = None,
3624
+ if_not_exists: bool = False,
3625
+ ) -> None:
3626
+ """Clone table within transaction asynchronously"""
3627
+ return await super().clone_table(target_table, source_table, snapshot_name, if_not_exists)
3628
+
3629
+ async def clone_table_with_snapshot(
3630
+ self, target_table: str, source_table: str, snapshot_name: str, if_not_exists: bool = False
3631
+ ) -> None:
3632
+ """Clone table with snapshot within transaction asynchronously"""
3633
+ return await super().clone_table_with_snapshot(target_table, source_table, snapshot_name, if_not_exists)
3634
+
3635
+
3636
+ class AsyncTransactionAccountManager:
3637
+ """Async transaction-scoped account manager"""
3638
+
3639
+ def __init__(self, transaction):
3640
+ self.transaction = transaction
3641
+ self.client = transaction.client
3642
+
3643
+ async def create_account(
3644
+ self,
3645
+ account_name: str,
3646
+ admin_name: str,
3647
+ password: str,
3648
+ comment: Optional[str] = None,
3649
+ admin_comment: Optional[str] = None,
3650
+ admin_host: str = "%",
3651
+ admin_identified_by: Optional[str] = None,
3652
+ ) -> Account:
3653
+ """Create account within async transaction"""
3654
+ try:
3655
+ sql_parts = [f"CREATE ACCOUNT {self.client._escape_identifier(account_name)}"]
3656
+ sql_parts.append(f"ADMIN_NAME {self.client._escape_string(admin_name)}")
3657
+ sql_parts.append(f"IDENTIFIED BY {self.client._escape_string(password)}")
3658
+
3659
+ if admin_host != "%":
3660
+ sql_parts.append(f"ADMIN_HOST {self.client._escape_string(admin_host)}")
3661
+
3662
+ if comment:
3663
+ sql_parts.append(f"COMMENT {self.client._escape_string(comment)}")
3664
+
3665
+ if admin_comment:
3666
+ sql_parts.append(f"ADMIN_COMMENT {self.client._escape_string(admin_comment)}")
3667
+
3668
+ if admin_identified_by:
3669
+ sql_parts.append(f"ADMIN_IDENTIFIED BY {self.client._escape_string(admin_identified_by)}")
3670
+
3671
+ sql = " ".join(sql_parts)
3672
+ await self.transaction.execute(sql)
3673
+
3674
+ return await self.get_account(account_name)
3675
+
3676
+ except Exception as e:
3677
+ raise AccountError(f"Failed to create account '{account_name}': {e}")
3678
+
3679
+ async def drop_account(self, account_name: str) -> None:
3680
+ """Drop account within async transaction"""
3681
+ try:
3682
+ sql = f"DROP ACCOUNT {self.client._escape_identifier(account_name)}"
3683
+ await self.transaction.execute(sql)
3684
+ except Exception as e:
3685
+ raise AccountError(f"Failed to drop account '{account_name}': {e}")
3686
+
3687
+ async def alter_account(
3688
+ self,
3689
+ account_name: str,
3690
+ comment: Optional[str] = None,
3691
+ suspend: Optional[bool] = None,
3692
+ suspend_reason: Optional[str] = None,
3693
+ ) -> Account:
3694
+ """Alter account within async transaction"""
3695
+ try:
3696
+ sql_parts = [f"ALTER ACCOUNT {self.client._escape_identifier(account_name)}"]
3697
+
3698
+ if comment is not None:
3699
+ sql_parts.append(f"COMMENT {self.client._escape_string(comment)}")
3700
+
3701
+ if suspend is not None:
3702
+ if suspend:
3703
+ if suspend_reason:
3704
+ sql_parts.append(f"SUSPEND COMMENT {self.client._escape_string(suspend_reason)}")
3705
+ else:
3706
+ sql_parts.append("SUSPEND")
3707
+ else:
3708
+ sql_parts.append("OPEN")
3709
+
3710
+ sql = " ".join(sql_parts)
3711
+ await self.transaction.execute(sql)
3712
+
3713
+ return await self.get_account(account_name)
3714
+
3715
+ except Exception as e:
3716
+ raise AccountError(f"Failed to alter account '{account_name}': {e}")
3717
+
3718
+ async def get_account(self, account_name: str) -> Account:
3719
+ """Get account within async transaction"""
3720
+ try:
3721
+ sql = "SHOW ACCOUNTS"
3722
+ result = await self.transaction.execute(sql)
3723
+
3724
+ if not result or not result.rows:
3725
+ raise AccountError(f"Account '{account_name}' not found")
3726
+
3727
+ for row in result.rows:
3728
+ if row[0] == account_name:
3729
+ return self._row_to_account(row)
3730
+
3731
+ raise AccountError(f"Account '{account_name}' not found")
3732
+
3733
+ except Exception as e:
3734
+ raise AccountError(f"Failed to get account '{account_name}': {e}")
3735
+
3736
+ async def list_accounts(self) -> List[Account]:
3737
+ """List accounts within async transaction"""
3738
+ try:
3739
+ sql = "SHOW ACCOUNTS"
3740
+ result = await self.transaction.execute(sql)
3741
+
3742
+ if not result or not result.rows:
3743
+ return []
3744
+
3745
+ return [self._row_to_account(row) for row in result.rows]
3746
+
3747
+ except Exception as e:
3748
+ raise AccountError(f"Failed to list accounts: {e}")
3749
+
3750
+ async def create_user(self, user_name: str, password: str, comment: Optional[str] = None) -> User:
3751
+ """
3752
+ Create user within async transaction according to MatrixOne CREATE USER syntax:
3753
+ CREATE USER [IF NOT EXISTS] user auth_option [, user auth_option] ...
3754
+ [DEFAULT ROLE rolename] [COMMENT 'comment_string' | ATTRIBUTE 'json_object']
3755
+
3756
+ Args::
3757
+
3758
+ user_name: Name of the user to create
3759
+ password: Password for the user
3760
+ comment: Comment for the user (not supported in MatrixOne)
3761
+
3762
+ Returns::
3763
+
3764
+ User: Created user object
3765
+ """
3766
+ try:
3767
+ # Build CREATE USER statement according to MatrixOne syntax
3768
+ # MatrixOne syntax: CREATE USER user_name IDENTIFIED BY 'password'
3769
+ sql_parts = [f"CREATE USER {self.client._escape_identifier(user_name)}"]
3770
+
3771
+ sql_parts.append(f"IDENTIFIED BY {self.client._escape_string(password)}")
3772
+
3773
+ # Note: MatrixOne doesn't support COMMENT or ATTRIBUTE clauses in CREATE USER
3774
+ # sql_parts.append(f"ACCOUNT {self.client._escape_identifier(account_name)}")
3775
+ # if comment:
3776
+ # sql_parts.append(f"COMMENT {self.client._escape_string(comment)}")
3777
+ # if identified_by:
3778
+ # sql_parts.append(f"IDENTIFIED BY {self.client._escape_string(identified_by)}")
3779
+
3780
+ sql = " ".join(sql_parts)
3781
+ await self.transaction.execute(sql)
3782
+
3783
+ # Return a User object with current account context
3784
+ return User(
3785
+ name=user_name,
3786
+ host="%", # Default host
3787
+ account="sys", # Default account
3788
+ created_time=datetime.now(),
3789
+ status="ACTIVE",
3790
+ comment=comment,
3791
+ )
3792
+
3793
+ except Exception as e:
3794
+ raise AccountError(f"Failed to create user '{user_name}': {e}")
3795
+
3796
+ async def drop_user(self, user_name: str, if_exists: bool = False) -> None:
3797
+ """
3798
+ Drop user within async transaction according to MatrixOne DROP USER syntax:
3799
+ DROP USER [IF EXISTS] user [, user] ...
3800
+
3801
+ Args::
3802
+
3803
+ user_name: Name of the user to drop
3804
+ if_exists: If True, add IF EXISTS clause to avoid errors when user doesn't exist
3805
+ """
3806
+ try:
3807
+ sql_parts = ["DROP USER"]
3808
+ if if_exists:
3809
+ sql_parts.append("IF EXISTS")
3810
+
3811
+ sql_parts.append(self.client._escape_identifier(user_name))
3812
+ sql = " ".join(sql_parts)
3813
+ await self.transaction.execute(sql)
3814
+
3815
+ except Exception as e:
3816
+ raise AccountError(f"Failed to drop user '{user_name}': {e}")
3817
+
3818
+ async def alter_user(
3819
+ self,
3820
+ user_name: str,
3821
+ password: Optional[str] = None,
3822
+ comment: Optional[str] = None,
3823
+ lock: Optional[bool] = None,
3824
+ lock_reason: Optional[str] = None,
3825
+ ) -> User:
3826
+ """Alter user within async transaction"""
3827
+ try:
3828
+ sql_parts = [f"ALTER USER {self.client._escape_identifier(user_name)}"]
3829
+
3830
+ if password is not None:
3831
+ sql_parts.append(f"IDENTIFIED BY {self.client._escape_string(password)}")
3832
+
3833
+ if comment is not None:
3834
+ sql_parts.append(f"COMMENT {self.client._escape_string(comment)}")
3835
+
3836
+ if lock is not None:
3837
+ if lock:
3838
+ if lock_reason:
3839
+ sql_parts.append(f"ACCOUNT LOCK COMMENT {self.client._escape_string(lock_reason)}")
3840
+ else:
3841
+ sql_parts.append("ACCOUNT LOCK")
3842
+ else:
3843
+ sql_parts.append("ACCOUNT UNLOCK")
3844
+
3845
+ sql = " ".join(sql_parts)
3846
+ await self.transaction.execute(sql)
3847
+
3848
+ return await self.get_user(user_name)
3849
+
3850
+ except Exception as e:
3851
+ raise AccountError(f"Failed to alter user '{user_name}': {e}")
3852
+
3853
+ async def get_user(self, user_name: str) -> User:
3854
+ """Get user within async transaction"""
3855
+ try:
3856
+ sql = "SHOW GRANTS"
3857
+ result = await self.transaction.execute(sql)
3858
+
3859
+ if not result or not result.rows:
3860
+ raise AccountError(f"User '{user_name}' not found")
3861
+
3862
+ for row in result.rows:
3863
+ if row[0] == user_name:
3864
+ return self._row_to_user(row)
3865
+
3866
+ raise AccountError(f"User '{user_name}' not found")
3867
+
3868
+ except Exception as e:
3869
+ raise AccountError(f"Failed to get user '{user_name}': {e}")
3870
+
3871
+ async def list_users(self, account_name: Optional[str] = None) -> List[User]:
3872
+ """List users within async transaction"""
3873
+ try:
3874
+ sql = "SHOW GRANTS"
3875
+ result = await self.transaction.execute(sql)
3876
+
3877
+ if not result or not result.rows:
3878
+ return []
3879
+
3880
+ users = [self._row_to_user(row) for row in result.rows]
3881
+
3882
+ if account_name:
3883
+ users = [user for user in users if user.account == account_name]
3884
+
3885
+ return users
3886
+
3887
+ except Exception as e:
3888
+ raise AccountError(f"Failed to list users: {e}")
3889
+
3890
+ def _row_to_account(self, row: tuple) -> Account:
3891
+ """Convert database row to Account object"""
3892
+ return Account(
3893
+ name=row[0],
3894
+ admin_name=row[1],
3895
+ created_time=row[2] if len(row) > 2 else None,
3896
+ status=row[3] if len(row) > 3 else None,
3897
+ comment=row[4] if len(row) > 4 else None,
3898
+ suspended_time=row[5] if len(row) > 5 else None,
3899
+ suspended_reason=row[6] if len(row) > 6 else None,
3900
+ )
3901
+
3902
+ def _row_to_user(self, row: tuple) -> User:
3903
+ """Convert database row to User object"""
3904
+ return User(
3905
+ name=row[0],
3906
+ host=row[1],
3907
+ account=row[2],
3908
+ created_time=row[3] if len(row) > 3 else None,
3909
+ status=row[4] if len(row) > 4 else None,
3910
+ comment=row[5] if len(row) > 5 else None,
3911
+ locked_time=row[6] if len(row) > 6 else None,
3912
+ locked_reason=row[7] if len(row) > 7 else None,
3913
+ )