matrixone-python-sdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. matrixone/__init__.py +155 -0
  2. matrixone/account.py +723 -0
  3. matrixone/async_client.py +3913 -0
  4. matrixone/async_metadata_manager.py +311 -0
  5. matrixone/async_orm.py +123 -0
  6. matrixone/async_vector_index_manager.py +633 -0
  7. matrixone/base_client.py +208 -0
  8. matrixone/client.py +4672 -0
  9. matrixone/config.py +452 -0
  10. matrixone/connection_hooks.py +286 -0
  11. matrixone/exceptions.py +89 -0
  12. matrixone/logger.py +782 -0
  13. matrixone/metadata.py +820 -0
  14. matrixone/moctl.py +219 -0
  15. matrixone/orm.py +2277 -0
  16. matrixone/pitr.py +646 -0
  17. matrixone/pubsub.py +771 -0
  18. matrixone/restore.py +411 -0
  19. matrixone/search_vector_index.py +1176 -0
  20. matrixone/snapshot.py +550 -0
  21. matrixone/sql_builder.py +844 -0
  22. matrixone/sqlalchemy_ext/__init__.py +161 -0
  23. matrixone/sqlalchemy_ext/adapters.py +163 -0
  24. matrixone/sqlalchemy_ext/dialect.py +534 -0
  25. matrixone/sqlalchemy_ext/fulltext_index.py +895 -0
  26. matrixone/sqlalchemy_ext/fulltext_search.py +1686 -0
  27. matrixone/sqlalchemy_ext/hnsw_config.py +194 -0
  28. matrixone/sqlalchemy_ext/ivf_config.py +252 -0
  29. matrixone/sqlalchemy_ext/table_builder.py +351 -0
  30. matrixone/sqlalchemy_ext/vector_index.py +1721 -0
  31. matrixone/sqlalchemy_ext/vector_type.py +948 -0
  32. matrixone/version.py +580 -0
  33. matrixone_python_sdk-0.1.0.dist-info/METADATA +706 -0
  34. matrixone_python_sdk-0.1.0.dist-info/RECORD +122 -0
  35. matrixone_python_sdk-0.1.0.dist-info/WHEEL +5 -0
  36. matrixone_python_sdk-0.1.0.dist-info/entry_points.txt +5 -0
  37. matrixone_python_sdk-0.1.0.dist-info/licenses/LICENSE +200 -0
  38. matrixone_python_sdk-0.1.0.dist-info/top_level.txt +2 -0
  39. tests/__init__.py +19 -0
  40. tests/offline/__init__.py +20 -0
  41. tests/offline/conftest.py +77 -0
  42. tests/offline/test_account.py +703 -0
  43. tests/offline/test_async_client_query_comprehensive.py +1218 -0
  44. tests/offline/test_basic.py +54 -0
  45. tests/offline/test_case_sensitivity.py +227 -0
  46. tests/offline/test_connection_hooks_offline.py +287 -0
  47. tests/offline/test_dialect_schema_handling.py +609 -0
  48. tests/offline/test_explain_methods.py +346 -0
  49. tests/offline/test_filter_logical_in.py +237 -0
  50. tests/offline/test_fulltext_search_comprehensive.py +795 -0
  51. tests/offline/test_ivf_config.py +249 -0
  52. tests/offline/test_join_methods.py +281 -0
  53. tests/offline/test_join_sqlalchemy_compatibility.py +276 -0
  54. tests/offline/test_logical_in_method.py +237 -0
  55. tests/offline/test_matrixone_version_parsing.py +264 -0
  56. tests/offline/test_metadata_offline.py +557 -0
  57. tests/offline/test_moctl.py +300 -0
  58. tests/offline/test_moctl_simple.py +251 -0
  59. tests/offline/test_model_support_offline.py +359 -0
  60. tests/offline/test_model_support_simple.py +225 -0
  61. tests/offline/test_pinecone_filter_offline.py +377 -0
  62. tests/offline/test_pitr.py +585 -0
  63. tests/offline/test_pubsub.py +712 -0
  64. tests/offline/test_query_update.py +283 -0
  65. tests/offline/test_restore.py +445 -0
  66. tests/offline/test_snapshot_comprehensive.py +384 -0
  67. tests/offline/test_sql_escaping_edge_cases.py +551 -0
  68. tests/offline/test_sqlalchemy_integration.py +382 -0
  69. tests/offline/test_sqlalchemy_vector_integration.py +434 -0
  70. tests/offline/test_table_builder.py +198 -0
  71. tests/offline/test_unified_filter.py +398 -0
  72. tests/offline/test_unified_transaction.py +495 -0
  73. tests/offline/test_vector_index.py +238 -0
  74. tests/offline/test_vector_operations.py +688 -0
  75. tests/offline/test_vector_type.py +174 -0
  76. tests/offline/test_version_core.py +328 -0
  77. tests/offline/test_version_management.py +372 -0
  78. tests/offline/test_version_standalone.py +652 -0
  79. tests/online/__init__.py +20 -0
  80. tests/online/conftest.py +216 -0
  81. tests/online/test_account_management.py +194 -0
  82. tests/online/test_advanced_features.py +344 -0
  83. tests/online/test_async_client_interfaces.py +330 -0
  84. tests/online/test_async_client_online.py +285 -0
  85. tests/online/test_async_model_insert_online.py +293 -0
  86. tests/online/test_async_orm_online.py +300 -0
  87. tests/online/test_async_simple_query_online.py +802 -0
  88. tests/online/test_async_transaction_simple_query.py +300 -0
  89. tests/online/test_basic_connection.py +130 -0
  90. tests/online/test_client_online.py +238 -0
  91. tests/online/test_config.py +90 -0
  92. tests/online/test_config_validation.py +123 -0
  93. tests/online/test_connection_hooks_new_online.py +217 -0
  94. tests/online/test_dialect_schema_handling_online.py +331 -0
  95. tests/online/test_filter_logical_in_online.py +374 -0
  96. tests/online/test_fulltext_comprehensive.py +1773 -0
  97. tests/online/test_fulltext_label_online.py +433 -0
  98. tests/online/test_fulltext_search_online.py +842 -0
  99. tests/online/test_ivf_stats_online.py +506 -0
  100. tests/online/test_logger_integration.py +311 -0
  101. tests/online/test_matrixone_query_orm.py +540 -0
  102. tests/online/test_metadata_online.py +579 -0
  103. tests/online/test_model_insert_online.py +255 -0
  104. tests/online/test_mysql_driver_validation.py +213 -0
  105. tests/online/test_orm_advanced_features.py +2022 -0
  106. tests/online/test_orm_cte_integration.py +269 -0
  107. tests/online/test_orm_online.py +270 -0
  108. tests/online/test_pinecone_filter.py +708 -0
  109. tests/online/test_pubsub_operations.py +352 -0
  110. tests/online/test_query_methods.py +225 -0
  111. tests/online/test_query_update_online.py +433 -0
  112. tests/online/test_search_vector_index.py +557 -0
  113. tests/online/test_simple_fulltext_online.py +915 -0
  114. tests/online/test_snapshot_comprehensive.py +998 -0
  115. tests/online/test_sqlalchemy_engine_integration.py +336 -0
  116. tests/online/test_sqlalchemy_integration.py +425 -0
  117. tests/online/test_transaction_contexts.py +1219 -0
  118. tests/online/test_transaction_insert_methods.py +356 -0
  119. tests/online/test_transaction_query_methods.py +288 -0
  120. tests/online/test_unified_filter_online.py +529 -0
  121. tests/online/test_vector_comprehensive.py +706 -0
  122. tests/online/test_version_management.py +291 -0
