matrixone-python-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- matrixone/__init__.py +155 -0
- matrixone/account.py +723 -0
- matrixone/async_client.py +3913 -0
- matrixone/async_metadata_manager.py +311 -0
- matrixone/async_orm.py +123 -0
- matrixone/async_vector_index_manager.py +633 -0
- matrixone/base_client.py +208 -0
- matrixone/client.py +4672 -0
- matrixone/config.py +452 -0
- matrixone/connection_hooks.py +286 -0
- matrixone/exceptions.py +89 -0
- matrixone/logger.py +782 -0
- matrixone/metadata.py +820 -0
- matrixone/moctl.py +219 -0
- matrixone/orm.py +2277 -0
- matrixone/pitr.py +646 -0
- matrixone/pubsub.py +771 -0
- matrixone/restore.py +411 -0
- matrixone/search_vector_index.py +1176 -0
- matrixone/snapshot.py +550 -0
- matrixone/sql_builder.py +844 -0
- matrixone/sqlalchemy_ext/__init__.py +161 -0
- matrixone/sqlalchemy_ext/adapters.py +163 -0
- matrixone/sqlalchemy_ext/dialect.py +534 -0
- matrixone/sqlalchemy_ext/fulltext_index.py +895 -0
- matrixone/sqlalchemy_ext/fulltext_search.py +1686 -0
- matrixone/sqlalchemy_ext/hnsw_config.py +194 -0
- matrixone/sqlalchemy_ext/ivf_config.py +252 -0
- matrixone/sqlalchemy_ext/table_builder.py +351 -0
- matrixone/sqlalchemy_ext/vector_index.py +1721 -0
- matrixone/sqlalchemy_ext/vector_type.py +948 -0
- matrixone/version.py +580 -0
- matrixone_python_sdk-0.1.0.dist-info/METADATA +706 -0
- matrixone_python_sdk-0.1.0.dist-info/RECORD +122 -0
- matrixone_python_sdk-0.1.0.dist-info/WHEEL +5 -0
- matrixone_python_sdk-0.1.0.dist-info/entry_points.txt +5 -0
- matrixone_python_sdk-0.1.0.dist-info/licenses/LICENSE +200 -0
- matrixone_python_sdk-0.1.0.dist-info/top_level.txt +2 -0
- tests/__init__.py +19 -0
- tests/offline/__init__.py +20 -0
- tests/offline/conftest.py +77 -0
- tests/offline/test_account.py +703 -0
- tests/offline/test_async_client_query_comprehensive.py +1218 -0
- tests/offline/test_basic.py +54 -0
- tests/offline/test_case_sensitivity.py +227 -0
- tests/offline/test_connection_hooks_offline.py +287 -0
- tests/offline/test_dialect_schema_handling.py +609 -0
- tests/offline/test_explain_methods.py +346 -0
- tests/offline/test_filter_logical_in.py +237 -0
- tests/offline/test_fulltext_search_comprehensive.py +795 -0
- tests/offline/test_ivf_config.py +249 -0
- tests/offline/test_join_methods.py +281 -0
- tests/offline/test_join_sqlalchemy_compatibility.py +276 -0
- tests/offline/test_logical_in_method.py +237 -0
- tests/offline/test_matrixone_version_parsing.py +264 -0
- tests/offline/test_metadata_offline.py +557 -0
- tests/offline/test_moctl.py +300 -0
- tests/offline/test_moctl_simple.py +251 -0
- tests/offline/test_model_support_offline.py +359 -0
- tests/offline/test_model_support_simple.py +225 -0
- tests/offline/test_pinecone_filter_offline.py +377 -0
- tests/offline/test_pitr.py +585 -0
- tests/offline/test_pubsub.py +712 -0
- tests/offline/test_query_update.py +283 -0
- tests/offline/test_restore.py +445 -0
- tests/offline/test_snapshot_comprehensive.py +384 -0
- tests/offline/test_sql_escaping_edge_cases.py +551 -0
- tests/offline/test_sqlalchemy_integration.py +382 -0
- tests/offline/test_sqlalchemy_vector_integration.py +434 -0
- tests/offline/test_table_builder.py +198 -0
- tests/offline/test_unified_filter.py +398 -0
- tests/offline/test_unified_transaction.py +495 -0
- tests/offline/test_vector_index.py +238 -0
- tests/offline/test_vector_operations.py +688 -0
- tests/offline/test_vector_type.py +174 -0
- tests/offline/test_version_core.py +328 -0
- tests/offline/test_version_management.py +372 -0
- tests/offline/test_version_standalone.py +652 -0
- tests/online/__init__.py +20 -0
- tests/online/conftest.py +216 -0
- tests/online/test_account_management.py +194 -0
- tests/online/test_advanced_features.py +344 -0
- tests/online/test_async_client_interfaces.py +330 -0
- tests/online/test_async_client_online.py +285 -0
- tests/online/test_async_model_insert_online.py +293 -0
- tests/online/test_async_orm_online.py +300 -0
- tests/online/test_async_simple_query_online.py +802 -0
- tests/online/test_async_transaction_simple_query.py +300 -0
- tests/online/test_basic_connection.py +130 -0
- tests/online/test_client_online.py +238 -0
- tests/online/test_config.py +90 -0
- tests/online/test_config_validation.py +123 -0
- tests/online/test_connection_hooks_new_online.py +217 -0
- tests/online/test_dialect_schema_handling_online.py +331 -0
- tests/online/test_filter_logical_in_online.py +374 -0
- tests/online/test_fulltext_comprehensive.py +1773 -0
- tests/online/test_fulltext_label_online.py +433 -0
- tests/online/test_fulltext_search_online.py +842 -0
- tests/online/test_ivf_stats_online.py +506 -0
- tests/online/test_logger_integration.py +311 -0
- tests/online/test_matrixone_query_orm.py +540 -0
- tests/online/test_metadata_online.py +579 -0
- tests/online/test_model_insert_online.py +255 -0
- tests/online/test_mysql_driver_validation.py +213 -0
- tests/online/test_orm_advanced_features.py +2022 -0
- tests/online/test_orm_cte_integration.py +269 -0
- tests/online/test_orm_online.py +270 -0
- tests/online/test_pinecone_filter.py +708 -0
- tests/online/test_pubsub_operations.py +352 -0
- tests/online/test_query_methods.py +225 -0
- tests/online/test_query_update_online.py +433 -0
- tests/online/test_search_vector_index.py +557 -0
- tests/online/test_simple_fulltext_online.py +915 -0
- tests/online/test_snapshot_comprehensive.py +998 -0
- tests/online/test_sqlalchemy_engine_integration.py +336 -0
- tests/online/test_sqlalchemy_integration.py +425 -0
- tests/online/test_transaction_contexts.py +1219 -0
- tests/online/test_transaction_insert_methods.py +356 -0
- tests/online/test_transaction_query_methods.py +288 -0
- tests/online/test_unified_filter_online.py +529 -0
- tests/online/test_vector_comprehensive.py +706 -0
- tests/online/test_version_management.py +291 -0
@@ -0,0 +1,1218 @@
|
|
1
|
+
# Copyright 2021 - 2022 Matrix Origin
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""
|
16
|
+
Comprehensive Async Client Query Tests
|
17
|
+
|
18
|
+
This file consolidates all async client and query-related tests from:
|
19
|
+
- test_async.py (comprehensive async client tests)
|
20
|
+
- test_async_simple.py (basic async client tests)
|
21
|
+
- test_async_sqlalchemy_transaction.py (SQLAlchemy integration tests)
|
22
|
+
|
23
|
+
The merged file eliminates redundancy while maintaining full test coverage.
|
24
|
+
"""
|
25
|
+
|
26
|
+
import unittest
|
27
|
+
import asyncio
|
28
|
+
import pytest
|
29
|
+
from unittest.mock import Mock, patch, AsyncMock
|
30
|
+
import sys
|
31
|
+
import os
|
32
|
+
from datetime import datetime
|
33
|
+
|
34
|
+
# Store original modules to restore later
|
35
|
+
_original_modules = {}
|
36
|
+
|
37
|
+
|
38
|
+
def setup_sqlalchemy_mocks():
|
39
|
+
"""Setup SQLAlchemy mocks for this test class"""
|
40
|
+
global _original_modules
|
41
|
+
_original_modules['pymysql'] = sys.modules.get('pymysql')
|
42
|
+
_original_modules['aiomysql'] = sys.modules.get('aiomysql')
|
43
|
+
_original_modules['sqlalchemy'] = sys.modules.get('sqlalchemy')
|
44
|
+
_original_modules['sqlalchemy.engine'] = sys.modules.get('sqlalchemy.engine')
|
45
|
+
_original_modules['sqlalchemy.orm'] = sys.modules.get('sqlalchemy.orm')
|
46
|
+
_original_modules['sqlalchemy.ext'] = sys.modules.get('sqlalchemy.ext')
|
47
|
+
_original_modules['sqlalchemy.ext.asyncio'] = sys.modules.get('sqlalchemy.ext.asyncio')
|
48
|
+
|
49
|
+
# Mock the external dependencies
|
50
|
+
sys.modules['pymysql'] = Mock()
|
51
|
+
sys.modules['aiomysql'] = Mock()
|
52
|
+
|
53
|
+
# Create a more sophisticated SQLAlchemy mock that supports submodules
|
54
|
+
sqlalchemy_mock = Mock()
|
55
|
+
sqlalchemy_mock.create_engine = Mock()
|
56
|
+
sqlalchemy_mock.text = Mock()
|
57
|
+
sqlalchemy_mock.Column = Mock()
|
58
|
+
sqlalchemy_mock.Integer = Mock()
|
59
|
+
sqlalchemy_mock.String = Mock()
|
60
|
+
sqlalchemy_mock.DateTime = Mock()
|
61
|
+
|
62
|
+
sys.modules['sqlalchemy'] = sqlalchemy_mock
|
63
|
+
sys.modules['sqlalchemy.engine'] = Mock()
|
64
|
+
sys.modules['sqlalchemy.engine'].Engine = Mock()
|
65
|
+
sys.modules['sqlalchemy.orm'] = Mock()
|
66
|
+
sys.modules['sqlalchemy.orm'].sessionmaker = Mock()
|
67
|
+
sys.modules['sqlalchemy.orm'].declarative_base = Mock()
|
68
|
+
|
69
|
+
# Mock SQLAlchemy async engine
|
70
|
+
sys.modules['sqlalchemy.ext'] = Mock()
|
71
|
+
sys.modules['sqlalchemy.ext.asyncio'] = Mock()
|
72
|
+
sys.modules['sqlalchemy.ext.asyncio'].create_async_engine = Mock()
|
73
|
+
sys.modules['sqlalchemy.ext.asyncio'].AsyncEngine = Mock()
|
74
|
+
sys.modules['sqlalchemy.ext.asyncio'].AsyncSession = Mock()
|
75
|
+
sys.modules['sqlalchemy.ext.asyncio'].async_sessionmaker = Mock()
|
76
|
+
|
77
|
+
|
78
|
+
def teardown_sqlalchemy_mocks():
|
79
|
+
"""Restore original modules"""
|
80
|
+
global _original_modules
|
81
|
+
for module_name, original_module in _original_modules.items():
|
82
|
+
if original_module is not None:
|
83
|
+
sys.modules[module_name] = original_module
|
84
|
+
elif module_name in sys.modules:
|
85
|
+
del sys.modules[module_name]
|
86
|
+
|
87
|
+
|
88
|
+
# Add the matrixone package to the path
|
89
|
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))
|
90
|
+
|
91
|
+
from matrixone.async_client import (
|
92
|
+
AsyncClient,
|
93
|
+
AsyncResultSet,
|
94
|
+
AsyncSnapshotManager,
|
95
|
+
AsyncCloneManager,
|
96
|
+
AsyncMoCtlManager,
|
97
|
+
AsyncTransactionWrapper,
|
98
|
+
)
|
99
|
+
from matrixone.snapshot import SnapshotLevel, Snapshot
|
100
|
+
from matrixone.exceptions import MoCtlError, ConnectionError
|
101
|
+
|
102
|
+
|
103
|
+
class TestAsyncResultSet(unittest.TestCase):
|
104
|
+
"""Test AsyncResultSet functionality"""
|
105
|
+
|
106
|
+
@classmethod
|
107
|
+
def setUpClass(cls):
|
108
|
+
"""Setup mocks for the entire test class"""
|
109
|
+
setup_sqlalchemy_mocks()
|
110
|
+
|
111
|
+
@classmethod
|
112
|
+
def tearDownClass(cls):
|
113
|
+
"""Restore original modules after tests"""
|
114
|
+
teardown_sqlalchemy_mocks()
|
115
|
+
|
116
|
+
def test_async_result_set(self):
|
117
|
+
"""Test AsyncResultSet basic functionality"""
|
118
|
+
columns = ['id', 'name', 'email']
|
119
|
+
rows = [(1, 'Alice', 'alice@example.com'), (2, 'Bob', 'bob@example.com')]
|
120
|
+
|
121
|
+
result = AsyncResultSet(columns, rows)
|
122
|
+
|
123
|
+
self.assertEqual(result.columns, columns)
|
124
|
+
self.assertEqual(result.rows, rows)
|
125
|
+
self.assertEqual(len(result), 2)
|
126
|
+
self.assertEqual(result.fetchone(), (1, 'Alice', 'alice@example.com'))
|
127
|
+
self.assertEqual(result.scalar(), 1)
|
128
|
+
|
129
|
+
# Test iteration
|
130
|
+
row_list = list(result)
|
131
|
+
self.assertEqual(len(row_list), 2)
|
132
|
+
|
133
|
+
|
134
|
+
class TestAsyncClientBasic(unittest.IsolatedAsyncioTestCase):
|
135
|
+
"""Test AsyncClient basic functionality and imports"""
|
136
|
+
|
137
|
+
@classmethod
|
138
|
+
def setUpClass(cls):
|
139
|
+
"""Setup mocks for the entire test class"""
|
140
|
+
setup_sqlalchemy_mocks()
|
141
|
+
|
142
|
+
@classmethod
|
143
|
+
def tearDownClass(cls):
|
144
|
+
"""Restore original modules after tests"""
|
145
|
+
teardown_sqlalchemy_mocks()
|
146
|
+
|
147
|
+
def setUp(self):
|
148
|
+
"""Set up test fixtures"""
|
149
|
+
self.client = AsyncClient()
|
150
|
+
|
151
|
+
def test_init(self):
|
152
|
+
"""Test AsyncClient initialization"""
|
153
|
+
self.assertEqual(self.client.connection_timeout, 30)
|
154
|
+
self.assertEqual(self.client.query_timeout, 300)
|
155
|
+
self.assertEqual(self.client.auto_commit, True)
|
156
|
+
self.assertEqual(self.client.charset, 'utf8mb4')
|
157
|
+
self.assertIsNone(self.client._engine)
|
158
|
+
self.assertIsNone(self.client._connection)
|
159
|
+
self.assertIsInstance(self.client._snapshots, AsyncSnapshotManager)
|
160
|
+
self.assertIsInstance(self.client._clone, AsyncCloneManager)
|
161
|
+
self.assertIsInstance(self.client._moctl, AsyncMoCtlManager)
|
162
|
+
|
163
|
+
def test_managers_properties(self):
|
164
|
+
"""Test async managers properties"""
|
165
|
+
# Test snapshot manager
|
166
|
+
snapshot_manager = self.client.snapshots
|
167
|
+
self.assertIsInstance(snapshot_manager, AsyncSnapshotManager)
|
168
|
+
|
169
|
+
# Test clone manager
|
170
|
+
clone_manager = self.client.clone
|
171
|
+
self.assertIsInstance(clone_manager, AsyncCloneManager)
|
172
|
+
|
173
|
+
# Test moctl manager
|
174
|
+
moctl_manager = self.client.moctl
|
175
|
+
self.assertIsInstance(moctl_manager, AsyncMoCtlManager)
|
176
|
+
|
177
|
+
def test_imports(self):
|
178
|
+
"""Test async imports"""
|
179
|
+
try:
|
180
|
+
from matrixone import AsyncClient, AsyncResultSet
|
181
|
+
from matrixone.async_client import (
|
182
|
+
AsyncSnapshotManager,
|
183
|
+
AsyncCloneManager,
|
184
|
+
AsyncMoCtlManager,
|
185
|
+
)
|
186
|
+
from matrixone.snapshot import SnapshotLevel
|
187
|
+
|
188
|
+
# Test enum import
|
189
|
+
self.assertIsNotNone(list(SnapshotLevel))
|
190
|
+
self.assertTrue(True) # If we get here, imports work
|
191
|
+
except ImportError as e:
|
192
|
+
self.fail(f"Import error: {e}")
|
193
|
+
|
194
|
+
async def test_context_manager(self):
|
195
|
+
"""Test async context manager"""
|
196
|
+
with patch.object(self.client, 'connect') as mock_connect, patch.object(
|
197
|
+
self.client, 'disconnect'
|
198
|
+
) as mock_disconnect, patch.object(self.client, 'connected') as mock_connected:
|
199
|
+
|
200
|
+
# Mock connected to return True so disconnect will be called
|
201
|
+
mock_connected.return_value = True
|
202
|
+
|
203
|
+
async with self.client as client:
|
204
|
+
self.assertEqual(client, self.client)
|
205
|
+
|
206
|
+
# disconnect should be called since connected() returns True
|
207
|
+
mock_disconnect.assert_called_once()
|
208
|
+
|
209
|
+
|
210
|
+
class TestAsyncClientConnection(unittest.IsolatedAsyncioTestCase):
|
211
|
+
"""Test AsyncClient connection functionality"""
|
212
|
+
|
213
|
+
@classmethod
|
214
|
+
def setUpClass(cls):
|
215
|
+
"""Setup mocks for the entire test class"""
|
216
|
+
setup_sqlalchemy_mocks()
|
217
|
+
|
218
|
+
@classmethod
|
219
|
+
def tearDownClass(cls):
|
220
|
+
"""Restore original modules after tests"""
|
221
|
+
teardown_sqlalchemy_mocks()
|
222
|
+
|
223
|
+
def setUp(self):
|
224
|
+
"""Set up test fixtures"""
|
225
|
+
self.client = AsyncClient()
|
226
|
+
|
227
|
+
@patch('matrixone.async_client.create_async_engine')
|
228
|
+
async def test_connect(self, mock_create_async_engine):
|
229
|
+
"""Test async connection"""
|
230
|
+
# Create mock connection and result
|
231
|
+
mock_connection = AsyncMock()
|
232
|
+
mock_result = AsyncMock()
|
233
|
+
mock_result.returns_rows = False
|
234
|
+
|
235
|
+
# Create a proper async context manager for engine.begin()
|
236
|
+
class MockBeginContext:
|
237
|
+
def __init__(self, connection):
|
238
|
+
self.connection = connection
|
239
|
+
|
240
|
+
async def __aenter__(self):
|
241
|
+
return self.connection
|
242
|
+
|
243
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
244
|
+
pass
|
245
|
+
|
246
|
+
# Create a mock engine class that properly implements begin()
|
247
|
+
class MockEngine:
|
248
|
+
def __init__(self, connection):
|
249
|
+
self.connection = connection
|
250
|
+
|
251
|
+
def begin(self):
|
252
|
+
return MockBeginContext(self.connection)
|
253
|
+
|
254
|
+
# Mock the connection.execute method - make it async
|
255
|
+
async def mock_execute(sql, params=None):
|
256
|
+
return mock_result
|
257
|
+
|
258
|
+
mock_connection.execute = mock_execute
|
259
|
+
|
260
|
+
# Create mock engine instance
|
261
|
+
mock_engine = MockEngine(mock_connection)
|
262
|
+
|
263
|
+
# Mock create_async_engine to return our mock engine
|
264
|
+
mock_create_async_engine.return_value = mock_engine
|
265
|
+
|
266
|
+
await self.client.connect(host="localhost", port=6001, user="root", password="111", database="test")
|
267
|
+
|
268
|
+
# Verify create_async_engine was called
|
269
|
+
mock_create_async_engine.assert_called_once()
|
270
|
+
self.assertEqual(self.client._engine, mock_engine)
|
271
|
+
self.assertEqual(self.client._connection_params['host'], 'localhost')
|
272
|
+
self.assertEqual(self.client._connection_params['port'], 6001)
|
273
|
+
self.assertEqual(self.client._connection_params['user'], 'root')
|
274
|
+
self.assertEqual(self.client._connection_params['password'], '111')
|
275
|
+
self.assertEqual(self.client._connection_params['database'], 'test')
|
276
|
+
|
277
|
+
@patch('matrixone.async_client.create_async_engine')
|
278
|
+
async def test_connect_failure(self, mock_create_async_engine):
|
279
|
+
"""Test async connection failure"""
|
280
|
+
mock_create_async_engine.side_effect = Exception("Connection failed")
|
281
|
+
|
282
|
+
with self.assertRaises(Exception):
|
283
|
+
await self.client.connect(host="localhost", port=6001, user="root", password="111", database="test")
|
284
|
+
|
285
|
+
async def test_disconnect(self):
|
286
|
+
"""Test async disconnection"""
|
287
|
+
mock_engine = AsyncMock()
|
288
|
+
mock_engine.dispose = AsyncMock()
|
289
|
+
|
290
|
+
self.client._engine = mock_engine
|
291
|
+
|
292
|
+
await self.client.disconnect()
|
293
|
+
|
294
|
+
mock_engine.dispose.assert_called_once()
|
295
|
+
self.assertIsNone(self.client._engine)
|
296
|
+
|
297
|
+
|
298
|
+
class TestAsyncClientQuery(unittest.IsolatedAsyncioTestCase):
|
299
|
+
"""Test AsyncClient query functionality"""
|
300
|
+
|
301
|
+
@classmethod
|
302
|
+
def setUpClass(cls):
|
303
|
+
"""Setup mocks for the entire test class"""
|
304
|
+
setup_sqlalchemy_mocks()
|
305
|
+
|
306
|
+
@classmethod
|
307
|
+
def tearDownClass(cls):
|
308
|
+
"""Restore original modules after tests"""
|
309
|
+
teardown_sqlalchemy_mocks()
|
310
|
+
|
311
|
+
def setUp(self):
|
312
|
+
"""Set up test fixtures"""
|
313
|
+
self.client = AsyncClient()
|
314
|
+
|
315
|
+
async def test_execute_success(self):
|
316
|
+
"""Test successful async execution"""
|
317
|
+
|
318
|
+
# Create a mock result class that properly implements the interface
|
319
|
+
class MockResult:
|
320
|
+
def __init__(self):
|
321
|
+
self.returns_rows = True
|
322
|
+
|
323
|
+
def fetchall(self):
|
324
|
+
return [(1, 'Alice'), (2, 'Bob')]
|
325
|
+
|
326
|
+
def keys(self):
|
327
|
+
return ['id', 'name']
|
328
|
+
|
329
|
+
# Create a mock connection class
|
330
|
+
class MockConnection:
|
331
|
+
def __init__(self):
|
332
|
+
self.execute_called = False
|
333
|
+
self.execute_args = None
|
334
|
+
|
335
|
+
async def execute(self, sql, params=None):
|
336
|
+
self.execute_called = True
|
337
|
+
self.execute_args = (sql, params)
|
338
|
+
return MockResult()
|
339
|
+
|
340
|
+
# Create a proper async context manager for engine.begin()
|
341
|
+
class MockBeginContext:
|
342
|
+
def __init__(self, connection):
|
343
|
+
self.connection = connection
|
344
|
+
|
345
|
+
async def __aenter__(self):
|
346
|
+
return self.connection
|
347
|
+
|
348
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
349
|
+
pass
|
350
|
+
|
351
|
+
# Create a mock engine class that properly implements begin()
|
352
|
+
class MockEngine:
|
353
|
+
def __init__(self, connection):
|
354
|
+
self.connection = connection
|
355
|
+
|
356
|
+
def begin(self):
|
357
|
+
return MockBeginContext(self.connection)
|
358
|
+
|
359
|
+
# Create mock instances
|
360
|
+
mock_connection = MockConnection()
|
361
|
+
mock_engine = MockEngine(mock_connection)
|
362
|
+
|
363
|
+
# Set the mock engine on the client
|
364
|
+
self.client._engine = mock_engine
|
365
|
+
|
366
|
+
result = await self.client.execute("SELECT id, name FROM users")
|
367
|
+
|
368
|
+
self.assertTrue(mock_connection.execute_called)
|
369
|
+
self.assertIsInstance(result, AsyncResultSet)
|
370
|
+
self.assertEqual(result.columns, ['id', 'name'])
|
371
|
+
self.assertEqual(result.rows, [(1, 'Alice'), (2, 'Bob')])
|
372
|
+
|
373
|
+
async def test_execute_with_params(self):
|
374
|
+
"""Test async execution with parameters"""
|
375
|
+
|
376
|
+
# Create a mock result class that properly implements the interface
|
377
|
+
class MockResult:
|
378
|
+
def __init__(self):
|
379
|
+
self.returns_rows = False
|
380
|
+
self.rowcount = 1
|
381
|
+
|
382
|
+
# Create a mock connection class
|
383
|
+
class MockConnection:
|
384
|
+
def __init__(self):
|
385
|
+
self.execute_called = False
|
386
|
+
self.execute_args = None
|
387
|
+
|
388
|
+
async def execute(self, sql, params=None):
|
389
|
+
self.execute_called = True
|
390
|
+
self.execute_args = (sql, params)
|
391
|
+
return MockResult()
|
392
|
+
|
393
|
+
# Create a proper async context manager for engine.begin()
|
394
|
+
class MockBeginContext:
|
395
|
+
def __init__(self, connection):
|
396
|
+
self.connection = connection
|
397
|
+
|
398
|
+
async def __aenter__(self):
|
399
|
+
return self.connection
|
400
|
+
|
401
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
402
|
+
pass
|
403
|
+
|
404
|
+
# Create a mock engine class that properly implements begin()
|
405
|
+
class MockEngine:
|
406
|
+
def __init__(self, connection):
|
407
|
+
self.connection = connection
|
408
|
+
|
409
|
+
def begin(self):
|
410
|
+
return MockBeginContext(self.connection)
|
411
|
+
|
412
|
+
# Create mock instances
|
413
|
+
mock_connection = MockConnection()
|
414
|
+
mock_engine = MockEngine(mock_connection)
|
415
|
+
|
416
|
+
# Set the mock engine on the client
|
417
|
+
self.client._engine = mock_engine
|
418
|
+
|
419
|
+
result = await self.client.execute("INSERT INTO users (name) VALUES (%s)", ("Alice",))
|
420
|
+
|
421
|
+
self.assertTrue(mock_connection.execute_called)
|
422
|
+
self.assertEqual(result.affected_rows, 1)
|
423
|
+
|
424
|
+
async def test_execute_not_connected(self):
|
425
|
+
"""Test async execution without connection"""
|
426
|
+
with self.assertRaises(Exception):
|
427
|
+
await self.client.execute("SELECT 1")
|
428
|
+
|
429
|
+
async def test_snapshot_query(self):
|
430
|
+
"""Test async snapshot query"""
|
431
|
+
|
432
|
+
# Create a proper async context manager for engine.begin()
|
433
|
+
class MockBeginContext:
|
434
|
+
def __init__(self, connection):
|
435
|
+
self.connection = connection
|
436
|
+
|
437
|
+
async def __aenter__(self):
|
438
|
+
return self.connection
|
439
|
+
|
440
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
441
|
+
pass
|
442
|
+
|
443
|
+
# Create a mock engine class that properly implements begin()
|
444
|
+
class MockEngine:
|
445
|
+
def __init__(self, connection):
|
446
|
+
self.connection = connection
|
447
|
+
|
448
|
+
def begin(self):
|
449
|
+
return MockBeginContext(self.connection)
|
450
|
+
|
451
|
+
# Create mock connection and result
|
452
|
+
mock_connection = Mock()
|
453
|
+
mock_result = Mock()
|
454
|
+
|
455
|
+
# Setup mock result
|
456
|
+
mock_result.returns_rows = True
|
457
|
+
mock_result.fetchall.return_value = [(1, 'Alice')]
|
458
|
+
mock_result.keys.return_value = ['id', 'name']
|
459
|
+
|
460
|
+
# Setup mock connection - make exec_driver_sql return a coroutine
|
461
|
+
async def mock_exec_driver_sql(sql):
|
462
|
+
return mock_result
|
463
|
+
|
464
|
+
mock_connection.exec_driver_sql = Mock(side_effect=mock_exec_driver_sql)
|
465
|
+
|
466
|
+
# Also setup execute for fallback
|
467
|
+
async def mock_execute(sql, params=None):
|
468
|
+
return mock_result
|
469
|
+
|
470
|
+
mock_connection.execute = Mock(side_effect=mock_execute)
|
471
|
+
|
472
|
+
# Mock the text function
|
473
|
+
sys.modules['sqlalchemy'].text = Mock(return_value="SELECT id, name FROM users {SNAPSHOT = 'test_snapshot'}")
|
474
|
+
|
475
|
+
# Create mock engine instance
|
476
|
+
mock_engine = MockEngine(mock_connection)
|
477
|
+
|
478
|
+
self.client._engine = mock_engine
|
479
|
+
|
480
|
+
result = await self.client.query("users", snapshot="test_snapshot").select("id", "name").execute()
|
481
|
+
|
482
|
+
expected_sql = "SELECT id, name FROM users {SNAPSHOT = 'test_snapshot'}"
|
483
|
+
# Check that exec_driver_sql or execute was called
|
484
|
+
assert mock_connection.exec_driver_sql.called or mock_connection.execute.called
|
485
|
+
self.assertIsInstance(result, AsyncResultSet)
|
486
|
+
|
487
|
+
|
488
|
+
class TestAsyncSnapshotManager(unittest.IsolatedAsyncioTestCase):
|
489
|
+
"""Test AsyncSnapshotManager functionality"""
|
490
|
+
|
491
|
+
@classmethod
|
492
|
+
def setUpClass(cls):
|
493
|
+
"""Setup mocks for the entire test class"""
|
494
|
+
setup_sqlalchemy_mocks()
|
495
|
+
|
496
|
+
@classmethod
|
497
|
+
def tearDownClass(cls):
|
498
|
+
"""Restore original modules after tests"""
|
499
|
+
teardown_sqlalchemy_mocks()
|
500
|
+
|
501
|
+
def setUp(self):
|
502
|
+
"""Set up test fixtures"""
|
503
|
+
self.mock_client = AsyncMock()
|
504
|
+
self.snapshot_manager = AsyncSnapshotManager(self.mock_client)
|
505
|
+
|
506
|
+
async def test_create_cluster_snapshot(self):
|
507
|
+
"""Test creating cluster snapshot"""
|
508
|
+
self.mock_client.execute.return_value = AsyncResultSet([], [])
|
509
|
+
|
510
|
+
snapshot = await self.snapshot_manager.create("test_snap", SnapshotLevel.CLUSTER)
|
511
|
+
|
512
|
+
self.mock_client.execute.assert_called_once_with("CREATE SNAPSHOT test_snap FOR CLUSTER")
|
513
|
+
self.assertEqual(snapshot.name, "test_snap")
|
514
|
+
self.assertEqual(snapshot.level, SnapshotLevel.CLUSTER)
|
515
|
+
|
516
|
+
async def test_create_database_snapshot(self):
|
517
|
+
"""Test creating database snapshot"""
|
518
|
+
# Mock the result to include database information
|
519
|
+
mock_result = AsyncResultSet([], [])
|
520
|
+
mock_result.database = "test_db"
|
521
|
+
self.mock_client.execute.return_value = mock_result
|
522
|
+
|
523
|
+
snapshot = await self.snapshot_manager.create("test_snap", SnapshotLevel.DATABASE, database="test_db")
|
524
|
+
|
525
|
+
self.mock_client.execute.assert_called_once_with("CREATE SNAPSHOT test_snap FOR DATABASE test_db")
|
526
|
+
self.assertEqual(snapshot.name, "test_snap")
|
527
|
+
self.assertEqual(snapshot.level, SnapshotLevel.DATABASE)
|
528
|
+
self.assertEqual(snapshot.database, "test_db")
|
529
|
+
|
530
|
+
async def test_create_table_snapshot(self):
|
531
|
+
"""Test creating table snapshot"""
|
532
|
+
# Mock the result to include database and table information
|
533
|
+
mock_result = AsyncResultSet([], [])
|
534
|
+
mock_result.database = "test_db"
|
535
|
+
mock_result.table = "test_table"
|
536
|
+
self.mock_client.execute.return_value = mock_result
|
537
|
+
|
538
|
+
snapshot = await self.snapshot_manager.create(
|
539
|
+
"test_snap", SnapshotLevel.TABLE, database="test_db", table="test_table"
|
540
|
+
)
|
541
|
+
|
542
|
+
self.mock_client.execute.assert_called_once_with("CREATE SNAPSHOT test_snap FOR TABLE test_db test_table")
|
543
|
+
self.assertEqual(snapshot.name, "test_snap")
|
544
|
+
self.assertEqual(snapshot.level, SnapshotLevel.TABLE)
|
545
|
+
self.assertEqual(snapshot.database, "test_db")
|
546
|
+
self.assertEqual(snapshot.table, "test_table")
|
547
|
+
|
548
|
+
async def test_get_snapshot(self):
|
549
|
+
"""Test getting snapshot"""
|
550
|
+
# Use integer timestamp instead of datetime object
|
551
|
+
timestamp_ns = int(datetime.now().timestamp() * 1000000000)
|
552
|
+
mock_result = AsyncResultSet(
|
553
|
+
['sname', 'ts', 'level', 'account_name', 'database_name', 'table_name'],
|
554
|
+
[('test_snap', timestamp_ns, 'database', 'sys', 'test_db', None)],
|
555
|
+
)
|
556
|
+
self.mock_client.execute.return_value = mock_result
|
557
|
+
|
558
|
+
snapshot = await self.snapshot_manager.get("test_snap")
|
559
|
+
|
560
|
+
self.assertEqual(snapshot.name, "test_snap")
|
561
|
+
self.assertEqual(snapshot.level, SnapshotLevel.DATABASE)
|
562
|
+
self.assertEqual(snapshot.database, "test_db")
|
563
|
+
|
564
|
+
async def test_list_snapshots(self):
|
565
|
+
"""Test listing snapshots"""
|
566
|
+
# Use integer timestamps instead of datetime objects
|
567
|
+
timestamp_ns = int(datetime.now().timestamp() * 1000000000)
|
568
|
+
mock_result = AsyncResultSet(
|
569
|
+
['sname', 'ts', 'level', 'account_name', 'database_name', 'table_name'],
|
570
|
+
[
|
571
|
+
('snap1', timestamp_ns, 'database', 'sys', 'db1', None),
|
572
|
+
('snap2', timestamp_ns, 'table', 'sys', 'db2', 'table2'),
|
573
|
+
],
|
574
|
+
)
|
575
|
+
self.mock_client.execute.return_value = mock_result
|
576
|
+
|
577
|
+
snapshots = await self.snapshot_manager.list()
|
578
|
+
|
579
|
+
self.assertEqual(len(snapshots), 2)
|
580
|
+
self.assertEqual(snapshots[0].name, "snap1")
|
581
|
+
self.assertEqual(snapshots[1].name, "snap2")
|
582
|
+
|
583
|
+
async def test_delete_snapshot(self):
|
584
|
+
"""Test deleting snapshot"""
|
585
|
+
self.mock_client.execute.return_value = AsyncResultSet([], [])
|
586
|
+
|
587
|
+
await self.snapshot_manager.delete("test_snap")
|
588
|
+
|
589
|
+
self.mock_client.execute.assert_called_once_with("DROP SNAPSHOT test_snap")
|
590
|
+
|
591
|
+
async def test_exists_snapshot(self):
|
592
|
+
"""Test checking snapshot existence"""
|
593
|
+
# Test exists
|
594
|
+
timestamp_ns = int(datetime.now().timestamp() * 1000000000)
|
595
|
+
mock_result = AsyncResultSet(
|
596
|
+
['sname', 'ts', 'level', 'account_name', 'database_name', 'table_name'],
|
597
|
+
[('test_snap', timestamp_ns, 'database', 'sys', 'test_db', None)],
|
598
|
+
)
|
599
|
+
self.mock_client.execute.return_value = mock_result
|
600
|
+
|
601
|
+
exists = await self.snapshot_manager.exists("test_snap")
|
602
|
+
self.assertTrue(exists)
|
603
|
+
|
604
|
+
# Test not exists
|
605
|
+
from matrixone.exceptions import SnapshotError
|
606
|
+
|
607
|
+
self.mock_client.execute.side_effect = SnapshotError("Snapshot 'nonexistent' not found")
|
608
|
+
|
609
|
+
exists = await self.snapshot_manager.exists("nonexistent")
|
610
|
+
self.assertFalse(exists)
|
611
|
+
|
612
|
+
|
613
|
+
class TestAsyncCloneManager(unittest.IsolatedAsyncioTestCase):
|
614
|
+
"""Test AsyncCloneManager functionality"""
|
615
|
+
|
616
|
+
@classmethod
|
617
|
+
def setUpClass(cls):
|
618
|
+
"""Setup mocks for the entire test class"""
|
619
|
+
setup_sqlalchemy_mocks()
|
620
|
+
|
621
|
+
@classmethod
|
622
|
+
def tearDownClass(cls):
|
623
|
+
"""Restore original modules after tests"""
|
624
|
+
teardown_sqlalchemy_mocks()
|
625
|
+
|
626
|
+
def setUp(self):
|
627
|
+
"""Set up test fixtures"""
|
628
|
+
self.mock_client = AsyncMock()
|
629
|
+
self.clone_manager = AsyncCloneManager(self.mock_client)
|
630
|
+
|
631
|
+
async def test_clone_database(self):
|
632
|
+
"""Test cloning database"""
|
633
|
+
self.mock_client.execute.return_value = AsyncResultSet([], [])
|
634
|
+
|
635
|
+
await self.clone_manager.clone_database("target_db", "source_db")
|
636
|
+
|
637
|
+
self.mock_client.execute.assert_called_once_with("CREATE DATABASE target_db CLONE source_db")
|
638
|
+
|
639
|
+
async def test_clone_database_with_snapshot(self):
|
640
|
+
"""Test cloning database with snapshot"""
|
641
|
+
self.mock_client.execute.return_value = AsyncResultSet([], [])
|
642
|
+
|
643
|
+
await self.clone_manager.clone_database_with_snapshot("target_db", "source_db", "test_snapshot")
|
644
|
+
|
645
|
+
self.mock_client.execute.assert_called_once_with(
|
646
|
+
"CREATE DATABASE target_db CLONE source_db FOR SNAPSHOT 'test_snapshot'"
|
647
|
+
)
|
648
|
+
|
649
|
+
async def test_clone_database_if_not_exists(self):
|
650
|
+
"""Test cloning database with if not exists"""
|
651
|
+
self.mock_client.execute.return_value = AsyncResultSet([], [])
|
652
|
+
|
653
|
+
await self.clone_manager.clone_database("target_db", "source_db", if_not_exists=True)
|
654
|
+
|
655
|
+
self.mock_client.execute.assert_called_once_with("CREATE DATABASE target_db IF NOT EXISTS CLONE source_db")
|
656
|
+
|
657
|
+
async def test_clone_table(self):
|
658
|
+
"""Test cloning table"""
|
659
|
+
self.mock_client.execute.return_value = AsyncResultSet([], [])
|
660
|
+
|
661
|
+
await self.clone_manager.clone_table("target_table", "source_table")
|
662
|
+
|
663
|
+
self.mock_client.execute.assert_called_once_with("CREATE TABLE target_table CLONE source_table")
|
664
|
+
|
665
|
+
async def test_clone_table_with_snapshot(self):
|
666
|
+
"""Test cloning table with snapshot"""
|
667
|
+
self.mock_client.execute.return_value = AsyncResultSet([], [])
|
668
|
+
|
669
|
+
await self.clone_manager.clone_table_with_snapshot("target_table", "source_table", "test_snapshot")
|
670
|
+
|
671
|
+
self.mock_client.execute.assert_called_once_with(
|
672
|
+
"CREATE TABLE target_table CLONE source_table FOR SNAPSHOT 'test_snapshot'"
|
673
|
+
)
|
674
|
+
|
675
|
+
|
676
|
+
class TestAsyncMoCtlManager(unittest.IsolatedAsyncioTestCase):
|
677
|
+
"""Test AsyncMoCtlManager functionality"""
|
678
|
+
|
679
|
+
@classmethod
|
680
|
+
def setUpClass(cls):
|
681
|
+
"""Setup mocks for the entire test class"""
|
682
|
+
setup_sqlalchemy_mocks()
|
683
|
+
|
684
|
+
@classmethod
|
685
|
+
def tearDownClass(cls):
|
686
|
+
"""Restore original modules after tests"""
|
687
|
+
teardown_sqlalchemy_mocks()
|
688
|
+
|
689
|
+
def setUp(self):
|
690
|
+
"""Set up test fixtures"""
|
691
|
+
self.mock_client = AsyncMock()
|
692
|
+
self.moctl_manager = AsyncMoCtlManager(self.mock_client)
|
693
|
+
|
694
|
+
async def test_flush_table(self):
|
695
|
+
"""Test async flush table"""
|
696
|
+
mock_result = AsyncResultSet(['result'], [('{"method": "Flush", "result": [{"returnStr": "OK"}]}',)])
|
697
|
+
self.mock_client.execute.return_value = mock_result
|
698
|
+
|
699
|
+
result = await self.moctl_manager.flush_table('db1', 'users')
|
700
|
+
|
701
|
+
expected_sql = "SELECT mo_ctl('dn', 'flush', 'db1.users')"
|
702
|
+
self.mock_client.execute.assert_called_once_with(expected_sql)
|
703
|
+
self.assertEqual(result['method'], 'Flush')
|
704
|
+
self.assertEqual(result['result'][0]['returnStr'], 'OK')
|
705
|
+
|
706
|
+
async def test_increment_checkpoint(self):
|
707
|
+
"""Test async increment checkpoint"""
|
708
|
+
mock_result = AsyncResultSet(['result'], [('{"method": "Checkpoint", "result": [{"returnStr": "OK"}]}',)])
|
709
|
+
self.mock_client.execute.return_value = mock_result
|
710
|
+
|
711
|
+
result = await self.moctl_manager.increment_checkpoint()
|
712
|
+
|
713
|
+
expected_sql = "SELECT mo_ctl('dn', 'checkpoint', '')"
|
714
|
+
self.mock_client.execute.assert_called_once_with(expected_sql)
|
715
|
+
self.assertEqual(result['method'], 'Checkpoint')
|
716
|
+
self.assertEqual(result['result'][0]['returnStr'], 'OK')
|
717
|
+
|
718
|
+
async def test_global_checkpoint(self):
|
719
|
+
"""Test async global checkpoint"""
|
720
|
+
mock_result = AsyncResultSet(['result'], [('{"method": "GlobalCheckpoint", "result": [{"returnStr": "OK"}]}',)])
|
721
|
+
self.mock_client.execute.return_value = mock_result
|
722
|
+
|
723
|
+
result = await self.moctl_manager.global_checkpoint()
|
724
|
+
|
725
|
+
expected_sql = "SELECT mo_ctl('dn', 'globalcheckpoint', '')"
|
726
|
+
self.mock_client.execute.assert_called_once_with(expected_sql)
|
727
|
+
self.assertEqual(result['method'], 'GlobalCheckpoint')
|
728
|
+
self.assertEqual(result['result'][0]['returnStr'], 'OK')
|
729
|
+
|
730
|
+
async def test_moctl_error_handling(self):
|
731
|
+
"""Test async mo_ctl error handling"""
|
732
|
+
mock_result = AsyncResultSet(
|
733
|
+
['result'],
|
734
|
+
[('{"method": "Flush", "result": [{"returnStr": "ERROR: Table not found"}]}',)],
|
735
|
+
)
|
736
|
+
self.mock_client.execute.return_value = mock_result
|
737
|
+
|
738
|
+
with self.assertRaises(MoCtlError) as context:
|
739
|
+
await self.moctl_manager.flush_table('db1', 'nonexistent')
|
740
|
+
|
741
|
+
self.assertIn("ERROR: Table not found", str(context.exception))
|
742
|
+
|
743
|
+
|
744
|
+
class TestAsyncTransaction(unittest.IsolatedAsyncioTestCase):
|
745
|
+
"""Test async transaction functionality"""
|
746
|
+
|
747
|
+
@classmethod
|
748
|
+
def setUpClass(cls):
|
749
|
+
"""Setup mocks for the entire test class"""
|
750
|
+
setup_sqlalchemy_mocks()
|
751
|
+
|
752
|
+
@classmethod
|
753
|
+
def tearDownClass(cls):
|
754
|
+
"""Restore original modules after tests"""
|
755
|
+
teardown_sqlalchemy_mocks()
|
756
|
+
|
757
|
+
def setUp(self):
|
758
|
+
"""Set up test fixtures"""
|
759
|
+
self.client = AsyncClient()
|
760
|
+
self.mock_connection = AsyncMock()
|
761
|
+
|
762
|
+
# Create a mock engine class that properly implements begin()
|
763
|
+
class MockEngine:
|
764
|
+
def __init__(self, connection):
|
765
|
+
self.connection = connection
|
766
|
+
|
767
|
+
def begin(self):
|
768
|
+
# Create a proper async context manager for engine.begin()
|
769
|
+
class MockBeginContext:
|
770
|
+
def __init__(self, connection):
|
771
|
+
self.connection = connection
|
772
|
+
|
773
|
+
async def __aenter__(self):
|
774
|
+
return self.connection
|
775
|
+
|
776
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
777
|
+
pass
|
778
|
+
|
779
|
+
return MockBeginContext(self.connection)
|
780
|
+
|
781
|
+
self.mock_engine = MockEngine(self.mock_connection)
|
782
|
+
self.client._engine = self.mock_engine
|
783
|
+
|
784
|
+
async def test_transaction_success(self):
|
785
|
+
"""Test successful async transaction"""
|
786
|
+
mock_result = AsyncMock()
|
787
|
+
mock_result.returns_rows = False
|
788
|
+
mock_result.rowcount = 1
|
789
|
+
|
790
|
+
# Mock the connection.exec_driver_sql method - make it async
|
791
|
+
async def mock_exec_driver_sql(sql):
|
792
|
+
return mock_result
|
793
|
+
|
794
|
+
self.mock_connection.exec_driver_sql = Mock(side_effect=mock_exec_driver_sql)
|
795
|
+
|
796
|
+
# Also setup execute for fallback
|
797
|
+
async def mock_execute(sql, params=None):
|
798
|
+
return mock_result
|
799
|
+
|
800
|
+
self.mock_connection.execute = Mock(side_effect=mock_execute)
|
801
|
+
|
802
|
+
async with self.client.transaction() as tx:
|
803
|
+
await tx.execute("INSERT INTO users (name) VALUES (%s)", ("Alice",))
|
804
|
+
|
805
|
+
# Verify that the transaction was used
|
806
|
+
assert self.mock_connection.exec_driver_sql.called or self.mock_connection.execute.called
|
807
|
+
|
808
|
+
async def test_transaction_rollback(self):
|
809
|
+
"""Test async transaction rollback"""
|
810
|
+
|
811
|
+
# Mock the connection.exec_driver_sql method to raise an exception
|
812
|
+
async def mock_exec_driver_sql_error(sql):
|
813
|
+
raise Exception("Query failed")
|
814
|
+
|
815
|
+
self.mock_connection.exec_driver_sql = Mock(side_effect=mock_exec_driver_sql_error)
|
816
|
+
|
817
|
+
# Also setup execute for fallback
|
818
|
+
async def mock_execute_error(sql, params=None):
|
819
|
+
raise Exception("Query failed")
|
820
|
+
|
821
|
+
self.mock_connection.execute = mock_execute_error
|
822
|
+
|
823
|
+
with self.assertRaises(Exception):
|
824
|
+
async with self.client.transaction() as tx:
|
825
|
+
await tx.execute("INSERT INTO users (name) VALUES (%s)", ("Alice",))
|
826
|
+
|
827
|
+
|
828
|
+
class TestAsyncSQLAlchemyTransaction(unittest.IsolatedAsyncioTestCase):
|
829
|
+
"""Test Async SQLAlchemy Transaction Integration"""
|
830
|
+
|
831
|
+
@classmethod
|
832
|
+
def setUpClass(cls):
|
833
|
+
"""Setup mocks for the entire test class"""
|
834
|
+
setup_sqlalchemy_mocks()
|
835
|
+
|
836
|
+
@classmethod
|
837
|
+
def tearDownClass(cls):
|
838
|
+
"""Restore original modules after tests"""
|
839
|
+
teardown_sqlalchemy_mocks()
|
840
|
+
|
841
|
+
def setUp(self):
|
842
|
+
"""Set up test fixtures"""
|
843
|
+
self.client = AsyncClient()
|
844
|
+
self.mock_connection = AsyncMock()
|
845
|
+
|
846
|
+
# Create a mock engine class that properly implements begin()
|
847
|
+
class MockEngine:
|
848
|
+
def __init__(self, connection):
|
849
|
+
self.connection = connection
|
850
|
+
|
851
|
+
def begin(self):
|
852
|
+
# Create a proper async context manager for engine.begin()
|
853
|
+
class MockBeginContext:
|
854
|
+
def __init__(self, connection):
|
855
|
+
self.connection = connection
|
856
|
+
|
857
|
+
async def __aenter__(self):
|
858
|
+
return self.connection
|
859
|
+
|
860
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
861
|
+
pass
|
862
|
+
|
863
|
+
return MockBeginContext(self.connection)
|
864
|
+
|
865
|
+
self.mock_engine = MockEngine(self.mock_connection)
|
866
|
+
self.client._engine = self.mock_engine
|
867
|
+
|
868
|
+
# Set up connection parameters for SQLAlchemy integration
|
869
|
+
self.client._connection_params = {
|
870
|
+
'user': 'testuser',
|
871
|
+
'password': 'testpass',
|
872
|
+
'host': 'localhost',
|
873
|
+
'port': 6001,
|
874
|
+
'db': 'testdb',
|
875
|
+
}
|
876
|
+
|
877
|
+
async def test_transaction_wrapper_sqlalchemy_session(self):
|
878
|
+
"""Test transaction wrapper SQLAlchemy session creation"""
|
879
|
+
tx_wrapper = AsyncTransactionWrapper(self.mock_connection, self.client)
|
880
|
+
|
881
|
+
# Mock SQLAlchemy session directly
|
882
|
+
mock_session = AsyncMock()
|
883
|
+
mock_session.begin = AsyncMock()
|
884
|
+
mock_session.commit = AsyncMock()
|
885
|
+
mock_session.rollback = AsyncMock()
|
886
|
+
mock_session.close = AsyncMock()
|
887
|
+
|
888
|
+
# Mock the get_sqlalchemy_session method directly
|
889
|
+
with patch.object(tx_wrapper, 'get_sqlalchemy_session', return_value=mock_session):
|
890
|
+
session = await tx_wrapper.get_sqlalchemy_session()
|
891
|
+
|
892
|
+
self.assertEqual(session, mock_session)
|
893
|
+
|
894
|
+
async def test_transaction_wrapper_commit_sqlalchemy(self):
|
895
|
+
"""Test transaction wrapper SQLAlchemy commit"""
|
896
|
+
tx_wrapper = AsyncTransactionWrapper(self.mock_connection, self.client)
|
897
|
+
|
898
|
+
# Mock SQLAlchemy session
|
899
|
+
mock_session = AsyncMock()
|
900
|
+
tx_wrapper._sqlalchemy_session = mock_session
|
901
|
+
|
902
|
+
await tx_wrapper.commit_sqlalchemy()
|
903
|
+
|
904
|
+
mock_session.commit.assert_called_once()
|
905
|
+
|
906
|
+
async def test_transaction_wrapper_rollback_sqlalchemy(self):
|
907
|
+
"""Test transaction wrapper SQLAlchemy rollback"""
|
908
|
+
tx_wrapper = AsyncTransactionWrapper(self.mock_connection, self.client)
|
909
|
+
|
910
|
+
# Mock SQLAlchemy session
|
911
|
+
mock_session = AsyncMock()
|
912
|
+
tx_wrapper._sqlalchemy_session = mock_session
|
913
|
+
|
914
|
+
await tx_wrapper.rollback_sqlalchemy()
|
915
|
+
|
916
|
+
mock_session.rollback.assert_called_once()
|
917
|
+
|
918
|
+
async def test_transaction_wrapper_close_sqlalchemy(self):
|
919
|
+
"""Test transaction wrapper SQLAlchemy close"""
|
920
|
+
tx_wrapper = AsyncTransactionWrapper(self.mock_connection, self.client)
|
921
|
+
|
922
|
+
# Mock SQLAlchemy session and engine
|
923
|
+
mock_session = AsyncMock()
|
924
|
+
mock_session.close = AsyncMock()
|
925
|
+
mock_engine = Mock()
|
926
|
+
mock_engine.dispose = AsyncMock()
|
927
|
+
tx_wrapper._sqlalchemy_session = mock_session
|
928
|
+
tx_wrapper._sqlalchemy_engine = mock_engine
|
929
|
+
|
930
|
+
await tx_wrapper.close_sqlalchemy()
|
931
|
+
|
932
|
+
mock_session.close.assert_called_once()
|
933
|
+
mock_engine.dispose.assert_called_once()
|
934
|
+
self.assertIsNone(tx_wrapper._sqlalchemy_session)
|
935
|
+
self.assertIsNone(tx_wrapper._sqlalchemy_engine)
|
936
|
+
|
937
|
+
async def test_transaction_success_flow(self):
|
938
|
+
"""Test successful transaction flow"""
|
939
|
+
|
940
|
+
# Create a mock result class that properly implements the interface
|
941
|
+
class MockResult:
|
942
|
+
def __init__(self):
|
943
|
+
self.returns_rows = False
|
944
|
+
self.rowcount = 1
|
945
|
+
|
946
|
+
# Create a mock connection class
|
947
|
+
class MockConnection:
|
948
|
+
def __init__(self):
|
949
|
+
self.execute_called = False
|
950
|
+
self.execute_args = None
|
951
|
+
|
952
|
+
async def exec_driver_sql(self, sql):
|
953
|
+
self.execute_called = True
|
954
|
+
self.execute_args = (sql, None)
|
955
|
+
return MockResult()
|
956
|
+
|
957
|
+
async def execute(self, sql, params=None):
|
958
|
+
self.execute_called = True
|
959
|
+
self.execute_args = (sql, params)
|
960
|
+
return MockResult()
|
961
|
+
|
962
|
+
# Replace the mock connection with our real mock
|
963
|
+
mock_connection = MockConnection()
|
964
|
+
self.mock_engine.connection = mock_connection
|
965
|
+
|
966
|
+
# Mock SQLAlchemy session
|
967
|
+
mock_session = AsyncMock()
|
968
|
+
mock_session.begin = AsyncMock()
|
969
|
+
mock_session.commit = AsyncMock()
|
970
|
+
mock_session.rollback = AsyncMock()
|
971
|
+
mock_session.close = AsyncMock()
|
972
|
+
|
973
|
+
# Mock the transaction wrapper's SQLAlchemy methods
|
974
|
+
with patch.object(AsyncTransactionWrapper, 'get_sqlalchemy_session', return_value=mock_session), patch.object(
|
975
|
+
AsyncTransactionWrapper, 'commit_sqlalchemy', return_value=None
|
976
|
+
), patch.object(AsyncTransactionWrapper, 'rollback_sqlalchemy', return_value=None), patch.object(
|
977
|
+
AsyncTransactionWrapper, 'close_sqlalchemy', return_value=None
|
978
|
+
):
|
979
|
+
|
980
|
+
async with self.client.transaction() as tx:
|
981
|
+
# Test SQLAlchemy session
|
982
|
+
session = await tx.get_sqlalchemy_session()
|
983
|
+
self.assertEqual(session, mock_session)
|
984
|
+
|
985
|
+
# Test MatrixOne async operations
|
986
|
+
result = await tx.execute("INSERT INTO users (name) VALUES (%s)", ("Alice",))
|
987
|
+
self.assertIsNotNone(result)
|
988
|
+
|
989
|
+
# Test snapshot operations
|
990
|
+
snapshot = await tx.snapshots.create("test_snapshot", SnapshotLevel.DATABASE, database="test")
|
991
|
+
self.assertIsNotNone(snapshot)
|
992
|
+
|
993
|
+
# Test clone operations
|
994
|
+
await tx.clone.clone_database("backup", "test")
|
995
|
+
|
996
|
+
async def test_transaction_sqlalchemy_session_not_connected(self):
|
997
|
+
"""Test SQLAlchemy session creation when not connected"""
|
998
|
+
tx_wrapper = AsyncTransactionWrapper(self.mock_connection, self.client)
|
999
|
+
self.client._connection_params = {}
|
1000
|
+
|
1001
|
+
with self.assertRaises(ConnectionError):
|
1002
|
+
await tx_wrapper.get_sqlalchemy_session()
|
1003
|
+
|
1004
|
+
|
1005
|
+
class TestAsyncSQLAlchemyIntegration(unittest.IsolatedAsyncioTestCase):
|
1006
|
+
"""Test Async SQLAlchemy Integration Patterns"""
|
1007
|
+
|
1008
|
+
@classmethod
|
1009
|
+
def setUpClass(cls):
|
1010
|
+
"""Setup mocks for the entire test class"""
|
1011
|
+
setup_sqlalchemy_mocks()
|
1012
|
+
|
1013
|
+
@classmethod
|
1014
|
+
def tearDownClass(cls):
|
1015
|
+
"""Restore original modules after tests"""
|
1016
|
+
teardown_sqlalchemy_mocks()
|
1017
|
+
|
1018
|
+
def setUp(self):
|
1019
|
+
"""Set up test fixtures"""
|
1020
|
+
self.client = AsyncClient()
|
1021
|
+
self.mock_connection = AsyncMock()
|
1022
|
+
|
1023
|
+
# Create a mock engine class that properly implements begin()
|
1024
|
+
class MockEngine:
|
1025
|
+
def __init__(self, connection):
|
1026
|
+
self.connection = connection
|
1027
|
+
|
1028
|
+
def begin(self):
|
1029
|
+
# Create a proper async context manager for engine.begin()
|
1030
|
+
class MockBeginContext:
|
1031
|
+
def __init__(self, connection):
|
1032
|
+
self.connection = connection
|
1033
|
+
|
1034
|
+
async def __aenter__(self):
|
1035
|
+
return self.connection
|
1036
|
+
|
1037
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
1038
|
+
pass
|
1039
|
+
|
1040
|
+
return MockBeginContext(self.connection)
|
1041
|
+
|
1042
|
+
self.mock_engine = MockEngine(self.mock_connection)
|
1043
|
+
self.client._engine = self.mock_engine
|
1044
|
+
|
1045
|
+
# Set up connection parameters for SQLAlchemy integration
|
1046
|
+
self.client._connection_params = {
|
1047
|
+
'user': 'testuser',
|
1048
|
+
'password': 'testpass',
|
1049
|
+
'host': 'localhost',
|
1050
|
+
'port': 6001,
|
1051
|
+
'db': 'testdb',
|
1052
|
+
}
|
1053
|
+
|
1054
|
+
async def test_mixed_operations_pattern(self):
|
1055
|
+
"""Test mixed SQLAlchemy and MatrixOne operations"""
|
1056
|
+
|
1057
|
+
# Create a mock result class that properly implements the interface
|
1058
|
+
class MockResult:
|
1059
|
+
def __init__(self):
|
1060
|
+
self.returns_rows = True
|
1061
|
+
|
1062
|
+
def fetchall(self):
|
1063
|
+
return [(1, 'Alice'), (2, 'Bob')]
|
1064
|
+
|
1065
|
+
def keys(self):
|
1066
|
+
return ['id', 'name']
|
1067
|
+
|
1068
|
+
# Create a mock connection class
|
1069
|
+
class MockConnection:
|
1070
|
+
def __init__(self):
|
1071
|
+
self.execute_called = False
|
1072
|
+
self.execute_args = None
|
1073
|
+
|
1074
|
+
async def execute(self, sql, params=None):
|
1075
|
+
self.execute_called = True
|
1076
|
+
self.execute_args = (sql, params)
|
1077
|
+
return MockResult()
|
1078
|
+
|
1079
|
+
# Replace the mock connection with our real mock
|
1080
|
+
mock_connection = MockConnection()
|
1081
|
+
self.mock_engine.connection = mock_connection
|
1082
|
+
|
1083
|
+
# Mock SQLAlchemy session
|
1084
|
+
mock_session = AsyncMock()
|
1085
|
+
mock_session.begin = AsyncMock()
|
1086
|
+
mock_session.commit = AsyncMock()
|
1087
|
+
mock_session.rollback = AsyncMock()
|
1088
|
+
mock_session.close = AsyncMock()
|
1089
|
+
|
1090
|
+
# Mock SQLAlchemy model
|
1091
|
+
mock_user = Mock()
|
1092
|
+
mock_user.id = 1
|
1093
|
+
mock_user.name = "Alice"
|
1094
|
+
mock_user.email = "alice@example.com"
|
1095
|
+
|
1096
|
+
# Mock query method properly
|
1097
|
+
mock_query = Mock()
|
1098
|
+
mock_query.all.return_value = [mock_user]
|
1099
|
+
mock_session.query = Mock(return_value=mock_query)
|
1100
|
+
|
1101
|
+
# Mock the transaction wrapper's SQLAlchemy methods
|
1102
|
+
with patch.object(AsyncTransactionWrapper, 'get_sqlalchemy_session', return_value=mock_session), patch.object(
|
1103
|
+
AsyncTransactionWrapper, 'commit_sqlalchemy', return_value=None
|
1104
|
+
), patch.object(AsyncTransactionWrapper, 'rollback_sqlalchemy', return_value=None), patch.object(
|
1105
|
+
AsyncTransactionWrapper, 'close_sqlalchemy', return_value=None
|
1106
|
+
):
|
1107
|
+
|
1108
|
+
async with self.client.transaction() as tx:
|
1109
|
+
session = await tx.get_sqlalchemy_session()
|
1110
|
+
|
1111
|
+
# SQLAlchemy operations
|
1112
|
+
users = session.query(mock_user).all()
|
1113
|
+
self.assertEqual(len(users), 1)
|
1114
|
+
|
1115
|
+
# MatrixOne async operations
|
1116
|
+
result = await tx.execute("SELECT id, name FROM users")
|
1117
|
+
self.assertEqual(len(result.rows), 2)
|
1118
|
+
|
1119
|
+
# Snapshot operations
|
1120
|
+
snapshot = await tx.snapshots.create("mixed_snapshot", SnapshotLevel.DATABASE, database="test")
|
1121
|
+
self.assertIsNotNone(snapshot)
|
1122
|
+
|
1123
|
+
# Clone operations
|
1124
|
+
await tx.clone.clone_database("mixed_backup", "test")
|
1125
|
+
|
1126
|
+
async def test_error_handling_pattern(self):
|
1127
|
+
"""Test error handling in mixed operations"""
|
1128
|
+
|
1129
|
+
# Create a mock connection class that raises an exception
|
1130
|
+
class MockConnection:
|
1131
|
+
def __init__(self):
|
1132
|
+
self.execute_called = False
|
1133
|
+
self.execute_args = None
|
1134
|
+
|
1135
|
+
async def execute(self, sql, params=None):
|
1136
|
+
self.execute_called = True
|
1137
|
+
self.execute_args = (sql, params)
|
1138
|
+
raise Exception("Database error")
|
1139
|
+
|
1140
|
+
# Replace the mock connection with our real mock
|
1141
|
+
mock_connection = MockConnection()
|
1142
|
+
self.mock_engine.connection = mock_connection
|
1143
|
+
|
1144
|
+
# Mock SQLAlchemy session
|
1145
|
+
mock_session = AsyncMock()
|
1146
|
+
mock_session.begin = AsyncMock()
|
1147
|
+
mock_session.commit = AsyncMock()
|
1148
|
+
mock_session.rollback = AsyncMock()
|
1149
|
+
mock_session.close = AsyncMock()
|
1150
|
+
mock_session.add = Mock()
|
1151
|
+
mock_session.flush = Mock()
|
1152
|
+
|
1153
|
+
# Mock the transaction wrapper's SQLAlchemy methods
|
1154
|
+
with patch.object(AsyncTransactionWrapper, 'get_sqlalchemy_session', return_value=mock_session), patch.object(
|
1155
|
+
AsyncTransactionWrapper, 'commit_sqlalchemy', return_value=None
|
1156
|
+
), patch.object(AsyncTransactionWrapper, 'rollback_sqlalchemy', return_value=None), patch.object(
|
1157
|
+
AsyncTransactionWrapper, 'close_sqlalchemy', return_value=None
|
1158
|
+
):
|
1159
|
+
|
1160
|
+
with self.assertRaises(Exception):
|
1161
|
+
async with self.client.transaction() as tx:
|
1162
|
+
session = await tx.get_sqlalchemy_session()
|
1163
|
+
|
1164
|
+
# SQLAlchemy operation (should succeed)
|
1165
|
+
mock_user = Mock()
|
1166
|
+
session.add(mock_user)
|
1167
|
+
session.flush()
|
1168
|
+
|
1169
|
+
# MatrixOne operation (should fail)
|
1170
|
+
await tx.execute("INSERT INTO users (name) VALUES (%s)", ("Alice",))
|
1171
|
+
|
1172
|
+
|
1173
|
+
def run_async_test(test_func):
|
1174
|
+
"""Helper function to run async tests"""
|
1175
|
+
loop = asyncio.new_event_loop()
|
1176
|
+
asyncio.set_event_loop(loop)
|
1177
|
+
try:
|
1178
|
+
return loop.run_until_complete(test_func())
|
1179
|
+
finally:
|
1180
|
+
loop.close()
|
1181
|
+
|
1182
|
+
|
1183
|
+
if __name__ == '__main__':
|
1184
|
+
# Create a test suite
|
1185
|
+
test_suite = unittest.TestSuite()
|
1186
|
+
|
1187
|
+
# Add test cases using TestLoader
|
1188
|
+
loader = unittest.TestLoader()
|
1189
|
+
test_suite.addTests(loader.loadTestsFromTestCase(TestAsyncResultSet))
|
1190
|
+
test_suite.addTests(loader.loadTestsFromTestCase(TestAsyncClientBasic))
|
1191
|
+
test_suite.addTests(loader.loadTestsFromTestCase(TestAsyncClientConnection))
|
1192
|
+
test_suite.addTests(loader.loadTestsFromTestCase(TestAsyncClientQuery))
|
1193
|
+
test_suite.addTests(loader.loadTestsFromTestCase(TestAsyncSnapshotManager))
|
1194
|
+
test_suite.addTests(loader.loadTestsFromTestCase(TestAsyncCloneManager))
|
1195
|
+
test_suite.addTests(loader.loadTestsFromTestCase(TestAsyncMoCtlManager))
|
1196
|
+
test_suite.addTests(loader.loadTestsFromTestCase(TestAsyncTransaction))
|
1197
|
+
test_suite.addTests(loader.loadTestsFromTestCase(TestAsyncSQLAlchemyTransaction))
|
1198
|
+
test_suite.addTests(loader.loadTestsFromTestCase(TestAsyncSQLAlchemyIntegration))
|
1199
|
+
|
1200
|
+
# Run the tests
|
1201
|
+
runner = unittest.TextTestRunner(verbosity=2)
|
1202
|
+
result = runner.run(test_suite)
|
1203
|
+
|
1204
|
+
# Print summary
|
1205
|
+
print(f"\n{'='*50}")
|
1206
|
+
print(f"Tests run: {result.testsRun}")
|
1207
|
+
print(f"Failures: {len(result.failures)}")
|
1208
|
+
print(f"Errors: {len(result.errors)}")
|
1209
|
+
if result.testsRun > 0:
|
1210
|
+
success_rate = (result.testsRun - len(result.failures) - len(result.errors)) / result.testsRun * 100
|
1211
|
+
print(f"Success rate: {success_rate:.1f}%")
|
1212
|
+
print(f"{'='*50}")
|
1213
|
+
|
1214
|
+
# Exit with appropriate code
|
1215
|
+
if result.failures or result.errors:
|
1216
|
+
sys.exit(1)
|
1217
|
+
else:
|
1218
|
+
sys.exit(0)
|