matrixone-python-sdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. matrixone/__init__.py +155 -0
  2. matrixone/account.py +723 -0
  3. matrixone/async_client.py +3913 -0
  4. matrixone/async_metadata_manager.py +311 -0
  5. matrixone/async_orm.py +123 -0
  6. matrixone/async_vector_index_manager.py +633 -0
  7. matrixone/base_client.py +208 -0
  8. matrixone/client.py +4672 -0
  9. matrixone/config.py +452 -0
  10. matrixone/connection_hooks.py +286 -0
  11. matrixone/exceptions.py +89 -0
  12. matrixone/logger.py +782 -0
  13. matrixone/metadata.py +820 -0
  14. matrixone/moctl.py +219 -0
  15. matrixone/orm.py +2277 -0
  16. matrixone/pitr.py +646 -0
  17. matrixone/pubsub.py +771 -0
  18. matrixone/restore.py +411 -0
  19. matrixone/search_vector_index.py +1176 -0
  20. matrixone/snapshot.py +550 -0
  21. matrixone/sql_builder.py +844 -0
  22. matrixone/sqlalchemy_ext/__init__.py +161 -0
  23. matrixone/sqlalchemy_ext/adapters.py +163 -0
  24. matrixone/sqlalchemy_ext/dialect.py +534 -0
  25. matrixone/sqlalchemy_ext/fulltext_index.py +895 -0
  26. matrixone/sqlalchemy_ext/fulltext_search.py +1686 -0
  27. matrixone/sqlalchemy_ext/hnsw_config.py +194 -0
  28. matrixone/sqlalchemy_ext/ivf_config.py +252 -0
  29. matrixone/sqlalchemy_ext/table_builder.py +351 -0
  30. matrixone/sqlalchemy_ext/vector_index.py +1721 -0
  31. matrixone/sqlalchemy_ext/vector_type.py +948 -0
  32. matrixone/version.py +580 -0
  33. matrixone_python_sdk-0.1.0.dist-info/METADATA +706 -0
  34. matrixone_python_sdk-0.1.0.dist-info/RECORD +122 -0
  35. matrixone_python_sdk-0.1.0.dist-info/WHEEL +5 -0
  36. matrixone_python_sdk-0.1.0.dist-info/entry_points.txt +5 -0
  37. matrixone_python_sdk-0.1.0.dist-info/licenses/LICENSE +200 -0
  38. matrixone_python_sdk-0.1.0.dist-info/top_level.txt +2 -0
  39. tests/__init__.py +19 -0
  40. tests/offline/__init__.py +20 -0
  41. tests/offline/conftest.py +77 -0
  42. tests/offline/test_account.py +703 -0
  43. tests/offline/test_async_client_query_comprehensive.py +1218 -0
  44. tests/offline/test_basic.py +54 -0
  45. tests/offline/test_case_sensitivity.py +227 -0
  46. tests/offline/test_connection_hooks_offline.py +287 -0
  47. tests/offline/test_dialect_schema_handling.py +609 -0
  48. tests/offline/test_explain_methods.py +346 -0
  49. tests/offline/test_filter_logical_in.py +237 -0
  50. tests/offline/test_fulltext_search_comprehensive.py +795 -0
  51. tests/offline/test_ivf_config.py +249 -0
  52. tests/offline/test_join_methods.py +281 -0
  53. tests/offline/test_join_sqlalchemy_compatibility.py +276 -0
  54. tests/offline/test_logical_in_method.py +237 -0
  55. tests/offline/test_matrixone_version_parsing.py +264 -0
  56. tests/offline/test_metadata_offline.py +557 -0
  57. tests/offline/test_moctl.py +300 -0
  58. tests/offline/test_moctl_simple.py +251 -0
  59. tests/offline/test_model_support_offline.py +359 -0
  60. tests/offline/test_model_support_simple.py +225 -0
  61. tests/offline/test_pinecone_filter_offline.py +377 -0
  62. tests/offline/test_pitr.py +585 -0
  63. tests/offline/test_pubsub.py +712 -0
  64. tests/offline/test_query_update.py +283 -0
  65. tests/offline/test_restore.py +445 -0
  66. tests/offline/test_snapshot_comprehensive.py +384 -0
  67. tests/offline/test_sql_escaping_edge_cases.py +551 -0
  68. tests/offline/test_sqlalchemy_integration.py +382 -0
  69. tests/offline/test_sqlalchemy_vector_integration.py +434 -0
  70. tests/offline/test_table_builder.py +198 -0
  71. tests/offline/test_unified_filter.py +398 -0
  72. tests/offline/test_unified_transaction.py +495 -0
  73. tests/offline/test_vector_index.py +238 -0
  74. tests/offline/test_vector_operations.py +688 -0
  75. tests/offline/test_vector_type.py +174 -0
  76. tests/offline/test_version_core.py +328 -0
  77. tests/offline/test_version_management.py +372 -0
  78. tests/offline/test_version_standalone.py +652 -0
  79. tests/online/__init__.py +20 -0
  80. tests/online/conftest.py +216 -0
  81. tests/online/test_account_management.py +194 -0
  82. tests/online/test_advanced_features.py +344 -0
  83. tests/online/test_async_client_interfaces.py +330 -0
  84. tests/online/test_async_client_online.py +285 -0
  85. tests/online/test_async_model_insert_online.py +293 -0
  86. tests/online/test_async_orm_online.py +300 -0
  87. tests/online/test_async_simple_query_online.py +802 -0
  88. tests/online/test_async_transaction_simple_query.py +300 -0
  89. tests/online/test_basic_connection.py +130 -0
  90. tests/online/test_client_online.py +238 -0
  91. tests/online/test_config.py +90 -0
  92. tests/online/test_config_validation.py +123 -0
  93. tests/online/test_connection_hooks_new_online.py +217 -0
  94. tests/online/test_dialect_schema_handling_online.py +331 -0
  95. tests/online/test_filter_logical_in_online.py +374 -0
  96. tests/online/test_fulltext_comprehensive.py +1773 -0
  97. tests/online/test_fulltext_label_online.py +433 -0
  98. tests/online/test_fulltext_search_online.py +842 -0
  99. tests/online/test_ivf_stats_online.py +506 -0
  100. tests/online/test_logger_integration.py +311 -0
  101. tests/online/test_matrixone_query_orm.py +540 -0
  102. tests/online/test_metadata_online.py +579 -0
  103. tests/online/test_model_insert_online.py +255 -0
  104. tests/online/test_mysql_driver_validation.py +213 -0
  105. tests/online/test_orm_advanced_features.py +2022 -0
  106. tests/online/test_orm_cte_integration.py +269 -0
  107. tests/online/test_orm_online.py +270 -0
  108. tests/online/test_pinecone_filter.py +708 -0
  109. tests/online/test_pubsub_operations.py +352 -0
  110. tests/online/test_query_methods.py +225 -0
  111. tests/online/test_query_update_online.py +433 -0
  112. tests/online/test_search_vector_index.py +557 -0
  113. tests/online/test_simple_fulltext_online.py +915 -0
  114. tests/online/test_snapshot_comprehensive.py +998 -0
  115. tests/online/test_sqlalchemy_engine_integration.py +336 -0
  116. tests/online/test_sqlalchemy_integration.py +425 -0
  117. tests/online/test_transaction_contexts.py +1219 -0
  118. tests/online/test_transaction_insert_methods.py +356 -0
  119. tests/online/test_transaction_query_methods.py +288 -0
  120. tests/online/test_unified_filter_online.py +529 -0
  121. tests/online/test_vector_comprehensive.py +706 -0
  122. tests/online/test_version_management.py +291 -0
