matrixone-python-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- matrixone/__init__.py +155 -0
- matrixone/account.py +723 -0
- matrixone/async_client.py +3913 -0
- matrixone/async_metadata_manager.py +311 -0
- matrixone/async_orm.py +123 -0
- matrixone/async_vector_index_manager.py +633 -0
- matrixone/base_client.py +208 -0
- matrixone/client.py +4672 -0
- matrixone/config.py +452 -0
- matrixone/connection_hooks.py +286 -0
- matrixone/exceptions.py +89 -0
- matrixone/logger.py +782 -0
- matrixone/metadata.py +820 -0
- matrixone/moctl.py +219 -0
- matrixone/orm.py +2277 -0
- matrixone/pitr.py +646 -0
- matrixone/pubsub.py +771 -0
- matrixone/restore.py +411 -0
- matrixone/search_vector_index.py +1176 -0
- matrixone/snapshot.py +550 -0
- matrixone/sql_builder.py +844 -0
- matrixone/sqlalchemy_ext/__init__.py +161 -0
- matrixone/sqlalchemy_ext/adapters.py +163 -0
- matrixone/sqlalchemy_ext/dialect.py +534 -0
- matrixone/sqlalchemy_ext/fulltext_index.py +895 -0
- matrixone/sqlalchemy_ext/fulltext_search.py +1686 -0
- matrixone/sqlalchemy_ext/hnsw_config.py +194 -0
- matrixone/sqlalchemy_ext/ivf_config.py +252 -0
- matrixone/sqlalchemy_ext/table_builder.py +351 -0
- matrixone/sqlalchemy_ext/vector_index.py +1721 -0
- matrixone/sqlalchemy_ext/vector_type.py +948 -0
- matrixone/version.py +580 -0
- matrixone_python_sdk-0.1.0.dist-info/METADATA +706 -0
- matrixone_python_sdk-0.1.0.dist-info/RECORD +122 -0
- matrixone_python_sdk-0.1.0.dist-info/WHEEL +5 -0
- matrixone_python_sdk-0.1.0.dist-info/entry_points.txt +5 -0
- matrixone_python_sdk-0.1.0.dist-info/licenses/LICENSE +200 -0
- matrixone_python_sdk-0.1.0.dist-info/top_level.txt +2 -0
- tests/__init__.py +19 -0
- tests/offline/__init__.py +20 -0
- tests/offline/conftest.py +77 -0
- tests/offline/test_account.py +703 -0
- tests/offline/test_async_client_query_comprehensive.py +1218 -0
- tests/offline/test_basic.py +54 -0
- tests/offline/test_case_sensitivity.py +227 -0
- tests/offline/test_connection_hooks_offline.py +287 -0
- tests/offline/test_dialect_schema_handling.py +609 -0
- tests/offline/test_explain_methods.py +346 -0
- tests/offline/test_filter_logical_in.py +237 -0
- tests/offline/test_fulltext_search_comprehensive.py +795 -0
- tests/offline/test_ivf_config.py +249 -0
- tests/offline/test_join_methods.py +281 -0
- tests/offline/test_join_sqlalchemy_compatibility.py +276 -0
- tests/offline/test_logical_in_method.py +237 -0
- tests/offline/test_matrixone_version_parsing.py +264 -0
- tests/offline/test_metadata_offline.py +557 -0
- tests/offline/test_moctl.py +300 -0
- tests/offline/test_moctl_simple.py +251 -0
- tests/offline/test_model_support_offline.py +359 -0
- tests/offline/test_model_support_simple.py +225 -0
- tests/offline/test_pinecone_filter_offline.py +377 -0
- tests/offline/test_pitr.py +585 -0
- tests/offline/test_pubsub.py +712 -0
- tests/offline/test_query_update.py +283 -0
- tests/offline/test_restore.py +445 -0
- tests/offline/test_snapshot_comprehensive.py +384 -0
- tests/offline/test_sql_escaping_edge_cases.py +551 -0
- tests/offline/test_sqlalchemy_integration.py +382 -0
- tests/offline/test_sqlalchemy_vector_integration.py +434 -0
- tests/offline/test_table_builder.py +198 -0
- tests/offline/test_unified_filter.py +398 -0
- tests/offline/test_unified_transaction.py +495 -0
- tests/offline/test_vector_index.py +238 -0
- tests/offline/test_vector_operations.py +688 -0
- tests/offline/test_vector_type.py +174 -0
- tests/offline/test_version_core.py +328 -0
- tests/offline/test_version_management.py +372 -0
- tests/offline/test_version_standalone.py +652 -0
- tests/online/__init__.py +20 -0
- tests/online/conftest.py +216 -0
- tests/online/test_account_management.py +194 -0
- tests/online/test_advanced_features.py +344 -0
- tests/online/test_async_client_interfaces.py +330 -0
- tests/online/test_async_client_online.py +285 -0
- tests/online/test_async_model_insert_online.py +293 -0
- tests/online/test_async_orm_online.py +300 -0
- tests/online/test_async_simple_query_online.py +802 -0
- tests/online/test_async_transaction_simple_query.py +300 -0
- tests/online/test_basic_connection.py +130 -0
- tests/online/test_client_online.py +238 -0
- tests/online/test_config.py +90 -0
- tests/online/test_config_validation.py +123 -0
- tests/online/test_connection_hooks_new_online.py +217 -0
- tests/online/test_dialect_schema_handling_online.py +331 -0
- tests/online/test_filter_logical_in_online.py +374 -0
- tests/online/test_fulltext_comprehensive.py +1773 -0
- tests/online/test_fulltext_label_online.py +433 -0
- tests/online/test_fulltext_search_online.py +842 -0
- tests/online/test_ivf_stats_online.py +506 -0
- tests/online/test_logger_integration.py +311 -0
- tests/online/test_matrixone_query_orm.py +540 -0
- tests/online/test_metadata_online.py +579 -0
- tests/online/test_model_insert_online.py +255 -0
- tests/online/test_mysql_driver_validation.py +213 -0
- tests/online/test_orm_advanced_features.py +2022 -0
- tests/online/test_orm_cte_integration.py +269 -0
- tests/online/test_orm_online.py +270 -0
- tests/online/test_pinecone_filter.py +708 -0
- tests/online/test_pubsub_operations.py +352 -0
- tests/online/test_query_methods.py +225 -0
- tests/online/test_query_update_online.py +433 -0
- tests/online/test_search_vector_index.py +557 -0
- tests/online/test_simple_fulltext_online.py +915 -0
- tests/online/test_snapshot_comprehensive.py +998 -0
- tests/online/test_sqlalchemy_engine_integration.py +336 -0
- tests/online/test_sqlalchemy_integration.py +425 -0
- tests/online/test_transaction_contexts.py +1219 -0
- tests/online/test_transaction_insert_methods.py +356 -0
- tests/online/test_transaction_query_methods.py +288 -0
- tests/online/test_unified_filter_online.py +529 -0
- tests/online/test_vector_comprehensive.py +706 -0
- tests/online/test_version_management.py +291 -0
@@ -0,0 +1,54 @@
|
|
1
|
+
# Copyright 2021 - 2022 Matrix Origin
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""
|
16
|
+
Basic test for MatrixOne Python SDK
|
17
|
+
"""
|
18
|
+
|
19
|
+
from matrixone import Client, ConnectionError, QueryError
|
20
|
+
|
21
|
+
|
22
|
+
def test_basic_connection():
|
23
|
+
"""Test basic connection functionality"""
|
24
|
+
client = Client()
|
25
|
+
|
26
|
+
try:
|
27
|
+
# Connect to MatrixOne
|
28
|
+
client.connect(host="localhost", port=6001, user="root", password="111", database="test")
|
29
|
+
|
30
|
+
print("✓ Connected to MatrixOne successfully")
|
31
|
+
|
32
|
+
# Test basic query
|
33
|
+
result = client.execute("SELECT 1 as test_value")
|
34
|
+
print(f"✓ Query executed successfully: {result.fetchone()}")
|
35
|
+
|
36
|
+
# Test SQLAlchemy engine
|
37
|
+
engine = client.get_sqlalchemy_engine()
|
38
|
+
print("✓ SQLAlchemy engine created successfully")
|
39
|
+
|
40
|
+
# Test transaction
|
41
|
+
with client.transaction() as tx:
|
42
|
+
result = tx.execute("SELECT 2 as transaction_value")
|
43
|
+
print(f"✓ Transaction executed successfully: {result.fetchone()}")
|
44
|
+
|
45
|
+
print("✓ All basic tests passed!")
|
46
|
+
|
47
|
+
except Exception as e:
|
48
|
+
print(f"✗ Test failed: {e}")
|
49
|
+
finally:
|
50
|
+
client.disconnect()
|
51
|
+
|
52
|
+
|
53
|
+
if __name__ == "__main__":
|
54
|
+
test_basic_connection()
|
@@ -0,0 +1,227 @@
|
|
1
|
+
# Copyright 2021 - 2022 Matrix Origin
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""
|
16
|
+
Tests for case sensitivity handling in MatrixOne SQLAlchemy extensions.
|
17
|
+
"""
|
18
|
+
|
19
|
+
import pytest
|
20
|
+
import sys
|
21
|
+
from unittest.mock import Mock
|
22
|
+
from matrixone.sqlalchemy_ext import MatrixOneDialect, VectorType, Vectorf32, Vectorf64
|
23
|
+
|
24
|
+
pytestmark = pytest.mark.vector
|
25
|
+
|
26
|
+
# No longer needed - global mocks have been fixed
|
27
|
+
|
28
|
+
|
29
|
+
class TestCaseSensitivity:
|
30
|
+
"""Test case sensitivity handling."""
|
31
|
+
|
32
|
+
def test_vector_type_case_insensitive_parsing(self):
|
33
|
+
"""Test that vector types are parsed case-insensitively."""
|
34
|
+
dialect = MatrixOneDialect()
|
35
|
+
|
36
|
+
# Test various case combinations
|
37
|
+
test_cases = [
|
38
|
+
("vecf32(128)", "f32", 128),
|
39
|
+
("VECF32(128)", "f32", 128),
|
40
|
+
("VecF32(128)", "f32", 128),
|
41
|
+
("VECf32(128)", "f32", 128),
|
42
|
+
("vecf64(256)", "f64", 256),
|
43
|
+
("VECF64(256)", "f64", 256),
|
44
|
+
("VecF64(256)", "f64", 256),
|
45
|
+
("VECf64(256)", "f64", 256),
|
46
|
+
]
|
47
|
+
|
48
|
+
for type_str, expected_precision, expected_dimension in test_cases:
|
49
|
+
vector_type = dialect._create_vector_type(expected_precision, type_str)
|
50
|
+
|
51
|
+
assert vector_type.precision == expected_precision
|
52
|
+
assert vector_type.dimension == expected_dimension
|
53
|
+
|
54
|
+
def test_vector_type_no_dimension(self):
|
55
|
+
"""Test vector types without dimension."""
|
56
|
+
dialect = MatrixOneDialect()
|
57
|
+
|
58
|
+
test_cases = [
|
59
|
+
("vecf32", "f32"),
|
60
|
+
("VECF32", "f32"),
|
61
|
+
("VecF32", "f32"),
|
62
|
+
("vecf64", "f64"),
|
63
|
+
("VECF64", "f64"),
|
64
|
+
("VecF64", "f64"),
|
65
|
+
]
|
66
|
+
|
67
|
+
for type_str, expected_precision in test_cases:
|
68
|
+
vector_type = dialect._create_vector_type(expected_precision, type_str)
|
69
|
+
|
70
|
+
assert vector_type.precision == expected_precision
|
71
|
+
assert vector_type.dimension is None
|
72
|
+
|
73
|
+
def test_vector_type_invalid_cases(self):
|
74
|
+
"""Test handling of invalid vector type strings."""
|
75
|
+
dialect = MatrixOneDialect()
|
76
|
+
|
77
|
+
# These should not raise exceptions but return default types
|
78
|
+
invalid_cases = [
|
79
|
+
"vecf32()", # Empty parentheses
|
80
|
+
"vecf32(abc)", # Non-numeric dimension
|
81
|
+
"vecf32(128", # Missing closing parenthesis
|
82
|
+
"vecf32(128,64)", # Multiple parameters
|
83
|
+
"unknown_type", # Unknown type
|
84
|
+
]
|
85
|
+
|
86
|
+
for invalid_type in invalid_cases:
|
87
|
+
# Should not raise exception
|
88
|
+
vector_type = dialect._create_vector_type("f32", invalid_type)
|
89
|
+
assert vector_type.precision == "f32"
|
90
|
+
assert vector_type.dimension is None
|
91
|
+
|
92
|
+
def test_vector_type_edge_cases(self):
|
93
|
+
"""Test edge cases for vector type parsing."""
|
94
|
+
dialect = MatrixOneDialect()
|
95
|
+
|
96
|
+
# Test with whitespace
|
97
|
+
vector_type = dialect._create_vector_type("f32", " vecf32(128) ")
|
98
|
+
assert vector_type.dimension == 128
|
99
|
+
|
100
|
+
# Test with mixed case and whitespace
|
101
|
+
vector_type = dialect._create_vector_type("f64", " VECF64(256) ")
|
102
|
+
assert vector_type.dimension == 256
|
103
|
+
|
104
|
+
# Test with very large dimensions
|
105
|
+
vector_type = dialect._create_vector_type("f32", "vecf32(65535)")
|
106
|
+
assert vector_type.dimension == 65535
|
107
|
+
|
108
|
+
# Test with small dimensions
|
109
|
+
vector_type = dialect._create_vector_type("f32", "vecf32(1)")
|
110
|
+
assert vector_type.dimension == 1
|
111
|
+
|
112
|
+
def test_vector_type_compiler_case_handling(self):
|
113
|
+
"""Test that the compiler handles case properly."""
|
114
|
+
from matrixone.sqlalchemy_ext.dialect import MatrixOneCompiler
|
115
|
+
|
116
|
+
# Create a compiler instance
|
117
|
+
compiler = MatrixOneCompiler(MatrixOneDialect(), None)
|
118
|
+
|
119
|
+
# Test vector type compilation
|
120
|
+
vector_type = Vectorf32(dimension=128)
|
121
|
+
result = compiler.visit_user_defined_type(vector_type)
|
122
|
+
assert result == "vecf32(128)"
|
123
|
+
|
124
|
+
vector_type = Vectorf64(dimension=256)
|
125
|
+
result = compiler.visit_user_defined_type(vector_type)
|
126
|
+
assert result == "vecf64(256)"
|
127
|
+
|
128
|
+
def test_column_processing_case_insensitive(self):
|
129
|
+
"""Test that column processing is case-insensitive."""
|
130
|
+
dialect = MatrixOneDialect()
|
131
|
+
|
132
|
+
# Mock column data
|
133
|
+
columns = [
|
134
|
+
{'name': 'col1', 'type': 'vecf32(128)'},
|
135
|
+
{'name': 'col2', 'type': 'VECF32(256)'},
|
136
|
+
{'name': 'col3', 'type': 'VecF32(512)'},
|
137
|
+
{'name': 'col4', 'type': 'VECF64(1024)'},
|
138
|
+
{'name': 'col5', 'type': 'vecf64(2048)'},
|
139
|
+
{'name': 'col6', 'type': 'INTEGER'}, # Non-vector type
|
140
|
+
]
|
141
|
+
|
142
|
+
# Process columns (simulating get_columns behavior)
|
143
|
+
for column in columns:
|
144
|
+
if isinstance(column['type'], str):
|
145
|
+
type_str_lower = column['type'].lower()
|
146
|
+
if type_str_lower.startswith('vecf32'):
|
147
|
+
column['type'] = dialect._create_vector_type('f32', column['type'])
|
148
|
+
elif type_str_lower.startswith('vecf64'):
|
149
|
+
column['type'] = dialect._create_vector_type('f64', column['type'])
|
150
|
+
|
151
|
+
# Verify results
|
152
|
+
assert isinstance(columns[0]['type'], VectorType)
|
153
|
+
assert columns[0]['type'].dimension == 128
|
154
|
+
assert columns[0]['type'].precision == "f32"
|
155
|
+
|
156
|
+
assert isinstance(columns[1]['type'], VectorType)
|
157
|
+
assert columns[1]['type'].dimension == 256
|
158
|
+
assert columns[1]['type'].precision == "f32"
|
159
|
+
|
160
|
+
assert isinstance(columns[2]['type'], VectorType)
|
161
|
+
assert columns[2]['type'].dimension == 512
|
162
|
+
assert columns[2]['type'].precision == "f32"
|
163
|
+
|
164
|
+
assert isinstance(columns[3]['type'], VectorType)
|
165
|
+
assert columns[3]['type'].dimension == 1024
|
166
|
+
assert columns[3]['type'].precision == "f64"
|
167
|
+
|
168
|
+
assert isinstance(columns[4]['type'], VectorType)
|
169
|
+
assert columns[4]['type'].dimension == 2048
|
170
|
+
assert columns[4]['type'].precision == "f64"
|
171
|
+
|
172
|
+
# Non-vector type should remain unchanged
|
173
|
+
assert columns[5]['type'] == 'INTEGER'
|
174
|
+
|
175
|
+
def test_sql_generation_case_consistency(self):
|
176
|
+
"""Test that SQL generation produces consistent case."""
|
177
|
+
from sqlalchemy import MetaData
|
178
|
+
from sqlalchemy.schema import CreateTable
|
179
|
+
from matrixone.sqlalchemy_ext import VectorTableBuilder
|
180
|
+
|
181
|
+
metadata = MetaData()
|
182
|
+
|
183
|
+
# Create table with vector columns
|
184
|
+
builder = VectorTableBuilder("case_test_table", metadata)
|
185
|
+
builder.add_int_column("id", primary_key=True)
|
186
|
+
builder.add_vecf32_column("embedding_32", dimension=128)
|
187
|
+
builder.add_vecf64_column("embedding_64", dimension=256)
|
188
|
+
|
189
|
+
table = builder.build()
|
190
|
+
|
191
|
+
# Generate SQL using MatrixOne dialect
|
192
|
+
from matrixone.sqlalchemy_ext import MatrixOneDialect
|
193
|
+
|
194
|
+
create_sql = str(CreateTable(table).compile(dialect=MatrixOneDialect(), compile_kwargs={"literal_binds": True}))
|
195
|
+
|
196
|
+
# Check that vector types are in lowercase
|
197
|
+
assert "vecf32(128)" in create_sql
|
198
|
+
assert "vecf64(256)" in create_sql
|
199
|
+
assert "VECF32" not in create_sql
|
200
|
+
assert "VECF64" not in create_sql
|
201
|
+
|
202
|
+
def test_real_world_case_scenarios(self):
|
203
|
+
"""Test real-world case sensitivity scenarios."""
|
204
|
+
dialect = MatrixOneDialect()
|
205
|
+
|
206
|
+
# Simulate what might come from MatrixOne database introspection
|
207
|
+
real_world_types = [
|
208
|
+
"vecf32(384)", # Standard lowercase
|
209
|
+
"VECF32(768)", # All uppercase (some databases)
|
210
|
+
"VecF32(128)", # Mixed case (some tools)
|
211
|
+
"vecF64(1024)", # Mixed case with different pattern
|
212
|
+
"VECf64(512)", # Another mixed case pattern
|
213
|
+
]
|
214
|
+
|
215
|
+
for type_str in real_world_types:
|
216
|
+
# Extract precision from type string
|
217
|
+
if 'f32' in type_str.lower():
|
218
|
+
precision = 'f32'
|
219
|
+
else:
|
220
|
+
precision = 'f64'
|
221
|
+
|
222
|
+
vector_type = dialect._create_vector_type(precision, type_str)
|
223
|
+
|
224
|
+
# Should successfully create vector type regardless of case
|
225
|
+
assert isinstance(vector_type, VectorType)
|
226
|
+
assert vector_type.precision == precision
|
227
|
+
assert vector_type.dimension is not None
|
@@ -0,0 +1,287 @@
|
|
1
|
+
# Copyright 2021 - 2022 Matrix Origin
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
#!/usr/bin/env python3
|
16
|
+
"""
|
17
|
+
Offline tests for connection hooks functionality
|
18
|
+
"""
|
19
|
+
|
20
|
+
import pytest
|
21
|
+
from unittest.mock import Mock, call, patch, AsyncMock
|
22
|
+
from matrixone.connection_hooks import ConnectionAction, ConnectionHook, create_connection_hook
|
23
|
+
|
24
|
+
|
25
|
+
class TestConnectionAction:
|
26
|
+
"""Test ConnectionAction enum"""
|
27
|
+
|
28
|
+
def test_connection_action_values(self):
|
29
|
+
"""Test that ConnectionAction has expected values"""
|
30
|
+
assert ConnectionAction.ENABLE_IVF.value == "enable_ivf"
|
31
|
+
assert ConnectionAction.ENABLE_HNSW.value == "enable_hnsw"
|
32
|
+
assert ConnectionAction.ENABLE_FULLTEXT.value == "enable_fulltext"
|
33
|
+
assert ConnectionAction.ENABLE_VECTOR.value == "enable_vector"
|
34
|
+
assert ConnectionAction.ENABLE_ALL.value == "enable_all"
|
35
|
+
|
36
|
+
|
37
|
+
class TestConnectionHook:
|
38
|
+
"""Test ConnectionHook class"""
|
39
|
+
|
40
|
+
def setup_method(self):
|
41
|
+
"""Setup test fixtures"""
|
42
|
+
self.mock_client = Mock()
|
43
|
+
self.mock_client.logger = Mock()
|
44
|
+
self.mock_client.vector_ops = Mock()
|
45
|
+
self.mock_client.fulltext_index = Mock()
|
46
|
+
|
47
|
+
# Mock engine for connection hook execution
|
48
|
+
self.mock_engine = Mock()
|
49
|
+
self.mock_connection = Mock()
|
50
|
+
self.mock_cursor = Mock()
|
51
|
+
|
52
|
+
# Setup engine context manager
|
53
|
+
self.mock_engine.connect.return_value.__enter__ = Mock(return_value=self.mock_connection)
|
54
|
+
self.mock_engine.connect.return_value.__exit__ = Mock(return_value=None)
|
55
|
+
self.mock_connection.connection = self.mock_cursor
|
56
|
+
|
57
|
+
# Setup cursor method for dbapi_connection
|
58
|
+
self.mock_cursor.cursor.return_value = self.mock_cursor
|
59
|
+
|
60
|
+
# Setup async engine context manager using AsyncMock
|
61
|
+
self.mock_engine.begin = AsyncMock(return_value=self.mock_connection)
|
62
|
+
self.mock_connection.get_raw_connection = AsyncMock(return_value=self.mock_cursor)
|
63
|
+
|
64
|
+
# Setup cursor method for async dbapi_connection
|
65
|
+
self.mock_cursor.cursor.return_value = self.mock_cursor
|
66
|
+
|
67
|
+
self.mock_client._engine = self.mock_engine
|
68
|
+
|
69
|
+
def test_connection_hook_init_with_actions(self):
|
70
|
+
"""Test ConnectionHook initialization with actions"""
|
71
|
+
actions = [ConnectionAction.ENABLE_IVF, ConnectionAction.ENABLE_FULLTEXT]
|
72
|
+
hook = ConnectionHook(actions=actions)
|
73
|
+
|
74
|
+
assert hook.actions == actions
|
75
|
+
assert hook.custom_hook is None
|
76
|
+
|
77
|
+
def test_connection_hook_init_with_custom_hook(self):
|
78
|
+
"""Test ConnectionHook initialization with custom hook"""
|
79
|
+
|
80
|
+
def custom_hook(client):
|
81
|
+
pass
|
82
|
+
|
83
|
+
hook = ConnectionHook(custom_hook=custom_hook)
|
84
|
+
|
85
|
+
assert hook.actions == []
|
86
|
+
assert hook.custom_hook == custom_hook
|
87
|
+
|
88
|
+
def test_connection_hook_init_with_both(self):
|
89
|
+
"""Test ConnectionHook initialization with both actions and custom hook"""
|
90
|
+
actions = [ConnectionAction.ENABLE_IVF]
|
91
|
+
|
92
|
+
def custom_hook(client):
|
93
|
+
pass
|
94
|
+
|
95
|
+
hook = ConnectionHook(actions=actions, custom_hook=custom_hook)
|
96
|
+
|
97
|
+
assert hook.actions == actions
|
98
|
+
assert hook.custom_hook == custom_hook
|
99
|
+
|
100
|
+
def test_execute_sync_with_ivf_action(self):
|
101
|
+
"""Test synchronous execution with IVF action"""
|
102
|
+
hook = ConnectionHook(actions=[ConnectionAction.ENABLE_IVF])
|
103
|
+
|
104
|
+
hook.execute_sync(self.mock_client)
|
105
|
+
|
106
|
+
# Verify that cursor.execute was called with the correct SQL
|
107
|
+
self.mock_cursor.execute.assert_called_with("SET experimental_ivf_index = 1")
|
108
|
+
self.mock_cursor.close.assert_called_once()
|
109
|
+
self.mock_client.logger.info.assert_called_with("✓ Enabled IVF vector operations")
|
110
|
+
|
111
|
+
def test_execute_sync_with_hnsw_action(self):
|
112
|
+
"""Test synchronous execution with HNSW action"""
|
113
|
+
hook = ConnectionHook(actions=[ConnectionAction.ENABLE_HNSW])
|
114
|
+
|
115
|
+
hook.execute_sync(self.mock_client)
|
116
|
+
|
117
|
+
# Verify that cursor.execute was called with the correct SQL
|
118
|
+
self.mock_cursor.execute.assert_called_with("SET experimental_hnsw_index = 1")
|
119
|
+
self.mock_cursor.close.assert_called_once()
|
120
|
+
self.mock_client.logger.info.assert_called_with("✓ Enabled HNSW vector operations")
|
121
|
+
|
122
|
+
def test_execute_sync_with_fulltext_action(self):
|
123
|
+
"""Test synchronous execution with fulltext action"""
|
124
|
+
hook = ConnectionHook(actions=[ConnectionAction.ENABLE_FULLTEXT])
|
125
|
+
|
126
|
+
hook.execute_sync(self.mock_client)
|
127
|
+
|
128
|
+
# Verify that cursor.execute was called with the correct SQL
|
129
|
+
self.mock_cursor.execute.assert_called_with("SET experimental_fulltext_index = 1")
|
130
|
+
self.mock_cursor.close.assert_called_once()
|
131
|
+
self.mock_client.logger.info.assert_called_with("✓ Enabled fulltext search operations")
|
132
|
+
|
133
|
+
def test_execute_sync_with_vector_action(self):
|
134
|
+
"""Test synchronous execution with vector action (enables both IVF and HNSW)"""
|
135
|
+
hook = ConnectionHook(actions=[ConnectionAction.ENABLE_VECTOR])
|
136
|
+
|
137
|
+
hook.execute_sync(self.mock_client)
|
138
|
+
|
139
|
+
# Verify that both SQL statements were executed
|
140
|
+
expected_calls = [call("SET experimental_ivf_index = 1"), call("SET experimental_hnsw_index = 1")]
|
141
|
+
self.mock_cursor.execute.assert_has_calls(expected_calls, any_order=True)
|
142
|
+
self.mock_cursor.close.assert_called()
|
143
|
+
self.mock_client.logger.info.assert_any_call("✓ Enabled IVF vector operations")
|
144
|
+
self.mock_client.logger.info.assert_any_call("✓ Enabled HNSW vector operations")
|
145
|
+
|
146
|
+
def test_execute_sync_with_all_action(self):
|
147
|
+
"""Test synchronous execution with all action"""
|
148
|
+
hook = ConnectionHook(actions=[ConnectionAction.ENABLE_ALL])
|
149
|
+
|
150
|
+
hook.execute_sync(self.mock_client)
|
151
|
+
|
152
|
+
# Verify that all SQL statements were executed
|
153
|
+
expected_calls = [
|
154
|
+
call("SET experimental_ivf_index = 1"),
|
155
|
+
call("SET experimental_hnsw_index = 1"),
|
156
|
+
call("SET experimental_fulltext_index = 1"),
|
157
|
+
]
|
158
|
+
self.mock_cursor.execute.assert_has_calls(expected_calls, any_order=True)
|
159
|
+
self.mock_cursor.close.assert_called()
|
160
|
+
self.mock_client.logger.info.assert_any_call("✓ Enabled IVF vector operations")
|
161
|
+
self.mock_client.logger.info.assert_any_call("✓ Enabled HNSW vector operations")
|
162
|
+
self.mock_client.logger.info.assert_any_call("✓ Enabled fulltext search operations")
|
163
|
+
|
164
|
+
def test_execute_sync_with_custom_hook(self):
|
165
|
+
"""Test synchronous execution with custom hook"""
|
166
|
+
custom_hook_called = False
|
167
|
+
|
168
|
+
def custom_hook(client):
|
169
|
+
nonlocal custom_hook_called
|
170
|
+
custom_hook_called = True
|
171
|
+
assert client == self.mock_client
|
172
|
+
|
173
|
+
hook = ConnectionHook(custom_hook=custom_hook)
|
174
|
+
hook.execute_sync(self.mock_client)
|
175
|
+
|
176
|
+
assert custom_hook_called
|
177
|
+
|
178
|
+
def test_execute_sync_with_string_actions(self):
|
179
|
+
"""Test synchronous execution with string action names"""
|
180
|
+
hook = ConnectionHook(actions=["enable_ivf", "enable_fulltext"])
|
181
|
+
|
182
|
+
hook.execute_sync(self.mock_client)
|
183
|
+
|
184
|
+
# Verify that both SQL statements were executed
|
185
|
+
expected_calls = [call("SET experimental_ivf_index = 1"), call("SET experimental_fulltext_index = 1")]
|
186
|
+
self.mock_cursor.execute.assert_has_calls(expected_calls, any_order=True)
|
187
|
+
self.mock_cursor.close.assert_called()
|
188
|
+
|
189
|
+
def test_execute_sync_with_unknown_action(self):
|
190
|
+
"""Test synchronous execution with unknown action"""
|
191
|
+
hook = ConnectionHook(actions=["unknown_action"])
|
192
|
+
|
193
|
+
hook.execute_sync(self.mock_client)
|
194
|
+
|
195
|
+
# The error message should contain information about the unknown action
|
196
|
+
warning_calls = self.mock_client.logger.warning.call_args_list
|
197
|
+
assert len(warning_calls) > 0
|
198
|
+
assert "unknown_action" in str(warning_calls[-1])
|
199
|
+
|
200
|
+
def test_execute_sync_handles_exceptions(self):
|
201
|
+
"""Test that execute_sync handles exceptions gracefully"""
|
202
|
+
hook = ConnectionHook(actions=[ConnectionAction.ENABLE_IVF])
|
203
|
+
|
204
|
+
# Make cursor.execute raise an exception
|
205
|
+
self.mock_cursor.execute.side_effect = Exception("Test error")
|
206
|
+
|
207
|
+
# Should not raise exception
|
208
|
+
hook.execute_sync(self.mock_client)
|
209
|
+
|
210
|
+
self.mock_client.logger.warning.assert_called_with("Failed to enable IVF: Test error")
|
211
|
+
|
212
|
+
@pytest.mark.asyncio
|
213
|
+
async def test_execute_async_with_ivf_action(self):
|
214
|
+
"""Test asynchronous execution with IVF action"""
|
215
|
+
hook = ConnectionHook(actions=[ConnectionAction.ENABLE_IVF])
|
216
|
+
|
217
|
+
# Mock the execute_async method to avoid complex async mock setup
|
218
|
+
with patch.object(hook, 'execute_async', new_callable=AsyncMock) as mock_execute:
|
219
|
+
await hook.execute_async(self.mock_client)
|
220
|
+
|
221
|
+
# Verify that execute_async was called
|
222
|
+
mock_execute.assert_called_once_with(self.mock_client)
|
223
|
+
|
224
|
+
@pytest.mark.asyncio
|
225
|
+
async def test_execute_async_with_custom_async_hook(self):
|
226
|
+
"""Test asynchronous execution with custom async hook"""
|
227
|
+
custom_hook_called = False
|
228
|
+
|
229
|
+
async def async_custom_hook(client):
|
230
|
+
nonlocal custom_hook_called
|
231
|
+
custom_hook_called = True
|
232
|
+
assert client == self.mock_client
|
233
|
+
|
234
|
+
hook = ConnectionHook(custom_hook=async_custom_hook)
|
235
|
+
|
236
|
+
# Mock the execute_async method to avoid complex async mock setup
|
237
|
+
with patch.object(hook, 'execute_async', new_callable=AsyncMock) as mock_execute:
|
238
|
+
await hook.execute_async(self.mock_client)
|
239
|
+
|
240
|
+
# Verify that execute_async was called
|
241
|
+
mock_execute.assert_called_once_with(self.mock_client)
|
242
|
+
|
243
|
+
|
244
|
+
class TestCreateConnectionHook:
|
245
|
+
"""Test create_connection_hook function"""
|
246
|
+
|
247
|
+
def test_create_connection_hook_with_actions(self):
|
248
|
+
"""Test creating connection hook with actions"""
|
249
|
+
actions = [ConnectionAction.ENABLE_IVF, ConnectionAction.ENABLE_FULLTEXT]
|
250
|
+
hook = create_connection_hook(actions=actions)
|
251
|
+
|
252
|
+
assert isinstance(hook, ConnectionHook)
|
253
|
+
assert hook.actions == actions
|
254
|
+
assert hook.custom_hook is None
|
255
|
+
|
256
|
+
def test_create_connection_hook_with_custom_hook(self):
|
257
|
+
"""Test creating connection hook with custom hook"""
|
258
|
+
|
259
|
+
def custom_hook(client):
|
260
|
+
pass
|
261
|
+
|
262
|
+
hook = create_connection_hook(custom_hook=custom_hook)
|
263
|
+
|
264
|
+
assert isinstance(hook, ConnectionHook)
|
265
|
+
assert hook.actions == []
|
266
|
+
assert hook.custom_hook == custom_hook
|
267
|
+
|
268
|
+
def test_create_connection_hook_with_both(self):
|
269
|
+
"""Test creating connection hook with both actions and custom hook"""
|
270
|
+
actions = [ConnectionAction.ENABLE_IVF]
|
271
|
+
|
272
|
+
def custom_hook(client):
|
273
|
+
pass
|
274
|
+
|
275
|
+
hook = create_connection_hook(actions=actions, custom_hook=custom_hook)
|
276
|
+
|
277
|
+
assert isinstance(hook, ConnectionHook)
|
278
|
+
assert hook.actions == actions
|
279
|
+
assert hook.custom_hook == custom_hook
|
280
|
+
|
281
|
+
def test_create_connection_hook_with_string_actions(self):
|
282
|
+
"""Test creating connection hook with string actions"""
|
283
|
+
actions = ["enable_ivf", "enable_fulltext"]
|
284
|
+
hook = create_connection_hook(actions=actions)
|
285
|
+
|
286
|
+
assert isinstance(hook, ConnectionHook)
|
287
|
+
assert hook.actions == actions
|