@@ -0,0 +1,1218 @@
1
+ # Copyright 2021 - 2022 Matrix Origin
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Comprehensive Async Client Query Tests
17
+
18
+ This file consolidates all async client and query-related tests from:
19
+ - test_async.py (comprehensive async client tests)
20
+ - test_async_simple.py (basic async client tests)
21
+ - test_async_sqlalchemy_transaction.py (SQLAlchemy integration tests)
22
+
23
+ The merged file eliminates redundancy while maintaining full test coverage.
24
+ """
25
+
26
+ import unittest
27
+ import asyncio
28
+ import pytest
29
+ from unittest.mock import Mock, patch, AsyncMock
30
+ import sys
31
+ import os
32
+ from datetime import datetime
33
+
34
+ # Store original modules to restore later
35
+ _original_modules = {}
36
+
37
+
38
+ def setup_sqlalchemy_mocks():
39
+ """Setup SQLAlchemy mocks for this test class"""
40
+ global _original_modules
41
+ _original_modules['pymysql'] = sys.modules.get('pymysql')
42
+ _original_modules['aiomysql'] = sys.modules.get('aiomysql')
43
+ _original_modules['sqlalchemy'] = sys.modules.get('sqlalchemy')
44
+ _original_modules['sqlalchemy.engine'] = sys.modules.get('sqlalchemy.engine')
45
+ _original_modules['sqlalchemy.orm'] = sys.modules.get('sqlalchemy.orm')
46
+ _original_modules['sqlalchemy.ext'] = sys.modules.get('sqlalchemy.ext')
47
+ _original_modules['sqlalchemy.ext.asyncio'] = sys.modules.get('sqlalchemy.ext.asyncio')
48
+
49
+ # Mock the external dependencies
50
+ sys.modules['pymysql'] = Mock()
51
+ sys.modules['aiomysql'] = Mock()
52
+
53
+ # Create a more sophisticated SQLAlchemy mock that supports submodules
54
+ sqlalchemy_mock = Mock()
55
+ sqlalchemy_mock.create_engine = Mock()
56
+ sqlalchemy_mock.text = Mock()
57
+ sqlalchemy_mock.Column = Mock()
58
+ sqlalchemy_mock.Integer = Mock()
59
+ sqlalchemy_mock.String = Mock()
60
+ sqlalchemy_mock.DateTime = Mock()
61
+
62
+ sys.modules['sqlalchemy'] = sqlalchemy_mock
63
+ sys.modules['sqlalchemy.engine'] = Mock()
64
+ sys.modules['sqlalchemy.engine'].Engine = Mock()
65
+ sys.modules['sqlalchemy.orm'] = Mock()
66
+ sys.modules['sqlalchemy.orm'].sessionmaker = Mock()
67
+ sys.modules['sqlalchemy.orm'].declarative_base = Mock()
68
+
69
+ # Mock SQLAlchemy async engine
70
+ sys.modules['sqlalchemy.ext'] = Mock()
71
+ sys.modules['sqlalchemy.ext.asyncio'] = Mock()
72
+ sys.modules['sqlalchemy.ext.asyncio'].create_async_engine = Mock()
73
+ sys.modules['sqlalchemy.ext.asyncio'].AsyncEngine = Mock()
74
+ sys.modules['sqlalchemy.ext.asyncio'].AsyncSession = Mock()
75
+ sys.modules['sqlalchemy.ext.asyncio'].async_sessionmaker = Mock()
76
+
77
+
78
+ def teardown_sqlalchemy_mocks():
79
+ """Restore original modules"""
80
+ global _original_modules
81
+ for module_name, original_module in _original_modules.items():
82
+ if original_module is not None:
83
+ sys.modules[module_name] = original_module
84
+ elif module_name in sys.modules:
85
+ del sys.modules[module_name]
86
+
87
+
88
+ # Add the matrixone package to the path
89
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))
90
+
91
+ from matrixone.async_client import (
92
+ AsyncClient,
93
+ AsyncResultSet,
94
+ AsyncSnapshotManager,
95
+ AsyncCloneManager,
96
+ AsyncMoCtlManager,
97
+ AsyncTransactionWrapper,
98
+ )
99
+ from matrixone.snapshot import SnapshotLevel, Snapshot
100
+ from matrixone.exceptions import MoCtlError, ConnectionError
101
+
102
+
103
+ class TestAsyncResultSet(unittest.TestCase):
104
+ """Test AsyncResultSet functionality"""
105
+
106
+ @classmethod
107
+ def setUpClass(cls):
108
+ """Setup mocks for the entire test class"""
109
+ setup_sqlalchemy_mocks()
110
+
111
+ @classmethod
112
+ def tearDownClass(cls):
113
+ """Restore original modules after tests"""
114
+ teardown_sqlalchemy_mocks()
115
+
116
+ def test_async_result_set(self):
117
+ """Test AsyncResultSet basic functionality"""
118
+ columns = ['id', 'name', 'email']
119
+ rows = [(1, 'Alice', 'alice@example.com'), (2, 'Bob', 'bob@example.com')]
120
+
121
+ result = AsyncResultSet(columns, rows)
122
+
123
+ self.assertEqual(result.columns, columns)
124
+ self.assertEqual(result.rows, rows)
125
+ self.assertEqual(len(result), 2)
126
+ self.assertEqual(result.fetchone(), (1, 'Alice', 'alice@example.com'))
127
+ self.assertEqual(result.scalar(), 1)
128
+
129
+ # Test iteration
130
+ row_list = list(result)
131
+ self.assertEqual(len(row_list), 2)
132
+
133
+
134
+ class TestAsyncClientBasic(unittest.IsolatedAsyncioTestCase):
135
+ """Test AsyncClient basic functionality and imports"""
136
+
137
+ @classmethod
138
+ def setUpClass(cls):
139
+ """Setup mocks for the entire test class"""
140
+ setup_sqlalchemy_mocks()
141
+
142
+ @classmethod
143
+ def tearDownClass(cls):
144
+ """Restore original modules after tests"""
145
+ teardown_sqlalchemy_mocks()
146
+
147
+ def setUp(self):
148
+ """Set up test fixtures"""
149
+ self.client = AsyncClient()
150
+
151
+ def test_init(self):
152
+ """Test AsyncClient initialization"""
153
+ self.assertEqual(self.client.connection_timeout, 30)
154
+ self.assertEqual(self.client.query_timeout, 300)
155
+ self.assertEqual(self.client.auto_commit, True)
156
+ self.assertEqual(self.client.charset, 'utf8mb4')
157
+ self.assertIsNone(self.client._engine)
158
+ self.assertIsNone(self.client._connection)
159
+ self.assertIsInstance(self.client._snapshots, AsyncSnapshotManager)
160
+ self.assertIsInstance(self.client._clone, AsyncCloneManager)
161
+ self.assertIsInstance(self.client._moctl, AsyncMoCtlManager)
162
+
163
+ def test_managers_properties(self):
164
+ """Test async managers properties"""
165
+ # Test snapshot manager
166
+ snapshot_manager = self.client.snapshots
167
+ self.assertIsInstance(snapshot_manager, AsyncSnapshotManager)
168
+
169
+ # Test clone manager
170
+ clone_manager = self.client.clone
171
+ self.assertIsInstance(clone_manager, AsyncCloneManager)
172
+
173
+ # Test moctl manager
174
+ moctl_manager = self.client.moctl
175
+ self.assertIsInstance(moctl_manager, AsyncMoCtlManager)
176
+
177
+ def test_imports(self):
178
+ """Test async imports"""
179
+ try:
180
+ from matrixone import AsyncClient, AsyncResultSet
181
+ from matrixone.async_client import (
182
+ AsyncSnapshotManager,
183
+ AsyncCloneManager,
184
+ AsyncMoCtlManager,
185
+ )
186
+ from matrixone.snapshot import SnapshotLevel
187
+
188
+ # Test enum import
189
+ self.assertIsNotNone(list(SnapshotLevel))
190
+ self.assertTrue(True) # If we get here, imports work
191
+ except ImportError as e:
192
+ self.fail(f"Import error: {e}")
193
+
194
+ async def test_context_manager(self):
195
+ """Test async context manager"""
196
+ with patch.object(self.client, 'connect') as mock_connect, patch.object(
197
+ self.client, 'disconnect'
198
+ ) as mock_disconnect, patch.object(self.client, 'connected') as mock_connected:
199
+
200
+ # Mock connected to return True so disconnect will be called
201
+ mock_connected.return_value = True
202
+
203
+ async with self.client as client:
204
+ self.assertEqual(client, self.client)
205
+
206
+ # disconnect should be called since connected() returns True
207
+ mock_disconnect.assert_called_once()
208
+
209
+
210
+ class TestAsyncClientConnection(unittest.IsolatedAsyncioTestCase):
211
+ """Test AsyncClient connection functionality"""
212
+
213
+ @classmethod
214
+ def setUpClass(cls):
215
+ """Setup mocks for the entire test class"""
216
+ setup_sqlalchemy_mocks()
217
+
218
+ @classmethod
219
+ def tearDownClass(cls):
220
+ """Restore original modules after tests"""
221
+ teardown_sqlalchemy_mocks()
222
+
223
+ def setUp(self):
224
+ """Set up test fixtures"""
225
+ self.client = AsyncClient()
226
+
227
+ @patch('matrixone.async_client.create_async_engine')
228
+ async def test_connect(self, mock_create_async_engine):
229
+ """Test async connection"""
230
+ # Create mock connection and result
231
+ mock_connection = AsyncMock()
232
+ mock_result = AsyncMock()
233
+ mock_result.returns_rows = False
234
+
235
+ # Create a proper async context manager for engine.begin()
236
+ class MockBeginContext:
237
+ def __init__(self, connection):
238
+ self.connection = connection
239
+
240
+ async def __aenter__(self):
241
+ return self.connection
242
+
243
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
244
+ pass
245
+
246
+ # Create a mock engine class that properly implements begin()
247
+ class MockEngine:
248
+ def __init__(self, connection):
249
+ self.connection = connection
250
+
251
+ def begin(self):
252
+ return MockBeginContext(self.connection)
253
+
254
+ # Mock the connection.execute method - make it async
255
+ async def mock_execute(sql, params=None):
256
+ return mock_result
257
+
258
+ mock_connection.execute = mock_execute
259
+
260
+ # Create mock engine instance
261
+ mock_engine = MockEngine(mock_connection)
262
+
263
+ # Mock create_async_engine to return our mock engine
264
+ mock_create_async_engine.return_value = mock_engine
265
+
266
+ await self.client.connect(host="localhost", port=6001, user="root", password="111", database="test")
267
+
268
+ # Verify create_async_engine was called
269
+ mock_create_async_engine.assert_called_once()
270
+ self.assertEqual(self.client._engine, mock_engine)
271
+ self.assertEqual(self.client._connection_params['host'], 'localhost')
272
+ self.assertEqual(self.client._connection_params['port'], 6001)
273
+ self.assertEqual(self.client._connection_params['user'], 'root')
274
+ self.assertEqual(self.client._connection_params['password'], '111')
275
+ self.assertEqual(self.client._connection_params['database'], 'test')
276
+
277
+ @patch('matrixone.async_client.create_async_engine')
278
+ async def test_connect_failure(self, mock_create_async_engine):
279
+ """Test async connection failure"""
280
+ mock_create_async_engine.side_effect = Exception("Connection failed")
281
+
282
+ with self.assertRaises(Exception):
283
+ await self.client.connect(host="localhost", port=6001, user="root", password="111", database="test")
284
+
285
+ async def test_disconnect(self):
286
+ """Test async disconnection"""
287
+ mock_engine = AsyncMock()
288
+ mock_engine.dispose = AsyncMock()
289
+
290
+ self.client._engine = mock_engine
291
+
292
+ await self.client.disconnect()
293
+
294
+ mock_engine.dispose.assert_called_once()
295
+ self.assertIsNone(self.client._engine)
296
+
297
+
298
+ class TestAsyncClientQuery(unittest.IsolatedAsyncioTestCase):
299
+ """Test AsyncClient query functionality"""
300
+
301
+ @classmethod
302
+ def setUpClass(cls):
303
+ """Setup mocks for the entire test class"""
304
+ setup_sqlalchemy_mocks()
305
+
306
+ @classmethod
307
+ def tearDownClass(cls):
308
+ """Restore original modules after tests"""
309
+ teardown_sqlalchemy_mocks()
310
+
311
+ def setUp(self):
312
+ """Set up test fixtures"""
313
+ self.client = AsyncClient()
314
+
315
+ async def test_execute_success(self):
316
+ """Test successful async execution"""
317
+
318
+ # Create a mock result class that properly implements the interface
319
+ class MockResult:
320
+ def __init__(self):
321
+ self.returns_rows = True
322
+
323
+ def fetchall(self):
324
+ return [(1, 'Alice'), (2, 'Bob')]
325
+
326
+ def keys(self):
327
+ return ['id', 'name']
328
+
329
+ # Create a mock connection class
330
+ class MockConnection:
331
+ def __init__(self):
332
+ self.execute_called = False
333
+ self.execute_args = None
334
+
335
+ async def execute(self, sql, params=None):
336
+ self.execute_called = True
337
+ self.execute_args = (sql, params)
338
+ return MockResult()
339
+
340
+ # Create a proper async context manager for engine.begin()
341
+ class MockBeginContext:
342
+ def __init__(self, connection):
343
+ self.connection = connection
344
+
345
+ async def __aenter__(self):
346
+ return self.connection
347
+
348
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
349
+ pass
350
+
351
+ # Create a mock engine class that properly implements begin()
352
+ class MockEngine:
353
+ def __init__(self, connection):
354
+ self.connection = connection
355
+
356
+ def begin(self):
357
+ return MockBeginContext(self.connection)
358
+
359
+ # Create mock instances
360
+ mock_connection = MockConnection()
361
+ mock_engine = MockEngine(mock_connection)
362
+
363
+ # Set the mock engine on the client
364
+ self.client._engine = mock_engine
365
+
366
+ result = await self.client.execute("SELECT id, name FROM users")
367
+
368
+ self.assertTrue(mock_connection.execute_called)
369
+ self.assertIsInstance(result, AsyncResultSet)
370
+ self.assertEqual(result.columns, ['id', 'name'])
371
+ self.assertEqual(result.rows, [(1, 'Alice'), (2, 'Bob')])
372
+
373
+ async def test_execute_with_params(self):
374
+ """Test async execution with parameters"""
375
+
376
+ # Create a mock result class that properly implements the interface
377
+ class MockResult:
378
+ def __init__(self):
379
+ self.returns_rows = False
380
+ self.rowcount = 1
381
+
382
+ # Create a mock connection class
383
+ class MockConnection:
384
+ def __init__(self):
385
+ self.execute_called = False
386
+ self.execute_args = None
387
+
388
+ async def execute(self, sql, params=None):
389
+ self.execute_called = True
390
+ self.execute_args = (sql, params)
391
+ return MockResult()
392
+
393
+ # Create a proper async context manager for engine.begin()
394
+ class MockBeginContext:
395
+ def __init__(self, connection):
396
+ self.connection = connection
397
+
398
+ async def __aenter__(self):
399
+ return self.connection
400
+
401
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
402
+ pass
403
+
404
+ # Create a mock engine class that properly implements begin()
405
+ class MockEngine:
406
+ def __init__(self, connection):
407
+ self.connection = connection
408
+
409
+ def begin(self):
410
+ return MockBeginContext(self.connection)
411
+
412
+ # Create mock instances
413
+ mock_connection = MockConnection()
414
+ mock_engine = MockEngine(mock_connection)
415
+
416
+ # Set the mock engine on the client
417
+ self.client._engine = mock_engine
418
+
419
+ result = await self.client.execute("INSERT INTO users (name) VALUES (%s)", ("Alice",))
420
+
421
+ self.assertTrue(mock_connection.execute_called)
422
+ self.assertEqual(result.affected_rows, 1)
423
+
424
+ async def test_execute_not_connected(self):
425
+ """Test async execution without connection"""
426
+ with self.assertRaises(Exception):
427
+ await self.client.execute("SELECT 1")
428
+
429
+ async def test_snapshot_query(self):
430
+ """Test async snapshot query"""
431
+
432
+ # Create a proper async context manager for engine.begin()
433
+ class MockBeginContext:
434
+ def __init__(self, connection):
435
+ self.connection = connection
436
+
437
+ async def __aenter__(self):
438
+ return self.connection
439
+
440
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
441
+ pass
442
+
443
+ # Create a mock engine class that properly implements begin()
444
+ class MockEngine:
445
+ def __init__(self, connection):
446
+ self.connection = connection
447
+
448
+ def begin(self):
449
+ return MockBeginContext(self.connection)
450
+
451
+ # Create mock connection and result
452
+ mock_connection = Mock()
453
+ mock_result = Mock()
454
+
455
+ # Setup mock result
456
+ mock_result.returns_rows = True
457
+ mock_result.fetchall.return_value = [(1, 'Alice')]
458
+ mock_result.keys.return_value = ['id', 'name']
459
+
460
+ # Setup mock connection - make exec_driver_sql return a coroutine
461
+ async def mock_exec_driver_sql(sql):
462
+ return mock_result
463
+
464
+ mock_connection.exec_driver_sql = Mock(side_effect=mock_exec_driver_sql)
465
+
466
+ # Also setup execute for fallback
467
+ async def mock_execute(sql, params=None):
468
+ return mock_result
469
+
470
+ mock_connection.execute = Mock(side_effect=mock_execute)
471
+
472
+ # Mock the text function
473
+ sys.modules['sqlalchemy'].text = Mock(return_value="SELECT id, name FROM users {SNAPSHOT = 'test_snapshot'}")
474
+
475
+ # Create mock engine instance
476
+ mock_engine = MockEngine(mock_connection)
477
+
478
+ self.client._engine = mock_engine
479
+
480
+ result = await self.client.query("users", snapshot="test_snapshot").select("id", "name").execute()
481
+
482
+ expected_sql = "SELECT id, name FROM users {SNAPSHOT = 'test_snapshot'}"
483
+ # Check that exec_driver_sql or execute was called
484
+ assert mock_connection.exec_driver_sql.called or mock_connection.execute.called
485
+ self.assertIsInstance(result, AsyncResultSet)
486
+
487
+
488
+ class TestAsyncSnapshotManager(unittest.IsolatedAsyncioTestCase):
489
+ """Test AsyncSnapshotManager functionality"""
490
+
491
+ @classmethod
492
+ def setUpClass(cls):
493
+ """Setup mocks for the entire test class"""
494
+ setup_sqlalchemy_mocks()
495
+
496
+ @classmethod
497
+ def tearDownClass(cls):
498
+ """Restore original modules after tests"""
499
+ teardown_sqlalchemy_mocks()
500
+
501
+ def setUp(self):
502
+ """Set up test fixtures"""
503
+ self.mock_client = AsyncMock()
504
+ self.snapshot_manager = AsyncSnapshotManager(self.mock_client)
505
+
506
+ async def test_create_cluster_snapshot(self):
507
+ """Test creating cluster snapshot"""
508
+ self.mock_client.execute.return_value = AsyncResultSet([], [])
509
+
510
+ snapshot = await self.snapshot_manager.create("test_snap", SnapshotLevel.CLUSTER)
511
+
512
+ self.mock_client.execute.assert_called_once_with("CREATE SNAPSHOT test_snap FOR CLUSTER")
513
+ self.assertEqual(snapshot.name, "test_snap")
514
+ self.assertEqual(snapshot.level, SnapshotLevel.CLUSTER)
515
+
516
+ async def test_create_database_snapshot(self):
517
+ """Test creating database snapshot"""
518
+ # Mock the result to include database information
519
+ mock_result = AsyncResultSet([], [])
520
+ mock_result.database = "test_db"
521
+ self.mock_client.execute.return_value = mock_result
522
+
523
+ snapshot = await self.snapshot_manager.create("test_snap", SnapshotLevel.DATABASE, database="test_db")
524
+
525
+ self.mock_client.execute.assert_called_once_with("CREATE SNAPSHOT test_snap FOR DATABASE test_db")
526
+ self.assertEqual(snapshot.name, "test_snap")
527
+ self.assertEqual(snapshot.level, SnapshotLevel.DATABASE)
528
+ self.assertEqual(snapshot.database, "test_db")
529
+
530
+ async def test_create_table_snapshot(self):
531
+ """Test creating table snapshot"""
532
+ # Mock the result to include database and table information
533
+ mock_result = AsyncResultSet([], [])
534
+ mock_result.database = "test_db"
535
+ mock_result.table = "test_table"
536
+ self.mock_client.execute.return_value = mock_result
537
+
538
+ snapshot = await self.snapshot_manager.create(
539
+ "test_snap", SnapshotLevel.TABLE, database="test_db", table="test_table"
540
+ )
541
+
542
+ self.mock_client.execute.assert_called_once_with("CREATE SNAPSHOT test_snap FOR TABLE test_db test_table")
543
+ self.assertEqual(snapshot.name, "test_snap")
544
+ self.assertEqual(snapshot.level, SnapshotLevel.TABLE)
545
+ self.assertEqual(snapshot.database, "test_db")
546
+ self.assertEqual(snapshot.table, "test_table")
547
+
548
+ async def test_get_snapshot(self):
549
+ """Test getting snapshot"""
550
+ # Use integer timestamp instead of datetime object
551
+ timestamp_ns = int(datetime.now().timestamp() * 1000000000)
552
+ mock_result = AsyncResultSet(
553
+ ['sname', 'ts', 'level', 'account_name', 'database_name', 'table_name'],
554
+ [('test_snap', timestamp_ns, 'database', 'sys', 'test_db', None)],
555
+ )
556
+ self.mock_client.execute.return_value = mock_result
557
+
558
+ snapshot = await self.snapshot_manager.get("test_snap")
559
+
560
+ self.assertEqual(snapshot.name, "test_snap")
561
+ self.assertEqual(snapshot.level, SnapshotLevel.DATABASE)
562
+ self.assertEqual(snapshot.database, "test_db")
563
+
564
+ async def test_list_snapshots(self):
565
+ """Test listing snapshots"""
566
+ # Use integer timestamps instead of datetime objects
567
+ timestamp_ns = int(datetime.now().timestamp() * 1000000000)
568
+ mock_result = AsyncResultSet(
569
+ ['sname', 'ts', 'level', 'account_name', 'database_name', 'table_name'],
570
+ [
571
+ ('snap1', timestamp_ns, 'database', 'sys', 'db1', None),
572
+ ('snap2', timestamp_ns, 'table', 'sys', 'db2', 'table2'),
573
+ ],
574
+ )
575
+ self.mock_client.execute.return_value = mock_result
576
+
577
+ snapshots = await self.snapshot_manager.list()
578
+
579
+ self.assertEqual(len(snapshots), 2)
580
+ self.assertEqual(snapshots[0].name, "snap1")
581
+ self.assertEqual(snapshots[1].name, "snap2")
582
+
583
+ async def test_delete_snapshot(self):
584
+ """Test deleting snapshot"""
585
+ self.mock_client.execute.return_value = AsyncResultSet([], [])
586
+
587
+ await self.snapshot_manager.delete("test_snap")
588
+
589
+ self.mock_client.execute.assert_called_once_with("DROP SNAPSHOT test_snap")
590
+
591
+ async def test_exists_snapshot(self):
592
+ """Test checking snapshot existence"""
593
+ # Test exists
594
+ timestamp_ns = int(datetime.now().timestamp() * 1000000000)
595
+ mock_result = AsyncResultSet(
596
+ ['sname', 'ts', 'level', 'account_name', 'database_name', 'table_name'],
597
+ [('test_snap', timestamp_ns, 'database', 'sys', 'test_db', None)],
598
+ )
599
+ self.mock_client.execute.return_value = mock_result
600
+
601
+ exists = await self.snapshot_manager.exists("test_snap")
602
+ self.assertTrue(exists)
603
+
604
+ # Test not exists
605
+ from matrixone.exceptions import SnapshotError
606
+
607
+ self.mock_client.execute.side_effect = SnapshotError("Snapshot 'nonexistent' not found")
608
+
609
+ exists = await self.snapshot_manager.exists("nonexistent")
610
+ self.assertFalse(exists)
611
+
612
+
613
+ class TestAsyncCloneManager(unittest.IsolatedAsyncioTestCase):
614
+ """Test AsyncCloneManager functionality"""
615
+
616
+ @classmethod
617
+ def setUpClass(cls):
618
+ """Setup mocks for the entire test class"""
619
+ setup_sqlalchemy_mocks()
620
+
621
+ @classmethod
622
+ def tearDownClass(cls):
623
+ """Restore original modules after tests"""
624
+ teardown_sqlalchemy_mocks()
625
+
626
+ def setUp(self):
627
+ """Set up test fixtures"""
628
+ self.mock_client = AsyncMock()
629
+ self.clone_manager = AsyncCloneManager(self.mock_client)
630
+
631
+ async def test_clone_database(self):
632
+ """Test cloning database"""
633
+ self.mock_client.execute.return_value = AsyncResultSet([], [])
634
+
635
+ await self.clone_manager.clone_database("target_db", "source_db")
636
+
637
+ self.mock_client.execute.assert_called_once_with("CREATE DATABASE target_db CLONE source_db")
638
+
639
+ async def test_clone_database_with_snapshot(self):
640
+ """Test cloning database with snapshot"""
641
+ self.mock_client.execute.return_value = AsyncResultSet([], [])
642
+
643
+ await self.clone_manager.clone_database_with_snapshot("target_db", "source_db", "test_snapshot")
644
+
645
+ self.mock_client.execute.assert_called_once_with(
646
+ "CREATE DATABASE target_db CLONE source_db FOR SNAPSHOT 'test_snapshot'"
647
+ )
648
+
649
+ async def test_clone_database_if_not_exists(self):
650
+ """Test cloning database with if not exists"""
651
+ self.mock_client.execute.return_value = AsyncResultSet([], [])
652
+
653
+ await self.clone_manager.clone_database("target_db", "source_db", if_not_exists=True)
654
+
655
+ self.mock_client.execute.assert_called_once_with("CREATE DATABASE target_db IF NOT EXISTS CLONE source_db")
656
+
657
+ async def test_clone_table(self):
658
+ """Test cloning table"""
659
+ self.mock_client.execute.return_value = AsyncResultSet([], [])
660
+
661
+ await self.clone_manager.clone_table("target_table", "source_table")
662
+
663
+ self.mock_client.execute.assert_called_once_with("CREATE TABLE target_table CLONE source_table")
664
+
665
+ async def test_clone_table_with_snapshot(self):
666
+ """Test cloning table with snapshot"""
667
+ self.mock_client.execute.return_value = AsyncResultSet([], [])
668
+
669
+ await self.clone_manager.clone_table_with_snapshot("target_table", "source_table", "test_snapshot")
670
+
671
+ self.mock_client.execute.assert_called_once_with(
672
+ "CREATE TABLE target_table CLONE source_table FOR SNAPSHOT 'test_snapshot'"
673
+ )
674
+
675
+
676
+ class TestAsyncMoCtlManager(unittest.IsolatedAsyncioTestCase):
677
+ """Test AsyncMoCtlManager functionality"""
678
+
679
+ @classmethod
680
+ def setUpClass(cls):
681
+ """Setup mocks for the entire test class"""
682
+ setup_sqlalchemy_mocks()
683
+
684
+ @classmethod
685
+ def tearDownClass(cls):
686
+ """Restore original modules after tests"""
687
+ teardown_sqlalchemy_mocks()
688
+
689
+ def setUp(self):
690
+ """Set up test fixtures"""
691
+ self.mock_client = AsyncMock()
692
+ self.moctl_manager = AsyncMoCtlManager(self.mock_client)
693
+
694
+ async def test_flush_table(self):
695
+ """Test async flush table"""
696
+ mock_result = AsyncResultSet(['result'], [('{"method": "Flush", "result": [{"returnStr": "OK"}]}',)])
697
+ self.mock_client.execute.return_value = mock_result
698
+
699
+ result = await self.moctl_manager.flush_table('db1', 'users')
700
+
701
+ expected_sql = "SELECT mo_ctl('dn', 'flush', 'db1.users')"
702
+ self.mock_client.execute.assert_called_once_with(expected_sql)
703
+ self.assertEqual(result['method'], 'Flush')
704
+ self.assertEqual(result['result'][0]['returnStr'], 'OK')
705
+
706
+ async def test_increment_checkpoint(self):
707
+ """Test async increment checkpoint"""
708
+ mock_result = AsyncResultSet(['result'], [('{"method": "Checkpoint", "result": [{"returnStr": "OK"}]}',)])
709
+ self.mock_client.execute.return_value = mock_result
710
+
711
+ result = await self.moctl_manager.increment_checkpoint()
712
+
713
+ expected_sql = "SELECT mo_ctl('dn', 'checkpoint', '')"
714
+ self.mock_client.execute.assert_called_once_with(expected_sql)
715
+ self.assertEqual(result['method'], 'Checkpoint')
716
+ self.assertEqual(result['result'][0]['returnStr'], 'OK')
717
+
718
+ async def test_global_checkpoint(self):
719
+ """Test async global checkpoint"""
720
+ mock_result = AsyncResultSet(['result'], [('{"method": "GlobalCheckpoint", "result": [{"returnStr": "OK"}]}',)])
721
+ self.mock_client.execute.return_value = mock_result
722
+
723
+ result = await self.moctl_manager.global_checkpoint()
724
+
725
+ expected_sql = "SELECT mo_ctl('dn', 'globalcheckpoint', '')"
726
+ self.mock_client.execute.assert_called_once_with(expected_sql)
727
+ self.assertEqual(result['method'], 'GlobalCheckpoint')
728
+ self.assertEqual(result['result'][0]['returnStr'], 'OK')
729
+
730
+ async def test_moctl_error_handling(self):
731
+ """Test async mo_ctl error handling"""
732
+ mock_result = AsyncResultSet(
733
+ ['result'],
734
+ [('{"method": "Flush", "result": [{"returnStr": "ERROR: Table not found"}]}',)],
735
+ )
736
+ self.mock_client.execute.return_value = mock_result
737
+
738
+ with self.assertRaises(MoCtlError) as context:
739
+ await self.moctl_manager.flush_table('db1', 'nonexistent')
740
+
741
+ self.assertIn("ERROR: Table not found", str(context.exception))
742
+
743
+
744
+ class TestAsyncTransaction(unittest.IsolatedAsyncioTestCase):
745
+ """Test async transaction functionality"""
746
+
747
+ @classmethod
748
+ def setUpClass(cls):
749
+ """Setup mocks for the entire test class"""
750
+ setup_sqlalchemy_mocks()
751
+
752
+ @classmethod
753
+ def tearDownClass(cls):
754
+ """Restore original modules after tests"""
755
+ teardown_sqlalchemy_mocks()
756
+
757
+ def setUp(self):
758
+ """Set up test fixtures"""
759
+ self.client = AsyncClient()
760
+ self.mock_connection = AsyncMock()
761
+
762
+ # Create a mock engine class that properly implements begin()
763
+ class MockEngine:
764
+ def __init__(self, connection):
765
+ self.connection = connection
766
+
767
+ def begin(self):
768
+ # Create a proper async context manager for engine.begin()
769
+ class MockBeginContext:
770
+ def __init__(self, connection):
771
+ self.connection = connection
772
+
773
+ async def __aenter__(self):
774
+ return self.connection
775
+
776
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
777
+ pass
778
+
779
+ return MockBeginContext(self.connection)
780
+
781
+ self.mock_engine = MockEngine(self.mock_connection)
782
+ self.client._engine = self.mock_engine
783
+
784
+ async def test_transaction_success(self):
785
+ """Test successful async transaction"""
786
+ mock_result = AsyncMock()
787
+ mock_result.returns_rows = False
788
+ mock_result.rowcount = 1
789
+
790
+ # Mock the connection.exec_driver_sql method - make it async
791
+ async def mock_exec_driver_sql(sql):
792
+ return mock_result
793
+
794
+ self.mock_connection.exec_driver_sql = Mock(side_effect=mock_exec_driver_sql)
795
+
796
+ # Also setup execute for fallback
797
+ async def mock_execute(sql, params=None):
798
+ return mock_result
799
+
800
+ self.mock_connection.execute = Mock(side_effect=mock_execute)
801
+
802
+ async with self.client.transaction() as tx:
803
+ await tx.execute("INSERT INTO users (name) VALUES (%s)", ("Alice",))
804
+
805
+ # Verify that the transaction was used
806
+ assert self.mock_connection.exec_driver_sql.called or self.mock_connection.execute.called
807
+
808
+ async def test_transaction_rollback(self):
809
+ """Test async transaction rollback"""
810
+
811
+ # Mock the connection.exec_driver_sql method to raise an exception
812
+ async def mock_exec_driver_sql_error(sql):
813
+ raise Exception("Query failed")
814
+
815
+ self.mock_connection.exec_driver_sql = Mock(side_effect=mock_exec_driver_sql_error)
816
+
817
+ # Also setup execute for fallback
818
+ async def mock_execute_error(sql, params=None):
819
+ raise Exception("Query failed")
820
+
821
+ self.mock_connection.execute = mock_execute_error
822
+
823
+ with self.assertRaises(Exception):
824
+ async with self.client.transaction() as tx:
825
+ await tx.execute("INSERT INTO users (name) VALUES (%s)", ("Alice",))
826
+
827
+
828
+ class TestAsyncSQLAlchemyTransaction(unittest.IsolatedAsyncioTestCase):
829
+ """Test Async SQLAlchemy Transaction Integration"""
830
+
831
+ @classmethod
832
+ def setUpClass(cls):
833
+ """Setup mocks for the entire test class"""
834
+ setup_sqlalchemy_mocks()
835
+
836
+ @classmethod
837
+ def tearDownClass(cls):
838
+ """Restore original modules after tests"""
839
+ teardown_sqlalchemy_mocks()
840
+
841
+ def setUp(self):
842
+ """Set up test fixtures"""
843
+ self.client = AsyncClient()
844
+ self.mock_connection = AsyncMock()
845
+
846
+ # Create a mock engine class that properly implements begin()
847
+ class MockEngine:
848
+ def __init__(self, connection):
849
+ self.connection = connection
850
+
851
+ def begin(self):
852
+ # Create a proper async context manager for engine.begin()
853
+ class MockBeginContext:
854
+ def __init__(self, connection):
855
+ self.connection = connection
856
+
857
+ async def __aenter__(self):
858
+ return self.connection
859
+
860
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
861
+ pass
862
+
863
+ return MockBeginContext(self.connection)
864
+
865
+ self.mock_engine = MockEngine(self.mock_connection)
866
+ self.client._engine = self.mock_engine
867
+
868
+ # Set up connection parameters for SQLAlchemy integration
869
+ self.client._connection_params = {
870
+ 'user': 'testuser',
871
+ 'password': 'testpass',
872
+ 'host': 'localhost',
873
+ 'port': 6001,
874
+ 'db': 'testdb',
875
+ }
876
+
877
+ async def test_transaction_wrapper_sqlalchemy_session(self):
878
+ """Test transaction wrapper SQLAlchemy session creation"""
879
+ tx_wrapper = AsyncTransactionWrapper(self.mock_connection, self.client)
880
+
881
+ # Mock SQLAlchemy session directly
882
+ mock_session = AsyncMock()
883
+ mock_session.begin = AsyncMock()
884
+ mock_session.commit = AsyncMock()
885
+ mock_session.rollback = AsyncMock()
886
+ mock_session.close = AsyncMock()
887
+
888
+ # Mock the get_sqlalchemy_session method directly
889
+ with patch.object(tx_wrapper, 'get_sqlalchemy_session', return_value=mock_session):
890
+ session = await tx_wrapper.get_sqlalchemy_session()
891
+
892
+ self.assertEqual(session, mock_session)
893
+
894
+ async def test_transaction_wrapper_commit_sqlalchemy(self):
895
+ """Test transaction wrapper SQLAlchemy commit"""
896
+ tx_wrapper = AsyncTransactionWrapper(self.mock_connection, self.client)
897
+
898
+ # Mock SQLAlchemy session
899
+ mock_session = AsyncMock()
900
+ tx_wrapper._sqlalchemy_session = mock_session
901
+
902
+ await tx_wrapper.commit_sqlalchemy()
903
+
904
+ mock_session.commit.assert_called_once()
905
+
906
+ async def test_transaction_wrapper_rollback_sqlalchemy(self):
907
+ """Test transaction wrapper SQLAlchemy rollback"""
908
+ tx_wrapper = AsyncTransactionWrapper(self.mock_connection, self.client)
909
+
910
+ # Mock SQLAlchemy session
911
+ mock_session = AsyncMock()
912
+ tx_wrapper._sqlalchemy_session = mock_session
913
+
914
+ await tx_wrapper.rollback_sqlalchemy()
915
+
916
+ mock_session.rollback.assert_called_once()
917
+
918
+ async def test_transaction_wrapper_close_sqlalchemy(self):
919
+ """Test transaction wrapper SQLAlchemy close"""
920
+ tx_wrapper = AsyncTransactionWrapper(self.mock_connection, self.client)
921
+
922
+ # Mock SQLAlchemy session and engine
923
+ mock_session = AsyncMock()
924
+ mock_session.close = AsyncMock()
925
+ mock_engine = Mock()
926
+ mock_engine.dispose = AsyncMock()
927
+ tx_wrapper._sqlalchemy_session = mock_session
928
+ tx_wrapper._sqlalchemy_engine = mock_engine
929
+
930
+ await tx_wrapper.close_sqlalchemy()
931
+
932
+ mock_session.close.assert_called_once()
933
+ mock_engine.dispose.assert_called_once()
934
+ self.assertIsNone(tx_wrapper._sqlalchemy_session)
935
+ self.assertIsNone(tx_wrapper._sqlalchemy_engine)
936
+
937
+ async def test_transaction_success_flow(self):
938
+ """Test successful transaction flow"""
939
+
940
+ # Create a mock result class that properly implements the interface
941
+ class MockResult:
942
+ def __init__(self):
943
+ self.returns_rows = False
944
+ self.rowcount = 1
945
+
946
+ # Create a mock connection class
947
+ class MockConnection:
948
+ def __init__(self):
949
+ self.execute_called = False
950
+ self.execute_args = None
951
+
952
+ async def exec_driver_sql(self, sql):
953
+ self.execute_called = True
954
+ self.execute_args = (sql, None)
955
+ return MockResult()
956
+
957
+ async def execute(self, sql, params=None):
958
+ self.execute_called = True
959
+ self.execute_args = (sql, params)
960
+ return MockResult()
961
+
962
+ # Replace the mock connection with our real mock
963
+ mock_connection = MockConnection()
964
+ self.mock_engine.connection = mock_connection
965
+
966
+ # Mock SQLAlchemy session
967
+ mock_session = AsyncMock()
968
+ mock_session.begin = AsyncMock()
969
+ mock_session.commit = AsyncMock()
970
+ mock_session.rollback = AsyncMock()
971
+ mock_session.close = AsyncMock()
972
+
973
+ # Mock the transaction wrapper's SQLAlchemy methods
974
+ with patch.object(AsyncTransactionWrapper, 'get_sqlalchemy_session', return_value=mock_session), patch.object(
975
+ AsyncTransactionWrapper, 'commit_sqlalchemy', return_value=None
976
+ ), patch.object(AsyncTransactionWrapper, 'rollback_sqlalchemy', return_value=None), patch.object(
977
+ AsyncTransactionWrapper, 'close_sqlalchemy', return_value=None
978
+ ):
979
+
980
+ async with self.client.transaction() as tx:
981
+ # Test SQLAlchemy session
982
+ session = await tx.get_sqlalchemy_session()
983
+ self.assertEqual(session, mock_session)
984
+
985
+ # Test MatrixOne async operations
986
+ result = await tx.execute("INSERT INTO users (name) VALUES (%s)", ("Alice",))
987
+ self.assertIsNotNone(result)
988
+
989
+ # Test snapshot operations
990
+ snapshot = await tx.snapshots.create("test_snapshot", SnapshotLevel.DATABASE, database="test")
991
+ self.assertIsNotNone(snapshot)
992
+
993
+ # Test clone operations
994
+ await tx.clone.clone_database("backup", "test")
995
+
996
+ async def test_transaction_sqlalchemy_session_not_connected(self):
997
+ """Test SQLAlchemy session creation when not connected"""
998
+ tx_wrapper = AsyncTransactionWrapper(self.mock_connection, self.client)
999
+ self.client._connection_params = {}
1000
+
1001
+ with self.assertRaises(ConnectionError):
1002
+ await tx_wrapper.get_sqlalchemy_session()
1003
+
1004
+
1005
+ class TestAsyncSQLAlchemyIntegration(unittest.IsolatedAsyncioTestCase):
1006
+ """Test Async SQLAlchemy Integration Patterns"""
1007
+
1008
+ @classmethod
1009
+ def setUpClass(cls):
1010
+ """Setup mocks for the entire test class"""
1011
+ setup_sqlalchemy_mocks()
1012
+
1013
+ @classmethod
1014
+ def tearDownClass(cls):
1015
+ """Restore original modules after tests"""
1016
+ teardown_sqlalchemy_mocks()
1017
+
1018
+ def setUp(self):
1019
+ """Set up test fixtures"""
1020
+ self.client = AsyncClient()
1021
+ self.mock_connection = AsyncMock()
1022
+
1023
+ # Create a mock engine class that properly implements begin()
1024
+ class MockEngine:
1025
+ def __init__(self, connection):
1026
+ self.connection = connection
1027
+
1028
+ def begin(self):
1029
+ # Create a proper async context manager for engine.begin()
1030
+ class MockBeginContext:
1031
+ def __init__(self, connection):
1032
+ self.connection = connection
1033
+
1034
+ async def __aenter__(self):
1035
+ return self.connection
1036
+
1037
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
1038
+ pass
1039
+
1040
+ return MockBeginContext(self.connection)
1041
+
1042
+ self.mock_engine = MockEngine(self.mock_connection)
1043
+ self.client._engine = self.mock_engine
1044
+
1045
+ # Set up connection parameters for SQLAlchemy integration
1046
+ self.client._connection_params = {
1047
+ 'user': 'testuser',
1048
+ 'password': 'testpass',
1049
+ 'host': 'localhost',
1050
+ 'port': 6001,
1051
+ 'db': 'testdb',
1052
+ }
1053
+
1054
+ async def test_mixed_operations_pattern(self):
1055
+ """Test mixed SQLAlchemy and MatrixOne operations"""
1056
+
1057
+ # Create a mock result class that properly implements the interface
1058
+ class MockResult:
1059
+ def __init__(self):
1060
+ self.returns_rows = True
1061
+
1062
+ def fetchall(self):
1063
+ return [(1, 'Alice'), (2, 'Bob')]
1064
+
1065
+ def keys(self):
1066
+ return ['id', 'name']
1067
+
1068
+ # Create a mock connection class
1069
+ class MockConnection:
1070
+ def __init__(self):
1071
+ self.execute_called = False
1072
+ self.execute_args = None
1073
+
1074
+ async def execute(self, sql, params=None):
1075
+ self.execute_called = True
1076
+ self.execute_args = (sql, params)
1077
+ return MockResult()
1078
+
1079
+ # Replace the mock connection with our real mock
1080
+ mock_connection = MockConnection()
1081
+ self.mock_engine.connection = mock_connection
1082
+
1083
+ # Mock SQLAlchemy session
1084
+ mock_session = AsyncMock()
1085
+ mock_session.begin = AsyncMock()
1086
+ mock_session.commit = AsyncMock()
1087
+ mock_session.rollback = AsyncMock()
1088
+ mock_session.close = AsyncMock()
1089
+
1090
+ # Mock SQLAlchemy model
1091
+ mock_user = Mock()
1092
+ mock_user.id = 1
1093
+ mock_user.name = "Alice"
1094
+ mock_user.email = "alice@example.com"
1095
+
1096
+ # Mock query method properly
1097
+ mock_query = Mock()
1098
+ mock_query.all.return_value = [mock_user]
1099
+ mock_session.query = Mock(return_value=mock_query)
1100
+
1101
+ # Mock the transaction wrapper's SQLAlchemy methods
1102
+ with patch.object(AsyncTransactionWrapper, 'get_sqlalchemy_session', return_value=mock_session), patch.object(
1103
+ AsyncTransactionWrapper, 'commit_sqlalchemy', return_value=None
1104
+ ), patch.object(AsyncTransactionWrapper, 'rollback_sqlalchemy', return_value=None), patch.object(
1105
+ AsyncTransactionWrapper, 'close_sqlalchemy', return_value=None
1106
+ ):
1107
+
1108
+ async with self.client.transaction() as tx:
1109
+ session = await tx.get_sqlalchemy_session()
1110
+
1111
+ # SQLAlchemy operations
1112
+ users = session.query(mock_user).all()
1113
+ self.assertEqual(len(users), 1)
1114
+
1115
+ # MatrixOne async operations
1116
+ result = await tx.execute("SELECT id, name FROM users")
1117
+ self.assertEqual(len(result.rows), 2)
1118
+
1119
+ # Snapshot operations
1120
+ snapshot = await tx.snapshots.create("mixed_snapshot", SnapshotLevel.DATABASE, database="test")
1121
+ self.assertIsNotNone(snapshot)
1122
+
1123
+ # Clone operations
1124
+ await tx.clone.clone_database("mixed_backup", "test")
1125
+
1126
+ async def test_error_handling_pattern(self):
1127
+ """Test error handling in mixed operations"""
1128
+
1129
+ # Create a mock connection class that raises an exception
1130
+ class MockConnection:
1131
+ def __init__(self):
1132
+ self.execute_called = False
1133
+ self.execute_args = None
1134
+
1135
+ async def execute(self, sql, params=None):
1136
+ self.execute_called = True
1137
+ self.execute_args = (sql, params)
1138
+ raise Exception("Database error")
1139
+
1140
+ # Replace the mock connection with our real mock
1141
+ mock_connection = MockConnection()
1142
+ self.mock_engine.connection = mock_connection
1143
+
1144
+ # Mock SQLAlchemy session
1145
+ mock_session = AsyncMock()
1146
+ mock_session.begin = AsyncMock()
1147
+ mock_session.commit = AsyncMock()
1148
+ mock_session.rollback = AsyncMock()
1149
+ mock_session.close = AsyncMock()
1150
+ mock_session.add = Mock()
1151
+ mock_session.flush = Mock()
1152
+
1153
+ # Mock the transaction wrapper's SQLAlchemy methods
1154
+ with patch.object(AsyncTransactionWrapper, 'get_sqlalchemy_session', return_value=mock_session), patch.object(
1155
+ AsyncTransactionWrapper, 'commit_sqlalchemy', return_value=None
1156
+ ), patch.object(AsyncTransactionWrapper, 'rollback_sqlalchemy', return_value=None), patch.object(
1157
+ AsyncTransactionWrapper, 'close_sqlalchemy', return_value=None
1158
+ ):
1159
+
1160
+ with self.assertRaises(Exception):
1161
+ async with self.client.transaction() as tx:
1162
+ session = await tx.get_sqlalchemy_session()
1163
+
1164
+ # SQLAlchemy operation (should succeed)
1165
+ mock_user = Mock()
1166
+ session.add(mock_user)
1167
+ session.flush()
1168
+
1169
+ # MatrixOne operation (should fail)
1170
+ await tx.execute("INSERT INTO users (name) VALUES (%s)", ("Alice",))
1171
+
1172
+
1173
+ def run_async_test(test_func):
1174
+ """Helper function to run async tests"""
1175
+ loop = asyncio.new_event_loop()
1176
+ asyncio.set_event_loop(loop)
1177
+ try:
1178
+ return loop.run_until_complete(test_func())
1179
+ finally:
1180
+ loop.close()
1181
+
1182
+
1183
+ if __name__ == '__main__':
1184
+ # Create a test suite
1185
+ test_suite = unittest.TestSuite()
1186
+
1187
+ # Add test cases using TestLoader
1188
+ loader = unittest.TestLoader()
1189
+ test_suite.addTests(loader.loadTestsFromTestCase(TestAsyncResultSet))
1190
+ test_suite.addTests(loader.loadTestsFromTestCase(TestAsyncClientBasic))
1191
+ test_suite.addTests(loader.loadTestsFromTestCase(TestAsyncClientConnection))
1192
+ test_suite.addTests(loader.loadTestsFromTestCase(TestAsyncClientQuery))
1193
+ test_suite.addTests(loader.loadTestsFromTestCase(TestAsyncSnapshotManager))
1194
+ test_suite.addTests(loader.loadTestsFromTestCase(TestAsyncCloneManager))
1195
+ test_suite.addTests(loader.loadTestsFromTestCase(TestAsyncMoCtlManager))
1196
+ test_suite.addTests(loader.loadTestsFromTestCase(TestAsyncTransaction))
1197
+ test_suite.addTests(loader.loadTestsFromTestCase(TestAsyncSQLAlchemyTransaction))
1198
+ test_suite.addTests(loader.loadTestsFromTestCase(TestAsyncSQLAlchemyIntegration))
1199
+
1200
+ # Run the tests
1201
+ runner = unittest.TextTestRunner(verbosity=2)
1202
+ result = runner.run(test_suite)
1203
+
1204
+ # Print summary
1205
+ print(f"\n{'='*50}")
1206
+ print(f"Tests run: {result.testsRun}")
1207
+ print(f"Failures: {len(result.failures)}")
1208
+ print(f"Errors: {len(result.errors)}")
1209
+ if result.testsRun > 0:
1210
+ success_rate = (result.testsRun - len(result.failures) - len(result.errors)) / result.testsRun * 100
1211
+ print(f"Success rate: {success_rate:.1f}%")
1212
+ print(f"{'='*50}")
1213
+
1214
+ # Exit with appropriate code
1215
+ if result.failures or result.errors:
1216
+ sys.exit(1)
1217
+ else:
1218
+ sys.exit(0)