matrixone-python-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- matrixone/__init__.py +155 -0
- matrixone/account.py +723 -0
- matrixone/async_client.py +3913 -0
- matrixone/async_metadata_manager.py +311 -0
- matrixone/async_orm.py +123 -0
- matrixone/async_vector_index_manager.py +633 -0
- matrixone/base_client.py +208 -0
- matrixone/client.py +4672 -0
- matrixone/config.py +452 -0
- matrixone/connection_hooks.py +286 -0
- matrixone/exceptions.py +89 -0
- matrixone/logger.py +782 -0
- matrixone/metadata.py +820 -0
- matrixone/moctl.py +219 -0
- matrixone/orm.py +2277 -0
- matrixone/pitr.py +646 -0
- matrixone/pubsub.py +771 -0
- matrixone/restore.py +411 -0
- matrixone/search_vector_index.py +1176 -0
- matrixone/snapshot.py +550 -0
- matrixone/sql_builder.py +844 -0
- matrixone/sqlalchemy_ext/__init__.py +161 -0
- matrixone/sqlalchemy_ext/adapters.py +163 -0
- matrixone/sqlalchemy_ext/dialect.py +534 -0
- matrixone/sqlalchemy_ext/fulltext_index.py +895 -0
- matrixone/sqlalchemy_ext/fulltext_search.py +1686 -0
- matrixone/sqlalchemy_ext/hnsw_config.py +194 -0
- matrixone/sqlalchemy_ext/ivf_config.py +252 -0
- matrixone/sqlalchemy_ext/table_builder.py +351 -0
- matrixone/sqlalchemy_ext/vector_index.py +1721 -0
- matrixone/sqlalchemy_ext/vector_type.py +948 -0
- matrixone/version.py +580 -0
- matrixone_python_sdk-0.1.0.dist-info/METADATA +706 -0
- matrixone_python_sdk-0.1.0.dist-info/RECORD +122 -0
- matrixone_python_sdk-0.1.0.dist-info/WHEEL +5 -0
- matrixone_python_sdk-0.1.0.dist-info/entry_points.txt +5 -0
- matrixone_python_sdk-0.1.0.dist-info/licenses/LICENSE +200 -0
- matrixone_python_sdk-0.1.0.dist-info/top_level.txt +2 -0
- tests/__init__.py +19 -0
- tests/offline/__init__.py +20 -0
- tests/offline/conftest.py +77 -0
- tests/offline/test_account.py +703 -0
- tests/offline/test_async_client_query_comprehensive.py +1218 -0
- tests/offline/test_basic.py +54 -0
- tests/offline/test_case_sensitivity.py +227 -0
- tests/offline/test_connection_hooks_offline.py +287 -0
- tests/offline/test_dialect_schema_handling.py +609 -0
- tests/offline/test_explain_methods.py +346 -0
- tests/offline/test_filter_logical_in.py +237 -0
- tests/offline/test_fulltext_search_comprehensive.py +795 -0
- tests/offline/test_ivf_config.py +249 -0
- tests/offline/test_join_methods.py +281 -0
- tests/offline/test_join_sqlalchemy_compatibility.py +276 -0
- tests/offline/test_logical_in_method.py +237 -0
- tests/offline/test_matrixone_version_parsing.py +264 -0
- tests/offline/test_metadata_offline.py +557 -0
- tests/offline/test_moctl.py +300 -0
- tests/offline/test_moctl_simple.py +251 -0
- tests/offline/test_model_support_offline.py +359 -0
- tests/offline/test_model_support_simple.py +225 -0
- tests/offline/test_pinecone_filter_offline.py +377 -0
- tests/offline/test_pitr.py +585 -0
- tests/offline/test_pubsub.py +712 -0
- tests/offline/test_query_update.py +283 -0
- tests/offline/test_restore.py +445 -0
- tests/offline/test_snapshot_comprehensive.py +384 -0
- tests/offline/test_sql_escaping_edge_cases.py +551 -0
- tests/offline/test_sqlalchemy_integration.py +382 -0
- tests/offline/test_sqlalchemy_vector_integration.py +434 -0
- tests/offline/test_table_builder.py +198 -0
- tests/offline/test_unified_filter.py +398 -0
- tests/offline/test_unified_transaction.py +495 -0
- tests/offline/test_vector_index.py +238 -0
- tests/offline/test_vector_operations.py +688 -0
- tests/offline/test_vector_type.py +174 -0
- tests/offline/test_version_core.py +328 -0
- tests/offline/test_version_management.py +372 -0
- tests/offline/test_version_standalone.py +652 -0
- tests/online/__init__.py +20 -0
- tests/online/conftest.py +216 -0
- tests/online/test_account_management.py +194 -0
- tests/online/test_advanced_features.py +344 -0
- tests/online/test_async_client_interfaces.py +330 -0
- tests/online/test_async_client_online.py +285 -0
- tests/online/test_async_model_insert_online.py +293 -0
- tests/online/test_async_orm_online.py +300 -0
- tests/online/test_async_simple_query_online.py +802 -0
- tests/online/test_async_transaction_simple_query.py +300 -0
- tests/online/test_basic_connection.py +130 -0
- tests/online/test_client_online.py +238 -0
- tests/online/test_config.py +90 -0
- tests/online/test_config_validation.py +123 -0
- tests/online/test_connection_hooks_new_online.py +217 -0
- tests/online/test_dialect_schema_handling_online.py +331 -0
- tests/online/test_filter_logical_in_online.py +374 -0
- tests/online/test_fulltext_comprehensive.py +1773 -0
- tests/online/test_fulltext_label_online.py +433 -0
- tests/online/test_fulltext_search_online.py +842 -0
- tests/online/test_ivf_stats_online.py +506 -0
- tests/online/test_logger_integration.py +311 -0
- tests/online/test_matrixone_query_orm.py +540 -0
- tests/online/test_metadata_online.py +579 -0
- tests/online/test_model_insert_online.py +255 -0
- tests/online/test_mysql_driver_validation.py +213 -0
- tests/online/test_orm_advanced_features.py +2022 -0
- tests/online/test_orm_cte_integration.py +269 -0
- tests/online/test_orm_online.py +270 -0
- tests/online/test_pinecone_filter.py +708 -0
- tests/online/test_pubsub_operations.py +352 -0
- tests/online/test_query_methods.py +225 -0
- tests/online/test_query_update_online.py +433 -0
- tests/online/test_search_vector_index.py +557 -0
- tests/online/test_simple_fulltext_online.py +915 -0
- tests/online/test_snapshot_comprehensive.py +998 -0
- tests/online/test_sqlalchemy_engine_integration.py +336 -0
- tests/online/test_sqlalchemy_integration.py +425 -0
- tests/online/test_transaction_contexts.py +1219 -0
- tests/online/test_transaction_insert_methods.py +356 -0
- tests/online/test_transaction_query_methods.py +288 -0
- tests/online/test_unified_filter_online.py +529 -0
- tests/online/test_vector_comprehensive.py +706 -0
- tests/online/test_version_management.py +291 -0
@@ -0,0 +1,688 @@
|
|
1
|
+
# Copyright 2021 - 2022 Matrix Origin
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""
|
16
|
+
Offline tests for vector operations (create, search, order by, limit).
|
17
|
+
Uses mocks to simulate database interactions without requiring a real connection.
|
18
|
+
"""
|
19
|
+
|
20
|
+
import pytest
|
21
|
+
from unittest.mock import Mock, patch, MagicMock
|
22
|
+
from sqlalchemy import create_engine, text, func
|
23
|
+
from sqlalchemy.orm import sessionmaker, declarative_base
|
24
|
+
from sqlalchemy.exc import SQLAlchemyError
|
25
|
+
|
26
|
+
# Import the vector types and functions we want to test
|
27
|
+
from matrixone.sqlalchemy_ext import (
|
28
|
+
VectorType,
|
29
|
+
Vectorf32,
|
30
|
+
Vectorf64,
|
31
|
+
VectorColumn,
|
32
|
+
create_vector_column,
|
33
|
+
vector_distance_functions,
|
34
|
+
)
|
35
|
+
from matrixone.sqlalchemy_ext.dialect import MatrixOneDialect
|
36
|
+
from matrixone.sql_builder import (
|
37
|
+
MatrixOneSQLBuilder,
|
38
|
+
DistanceFunction,
|
39
|
+
build_vector_similarity_query,
|
40
|
+
build_select_query,
|
41
|
+
build_insert_query,
|
42
|
+
build_update_query,
|
43
|
+
build_delete_query,
|
44
|
+
)
|
45
|
+
|
46
|
+
|
47
|
+
class TestVectorOperationsOffline:
|
48
|
+
"""Test vector operations using mocks."""
|
49
|
+
|
50
|
+
def setup_method(self):
|
51
|
+
"""Set up test fixtures."""
|
52
|
+
# Mock engine and session
|
53
|
+
self.mock_engine = Mock()
|
54
|
+
self.mock_session = Mock()
|
55
|
+
self.mock_connection = Mock()
|
56
|
+
|
57
|
+
# Mock SQLAlchemy components
|
58
|
+
with patch('sqlalchemy.create_engine') as mock_create_engine:
|
59
|
+
mock_create_engine.return_value = self.mock_engine
|
60
|
+
self.mock_engine.connect.return_value = self.mock_connection
|
61
|
+
self.mock_engine.begin.return_value.__enter__ = Mock(return_value=self.mock_connection)
|
62
|
+
self.mock_engine.begin.return_value.__exit__ = Mock(return_value=None)
|
63
|
+
|
64
|
+
# Mock sessionmaker
|
65
|
+
with patch('sqlalchemy.orm.sessionmaker') as mock_sessionmaker:
|
66
|
+
mock_sessionmaker.return_value = self.mock_session
|
67
|
+
|
68
|
+
# Mock declarative_base
|
69
|
+
with patch('sqlalchemy.orm.declarative_base') as mock_declarative_base:
|
70
|
+
self.mock_base = Mock()
|
71
|
+
mock_declarative_base.return_value = self.mock_base
|
72
|
+
|
73
|
+
# Mock metadata
|
74
|
+
self.mock_metadata = Mock()
|
75
|
+
self.mock_base.metadata = self.mock_metadata
|
76
|
+
|
77
|
+
# Mock table creation
|
78
|
+
self.mock_metadata.create_all.return_value = None
|
79
|
+
self.mock_metadata.drop_all.return_value = None
|
80
|
+
|
81
|
+
def test_vector_column_creation(self):
|
82
|
+
"""Test creating vector columns with different dimensions and precisions."""
|
83
|
+
# Test Vectorf32 column
|
84
|
+
vec32_col = create_vector_column(128, "f32")
|
85
|
+
assert vec32_col.type.precision == "f32"
|
86
|
+
assert vec32_col.type.dimension == 128
|
87
|
+
|
88
|
+
# Test Vectorf64 column
|
89
|
+
vec64_col = create_vector_column(256, "f64")
|
90
|
+
assert vec64_col.type.precision == "f64"
|
91
|
+
assert vec64_col.type.dimension == 256
|
92
|
+
|
93
|
+
# Test with additional column options
|
94
|
+
vec_col_with_options = create_vector_column(512, "f32", nullable=False, index=True)
|
95
|
+
assert vec_col_with_options.nullable is False
|
96
|
+
assert vec_col_with_options.index is True
|
97
|
+
|
98
|
+
def test_vector_distance_functions_availability(self):
|
99
|
+
"""Test that all expected vector distance functions are available."""
|
100
|
+
functions = vector_distance_functions()
|
101
|
+
expected_functions = ["l2_distance", "l2_distance_sq", "cosine_distance"]
|
102
|
+
|
103
|
+
for func in expected_functions:
|
104
|
+
assert func in functions
|
105
|
+
|
106
|
+
def test_vector_column_distance_methods(self):
|
107
|
+
"""Test vector column distance calculation methods."""
|
108
|
+
# Create a vector column
|
109
|
+
vec_col = create_vector_column(10, "f32")
|
110
|
+
|
111
|
+
# Test that methods exist and return function expressions
|
112
|
+
# Test l2_distance
|
113
|
+
result = vec_col.l2_distance([1.0, 2.0, 3.0])
|
114
|
+
assert result is not None
|
115
|
+
assert hasattr(result, 'compile')
|
116
|
+
|
117
|
+
# Test l2_distance_sq
|
118
|
+
result = vec_col.l2_distance_sq([1.0, 2.0, 3.0])
|
119
|
+
assert result is not None
|
120
|
+
assert hasattr(result, 'compile')
|
121
|
+
|
122
|
+
# Test cosine_distance
|
123
|
+
result = vec_col.cosine_distance([1.0, 2.0, 3.0])
|
124
|
+
assert result is not None
|
125
|
+
assert hasattr(result, 'compile')
|
126
|
+
|
127
|
+
def test_vector_table_creation_mock(self):
|
128
|
+
"""Test vector table creation using mocks."""
|
129
|
+
# Mock the table creation process
|
130
|
+
self.mock_metadata.create_all.return_value = None
|
131
|
+
|
132
|
+
# Simulate table creation
|
133
|
+
self.mock_metadata.create_all(self.mock_engine)
|
134
|
+
|
135
|
+
# Verify create_all was called
|
136
|
+
self.mock_metadata.create_all.assert_called_once_with(self.mock_engine)
|
137
|
+
|
138
|
+
def test_vector_table_drop_mock(self):
|
139
|
+
"""Test vector table dropping using mocks."""
|
140
|
+
# Mock the table drop process
|
141
|
+
self.mock_metadata.drop_all.return_value = None
|
142
|
+
|
143
|
+
# Simulate table dropping
|
144
|
+
self.mock_metadata.drop_all(self.mock_engine)
|
145
|
+
|
146
|
+
# Verify drop_all was called
|
147
|
+
self.mock_metadata.drop_all.assert_called_once_with(self.mock_engine)
|
148
|
+
|
149
|
+
def test_vector_search_query_construction(self):
|
150
|
+
"""Test constructing vector search queries."""
|
151
|
+
# Mock session and query
|
152
|
+
mock_query = Mock()
|
153
|
+
self.mock_session.query.return_value = mock_query
|
154
|
+
mock_query.filter.return_value = mock_query
|
155
|
+
mock_query.order_by.return_value = mock_query
|
156
|
+
mock_query.limit.return_value = mock_query
|
157
|
+
|
158
|
+
# Create a mock model
|
159
|
+
mock_model = Mock()
|
160
|
+
mock_model.embedding = create_vector_column(10, "f32")
|
161
|
+
|
162
|
+
# Test search query construction
|
163
|
+
query = self.mock_session.query(mock_model)
|
164
|
+
filtered_query = query.filter(mock_model.embedding.l2_distance([1.0, 2.0, 3.0]) < 0.5)
|
165
|
+
ordered_query = filtered_query.order_by(mock_model.embedding.l2_distance([1.0, 2.0, 3.0]))
|
166
|
+
limited_query = ordered_query.limit(10)
|
167
|
+
|
168
|
+
# Verify query chain was constructed
|
169
|
+
self.mock_session.query.assert_called_once_with(mock_model)
|
170
|
+
mock_query.filter.assert_called_once()
|
171
|
+
mock_query.order_by.assert_called_once()
|
172
|
+
mock_query.limit.assert_called_once_with(10)
|
173
|
+
|
174
|
+
def test_vector_insertion_mock(self):
|
175
|
+
"""Test vector data insertion using mocks."""
|
176
|
+
# Mock session add and commit
|
177
|
+
self.mock_session.add.return_value = None
|
178
|
+
self.mock_session.commit.return_value = None
|
179
|
+
|
180
|
+
# Create mock data
|
181
|
+
mock_data = Mock()
|
182
|
+
mock_data.embedding = [1.0, 2.0, 3.0, 4.0, 5.0]
|
183
|
+
mock_data.description = "Test document"
|
184
|
+
|
185
|
+
# Simulate insertion
|
186
|
+
self.mock_session.add(mock_data)
|
187
|
+
self.mock_session.commit()
|
188
|
+
|
189
|
+
# Verify insertion was called
|
190
|
+
self.mock_session.add.assert_called_once_with(mock_data)
|
191
|
+
self.mock_session.commit.assert_called_once()
|
192
|
+
|
193
|
+
def test_vector_batch_insertion_mock(self):
|
194
|
+
"""Test batch vector data insertion using mocks."""
|
195
|
+
# Mock session bulk_insert_mappings
|
196
|
+
self.mock_session.bulk_insert_mappings.return_value = None
|
197
|
+
self.mock_session.commit.return_value = None
|
198
|
+
|
199
|
+
# Create mock batch data
|
200
|
+
batch_data = [
|
201
|
+
{"embedding": [1.0, 2.0, 3.0], "description": "Doc 1"},
|
202
|
+
{"embedding": [4.0, 5.0, 6.0], "description": "Doc 2"},
|
203
|
+
{"embedding": [7.0, 8.0, 9.0], "description": "Doc 3"},
|
204
|
+
]
|
205
|
+
|
206
|
+
# Simulate batch insertion
|
207
|
+
self.mock_session.bulk_insert_mappings(Mock, batch_data)
|
208
|
+
self.mock_session.commit()
|
209
|
+
|
210
|
+
# Verify batch insertion was called
|
211
|
+
self.mock_session.bulk_insert_mappings.assert_called_once()
|
212
|
+
self.mock_session.commit.assert_called_once()
|
213
|
+
|
214
|
+
def test_vector_search_with_multiple_criteria(self):
|
215
|
+
"""Test complex vector search with multiple criteria."""
|
216
|
+
# Mock session and query
|
217
|
+
mock_query = Mock()
|
218
|
+
self.mock_session.query.return_value = mock_query
|
219
|
+
mock_query.filter.return_value = mock_query
|
220
|
+
mock_query.order_by.return_value = mock_query
|
221
|
+
mock_query.limit.return_value = mock_query
|
222
|
+
mock_query.all.return_value = []
|
223
|
+
|
224
|
+
# Create mock model
|
225
|
+
mock_model = Mock()
|
226
|
+
mock_model.embedding = create_vector_column(10, "f32")
|
227
|
+
|
228
|
+
# Test complex search
|
229
|
+
query = self.mock_session.query(mock_model)
|
230
|
+
query = query.filter(mock_model.embedding.l2_distance([1.0, 2.0, 3.0]) < 0.5)
|
231
|
+
query = query.filter(mock_model.embedding.cosine_distance([1.0, 2.0, 3.0]) < 0.3)
|
232
|
+
query = query.order_by(mock_model.embedding.l2_distance([1.0, 2.0, 3.0]))
|
233
|
+
query = query.limit(5)
|
234
|
+
results = query.all()
|
235
|
+
|
236
|
+
# Verify complex query chain
|
237
|
+
assert mock_query.filter.call_count == 2 # Two filter calls
|
238
|
+
mock_query.order_by.assert_called_once()
|
239
|
+
mock_query.limit.assert_called_once_with(5)
|
240
|
+
mock_query.all.assert_called_once()
|
241
|
+
|
242
|
+
def test_vector_type_compilation(self):
|
243
|
+
"""Test vector type compilation for SQL generation."""
|
244
|
+
# Test Vectorf32 compilation
|
245
|
+
vec32_type = Vectorf32(dimension=128)
|
246
|
+
col_spec = vec32_type.get_col_spec()
|
247
|
+
assert col_spec == "vecf32(128)"
|
248
|
+
|
249
|
+
# Test Vectorf64 compilation
|
250
|
+
vec64_type = Vectorf64(dimension=256)
|
251
|
+
col_spec = vec64_type.get_col_spec()
|
252
|
+
assert col_spec == "vecf64(256)"
|
253
|
+
|
254
|
+
def test_vector_column_string_representation(self):
|
255
|
+
"""Test vector column string representation."""
|
256
|
+
vec_col = create_vector_column(64, "f32")
|
257
|
+
|
258
|
+
# Test that the column has the expected type
|
259
|
+
assert str(vec_col.type) == "vecf32(64)"
|
260
|
+
|
261
|
+
# Test column name
|
262
|
+
vec_col.name = "embedding"
|
263
|
+
assert vec_col.name == "embedding"
|
264
|
+
|
265
|
+
def test_vector_distance_function_parameters(self):
|
266
|
+
"""Test vector distance functions with different parameter types."""
|
267
|
+
vec_col = create_vector_column(5, "f32")
|
268
|
+
|
269
|
+
# Test with list parameter - should return function expression
|
270
|
+
result = vec_col.l2_distance([1.0, 2.0, 3.0, 4.0, 5.0])
|
271
|
+
assert result is not None
|
272
|
+
assert hasattr(result, 'compile')
|
273
|
+
|
274
|
+
# Test with string parameter - should return function expression
|
275
|
+
result = vec_col.l2_distance("[1.0,2.0,3.0,4.0,5.0]")
|
276
|
+
assert result is not None
|
277
|
+
assert hasattr(result, 'compile')
|
278
|
+
|
279
|
+
# Test with column parameter - should return function expression
|
280
|
+
other_col = Mock()
|
281
|
+
result = vec_col.l2_distance(other_col)
|
282
|
+
assert result is not None
|
283
|
+
assert hasattr(result, 'compile')
|
284
|
+
|
285
|
+
def test_vector_sql_generation_comparison(self):
|
286
|
+
"""Test actual SQL generation against expected SQL."""
|
287
|
+
# Create a vector column
|
288
|
+
vec_col = create_vector_column(10, "f32")
|
289
|
+
vec_col.name = "embedding"
|
290
|
+
|
291
|
+
# Test L2 distance SQL generation
|
292
|
+
result = vec_col.l2_distance([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0])
|
293
|
+
|
294
|
+
# Verify the result is a function expression
|
295
|
+
assert result is not None
|
296
|
+
assert hasattr(result, 'compile')
|
297
|
+
|
298
|
+
# Test that we can compile it to SQL
|
299
|
+
sql = str(result.compile(dialect=MatrixOneDialect()))
|
300
|
+
assert "l2_distance" in sql
|
301
|
+
assert "embedding" in sql
|
302
|
+
|
303
|
+
def test_vector_table_creation_sql_comparison(self):
|
304
|
+
"""Test table creation SQL generation against expected SQL."""
|
305
|
+
from sqlalchemy import Table, Column, Integer, String, MetaData
|
306
|
+
from sqlalchemy.schema import CreateTable
|
307
|
+
|
308
|
+
# Create metadata and table
|
309
|
+
metadata = MetaData()
|
310
|
+
table = Table(
|
311
|
+
'vector_docs',
|
312
|
+
metadata,
|
313
|
+
Column('id', Integer, primary_key=True),
|
314
|
+
Column('embedding', Vectorf32(dimension=128)),
|
315
|
+
Column('description', String(200)),
|
316
|
+
)
|
317
|
+
|
318
|
+
# Generate CREATE TABLE SQL
|
319
|
+
create_sql = str(CreateTable(table).compile(dialect=MatrixOneDialect()))
|
320
|
+
|
321
|
+
# Expected SQL components
|
322
|
+
expected_components = [
|
323
|
+
"CREATE TABLE vector_docs",
|
324
|
+
"id INTEGER NOT NULL",
|
325
|
+
"embedding vecf32(128)",
|
326
|
+
"description VARCHAR(200)",
|
327
|
+
"PRIMARY KEY (id)",
|
328
|
+
]
|
329
|
+
|
330
|
+
# Verify all expected components are in the generated SQL
|
331
|
+
for component in expected_components:
|
332
|
+
assert component in create_sql
|
333
|
+
|
334
|
+
def test_vector_search_query_sql_comparison(self):
|
335
|
+
"""Test search query SQL generation against expected SQL."""
|
336
|
+
from sqlalchemy import Table, Column, Integer, String, MetaData, select
|
337
|
+
from sqlalchemy.orm import declarative_base
|
338
|
+
|
339
|
+
# Create base and model
|
340
|
+
Base = declarative_base()
|
341
|
+
|
342
|
+
class VectorDoc(Base):
|
343
|
+
__tablename__ = 'vector_docs'
|
344
|
+
id = Column(Integer, primary_key=True)
|
345
|
+
embedding = create_vector_column(10, "f32")
|
346
|
+
description = Column(String(200))
|
347
|
+
|
348
|
+
# Create search query
|
349
|
+
query = (
|
350
|
+
select(VectorDoc)
|
351
|
+
.where(VectorDoc.embedding.l2_distance([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]) < 0.5)
|
352
|
+
.order_by(VectorDoc.embedding.l2_distance([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]))
|
353
|
+
.limit(10)
|
354
|
+
)
|
355
|
+
|
356
|
+
# Generate SQL
|
357
|
+
sql = str(query.compile(dialect=MatrixOneDialect()))
|
358
|
+
|
359
|
+
# Expected SQL components (actual format from SQLAlchemy)
|
360
|
+
expected_components = [
|
361
|
+
"SELECT vector_docs.id",
|
362
|
+
"vector_docs.embedding",
|
363
|
+
"vector_docs.description",
|
364
|
+
"FROM vector_docs",
|
365
|
+
"WHERE l2_distance(vector_docs.embedding",
|
366
|
+
"ORDER BY l2_distance(vector_docs.embedding",
|
367
|
+
"LIMIT %s", # SQLAlchemy uses %s for parameters
|
368
|
+
]
|
369
|
+
|
370
|
+
# Verify all expected components are in the generated SQL
|
371
|
+
for component in expected_components:
|
372
|
+
assert component in sql
|
373
|
+
|
374
|
+
def test_vector_complex_query_sql_comparison(self):
|
375
|
+
"""Test complex vector query SQL generation against expected SQL."""
|
376
|
+
from sqlalchemy import Table, Column, Integer, String, MetaData, select, and_
|
377
|
+
from sqlalchemy.orm import declarative_base
|
378
|
+
|
379
|
+
# Create base and model
|
380
|
+
Base = declarative_base()
|
381
|
+
|
382
|
+
class VectorDoc(Base):
|
383
|
+
__tablename__ = 'vector_docs'
|
384
|
+
id = Column(Integer, primary_key=True)
|
385
|
+
embedding = create_vector_column(5, "f32")
|
386
|
+
description = Column(String(200))
|
387
|
+
category = Column(String(50))
|
388
|
+
|
389
|
+
# Create complex search query
|
390
|
+
query = (
|
391
|
+
select(VectorDoc)
|
392
|
+
.where(
|
393
|
+
and_(
|
394
|
+
VectorDoc.embedding.l2_distance([1.0, 2.0, 3.0, 4.0, 5.0]) < 0.5,
|
395
|
+
VectorDoc.embedding.cosine_distance([1.0, 2.0, 3.0, 4.0, 5.0]) < 0.3,
|
396
|
+
VectorDoc.category == "science",
|
397
|
+
)
|
398
|
+
)
|
399
|
+
.order_by(VectorDoc.embedding.l2_distance([1.0, 2.0, 3.0, 4.0, 5.0]))
|
400
|
+
.limit(5)
|
401
|
+
)
|
402
|
+
|
403
|
+
# Generate SQL
|
404
|
+
sql = str(query.compile(dialect=MatrixOneDialect()))
|
405
|
+
|
406
|
+
# Expected SQL components (actual format from SQLAlchemy)
|
407
|
+
expected_components = [
|
408
|
+
"SELECT vector_docs.id",
|
409
|
+
"vector_docs.embedding",
|
410
|
+
"vector_docs.description",
|
411
|
+
"vector_docs.category",
|
412
|
+
"FROM vector_docs",
|
413
|
+
"WHERE l2_distance(vector_docs.embedding",
|
414
|
+
"cosine_distance(vector_docs.embedding",
|
415
|
+
"vector_docs.category = %s", # SQLAlchemy uses %s for parameters
|
416
|
+
"ORDER BY l2_distance(vector_docs.embedding",
|
417
|
+
"LIMIT %s", # SQLAlchemy uses %s for parameters
|
418
|
+
]
|
419
|
+
|
420
|
+
# Verify all expected components are in the generated SQL
|
421
|
+
for component in expected_components:
|
422
|
+
assert component in sql
|
423
|
+
|
424
|
+
def test_vector_batch_insert_sql_comparison(self):
|
425
|
+
"""Test batch insert SQL generation against expected SQL."""
|
426
|
+
from sqlalchemy import Table, Column, Integer, String, MetaData, insert
|
427
|
+
from sqlalchemy.orm import declarative_base
|
428
|
+
|
429
|
+
# Create base and model
|
430
|
+
Base = declarative_base()
|
431
|
+
|
432
|
+
class VectorDoc(Base):
|
433
|
+
__tablename__ = 'vector_docs'
|
434
|
+
id = Column(Integer, primary_key=True)
|
435
|
+
embedding = create_vector_column(3, "f32")
|
436
|
+
description = Column(String(200))
|
437
|
+
|
438
|
+
# Create batch insert statement
|
439
|
+
stmt = insert(VectorDoc).values(
|
440
|
+
[
|
441
|
+
{"embedding": [1.0, 2.0, 3.0], "description": "Doc 1"},
|
442
|
+
{"embedding": [4.0, 5.0, 6.0], "description": "Doc 2"},
|
443
|
+
{"embedding": [7.0, 8.0, 9.0], "description": "Doc 3"},
|
444
|
+
]
|
445
|
+
)
|
446
|
+
|
447
|
+
# Generate SQL
|
448
|
+
sql = str(stmt.compile(dialect=MatrixOneDialect()))
|
449
|
+
|
450
|
+
# Expected SQL components (actual format from SQLAlchemy)
|
451
|
+
expected_components = [
|
452
|
+
"INSERT INTO vector_docs",
|
453
|
+
"embedding",
|
454
|
+
"description",
|
455
|
+
"VALUES",
|
456
|
+
"(%s, %s)", # SQLAlchemy uses %s for parameters
|
457
|
+
"(%s, %s)",
|
458
|
+
"(%s, %s)",
|
459
|
+
]
|
460
|
+
|
461
|
+
# Verify all expected components are in the generated SQL
|
462
|
+
for component in expected_components:
|
463
|
+
assert component in sql
|
464
|
+
|
465
|
+
|
466
|
+
class TestVectorOperationsIntegrationOffline:
|
467
|
+
"""Integration tests for vector operations using mocks."""
|
468
|
+
|
469
|
+
def test_complete_vector_workflow_mock(self):
|
470
|
+
"""Test complete vector workflow: create table, insert data, search."""
|
471
|
+
# Mock all necessary components
|
472
|
+
with patch('sqlalchemy.create_engine') as mock_create_engine, patch(
|
473
|
+
'sqlalchemy.orm.sessionmaker'
|
474
|
+
) as mock_sessionmaker, patch('sqlalchemy.orm.declarative_base') as mock_declarative_base:
|
475
|
+
|
476
|
+
# Setup mocks
|
477
|
+
mock_engine = Mock()
|
478
|
+
mock_create_engine.return_value = mock_engine
|
479
|
+
|
480
|
+
mock_session = Mock()
|
481
|
+
mock_sessionmaker.return_value = mock_session
|
482
|
+
|
483
|
+
mock_base = Mock()
|
484
|
+
mock_declarative_base.return_value = mock_base
|
485
|
+
mock_base.metadata = Mock()
|
486
|
+
|
487
|
+
# Test workflow
|
488
|
+
# 1. Create table
|
489
|
+
mock_base.metadata.create_all(mock_engine)
|
490
|
+
|
491
|
+
# 2. Insert data
|
492
|
+
mock_data = Mock()
|
493
|
+
mock_session.add(mock_data)
|
494
|
+
mock_session.commit()
|
495
|
+
|
496
|
+
# 3. Search data
|
497
|
+
mock_query = Mock()
|
498
|
+
mock_session.query.return_value = mock_query
|
499
|
+
mock_query.filter.return_value = mock_query
|
500
|
+
mock_query.order_by.return_value = mock_query
|
501
|
+
mock_query.limit.return_value = mock_query
|
502
|
+
mock_query.all.return_value = []
|
503
|
+
|
504
|
+
# Execute search
|
505
|
+
results = mock_query.all()
|
506
|
+
|
507
|
+
# Verify workflow
|
508
|
+
mock_base.metadata.create_all.assert_called_once_with(mock_engine)
|
509
|
+
mock_session.add.assert_called_once_with(mock_data)
|
510
|
+
mock_session.commit.assert_called_once()
|
511
|
+
mock_query.all.assert_called_once()
|
512
|
+
|
513
|
+
def test_vector_error_handling_mock(self):
|
514
|
+
"""Test vector operations error handling using mocks."""
|
515
|
+
# Mock session with error
|
516
|
+
mock_session = Mock()
|
517
|
+
mock_session.add.side_effect = SQLAlchemyError("Mock database error")
|
518
|
+
|
519
|
+
# Test error handling
|
520
|
+
with pytest.raises(SQLAlchemyError):
|
521
|
+
mock_data = Mock()
|
522
|
+
mock_session.add(mock_data)
|
523
|
+
mock_session.commit()
|
524
|
+
|
525
|
+
# Verify error was raised
|
526
|
+
mock_session.add.assert_called_once()
|
527
|
+
mock_session.commit.assert_not_called()
|
528
|
+
|
529
|
+
|
530
|
+
class TestUnifiedSQLBuilderOffline:
|
531
|
+
"""Test unified SQL builder functionality offline."""
|
532
|
+
|
533
|
+
def test_basic_sql_construction(self):
|
534
|
+
"""Test basic SQL construction with MatrixOneSQLBuilder."""
|
535
|
+
builder = MatrixOneSQLBuilder()
|
536
|
+
sql, params = builder.select('id', 'name').from_table('users').build()
|
537
|
+
|
538
|
+
assert sql == "SELECT id, name FROM users"
|
539
|
+
assert params == []
|
540
|
+
|
541
|
+
def test_select_with_where(self):
|
542
|
+
"""Test SELECT with WHERE clause."""
|
543
|
+
builder = MatrixOneSQLBuilder()
|
544
|
+
sql, params = builder.select('*').from_table('users').where('age > ?', 18).build()
|
545
|
+
|
546
|
+
assert "SELECT * FROM users WHERE age > ?" in sql
|
547
|
+
assert params == [18]
|
548
|
+
|
549
|
+
def test_vector_similarity_query(self):
|
550
|
+
"""Test vector similarity query building."""
|
551
|
+
sql = build_vector_similarity_query(
|
552
|
+
table_name='documents',
|
553
|
+
vector_column='embedding',
|
554
|
+
query_vector=[0.1, 0.2, 0.3],
|
555
|
+
distance_func=DistanceFunction.L2_SQ,
|
556
|
+
limit=10,
|
557
|
+
select_columns=['id', 'title'],
|
558
|
+
where_conditions=['category = ?'],
|
559
|
+
where_params=['news'],
|
560
|
+
)
|
561
|
+
|
562
|
+
assert "l2_distance_sq" in sql
|
563
|
+
assert "WHERE category = 'news'" in sql
|
564
|
+
assert "ORDER BY distance" in sql
|
565
|
+
assert "LIMIT 10" in sql
|
566
|
+
|
567
|
+
def test_cte_query(self):
|
568
|
+
"""Test CTE query construction."""
|
569
|
+
builder = MatrixOneSQLBuilder()
|
570
|
+
sql, params = (
|
571
|
+
builder.with_cte(
|
572
|
+
'dept_stats',
|
573
|
+
'SELECT department_id, COUNT(*) as emp_count FROM employees GROUP BY department_id',
|
574
|
+
)
|
575
|
+
.select('d.name', 'ds.emp_count')
|
576
|
+
.from_table('departments d')
|
577
|
+
.join('dept_stats ds', 'd.id = ds.department_id', 'INNER')
|
578
|
+
.where('ds.emp_count > ?', 5)
|
579
|
+
.build()
|
580
|
+
)
|
581
|
+
|
582
|
+
assert "WITH dept_stats AS" in sql
|
583
|
+
assert "INNER JOIN dept_stats ds" in sql
|
584
|
+
assert params == [5]
|
585
|
+
|
586
|
+
def test_insert_query(self):
|
587
|
+
"""Test INSERT query construction."""
|
588
|
+
sql, params = build_insert_query(
|
589
|
+
table_name="users", values={'name': 'John Doe', 'email': 'john@example.com', 'age': 30}
|
590
|
+
)
|
591
|
+
|
592
|
+
assert "INSERT INTO users" in sql
|
593
|
+
assert "name, email, age" in sql
|
594
|
+
assert params == ['John Doe', 'john@example.com', 30]
|
595
|
+
|
596
|
+
def test_update_query(self):
|
597
|
+
"""Test UPDATE query construction."""
|
598
|
+
sql, params = build_update_query(
|
599
|
+
table_name="users",
|
600
|
+
set_values={'age': 31, 'last_login': '2024-01-01'},
|
601
|
+
where_conditions=['id = ?'],
|
602
|
+
where_params=[123],
|
603
|
+
)
|
604
|
+
|
605
|
+
assert "UPDATE users SET" in sql
|
606
|
+
assert "WHERE id = ?" in sql
|
607
|
+
assert params == [31, '2024-01-01', 123]
|
608
|
+
|
609
|
+
def test_delete_query(self):
|
610
|
+
"""Test DELETE query construction."""
|
611
|
+
sql, params = build_delete_query(
|
612
|
+
table_name="users",
|
613
|
+
where_conditions=['status = ?', 'last_login < ?'],
|
614
|
+
where_params=['inactive', '2023-01-01'],
|
615
|
+
)
|
616
|
+
|
617
|
+
assert "DELETE FROM users" in sql
|
618
|
+
assert "WHERE status = ? AND last_login < ?" in sql
|
619
|
+
assert params == ['inactive', '2023-01-01']
|
620
|
+
|
621
|
+
def test_convenience_functions(self):
|
622
|
+
"""Test convenience functions for common operations."""
|
623
|
+
# Test build_select_query
|
624
|
+
sql = build_select_query(
|
625
|
+
table_name="products",
|
626
|
+
select_columns=["id", "name", "price"],
|
627
|
+
where_conditions=["category = ?", "price < ?"],
|
628
|
+
where_params=["electronics", 1000],
|
629
|
+
order_by=["price"],
|
630
|
+
limit=5,
|
631
|
+
)
|
632
|
+
|
633
|
+
assert "SELECT id, name, price FROM products" in sql
|
634
|
+
assert "WHERE category = 'electronics' AND price < 1000" in sql
|
635
|
+
assert "ORDER BY price" in sql
|
636
|
+
assert "LIMIT 5" in sql
|
637
|
+
|
638
|
+
def test_parameter_substitution(self):
|
639
|
+
"""Test parameter substitution for MatrixOne compatibility."""
|
640
|
+
builder = MatrixOneSQLBuilder()
|
641
|
+
sql = (
|
642
|
+
builder.select('*')
|
643
|
+
.from_table('users')
|
644
|
+
.where('age > ?', 18)
|
645
|
+
.where('name = ?', "John")
|
646
|
+
.build_with_parameter_substitution()
|
647
|
+
)
|
648
|
+
|
649
|
+
assert "WHERE age > 18 AND name = 'John'" in sql
|
650
|
+
assert "?" not in sql # All parameters should be substituted
|
651
|
+
|
652
|
+
def test_snapshot_syntax(self):
|
653
|
+
"""Test snapshot query syntax."""
|
654
|
+
builder = MatrixOneSQLBuilder()
|
655
|
+
sql, params = builder.select('*').from_table('users', 'user_snapshot').build()
|
656
|
+
|
657
|
+
assert "FROM users{snapshot = 'user_snapshot'}" in sql
|
658
|
+
assert params == []
|
659
|
+
|
660
|
+
def test_complex_query_construction(self):
|
661
|
+
"""Test complex query construction with multiple clauses."""
|
662
|
+
builder = MatrixOneSQLBuilder()
|
663
|
+
sql, params = (
|
664
|
+
builder.select('u.id', 'u.name', 'd.name as dept_name', 'COUNT(p.id) as project_count')
|
665
|
+
.from_table('users u')
|
666
|
+
.left_join('departments d', 'u.department_id = d.id')
|
667
|
+
.left_join('projects p', 'u.id = p.owner_id')
|
668
|
+
.where('u.status = ?', 'active')
|
669
|
+
.where('u.created_at > ?', '2023-01-01')
|
670
|
+
.group_by('u.id', 'u.name', 'd.name')
|
671
|
+
.having('COUNT(p.id) > ?', 0)
|
672
|
+
.order_by('project_count DESC', 'u.name')
|
673
|
+
.limit(10)
|
674
|
+
.offset(20)
|
675
|
+
.build()
|
676
|
+
)
|
677
|
+
|
678
|
+
assert "SELECT u.id, u.name, d.name as dept_name, COUNT(p.id) as project_count" in sql
|
679
|
+
assert "FROM users u" in sql
|
680
|
+
assert "LEFT JOIN departments d" in sql
|
681
|
+
assert "LEFT JOIN projects p" in sql
|
682
|
+
assert "WHERE u.status = ? AND u.created_at > ?" in sql
|
683
|
+
assert "GROUP BY u.id, u.name, d.name" in sql
|
684
|
+
assert "HAVING COUNT(p.id) > ?" in sql
|
685
|
+
assert "ORDER BY project_count DESC, u.name" in sql
|
686
|
+
assert "LIMIT 10" in sql
|
687
|
+
assert "OFFSET 20" in sql
|
688
|
+
assert params == ['active', '2023-01-01', 0]
|