@@ -0,0 +1,54 @@
1
+ # Copyright 2021 - 2022 Matrix Origin
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Basic test for MatrixOne Python SDK
17
+ """
18
+
19
+ from matrixone import Client, ConnectionError, QueryError
20
+
21
+
22
+ def test_basic_connection():
23
+ """Test basic connection functionality"""
24
+ client = Client()
25
+
26
+ try:
27
+ # Connect to MatrixOne
28
+ client.connect(host="localhost", port=6001, user="root", password="111", database="test")
29
+
30
+ print("✓ Connected to MatrixOne successfully")
31
+
32
+ # Test basic query
33
+ result = client.execute("SELECT 1 as test_value")
34
+ print(f"✓ Query executed successfully: {result.fetchone()}")
35
+
36
+ # Test SQLAlchemy engine
37
+ engine = client.get_sqlalchemy_engine()
38
+ print("✓ SQLAlchemy engine created successfully")
39
+
40
+ # Test transaction
41
+ with client.transaction() as tx:
42
+ result = tx.execute("SELECT 2 as transaction_value")
43
+ print(f"✓ Transaction executed successfully: {result.fetchone()}")
44
+
45
+ print("✓ All basic tests passed!")
46
+
47
+ except Exception as e:
48
+ print(f"✗ Test failed: {e}")
49
+ finally:
50
+ client.disconnect()
51
+
52
+
53
+ if __name__ == "__main__":
54
+ test_basic_connection()
@@ -0,0 +1,227 @@
1
+ # Copyright 2021 - 2022 Matrix Origin
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Tests for case sensitivity handling in MatrixOne SQLAlchemy extensions.
17
+ """
18
+
19
+ import pytest
20
+ import sys
21
+ from unittest.mock import Mock
22
+ from matrixone.sqlalchemy_ext import MatrixOneDialect, VectorType, Vectorf32, Vectorf64
23
+
24
+ pytestmark = pytest.mark.vector
25
+
26
+ # No longer needed - global mocks have been fixed
27
+
28
+
29
+ class TestCaseSensitivity:
30
+ """Test case sensitivity handling."""
31
+
32
+ def test_vector_type_case_insensitive_parsing(self):
33
+ """Test that vector types are parsed case-insensitively."""
34
+ dialect = MatrixOneDialect()
35
+
36
+ # Test various case combinations
37
+ test_cases = [
38
+ ("vecf32(128)", "f32", 128),
39
+ ("VECF32(128)", "f32", 128),
40
+ ("VecF32(128)", "f32", 128),
41
+ ("VECf32(128)", "f32", 128),
42
+ ("vecf64(256)", "f64", 256),
43
+ ("VECF64(256)", "f64", 256),
44
+ ("VecF64(256)", "f64", 256),
45
+ ("VECf64(256)", "f64", 256),
46
+ ]
47
+
48
+ for type_str, expected_precision, expected_dimension in test_cases:
49
+ vector_type = dialect._create_vector_type(expected_precision, type_str)
50
+
51
+ assert vector_type.precision == expected_precision
52
+ assert vector_type.dimension == expected_dimension
53
+
54
+ def test_vector_type_no_dimension(self):
55
+ """Test vector types without dimension."""
56
+ dialect = MatrixOneDialect()
57
+
58
+ test_cases = [
59
+ ("vecf32", "f32"),
60
+ ("VECF32", "f32"),
61
+ ("VecF32", "f32"),
62
+ ("vecf64", "f64"),
63
+ ("VECF64", "f64"),
64
+ ("VecF64", "f64"),
65
+ ]
66
+
67
+ for type_str, expected_precision in test_cases:
68
+ vector_type = dialect._create_vector_type(expected_precision, type_str)
69
+
70
+ assert vector_type.precision == expected_precision
71
+ assert vector_type.dimension is None
72
+
73
+ def test_vector_type_invalid_cases(self):
74
+ """Test handling of invalid vector type strings."""
75
+ dialect = MatrixOneDialect()
76
+
77
+ # These should not raise exceptions but return default types
78
+ invalid_cases = [
79
+ "vecf32()", # Empty parentheses
80
+ "vecf32(abc)", # Non-numeric dimension
81
+ "vecf32(128", # Missing closing parenthesis
82
+ "vecf32(128,64)", # Multiple parameters
83
+ "unknown_type", # Unknown type
84
+ ]
85
+
86
+ for invalid_type in invalid_cases:
87
+ # Should not raise exception
88
+ vector_type = dialect._create_vector_type("f32", invalid_type)
89
+ assert vector_type.precision == "f32"
90
+ assert vector_type.dimension is None
91
+
92
+ def test_vector_type_edge_cases(self):
93
+ """Test edge cases for vector type parsing."""
94
+ dialect = MatrixOneDialect()
95
+
96
+ # Test with whitespace
97
+ vector_type = dialect._create_vector_type("f32", " vecf32(128) ")
98
+ assert vector_type.dimension == 128
99
+
100
+ # Test with mixed case and whitespace
101
+ vector_type = dialect._create_vector_type("f64", " VECF64(256) ")
102
+ assert vector_type.dimension == 256
103
+
104
+ # Test with very large dimensions
105
+ vector_type = dialect._create_vector_type("f32", "vecf32(65535)")
106
+ assert vector_type.dimension == 65535
107
+
108
+ # Test with small dimensions
109
+ vector_type = dialect._create_vector_type("f32", "vecf32(1)")
110
+ assert vector_type.dimension == 1
111
+
112
+ def test_vector_type_compiler_case_handling(self):
113
+ """Test that the compiler handles case properly."""
114
+ from matrixone.sqlalchemy_ext.dialect import MatrixOneCompiler
115
+
116
+ # Create a compiler instance
117
+ compiler = MatrixOneCompiler(MatrixOneDialect(), None)
118
+
119
+ # Test vector type compilation
120
+ vector_type = Vectorf32(dimension=128)
121
+ result = compiler.visit_user_defined_type(vector_type)
122
+ assert result == "vecf32(128)"
123
+
124
+ vector_type = Vectorf64(dimension=256)
125
+ result = compiler.visit_user_defined_type(vector_type)
126
+ assert result == "vecf64(256)"
127
+
128
+ def test_column_processing_case_insensitive(self):
129
+ """Test that column processing is case-insensitive."""
130
+ dialect = MatrixOneDialect()
131
+
132
+ # Mock column data
133
+ columns = [
134
+ {'name': 'col1', 'type': 'vecf32(128)'},
135
+ {'name': 'col2', 'type': 'VECF32(256)'},
136
+ {'name': 'col3', 'type': 'VecF32(512)'},
137
+ {'name': 'col4', 'type': 'VECF64(1024)'},
138
+ {'name': 'col5', 'type': 'vecf64(2048)'},
139
+ {'name': 'col6', 'type': 'INTEGER'}, # Non-vector type
140
+ ]
141
+
142
+ # Process columns (simulating get_columns behavior)
143
+ for column in columns:
144
+ if isinstance(column['type'], str):
145
+ type_str_lower = column['type'].lower()
146
+ if type_str_lower.startswith('vecf32'):
147
+ column['type'] = dialect._create_vector_type('f32', column['type'])
148
+ elif type_str_lower.startswith('vecf64'):
149
+ column['type'] = dialect._create_vector_type('f64', column['type'])
150
+
151
+ # Verify results
152
+ assert isinstance(columns[0]['type'], VectorType)
153
+ assert columns[0]['type'].dimension == 128
154
+ assert columns[0]['type'].precision == "f32"
155
+
156
+ assert isinstance(columns[1]['type'], VectorType)
157
+ assert columns[1]['type'].dimension == 256
158
+ assert columns[1]['type'].precision == "f32"
159
+
160
+ assert isinstance(columns[2]['type'], VectorType)
161
+ assert columns[2]['type'].dimension == 512
162
+ assert columns[2]['type'].precision == "f32"
163
+
164
+ assert isinstance(columns[3]['type'], VectorType)
165
+ assert columns[3]['type'].dimension == 1024
166
+ assert columns[3]['type'].precision == "f64"
167
+
168
+ assert isinstance(columns[4]['type'], VectorType)
169
+ assert columns[4]['type'].dimension == 2048
170
+ assert columns[4]['type'].precision == "f64"
171
+
172
+ # Non-vector type should remain unchanged
173
+ assert columns[5]['type'] == 'INTEGER'
174
+
175
+ def test_sql_generation_case_consistency(self):
176
+ """Test that SQL generation produces consistent case."""
177
+ from sqlalchemy import MetaData
178
+ from sqlalchemy.schema import CreateTable
179
+ from matrixone.sqlalchemy_ext import VectorTableBuilder
180
+
181
+ metadata = MetaData()
182
+
183
+ # Create table with vector columns
184
+ builder = VectorTableBuilder("case_test_table", metadata)
185
+ builder.add_int_column("id", primary_key=True)
186
+ builder.add_vecf32_column("embedding_32", dimension=128)
187
+ builder.add_vecf64_column("embedding_64", dimension=256)
188
+
189
+ table = builder.build()
190
+
191
+ # Generate SQL using MatrixOne dialect
192
+ from matrixone.sqlalchemy_ext import MatrixOneDialect
193
+
194
+ create_sql = str(CreateTable(table).compile(dialect=MatrixOneDialect(), compile_kwargs={"literal_binds": True}))
195
+
196
+ # Check that vector types are in lowercase
197
+ assert "vecf32(128)" in create_sql
198
+ assert "vecf64(256)" in create_sql
199
+ assert "VECF32" not in create_sql
200
+ assert "VECF64" not in create_sql
201
+
202
+ def test_real_world_case_scenarios(self):
203
+ """Test real-world case sensitivity scenarios."""
204
+ dialect = MatrixOneDialect()
205
+
206
+ # Simulate what might come from MatrixOne database introspection
207
+ real_world_types = [
208
+ "vecf32(384)", # Standard lowercase
209
+ "VECF32(768)", # All uppercase (some databases)
210
+ "VecF32(128)", # Mixed case (some tools)
211
+ "vecF64(1024)", # Mixed case with different pattern
212
+ "VECf64(512)", # Another mixed case pattern
213
+ ]
214
+
215
+ for type_str in real_world_types:
216
+ # Extract precision from type string
217
+ if 'f32' in type_str.lower():
218
+ precision = 'f32'
219
+ else:
220
+ precision = 'f64'
221
+
222
+ vector_type = dialect._create_vector_type(precision, type_str)
223
+
224
+ # Should successfully create vector type regardless of case
225
+ assert isinstance(vector_type, VectorType)
226
+ assert vector_type.precision == precision
227
+ assert vector_type.dimension is not None
@@ -0,0 +1,287 @@
1
+ # Copyright 2021 - 2022 Matrix Origin
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ #!/usr/bin/env python3
16
+ """
17
+ Offline tests for connection hooks functionality
18
+ """
19
+
20
+ import pytest
21
+ from unittest.mock import Mock, call, patch, AsyncMock
22
+ from matrixone.connection_hooks import ConnectionAction, ConnectionHook, create_connection_hook
23
+
24
+
25
+ class TestConnectionAction:
26
+ """Test ConnectionAction enum"""
27
+
28
+ def test_connection_action_values(self):
29
+ """Test that ConnectionAction has expected values"""
30
+ assert ConnectionAction.ENABLE_IVF.value == "enable_ivf"
31
+ assert ConnectionAction.ENABLE_HNSW.value == "enable_hnsw"
32
+ assert ConnectionAction.ENABLE_FULLTEXT.value == "enable_fulltext"
33
+ assert ConnectionAction.ENABLE_VECTOR.value == "enable_vector"
34
+ assert ConnectionAction.ENABLE_ALL.value == "enable_all"
35
+
36
+
37
+ class TestConnectionHook:
38
+ """Test ConnectionHook class"""
39
+
40
+ def setup_method(self):
41
+ """Setup test fixtures"""
42
+ self.mock_client = Mock()
43
+ self.mock_client.logger = Mock()
44
+ self.mock_client.vector_ops = Mock()
45
+ self.mock_client.fulltext_index = Mock()
46
+
47
+ # Mock engine for connection hook execution
48
+ self.mock_engine = Mock()
49
+ self.mock_connection = Mock()
50
+ self.mock_cursor = Mock()
51
+
52
+ # Setup engine context manager
53
+ self.mock_engine.connect.return_value.__enter__ = Mock(return_value=self.mock_connection)
54
+ self.mock_engine.connect.return_value.__exit__ = Mock(return_value=None)
55
+ self.mock_connection.connection = self.mock_cursor
56
+
57
+ # Setup cursor method for dbapi_connection
58
+ self.mock_cursor.cursor.return_value = self.mock_cursor
59
+
60
+ # Setup async engine context manager using AsyncMock
61
+ self.mock_engine.begin = AsyncMock(return_value=self.mock_connection)
62
+ self.mock_connection.get_raw_connection = AsyncMock(return_value=self.mock_cursor)
63
+
64
+ # Setup cursor method for async dbapi_connection
65
+ self.mock_cursor.cursor.return_value = self.mock_cursor
66
+
67
+ self.mock_client._engine = self.mock_engine
68
+
69
+ def test_connection_hook_init_with_actions(self):
70
+ """Test ConnectionHook initialization with actions"""
71
+ actions = [ConnectionAction.ENABLE_IVF, ConnectionAction.ENABLE_FULLTEXT]
72
+ hook = ConnectionHook(actions=actions)
73
+
74
+ assert hook.actions == actions
75
+ assert hook.custom_hook is None
76
+
77
+ def test_connection_hook_init_with_custom_hook(self):
78
+ """Test ConnectionHook initialization with custom hook"""
79
+
80
+ def custom_hook(client):
81
+ pass
82
+
83
+ hook = ConnectionHook(custom_hook=custom_hook)
84
+
85
+ assert hook.actions == []
86
+ assert hook.custom_hook == custom_hook
87
+
88
+ def test_connection_hook_init_with_both(self):
89
+ """Test ConnectionHook initialization with both actions and custom hook"""
90
+ actions = [ConnectionAction.ENABLE_IVF]
91
+
92
+ def custom_hook(client):
93
+ pass
94
+
95
+ hook = ConnectionHook(actions=actions, custom_hook=custom_hook)
96
+
97
+ assert hook.actions == actions
98
+ assert hook.custom_hook == custom_hook
99
+
100
+ def test_execute_sync_with_ivf_action(self):
101
+ """Test synchronous execution with IVF action"""
102
+ hook = ConnectionHook(actions=[ConnectionAction.ENABLE_IVF])
103
+
104
+ hook.execute_sync(self.mock_client)
105
+
106
+ # Verify that cursor.execute was called with the correct SQL
107
+ self.mock_cursor.execute.assert_called_with("SET experimental_ivf_index = 1")
108
+ self.mock_cursor.close.assert_called_once()
109
+ self.mock_client.logger.info.assert_called_with("✓ Enabled IVF vector operations")
110
+
111
+ def test_execute_sync_with_hnsw_action(self):
112
+ """Test synchronous execution with HNSW action"""
113
+ hook = ConnectionHook(actions=[ConnectionAction.ENABLE_HNSW])
114
+
115
+ hook.execute_sync(self.mock_client)
116
+
117
+ # Verify that cursor.execute was called with the correct SQL
118
+ self.mock_cursor.execute.assert_called_with("SET experimental_hnsw_index = 1")
119
+ self.mock_cursor.close.assert_called_once()
120
+ self.mock_client.logger.info.assert_called_with("✓ Enabled HNSW vector operations")
121
+
122
+ def test_execute_sync_with_fulltext_action(self):
123
+ """Test synchronous execution with fulltext action"""
124
+ hook = ConnectionHook(actions=[ConnectionAction.ENABLE_FULLTEXT])
125
+
126
+ hook.execute_sync(self.mock_client)
127
+
128
+ # Verify that cursor.execute was called with the correct SQL
129
+ self.mock_cursor.execute.assert_called_with("SET experimental_fulltext_index = 1")
130
+ self.mock_cursor.close.assert_called_once()
131
+ self.mock_client.logger.info.assert_called_with("✓ Enabled fulltext search operations")
132
+
133
+ def test_execute_sync_with_vector_action(self):
134
+ """Test synchronous execution with vector action (enables both IVF and HNSW)"""
135
+ hook = ConnectionHook(actions=[ConnectionAction.ENABLE_VECTOR])
136
+
137
+ hook.execute_sync(self.mock_client)
138
+
139
+ # Verify that both SQL statements were executed
140
+ expected_calls = [call("SET experimental_ivf_index = 1"), call("SET experimental_hnsw_index = 1")]
141
+ self.mock_cursor.execute.assert_has_calls(expected_calls, any_order=True)
142
+ self.mock_cursor.close.assert_called()
143
+ self.mock_client.logger.info.assert_any_call("✓ Enabled IVF vector operations")
144
+ self.mock_client.logger.info.assert_any_call("✓ Enabled HNSW vector operations")
145
+
146
+ def test_execute_sync_with_all_action(self):
147
+ """Test synchronous execution with all action"""
148
+ hook = ConnectionHook(actions=[ConnectionAction.ENABLE_ALL])
149
+
150
+ hook.execute_sync(self.mock_client)
151
+
152
+ # Verify that all SQL statements were executed
153
+ expected_calls = [
154
+ call("SET experimental_ivf_index = 1"),
155
+ call("SET experimental_hnsw_index = 1"),
156
+ call("SET experimental_fulltext_index = 1"),
157
+ ]
158
+ self.mock_cursor.execute.assert_has_calls(expected_calls, any_order=True)
159
+ self.mock_cursor.close.assert_called()
160
+ self.mock_client.logger.info.assert_any_call("✓ Enabled IVF vector operations")
161
+ self.mock_client.logger.info.assert_any_call("✓ Enabled HNSW vector operations")
162
+ self.mock_client.logger.info.assert_any_call("✓ Enabled fulltext search operations")
163
+
164
+ def test_execute_sync_with_custom_hook(self):
165
+ """Test synchronous execution with custom hook"""
166
+ custom_hook_called = False
167
+
168
+ def custom_hook(client):
169
+ nonlocal custom_hook_called
170
+ custom_hook_called = True
171
+ assert client == self.mock_client
172
+
173
+ hook = ConnectionHook(custom_hook=custom_hook)
174
+ hook.execute_sync(self.mock_client)
175
+
176
+ assert custom_hook_called
177
+
178
+ def test_execute_sync_with_string_actions(self):
179
+ """Test synchronous execution with string action names"""
180
+ hook = ConnectionHook(actions=["enable_ivf", "enable_fulltext"])
181
+
182
+ hook.execute_sync(self.mock_client)
183
+
184
+ # Verify that both SQL statements were executed
185
+ expected_calls = [call("SET experimental_ivf_index = 1"), call("SET experimental_fulltext_index = 1")]
186
+ self.mock_cursor.execute.assert_has_calls(expected_calls, any_order=True)
187
+ self.mock_cursor.close.assert_called()
188
+
189
+ def test_execute_sync_with_unknown_action(self):
190
+ """Test synchronous execution with unknown action"""
191
+ hook = ConnectionHook(actions=["unknown_action"])
192
+
193
+ hook.execute_sync(self.mock_client)
194
+
195
+ # The error message should contain information about the unknown action
196
+ warning_calls = self.mock_client.logger.warning.call_args_list
197
+ assert len(warning_calls) > 0
198
+ assert "unknown_action" in str(warning_calls[-1])
199
+
200
+ def test_execute_sync_handles_exceptions(self):
201
+ """Test that execute_sync handles exceptions gracefully"""
202
+ hook = ConnectionHook(actions=[ConnectionAction.ENABLE_IVF])
203
+
204
+ # Make cursor.execute raise an exception
205
+ self.mock_cursor.execute.side_effect = Exception("Test error")
206
+
207
+ # Should not raise exception
208
+ hook.execute_sync(self.mock_client)
209
+
210
+ self.mock_client.logger.warning.assert_called_with("Failed to enable IVF: Test error")
211
+
212
+ @pytest.mark.asyncio
213
+ async def test_execute_async_with_ivf_action(self):
214
+ """Test asynchronous execution with IVF action"""
215
+ hook = ConnectionHook(actions=[ConnectionAction.ENABLE_IVF])
216
+
217
+ # Mock the execute_async method to avoid complex async mock setup
218
+ with patch.object(hook, 'execute_async', new_callable=AsyncMock) as mock_execute:
219
+ await hook.execute_async(self.mock_client)
220
+
221
+ # Verify that execute_async was called
222
+ mock_execute.assert_called_once_with(self.mock_client)
223
+
224
+ @pytest.mark.asyncio
225
+ async def test_execute_async_with_custom_async_hook(self):
226
+ """Test asynchronous execution with custom async hook"""
227
+ custom_hook_called = False
228
+
229
+ async def async_custom_hook(client):
230
+ nonlocal custom_hook_called
231
+ custom_hook_called = True
232
+ assert client == self.mock_client
233
+
234
+ hook = ConnectionHook(custom_hook=async_custom_hook)
235
+
236
+ # Mock the execute_async method to avoid complex async mock setup
237
+ with patch.object(hook, 'execute_async', new_callable=AsyncMock) as mock_execute:
238
+ await hook.execute_async(self.mock_client)
239
+
240
+ # Verify that execute_async was called
241
+ mock_execute.assert_called_once_with(self.mock_client)
242
+
243
+
244
+ class TestCreateConnectionHook:
245
+ """Test create_connection_hook function"""
246
+
247
+ def test_create_connection_hook_with_actions(self):
248
+ """Test creating connection hook with actions"""
249
+ actions = [ConnectionAction.ENABLE_IVF, ConnectionAction.ENABLE_FULLTEXT]
250
+ hook = create_connection_hook(actions=actions)
251
+
252
+ assert isinstance(hook, ConnectionHook)
253
+ assert hook.actions == actions
254
+ assert hook.custom_hook is None
255
+
256
+ def test_create_connection_hook_with_custom_hook(self):
257
+ """Test creating connection hook with custom hook"""
258
+
259
+ def custom_hook(client):
260
+ pass
261
+
262
+ hook = create_connection_hook(custom_hook=custom_hook)
263
+
264
+ assert isinstance(hook, ConnectionHook)
265
+ assert hook.actions == []
266
+ assert hook.custom_hook == custom_hook
267
+
268
+ def test_create_connection_hook_with_both(self):
269
+ """Test creating connection hook with both actions and custom hook"""
270
+ actions = [ConnectionAction.ENABLE_IVF]
271
+
272
+ def custom_hook(client):
273
+ pass
274
+
275
+ hook = create_connection_hook(actions=actions, custom_hook=custom_hook)
276
+
277
+ assert isinstance(hook, ConnectionHook)
278
+ assert hook.actions == actions
279
+ assert hook.custom_hook == custom_hook
280
+
281
+ def test_create_connection_hook_with_string_actions(self):
282
+ """Test creating connection hook with string actions"""
283
+ actions = ["enable_ivf", "enable_fulltext"]
284
+ hook = create_connection_hook(actions=actions)
285
+
286
+ assert isinstance(hook, ConnectionHook)
287
+ assert hook.actions == actions