matrixone-python-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- matrixone/__init__.py +155 -0
- matrixone/account.py +723 -0
- matrixone/async_client.py +3913 -0
- matrixone/async_metadata_manager.py +311 -0
- matrixone/async_orm.py +123 -0
- matrixone/async_vector_index_manager.py +633 -0
- matrixone/base_client.py +208 -0
- matrixone/client.py +4672 -0
- matrixone/config.py +452 -0
- matrixone/connection_hooks.py +286 -0
- matrixone/exceptions.py +89 -0
- matrixone/logger.py +782 -0
- matrixone/metadata.py +820 -0
- matrixone/moctl.py +219 -0
- matrixone/orm.py +2277 -0
- matrixone/pitr.py +646 -0
- matrixone/pubsub.py +771 -0
- matrixone/restore.py +411 -0
- matrixone/search_vector_index.py +1176 -0
- matrixone/snapshot.py +550 -0
- matrixone/sql_builder.py +844 -0
- matrixone/sqlalchemy_ext/__init__.py +161 -0
- matrixone/sqlalchemy_ext/adapters.py +163 -0
- matrixone/sqlalchemy_ext/dialect.py +534 -0
- matrixone/sqlalchemy_ext/fulltext_index.py +895 -0
- matrixone/sqlalchemy_ext/fulltext_search.py +1686 -0
- matrixone/sqlalchemy_ext/hnsw_config.py +194 -0
- matrixone/sqlalchemy_ext/ivf_config.py +252 -0
- matrixone/sqlalchemy_ext/table_builder.py +351 -0
- matrixone/sqlalchemy_ext/vector_index.py +1721 -0
- matrixone/sqlalchemy_ext/vector_type.py +948 -0
- matrixone/version.py +580 -0
- matrixone_python_sdk-0.1.0.dist-info/METADATA +706 -0
- matrixone_python_sdk-0.1.0.dist-info/RECORD +122 -0
- matrixone_python_sdk-0.1.0.dist-info/WHEEL +5 -0
- matrixone_python_sdk-0.1.0.dist-info/entry_points.txt +5 -0
- matrixone_python_sdk-0.1.0.dist-info/licenses/LICENSE +200 -0
- matrixone_python_sdk-0.1.0.dist-info/top_level.txt +2 -0
- tests/__init__.py +19 -0
- tests/offline/__init__.py +20 -0
- tests/offline/conftest.py +77 -0
- tests/offline/test_account.py +703 -0
- tests/offline/test_async_client_query_comprehensive.py +1218 -0
- tests/offline/test_basic.py +54 -0
- tests/offline/test_case_sensitivity.py +227 -0
- tests/offline/test_connection_hooks_offline.py +287 -0
- tests/offline/test_dialect_schema_handling.py +609 -0
- tests/offline/test_explain_methods.py +346 -0
- tests/offline/test_filter_logical_in.py +237 -0
- tests/offline/test_fulltext_search_comprehensive.py +795 -0
- tests/offline/test_ivf_config.py +249 -0
- tests/offline/test_join_methods.py +281 -0
- tests/offline/test_join_sqlalchemy_compatibility.py +276 -0
- tests/offline/test_logical_in_method.py +237 -0
- tests/offline/test_matrixone_version_parsing.py +264 -0
- tests/offline/test_metadata_offline.py +557 -0
- tests/offline/test_moctl.py +300 -0
- tests/offline/test_moctl_simple.py +251 -0
- tests/offline/test_model_support_offline.py +359 -0
- tests/offline/test_model_support_simple.py +225 -0
- tests/offline/test_pinecone_filter_offline.py +377 -0
- tests/offline/test_pitr.py +585 -0
- tests/offline/test_pubsub.py +712 -0
- tests/offline/test_query_update.py +283 -0
- tests/offline/test_restore.py +445 -0
- tests/offline/test_snapshot_comprehensive.py +384 -0
- tests/offline/test_sql_escaping_edge_cases.py +551 -0
- tests/offline/test_sqlalchemy_integration.py +382 -0
- tests/offline/test_sqlalchemy_vector_integration.py +434 -0
- tests/offline/test_table_builder.py +198 -0
- tests/offline/test_unified_filter.py +398 -0
- tests/offline/test_unified_transaction.py +495 -0
- tests/offline/test_vector_index.py +238 -0
- tests/offline/test_vector_operations.py +688 -0
- tests/offline/test_vector_type.py +174 -0
- tests/offline/test_version_core.py +328 -0
- tests/offline/test_version_management.py +372 -0
- tests/offline/test_version_standalone.py +652 -0
- tests/online/__init__.py +20 -0
- tests/online/conftest.py +216 -0
- tests/online/test_account_management.py +194 -0
- tests/online/test_advanced_features.py +344 -0
- tests/online/test_async_client_interfaces.py +330 -0
- tests/online/test_async_client_online.py +285 -0
- tests/online/test_async_model_insert_online.py +293 -0
- tests/online/test_async_orm_online.py +300 -0
- tests/online/test_async_simple_query_online.py +802 -0
- tests/online/test_async_transaction_simple_query.py +300 -0
- tests/online/test_basic_connection.py +130 -0
- tests/online/test_client_online.py +238 -0
- tests/online/test_config.py +90 -0
- tests/online/test_config_validation.py +123 -0
- tests/online/test_connection_hooks_new_online.py +217 -0
- tests/online/test_dialect_schema_handling_online.py +331 -0
- tests/online/test_filter_logical_in_online.py +374 -0
- tests/online/test_fulltext_comprehensive.py +1773 -0
- tests/online/test_fulltext_label_online.py +433 -0
- tests/online/test_fulltext_search_online.py +842 -0
- tests/online/test_ivf_stats_online.py +506 -0
- tests/online/test_logger_integration.py +311 -0
- tests/online/test_matrixone_query_orm.py +540 -0
- tests/online/test_metadata_online.py +579 -0
- tests/online/test_model_insert_online.py +255 -0
- tests/online/test_mysql_driver_validation.py +213 -0
- tests/online/test_orm_advanced_features.py +2022 -0
- tests/online/test_orm_cte_integration.py +269 -0
- tests/online/test_orm_online.py +270 -0
- tests/online/test_pinecone_filter.py +708 -0
- tests/online/test_pubsub_operations.py +352 -0
- tests/online/test_query_methods.py +225 -0
- tests/online/test_query_update_online.py +433 -0
- tests/online/test_search_vector_index.py +557 -0
- tests/online/test_simple_fulltext_online.py +915 -0
- tests/online/test_snapshot_comprehensive.py +998 -0
- tests/online/test_sqlalchemy_engine_integration.py +336 -0
- tests/online/test_sqlalchemy_integration.py +425 -0
- tests/online/test_transaction_contexts.py +1219 -0
- tests/online/test_transaction_insert_methods.py +356 -0
- tests/online/test_transaction_query_methods.py +288 -0
- tests/online/test_unified_filter_online.py +529 -0
- tests/online/test_vector_comprehensive.py +706 -0
- tests/online/test_version_management.py +291 -0
@@ -0,0 +1,434 @@
|
|
1
|
+
# Copyright 2021 - 2022 Matrix Origin
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""
|
16
|
+
Offline tests for SQLAlchemy integration with MatrixOne vector types.
|
17
|
+
"""
|
18
|
+
|
19
|
+
import pytest
|
20
|
+
import sys
|
21
|
+
from unittest.mock import Mock
|
22
|
+
from sqlalchemy import MetaData, Table, Column, Integer, String, Index
|
23
|
+
from sqlalchemy.schema import CreateTable
|
24
|
+
from matrixone.sqlalchemy_ext import (
|
25
|
+
VectorType,
|
26
|
+
Vectorf32,
|
27
|
+
Vectorf64,
|
28
|
+
VectorTypeDecorator,
|
29
|
+
MatrixOneDialect,
|
30
|
+
VectorTableBuilder,
|
31
|
+
)
|
32
|
+
from sqlalchemy.orm import declarative_base
|
33
|
+
|
34
|
+
pytestmark = pytest.mark.vector
|
35
|
+
|
36
|
+
# No longer needed - global mocks have been fixed
|
37
|
+
|
38
|
+
|
39
|
+
class TestSQLAlchemyVectorIntegration:
|
40
|
+
"""Test SQLAlchemy vector integration functionality."""
|
41
|
+
|
42
|
+
def test_vector_type_in_sqlalchemy_table(self):
|
43
|
+
"""Test using VectorType in SQLAlchemy Table definition."""
|
44
|
+
metadata = MetaData()
|
45
|
+
|
46
|
+
# Create table with vector columns
|
47
|
+
table = Table(
|
48
|
+
'vector_test',
|
49
|
+
metadata,
|
50
|
+
Column('id', Integer, primary_key=True),
|
51
|
+
Column('name', String(100)),
|
52
|
+
Column('embedding_32', Vectorf32(dimension=128)),
|
53
|
+
Column('embedding_64', Vectorf64(dimension=256)),
|
54
|
+
)
|
55
|
+
|
56
|
+
# Verify table structure
|
57
|
+
assert table.name == 'vector_test'
|
58
|
+
assert len(table.columns) == 4
|
59
|
+
|
60
|
+
# Check column types
|
61
|
+
column_types = {col.name: type(col.type) for col in table.columns}
|
62
|
+
assert column_types['id'] == Integer
|
63
|
+
assert column_types['name'] == String
|
64
|
+
assert column_types['embedding_32'] == Vectorf32
|
65
|
+
assert column_types['embedding_64'] == Vectorf64
|
66
|
+
|
67
|
+
def test_sql_generation_with_vector_types(self):
|
68
|
+
"""Test SQL generation with vector types."""
|
69
|
+
metadata = MetaData()
|
70
|
+
|
71
|
+
table = Table(
|
72
|
+
'sql_test',
|
73
|
+
metadata,
|
74
|
+
Column('id', Integer, primary_key=True),
|
75
|
+
Column('vector', Vectorf32(dimension=128)),
|
76
|
+
)
|
77
|
+
|
78
|
+
# Generate CREATE TABLE SQL using MatrixOne dialect
|
79
|
+
from matrixone.sqlalchemy_ext import MatrixOneDialect
|
80
|
+
|
81
|
+
create_sql = str(CreateTable(table).compile(dialect=MatrixOneDialect(), compile_kwargs={"literal_binds": True}))
|
82
|
+
|
83
|
+
# Verify SQL contains vector type
|
84
|
+
assert "CREATE TABLE sql_test" in create_sql
|
85
|
+
assert "vector vecf32(128)" in create_sql
|
86
|
+
assert "PRIMARY KEY (id)" in create_sql
|
87
|
+
|
88
|
+
def test_vector_type_decorator_in_table(self):
|
89
|
+
"""Test VectorTypeDecorator in table definition."""
|
90
|
+
metadata = MetaData()
|
91
|
+
|
92
|
+
table = Table(
|
93
|
+
'decorator_test',
|
94
|
+
metadata,
|
95
|
+
Column('id', Integer, primary_key=True),
|
96
|
+
Column('vector', VectorTypeDecorator(dimension=64, precision="f32")),
|
97
|
+
)
|
98
|
+
|
99
|
+
# Verify table structure
|
100
|
+
assert len(table.columns) == 2
|
101
|
+
vector_col = table.columns['vector']
|
102
|
+
assert isinstance(vector_col.type, VectorTypeDecorator)
|
103
|
+
assert vector_col.type.dimension == 64
|
104
|
+
assert vector_col.type.precision == "f32"
|
105
|
+
|
106
|
+
def test_matrixone_dialect_type_handling(self):
|
107
|
+
"""Test MatrixOneDialect type handling."""
|
108
|
+
dialect = MatrixOneDialect()
|
109
|
+
|
110
|
+
# Test vector type creation from strings
|
111
|
+
test_cases = [
|
112
|
+
("vecf32(128)", "f32", 128),
|
113
|
+
("VECF32(256)", "f32", 256),
|
114
|
+
("vecf64(512)", "f64", 512),
|
115
|
+
("VECF64(1024)", "f64", 1024),
|
116
|
+
]
|
117
|
+
|
118
|
+
for type_str, precision, expected_dim in test_cases:
|
119
|
+
vector_type = dialect._create_vector_type(precision, type_str)
|
120
|
+
assert vector_type.precision == precision
|
121
|
+
assert vector_type.dimension == expected_dim
|
122
|
+
|
123
|
+
def test_vector_table_builder_integration(self):
|
124
|
+
"""Test VectorTableBuilder integration with SQLAlchemy."""
|
125
|
+
metadata = MetaData()
|
126
|
+
|
127
|
+
# Create table using builder
|
128
|
+
builder = VectorTableBuilder("builder_test", metadata)
|
129
|
+
builder.add_int_column("id", primary_key=True)
|
130
|
+
builder.add_string_column("name", length=100)
|
131
|
+
builder.add_vecf32_column("embedding", dimension=128)
|
132
|
+
builder.add_index("name")
|
133
|
+
|
134
|
+
table = builder.build()
|
135
|
+
|
136
|
+
# Verify table structure
|
137
|
+
assert table.name == "builder_test"
|
138
|
+
assert len(table.columns) == 3
|
139
|
+
assert len(table.indexes) == 1
|
140
|
+
|
141
|
+
# Check column types
|
142
|
+
for col in table.columns:
|
143
|
+
if col.name == "embedding":
|
144
|
+
assert str(col.type) == "vecf32(128)"
|
145
|
+
|
146
|
+
def test_vector_type_serialization(self):
|
147
|
+
"""Test vector type data serialization."""
|
148
|
+
# Test Vectorf32 serialization
|
149
|
+
vec32 = Vectorf32(dimension=128)
|
150
|
+
bind_processor = vec32.bind_processor(None)
|
151
|
+
|
152
|
+
# Test list to string conversion
|
153
|
+
vector_list = [1.0, 2.0, 3.0]
|
154
|
+
result = bind_processor(vector_list)
|
155
|
+
assert result == "[1.0,2.0,3.0]"
|
156
|
+
|
157
|
+
# Test string passthrough
|
158
|
+
vector_string = "[4.0,5.0,6.0]"
|
159
|
+
result = bind_processor(vector_string)
|
160
|
+
assert result == vector_string
|
161
|
+
|
162
|
+
def test_vector_type_deserialization(self):
|
163
|
+
"""Test vector type data deserialization."""
|
164
|
+
# Test Vectorf64 deserialization
|
165
|
+
vec64 = Vectorf64(dimension=256)
|
166
|
+
result_processor = vec64.result_processor(None, None)
|
167
|
+
|
168
|
+
# Test string to list conversion
|
169
|
+
vector_string = "[1.0,2.0,3.0]"
|
170
|
+
result = result_processor(vector_string)
|
171
|
+
assert result == [1.0, 2.0, 3.0]
|
172
|
+
|
173
|
+
# Test None handling
|
174
|
+
result = result_processor(None)
|
175
|
+
assert result is None
|
176
|
+
|
177
|
+
def test_complex_table_with_multiple_vector_columns(self):
|
178
|
+
"""Test complex table with multiple vector columns."""
|
179
|
+
metadata = MetaData()
|
180
|
+
|
181
|
+
table = Table(
|
182
|
+
'complex_vector_table',
|
183
|
+
metadata,
|
184
|
+
Column('id', Integer, primary_key=True),
|
185
|
+
Column('title', String(200)),
|
186
|
+
Column('content_embedding', Vectorf32(dimension=384)),
|
187
|
+
Column('title_embedding', Vectorf32(dimension=128)),
|
188
|
+
Column('metadata_embedding', Vectorf64(dimension=512)),
|
189
|
+
)
|
190
|
+
|
191
|
+
# Verify table structure
|
192
|
+
assert table.name == 'complex_vector_table'
|
193
|
+
assert len(table.columns) == 5
|
194
|
+
|
195
|
+
# Check vector columns
|
196
|
+
vector_columns = [
|
197
|
+
col for col in table.columns if isinstance(col.type, (VectorType, VectorTypeDecorator, Vectorf32, Vectorf64))
|
198
|
+
]
|
199
|
+
assert len(vector_columns) == 3
|
200
|
+
|
201
|
+
# Verify dimensions
|
202
|
+
dimensions = {col.name: col.type.dimension for col in vector_columns}
|
203
|
+
assert dimensions['content_embedding'] == 384
|
204
|
+
assert dimensions['title_embedding'] == 128
|
205
|
+
assert dimensions['metadata_embedding'] == 512
|
206
|
+
|
207
|
+
def test_table_with_indexes_and_vector_columns(self):
|
208
|
+
"""Test table with indexes and vector columns."""
|
209
|
+
metadata = MetaData()
|
210
|
+
|
211
|
+
table = Table(
|
212
|
+
'indexed_vector_table',
|
213
|
+
metadata,
|
214
|
+
Column('id', Integer, primary_key=True),
|
215
|
+
Column('category', String(50)),
|
216
|
+
Column('embedding', Vectorf32(dimension=128)),
|
217
|
+
Column('score', Integer),
|
218
|
+
)
|
219
|
+
|
220
|
+
# Add indexes
|
221
|
+
Index('idx_category', table.c.category)
|
222
|
+
Index('idx_score', table.c.score)
|
223
|
+
|
224
|
+
# Verify table structure
|
225
|
+
assert len(table.indexes) == 2
|
226
|
+
|
227
|
+
# Generate SQL to verify structure using MatrixOne dialect
|
228
|
+
from matrixone.sqlalchemy_ext import MatrixOneDialect
|
229
|
+
|
230
|
+
create_sql = str(CreateTable(table).compile(dialect=MatrixOneDialect(), compile_kwargs={"literal_binds": True}))
|
231
|
+
|
232
|
+
# Check SQL contains expected elements
|
233
|
+
assert "CREATE TABLE indexed_vector_table" in create_sql
|
234
|
+
assert "embedding vecf32(128)" in create_sql
|
235
|
+
# Note: Index creation is separate from table creation in SQLAlchemy
|
236
|
+
# The indexes are defined but not included in the CREATE TABLE statement
|
237
|
+
|
238
|
+
def test_vector_type_comparison(self):
|
239
|
+
"""Test vector type comparison."""
|
240
|
+
vec1 = Vectorf32(dimension=128)
|
241
|
+
vec2 = Vectorf32(dimension=128)
|
242
|
+
vec3 = Vectorf32(dimension=256)
|
243
|
+
vec4 = Vectorf64(dimension=128)
|
244
|
+
|
245
|
+
# Test equality by comparing properties
|
246
|
+
assert vec1.dimension == vec2.dimension
|
247
|
+
assert vec1.precision == vec2.precision
|
248
|
+
|
249
|
+
# Different dimension should not be equal
|
250
|
+
assert vec1.dimension != vec3.dimension
|
251
|
+
|
252
|
+
# Different precision should not be equal
|
253
|
+
assert vec1.precision != vec4.precision
|
254
|
+
|
255
|
+
def test_vector_type_edge_cases(self):
|
256
|
+
"""Test edge cases for vector types."""
|
257
|
+
# Test very large dimensions
|
258
|
+
large_vec = Vectorf32(dimension=65535)
|
259
|
+
assert large_vec.dimension == 65535
|
260
|
+
|
261
|
+
# Test dimension 1
|
262
|
+
small_vec = Vectorf32(dimension=1)
|
263
|
+
assert small_vec.dimension == 1
|
264
|
+
|
265
|
+
# Test without dimension
|
266
|
+
no_dim_vec = VectorType(precision="f32")
|
267
|
+
assert no_dim_vec.dimension is None
|
268
|
+
assert no_dim_vec.get_col_spec() == "vecf32"
|
269
|
+
|
270
|
+
|
271
|
+
class TestVectorDistanceFunctionBugFixes:
|
272
|
+
"""Test cases to catch vector distance function bugs found during validation."""
|
273
|
+
|
274
|
+
def test_vector_distance_with_decimal_preservation(self):
|
275
|
+
"""
|
276
|
+
Test that vector distance expressions preserve decimal points in filter conditions.
|
277
|
+
|
278
|
+
Bug: orm.py's table prefix regex was removing decimal points from vectors.
|
279
|
+
Example: [0.374, 0.950] became [0374, 0950]
|
280
|
+
"""
|
281
|
+
from matrixone.sqlalchemy_ext import create_vector_column
|
282
|
+
from sqlalchemy import Column, Integer, String, Text
|
283
|
+
|
284
|
+
Base = declarative_base()
|
285
|
+
|
286
|
+
class TestDoc(Base):
|
287
|
+
__tablename__ = 'test_doc'
|
288
|
+
id = Column(Integer, primary_key=True)
|
289
|
+
title = Column(String(200))
|
290
|
+
embedding = create_vector_column(8, precision='f32')
|
291
|
+
|
292
|
+
# Create a filter condition with vector distance
|
293
|
+
query_vector = [0.374, 0.950, 0.731, 0.598, 0.156, 0.155, 0.058, 0.866]
|
294
|
+
|
295
|
+
# Compile the condition
|
296
|
+
condition = TestDoc.embedding.l2_distance(query_vector) < 0.5
|
297
|
+
compiled = condition.compile(compile_kwargs={"literal_binds": True})
|
298
|
+
sql_str = str(compiled)
|
299
|
+
|
300
|
+
# Verify decimals are preserved
|
301
|
+
assert "0.374" in sql_str or "0.3745" in sql_str, f"Decimal points should be preserved in vector: {sql_str}"
|
302
|
+
assert "0374" not in sql_str, f"Decimal points should not be removed: {sql_str}"
|
303
|
+
|
304
|
+
# Verify the vector is properly quoted
|
305
|
+
assert "[" in sql_str and "]" in sql_str, f"Vector array brackets should be present: {sql_str}"
|
306
|
+
|
307
|
+
def test_sqlalchemy_reserved_field_names(self):
|
308
|
+
"""
|
309
|
+
Test that SQLAlchemy reserved field names are avoided.
|
310
|
+
|
311
|
+
Bug: Using 'metadata' as column name causes:
|
312
|
+
"Attribute name 'metadata' is reserved when using the Declarative API"
|
313
|
+
"""
|
314
|
+
from sqlalchemy import Column, Integer, String
|
315
|
+
from sqlalchemy.exc import InvalidRequestError
|
316
|
+
|
317
|
+
Base = declarative_base()
|
318
|
+
|
319
|
+
# This should NOT raise an error
|
320
|
+
class GoodDoc(Base):
|
321
|
+
__tablename__ = 'good_doc'
|
322
|
+
id = Column(Integer, primary_key=True)
|
323
|
+
doc_metadata = Column(String(500)) # Use doc_metadata instead
|
324
|
+
|
325
|
+
assert hasattr(GoodDoc, 'doc_metadata')
|
326
|
+
|
327
|
+
# This SHOULD raise an error (reserved name)
|
328
|
+
with pytest.raises(InvalidRequestError, match="reserved"):
|
329
|
+
|
330
|
+
class BadDoc(Base):
|
331
|
+
__tablename__ = 'bad_doc'
|
332
|
+
id = Column(Integer, primary_key=True)
|
333
|
+
metadata = Column(String(500)) # Reserved name!
|
334
|
+
|
335
|
+
def test_vector_distance_in_filter_with_literal_binds(self):
|
336
|
+
"""
|
337
|
+
Test that vector distance in filter() preserves vector format with literal_binds.
|
338
|
+
|
339
|
+
Bug: When using literal_binds=True, vector parameters lose proper formatting.
|
340
|
+
"""
|
341
|
+
from matrixone.sqlalchemy_ext import create_vector_column
|
342
|
+
from sqlalchemy import Column, Integer, String
|
343
|
+
|
344
|
+
Base = declarative_base()
|
345
|
+
|
346
|
+
class TestVec(Base):
|
347
|
+
__tablename__ = 'test_vec'
|
348
|
+
id = Column(Integer, primary_key=True)
|
349
|
+
name = Column(String(100))
|
350
|
+
embedding = create_vector_column(4, precision='f32')
|
351
|
+
|
352
|
+
# Test vector with decimal values
|
353
|
+
query_vector = [0.1, 0.2, 0.3, 0.4]
|
354
|
+
|
355
|
+
# Create filter expression
|
356
|
+
filter_expr = TestVec.embedding.l2_distance(query_vector) < 1.0
|
357
|
+
|
358
|
+
# Compile with literal_binds (this is what orm.py does)
|
359
|
+
compiled = filter_expr.compile(compile_kwargs={"literal_binds": True})
|
360
|
+
sql_str = str(compiled)
|
361
|
+
|
362
|
+
# Verify vector format is correct
|
363
|
+
assert "[0.1" in sql_str or "[0.1," in sql_str, f"Vector should contain decimal values: {sql_str}"
|
364
|
+
assert "l2_distance" in sql_str, f"Function name should be present: {sql_str}"
|
365
|
+
assert "< 1.0" in sql_str or "< 1" in sql_str, f"Comparison should be present: {sql_str}"
|
366
|
+
|
367
|
+
def test_multiple_vector_distance_calls_in_query(self):
|
368
|
+
"""
|
369
|
+
Test that multiple vector distance calls in same query work correctly.
|
370
|
+
|
371
|
+
Bug: When vector distance is used in both SELECT and WHERE,
|
372
|
+
parameters might be duplicated or incorrectly formatted.
|
373
|
+
"""
|
374
|
+
from matrixone.sqlalchemy_ext import create_vector_column
|
375
|
+
from sqlalchemy import Column, Integer, String
|
376
|
+
|
377
|
+
Base = declarative_base()
|
378
|
+
|
379
|
+
class MultiVec(Base):
|
380
|
+
__tablename__ = 'multi_vec'
|
381
|
+
id = Column(Integer, primary_key=True)
|
382
|
+
embedding = create_vector_column(4, precision='f32')
|
383
|
+
|
384
|
+
query_vector = [0.5, 0.6, 0.7, 0.8]
|
385
|
+
|
386
|
+
# Use distance in both select and filter
|
387
|
+
dist_expr = MultiVec.embedding.l2_distance(query_vector)
|
388
|
+
filter_expr = MultiVec.embedding.l2_distance(query_vector) < 2.0
|
389
|
+
|
390
|
+
# Compile both expressions
|
391
|
+
select_compiled = dist_expr.compile(compile_kwargs={"literal_binds": True})
|
392
|
+
filter_compiled = filter_expr.compile(compile_kwargs={"literal_binds": True})
|
393
|
+
|
394
|
+
select_sql = str(select_compiled)
|
395
|
+
filter_sql = str(filter_compiled)
|
396
|
+
|
397
|
+
# Both should have valid vector format
|
398
|
+
for sql_str in [select_sql, filter_sql]:
|
399
|
+
assert "[0.5" in sql_str or "[0.5," in sql_str, f"Vector should contain decimal values: {sql_str}"
|
400
|
+
assert "l2_distance" in sql_str, f"Function should be present: {sql_str}"
|
401
|
+
|
402
|
+
def test_vector_distance_with_different_types(self):
|
403
|
+
"""
|
404
|
+
Test vector distance functions with different distance types.
|
405
|
+
|
406
|
+
Ensures all distance functions (l2, cosine, inner_product) properly
|
407
|
+
handle vector parameters.
|
408
|
+
"""
|
409
|
+
from matrixone.sqlalchemy_ext import create_vector_column
|
410
|
+
from sqlalchemy import Column, Integer
|
411
|
+
|
412
|
+
Base = declarative_base()
|
413
|
+
|
414
|
+
class DistTest(Base):
|
415
|
+
__tablename__ = 'dist_test'
|
416
|
+
id = Column(Integer, primary_key=True)
|
417
|
+
vec = create_vector_column(3, precision='f32')
|
418
|
+
|
419
|
+
query_vec = [0.1, 0.2, 0.3]
|
420
|
+
|
421
|
+
# Test all distance functions
|
422
|
+
distances = {
|
423
|
+
'l2': DistTest.vec.l2_distance(query_vec),
|
424
|
+
'cosine': DistTest.vec.cosine_distance(query_vec),
|
425
|
+
'inner': DistTest.vec.inner_product(query_vec),
|
426
|
+
}
|
427
|
+
|
428
|
+
for dist_name, dist_expr in distances.items():
|
429
|
+
compiled = dist_expr.compile(compile_kwargs={"literal_binds": True})
|
430
|
+
sql_str = str(compiled)
|
431
|
+
|
432
|
+
# Verify vector format
|
433
|
+
assert "[0.1" in sql_str or "[0.1," in sql_str, f"{dist_name}: Vector should contain decimals: {sql_str}"
|
434
|
+
assert "0.2" in sql_str, f"{dist_name}: All vector elements should be present: {sql_str}"
|
@@ -0,0 +1,198 @@
|
|
1
|
+
# Copyright 2021 - 2022 Matrix Origin
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""
|
16
|
+
Offline tests for VectorTableBuilder functionality.
|
17
|
+
"""
|
18
|
+
|
19
|
+
import pytest
|
20
|
+
import sys
|
21
|
+
from unittest.mock import Mock
|
22
|
+
from sqlalchemy import MetaData
|
23
|
+
from sqlalchemy.schema import CreateTable
|
24
|
+
from matrixone.sqlalchemy_ext import (
|
25
|
+
VectorTableBuilder,
|
26
|
+
create_vector_table,
|
27
|
+
create_vector_index_table,
|
28
|
+
Vectorf32,
|
29
|
+
Vectorf64,
|
30
|
+
)
|
31
|
+
|
32
|
+
pytestmark = pytest.mark.vector
|
33
|
+
|
34
|
+
# No longer needed - global mocks have been fixed
|
35
|
+
|
36
|
+
|
37
|
+
class TestVectorTableBuilder:
|
38
|
+
"""Test VectorTableBuilder functionality."""
|
39
|
+
|
40
|
+
def test_create_vector_index_table(self):
|
41
|
+
"""Test creating the exact table from the example."""
|
42
|
+
metadata = MetaData()
|
43
|
+
builder = create_vector_index_table("vector_index_07", metadata)
|
44
|
+
table = builder.build()
|
45
|
+
|
46
|
+
assert table.name == "vector_index_07"
|
47
|
+
assert len(table.columns) == 3
|
48
|
+
|
49
|
+
# Check columns
|
50
|
+
column_names = [col.name for col in table.columns]
|
51
|
+
assert "a" in column_names
|
52
|
+
assert "b" in column_names
|
53
|
+
assert "c" in column_names
|
54
|
+
|
55
|
+
# Check column types
|
56
|
+
for col in table.columns:
|
57
|
+
if col.name == "a":
|
58
|
+
assert col.primary_key
|
59
|
+
assert str(col.type) == "INTEGER"
|
60
|
+
elif col.name == "b":
|
61
|
+
assert str(col.type) == "vecf32(128)"
|
62
|
+
elif col.name == "c":
|
63
|
+
assert str(col.type) == "INTEGER"
|
64
|
+
|
65
|
+
# Check indexes
|
66
|
+
assert len(table.indexes) == 1
|
67
|
+
index = list(table.indexes)[0]
|
68
|
+
assert index.name == "c_k"
|
69
|
+
assert len(index.columns) == 1
|
70
|
+
assert index.columns[0].name == "c"
|
71
|
+
|
72
|
+
def test_manual_table_builder(self):
|
73
|
+
"""Test manual table building."""
|
74
|
+
metadata = MetaData()
|
75
|
+
builder = VectorTableBuilder("test_table", metadata)
|
76
|
+
|
77
|
+
builder.add_int_column("id", primary_key=True)
|
78
|
+
builder.add_string_column("name", length=100)
|
79
|
+
builder.add_vecf32_column("embedding", dimension=128)
|
80
|
+
builder.add_index("name")
|
81
|
+
|
82
|
+
table = builder.build()
|
83
|
+
|
84
|
+
assert table.name == "test_table"
|
85
|
+
assert len(table.columns) == 3
|
86
|
+
assert len(table.indexes) == 1
|
87
|
+
|
88
|
+
# Check column types
|
89
|
+
for col in table.columns:
|
90
|
+
if col.name == "id":
|
91
|
+
assert col.primary_key
|
92
|
+
assert str(col.type) == "INTEGER"
|
93
|
+
elif col.name == "name":
|
94
|
+
assert str(col.type) == "VARCHAR(100)"
|
95
|
+
elif col.name == "embedding":
|
96
|
+
assert str(col.type) == "vecf32(128)"
|
97
|
+
|
98
|
+
def test_vector_column_types(self):
|
99
|
+
"""Test different vector column types."""
|
100
|
+
metadata = MetaData()
|
101
|
+
builder = VectorTableBuilder("vector_types_test", metadata)
|
102
|
+
|
103
|
+
builder.add_vecf32_column("vec32", dimension=128)
|
104
|
+
builder.add_vecf64_column("vec64", dimension=256)
|
105
|
+
builder.add_int_column("id", primary_key=True)
|
106
|
+
|
107
|
+
table = builder.build()
|
108
|
+
|
109
|
+
# Check vector column types
|
110
|
+
for col in table.columns:
|
111
|
+
if col.name == "vec32":
|
112
|
+
assert str(col.type) == "vecf32(128)"
|
113
|
+
elif col.name == "vec64":
|
114
|
+
assert str(col.type) == "vecf64(256)"
|
115
|
+
|
116
|
+
def test_multiple_indexes(self):
|
117
|
+
"""Test table with multiple indexes."""
|
118
|
+
metadata = MetaData()
|
119
|
+
builder = VectorTableBuilder("multi_index_table", metadata)
|
120
|
+
|
121
|
+
builder.add_int_column("id", primary_key=True)
|
122
|
+
builder.add_string_column("name", length=100)
|
123
|
+
builder.add_string_column("category", length=50)
|
124
|
+
builder.add_int_column("version")
|
125
|
+
builder.add_vecf32_column("embedding", dimension=128)
|
126
|
+
|
127
|
+
# Add multiple indexes
|
128
|
+
builder.add_index("name")
|
129
|
+
builder.add_index("category")
|
130
|
+
builder.add_index(["category", "version"])
|
131
|
+
|
132
|
+
table = builder.build()
|
133
|
+
|
134
|
+
assert len(table.indexes) == 3
|
135
|
+
|
136
|
+
index_names = [idx.name for idx in table.indexes]
|
137
|
+
assert any("name" in name for name in index_names)
|
138
|
+
assert any("category" in name for name in index_names)
|
139
|
+
assert any("category" in name and "version" in name for name in index_names)
|
140
|
+
|
141
|
+
def test_sql_generation(self):
|
142
|
+
"""Test SQL generation for vector tables."""
|
143
|
+
metadata = MetaData()
|
144
|
+
builder = create_vector_index_table("sql_test_table", metadata)
|
145
|
+
table = builder.build()
|
146
|
+
|
147
|
+
# Generate CREATE TABLE SQL using MatrixOne dialect
|
148
|
+
from matrixone.sqlalchemy_ext import MatrixOneDialect
|
149
|
+
|
150
|
+
create_sql = str(CreateTable(table).compile(dialect=MatrixOneDialect(), compile_kwargs={"literal_binds": True}))
|
151
|
+
|
152
|
+
# Check that SQL contains expected elements
|
153
|
+
assert "CREATE TABLE sql_test_table" in create_sql
|
154
|
+
assert "a INTEGER NOT NULL" in create_sql
|
155
|
+
assert "b vecf32(128)" in create_sql
|
156
|
+
assert "c INTEGER" in create_sql
|
157
|
+
assert "PRIMARY KEY (a)" in create_sql
|
158
|
+
|
159
|
+
def test_empty_table_builder(self):
|
160
|
+
"""Test creating an empty table builder."""
|
161
|
+
metadata = MetaData()
|
162
|
+
builder = VectorTableBuilder("empty_table", metadata)
|
163
|
+
table = builder.build()
|
164
|
+
|
165
|
+
assert table.name == "empty_table"
|
166
|
+
assert len(table.columns) == 0
|
167
|
+
assert len(table.indexes) == 0
|
168
|
+
|
169
|
+
def test_table_builder_fluent_interface(self):
|
170
|
+
"""Test the fluent interface of table builder."""
|
171
|
+
metadata = MetaData()
|
172
|
+
|
173
|
+
# Chain method calls
|
174
|
+
table = (
|
175
|
+
VectorTableBuilder("fluent_table", metadata)
|
176
|
+
.add_int_column("id", primary_key=True)
|
177
|
+
.add_string_column("name", length=100)
|
178
|
+
.add_vecf32_column("embedding", dimension=128)
|
179
|
+
.add_index("name")
|
180
|
+
.build()
|
181
|
+
)
|
182
|
+
|
183
|
+
assert table.name == "fluent_table"
|
184
|
+
assert len(table.columns) == 3
|
185
|
+
assert len(table.indexes) == 1
|
186
|
+
|
187
|
+
def test_create_vector_table_convenience(self):
|
188
|
+
"""Test create_vector_table convenience function."""
|
189
|
+
metadata = MetaData()
|
190
|
+
builder = create_vector_table("convenience_table", metadata)
|
191
|
+
|
192
|
+
# Should return a builder, not a table
|
193
|
+
assert isinstance(builder, VectorTableBuilder)
|
194
|
+
assert builder.table_name == "convenience_table"
|
195
|
+
|
196
|
+
# Build the table
|
197
|
+
table = builder.add_int_column("id", primary_key=True).build()
|
198
|
+
assert table.name == "convenience_table"
|