matrixone-python-sdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. matrixone/__init__.py +155 -0
  2. matrixone/account.py +723 -0
  3. matrixone/async_client.py +3913 -0
  4. matrixone/async_metadata_manager.py +311 -0
  5. matrixone/async_orm.py +123 -0
  6. matrixone/async_vector_index_manager.py +633 -0
  7. matrixone/base_client.py +208 -0
  8. matrixone/client.py +4672 -0
  9. matrixone/config.py +452 -0
  10. matrixone/connection_hooks.py +286 -0
  11. matrixone/exceptions.py +89 -0
  12. matrixone/logger.py +782 -0
  13. matrixone/metadata.py +820 -0
  14. matrixone/moctl.py +219 -0
  15. matrixone/orm.py +2277 -0
  16. matrixone/pitr.py +646 -0
  17. matrixone/pubsub.py +771 -0
  18. matrixone/restore.py +411 -0
  19. matrixone/search_vector_index.py +1176 -0
  20. matrixone/snapshot.py +550 -0
  21. matrixone/sql_builder.py +844 -0
  22. matrixone/sqlalchemy_ext/__init__.py +161 -0
  23. matrixone/sqlalchemy_ext/adapters.py +163 -0
  24. matrixone/sqlalchemy_ext/dialect.py +534 -0
  25. matrixone/sqlalchemy_ext/fulltext_index.py +895 -0
  26. matrixone/sqlalchemy_ext/fulltext_search.py +1686 -0
  27. matrixone/sqlalchemy_ext/hnsw_config.py +194 -0
  28. matrixone/sqlalchemy_ext/ivf_config.py +252 -0
  29. matrixone/sqlalchemy_ext/table_builder.py +351 -0
  30. matrixone/sqlalchemy_ext/vector_index.py +1721 -0
  31. matrixone/sqlalchemy_ext/vector_type.py +948 -0
  32. matrixone/version.py +580 -0
  33. matrixone_python_sdk-0.1.0.dist-info/METADATA +706 -0
  34. matrixone_python_sdk-0.1.0.dist-info/RECORD +122 -0
  35. matrixone_python_sdk-0.1.0.dist-info/WHEEL +5 -0
  36. matrixone_python_sdk-0.1.0.dist-info/entry_points.txt +5 -0
  37. matrixone_python_sdk-0.1.0.dist-info/licenses/LICENSE +200 -0
  38. matrixone_python_sdk-0.1.0.dist-info/top_level.txt +2 -0
  39. tests/__init__.py +19 -0
  40. tests/offline/__init__.py +20 -0
  41. tests/offline/conftest.py +77 -0
  42. tests/offline/test_account.py +703 -0
  43. tests/offline/test_async_client_query_comprehensive.py +1218 -0
  44. tests/offline/test_basic.py +54 -0
  45. tests/offline/test_case_sensitivity.py +227 -0
  46. tests/offline/test_connection_hooks_offline.py +287 -0
  47. tests/offline/test_dialect_schema_handling.py +609 -0
  48. tests/offline/test_explain_methods.py +346 -0
  49. tests/offline/test_filter_logical_in.py +237 -0
  50. tests/offline/test_fulltext_search_comprehensive.py +795 -0
  51. tests/offline/test_ivf_config.py +249 -0
  52. tests/offline/test_join_methods.py +281 -0
  53. tests/offline/test_join_sqlalchemy_compatibility.py +276 -0
  54. tests/offline/test_logical_in_method.py +237 -0
  55. tests/offline/test_matrixone_version_parsing.py +264 -0
  56. tests/offline/test_metadata_offline.py +557 -0
  57. tests/offline/test_moctl.py +300 -0
  58. tests/offline/test_moctl_simple.py +251 -0
  59. tests/offline/test_model_support_offline.py +359 -0
  60. tests/offline/test_model_support_simple.py +225 -0
  61. tests/offline/test_pinecone_filter_offline.py +377 -0
  62. tests/offline/test_pitr.py +585 -0
  63. tests/offline/test_pubsub.py +712 -0
  64. tests/offline/test_query_update.py +283 -0
  65. tests/offline/test_restore.py +445 -0
  66. tests/offline/test_snapshot_comprehensive.py +384 -0
  67. tests/offline/test_sql_escaping_edge_cases.py +551 -0
  68. tests/offline/test_sqlalchemy_integration.py +382 -0
  69. tests/offline/test_sqlalchemy_vector_integration.py +434 -0
  70. tests/offline/test_table_builder.py +198 -0
  71. tests/offline/test_unified_filter.py +398 -0
  72. tests/offline/test_unified_transaction.py +495 -0
  73. tests/offline/test_vector_index.py +238 -0
  74. tests/offline/test_vector_operations.py +688 -0
  75. tests/offline/test_vector_type.py +174 -0
  76. tests/offline/test_version_core.py +328 -0
  77. tests/offline/test_version_management.py +372 -0
  78. tests/offline/test_version_standalone.py +652 -0
  79. tests/online/__init__.py +20 -0
  80. tests/online/conftest.py +216 -0
  81. tests/online/test_account_management.py +194 -0
  82. tests/online/test_advanced_features.py +344 -0
  83. tests/online/test_async_client_interfaces.py +330 -0
  84. tests/online/test_async_client_online.py +285 -0
  85. tests/online/test_async_model_insert_online.py +293 -0
  86. tests/online/test_async_orm_online.py +300 -0
  87. tests/online/test_async_simple_query_online.py +802 -0
  88. tests/online/test_async_transaction_simple_query.py +300 -0
  89. tests/online/test_basic_connection.py +130 -0
  90. tests/online/test_client_online.py +238 -0
  91. tests/online/test_config.py +90 -0
  92. tests/online/test_config_validation.py +123 -0
  93. tests/online/test_connection_hooks_new_online.py +217 -0
  94. tests/online/test_dialect_schema_handling_online.py +331 -0
  95. tests/online/test_filter_logical_in_online.py +374 -0
  96. tests/online/test_fulltext_comprehensive.py +1773 -0
  97. tests/online/test_fulltext_label_online.py +433 -0
  98. tests/online/test_fulltext_search_online.py +842 -0
  99. tests/online/test_ivf_stats_online.py +506 -0
  100. tests/online/test_logger_integration.py +311 -0
  101. tests/online/test_matrixone_query_orm.py +540 -0
  102. tests/online/test_metadata_online.py +579 -0
  103. tests/online/test_model_insert_online.py +255 -0
  104. tests/online/test_mysql_driver_validation.py +213 -0
  105. tests/online/test_orm_advanced_features.py +2022 -0
  106. tests/online/test_orm_cte_integration.py +269 -0
  107. tests/online/test_orm_online.py +270 -0
  108. tests/online/test_pinecone_filter.py +708 -0
  109. tests/online/test_pubsub_operations.py +352 -0
  110. tests/online/test_query_methods.py +225 -0
  111. tests/online/test_query_update_online.py +433 -0
  112. tests/online/test_search_vector_index.py +557 -0
  113. tests/online/test_simple_fulltext_online.py +915 -0
  114. tests/online/test_snapshot_comprehensive.py +998 -0
  115. tests/online/test_sqlalchemy_engine_integration.py +336 -0
  116. tests/online/test_sqlalchemy_integration.py +425 -0
  117. tests/online/test_transaction_contexts.py +1219 -0
  118. tests/online/test_transaction_insert_methods.py +356 -0
  119. tests/online/test_transaction_query_methods.py +288 -0
  120. tests/online/test_unified_filter_online.py +529 -0
  121. tests/online/test_vector_comprehensive.py +706 -0
  122. tests/online/test_version_management.py +291 -0
@@ -0,0 +1,434 @@
1
+ # Copyright 2021 - 2022 Matrix Origin
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Offline tests for SQLAlchemy integration with MatrixOne vector types.
17
+ """
18
+
19
+ import pytest
20
+ import sys
21
+ from unittest.mock import Mock
22
+ from sqlalchemy import MetaData, Table, Column, Integer, String, Index
23
+ from sqlalchemy.schema import CreateTable
24
+ from matrixone.sqlalchemy_ext import (
25
+ VectorType,
26
+ Vectorf32,
27
+ Vectorf64,
28
+ VectorTypeDecorator,
29
+ MatrixOneDialect,
30
+ VectorTableBuilder,
31
+ )
32
+ from sqlalchemy.orm import declarative_base
33
+
34
+ pytestmark = pytest.mark.vector
35
+
36
+ # No longer needed - global mocks have been fixed
37
+
38
+
39
+ class TestSQLAlchemyVectorIntegration:
40
+ """Test SQLAlchemy vector integration functionality."""
41
+
42
+ def test_vector_type_in_sqlalchemy_table(self):
43
+ """Test using VectorType in SQLAlchemy Table definition."""
44
+ metadata = MetaData()
45
+
46
+ # Create table with vector columns
47
+ table = Table(
48
+ 'vector_test',
49
+ metadata,
50
+ Column('id', Integer, primary_key=True),
51
+ Column('name', String(100)),
52
+ Column('embedding_32', Vectorf32(dimension=128)),
53
+ Column('embedding_64', Vectorf64(dimension=256)),
54
+ )
55
+
56
+ # Verify table structure
57
+ assert table.name == 'vector_test'
58
+ assert len(table.columns) == 4
59
+
60
+ # Check column types
61
+ column_types = {col.name: type(col.type) for col in table.columns}
62
+ assert column_types['id'] == Integer
63
+ assert column_types['name'] == String
64
+ assert column_types['embedding_32'] == Vectorf32
65
+ assert column_types['embedding_64'] == Vectorf64
66
+
67
+ def test_sql_generation_with_vector_types(self):
68
+ """Test SQL generation with vector types."""
69
+ metadata = MetaData()
70
+
71
+ table = Table(
72
+ 'sql_test',
73
+ metadata,
74
+ Column('id', Integer, primary_key=True),
75
+ Column('vector', Vectorf32(dimension=128)),
76
+ )
77
+
78
+ # Generate CREATE TABLE SQL using MatrixOne dialect
79
+ from matrixone.sqlalchemy_ext import MatrixOneDialect
80
+
81
+ create_sql = str(CreateTable(table).compile(dialect=MatrixOneDialect(), compile_kwargs={"literal_binds": True}))
82
+
83
+ # Verify SQL contains vector type
84
+ assert "CREATE TABLE sql_test" in create_sql
85
+ assert "vector vecf32(128)" in create_sql
86
+ assert "PRIMARY KEY (id)" in create_sql
87
+
88
+ def test_vector_type_decorator_in_table(self):
89
+ """Test VectorTypeDecorator in table definition."""
90
+ metadata = MetaData()
91
+
92
+ table = Table(
93
+ 'decorator_test',
94
+ metadata,
95
+ Column('id', Integer, primary_key=True),
96
+ Column('vector', VectorTypeDecorator(dimension=64, precision="f32")),
97
+ )
98
+
99
+ # Verify table structure
100
+ assert len(table.columns) == 2
101
+ vector_col = table.columns['vector']
102
+ assert isinstance(vector_col.type, VectorTypeDecorator)
103
+ assert vector_col.type.dimension == 64
104
+ assert vector_col.type.precision == "f32"
105
+
106
+ def test_matrixone_dialect_type_handling(self):
107
+ """Test MatrixOneDialect type handling."""
108
+ dialect = MatrixOneDialect()
109
+
110
+ # Test vector type creation from strings
111
+ test_cases = [
112
+ ("vecf32(128)", "f32", 128),
113
+ ("VECF32(256)", "f32", 256),
114
+ ("vecf64(512)", "f64", 512),
115
+ ("VECF64(1024)", "f64", 1024),
116
+ ]
117
+
118
+ for type_str, precision, expected_dim in test_cases:
119
+ vector_type = dialect._create_vector_type(precision, type_str)
120
+ assert vector_type.precision == precision
121
+ assert vector_type.dimension == expected_dim
122
+
123
+ def test_vector_table_builder_integration(self):
124
+ """Test VectorTableBuilder integration with SQLAlchemy."""
125
+ metadata = MetaData()
126
+
127
+ # Create table using builder
128
+ builder = VectorTableBuilder("builder_test", metadata)
129
+ builder.add_int_column("id", primary_key=True)
130
+ builder.add_string_column("name", length=100)
131
+ builder.add_vecf32_column("embedding", dimension=128)
132
+ builder.add_index("name")
133
+
134
+ table = builder.build()
135
+
136
+ # Verify table structure
137
+ assert table.name == "builder_test"
138
+ assert len(table.columns) == 3
139
+ assert len(table.indexes) == 1
140
+
141
+ # Check column types
142
+ for col in table.columns:
143
+ if col.name == "embedding":
144
+ assert str(col.type) == "vecf32(128)"
145
+
146
+ def test_vector_type_serialization(self):
147
+ """Test vector type data serialization."""
148
+ # Test Vectorf32 serialization
149
+ vec32 = Vectorf32(dimension=128)
150
+ bind_processor = vec32.bind_processor(None)
151
+
152
+ # Test list to string conversion
153
+ vector_list = [1.0, 2.0, 3.0]
154
+ result = bind_processor(vector_list)
155
+ assert result == "[1.0,2.0,3.0]"
156
+
157
+ # Test string passthrough
158
+ vector_string = "[4.0,5.0,6.0]"
159
+ result = bind_processor(vector_string)
160
+ assert result == vector_string
161
+
162
+ def test_vector_type_deserialization(self):
163
+ """Test vector type data deserialization."""
164
+ # Test Vectorf64 deserialization
165
+ vec64 = Vectorf64(dimension=256)
166
+ result_processor = vec64.result_processor(None, None)
167
+
168
+ # Test string to list conversion
169
+ vector_string = "[1.0,2.0,3.0]"
170
+ result = result_processor(vector_string)
171
+ assert result == [1.0, 2.0, 3.0]
172
+
173
+ # Test None handling
174
+ result = result_processor(None)
175
+ assert result is None
176
+
177
+ def test_complex_table_with_multiple_vector_columns(self):
178
+ """Test complex table with multiple vector columns."""
179
+ metadata = MetaData()
180
+
181
+ table = Table(
182
+ 'complex_vector_table',
183
+ metadata,
184
+ Column('id', Integer, primary_key=True),
185
+ Column('title', String(200)),
186
+ Column('content_embedding', Vectorf32(dimension=384)),
187
+ Column('title_embedding', Vectorf32(dimension=128)),
188
+ Column('metadata_embedding', Vectorf64(dimension=512)),
189
+ )
190
+
191
+ # Verify table structure
192
+ assert table.name == 'complex_vector_table'
193
+ assert len(table.columns) == 5
194
+
195
+ # Check vector columns
196
+ vector_columns = [
197
+ col for col in table.columns if isinstance(col.type, (VectorType, VectorTypeDecorator, Vectorf32, Vectorf64))
198
+ ]
199
+ assert len(vector_columns) == 3
200
+
201
+ # Verify dimensions
202
+ dimensions = {col.name: col.type.dimension for col in vector_columns}
203
+ assert dimensions['content_embedding'] == 384
204
+ assert dimensions['title_embedding'] == 128
205
+ assert dimensions['metadata_embedding'] == 512
206
+
207
+ def test_table_with_indexes_and_vector_columns(self):
208
+ """Test table with indexes and vector columns."""
209
+ metadata = MetaData()
210
+
211
+ table = Table(
212
+ 'indexed_vector_table',
213
+ metadata,
214
+ Column('id', Integer, primary_key=True),
215
+ Column('category', String(50)),
216
+ Column('embedding', Vectorf32(dimension=128)),
217
+ Column('score', Integer),
218
+ )
219
+
220
+ # Add indexes
221
+ Index('idx_category', table.c.category)
222
+ Index('idx_score', table.c.score)
223
+
224
+ # Verify table structure
225
+ assert len(table.indexes) == 2
226
+
227
+ # Generate SQL to verify structure using MatrixOne dialect
228
+ from matrixone.sqlalchemy_ext import MatrixOneDialect
229
+
230
+ create_sql = str(CreateTable(table).compile(dialect=MatrixOneDialect(), compile_kwargs={"literal_binds": True}))
231
+
232
+ # Check SQL contains expected elements
233
+ assert "CREATE TABLE indexed_vector_table" in create_sql
234
+ assert "embedding vecf32(128)" in create_sql
235
+ # Note: Index creation is separate from table creation in SQLAlchemy
236
+ # The indexes are defined but not included in the CREATE TABLE statement
237
+
238
+ def test_vector_type_comparison(self):
239
+ """Test vector type comparison."""
240
+ vec1 = Vectorf32(dimension=128)
241
+ vec2 = Vectorf32(dimension=128)
242
+ vec3 = Vectorf32(dimension=256)
243
+ vec4 = Vectorf64(dimension=128)
244
+
245
+ # Test equality by comparing properties
246
+ assert vec1.dimension == vec2.dimension
247
+ assert vec1.precision == vec2.precision
248
+
249
+ # Different dimension should not be equal
250
+ assert vec1.dimension != vec3.dimension
251
+
252
+ # Different precision should not be equal
253
+ assert vec1.precision != vec4.precision
254
+
255
+ def test_vector_type_edge_cases(self):
256
+ """Test edge cases for vector types."""
257
+ # Test very large dimensions
258
+ large_vec = Vectorf32(dimension=65535)
259
+ assert large_vec.dimension == 65535
260
+
261
+ # Test dimension 1
262
+ small_vec = Vectorf32(dimension=1)
263
+ assert small_vec.dimension == 1
264
+
265
+ # Test without dimension
266
+ no_dim_vec = VectorType(precision="f32")
267
+ assert no_dim_vec.dimension is None
268
+ assert no_dim_vec.get_col_spec() == "vecf32"
269
+
270
+
271
+ class TestVectorDistanceFunctionBugFixes:
272
+ """Test cases to catch vector distance function bugs found during validation."""
273
+
274
+ def test_vector_distance_with_decimal_preservation(self):
275
+ """
276
+ Test that vector distance expressions preserve decimal points in filter conditions.
277
+
278
+ Bug: orm.py's table prefix regex was removing decimal points from vectors.
279
+ Example: [0.374, 0.950] became [0374, 0950]
280
+ """
281
+ from matrixone.sqlalchemy_ext import create_vector_column
282
+ from sqlalchemy import Column, Integer, String, Text
283
+
284
+ Base = declarative_base()
285
+
286
+ class TestDoc(Base):
287
+ __tablename__ = 'test_doc'
288
+ id = Column(Integer, primary_key=True)
289
+ title = Column(String(200))
290
+ embedding = create_vector_column(8, precision='f32')
291
+
292
+ # Create a filter condition with vector distance
293
+ query_vector = [0.374, 0.950, 0.731, 0.598, 0.156, 0.155, 0.058, 0.866]
294
+
295
+ # Compile the condition
296
+ condition = TestDoc.embedding.l2_distance(query_vector) < 0.5
297
+ compiled = condition.compile(compile_kwargs={"literal_binds": True})
298
+ sql_str = str(compiled)
299
+
300
+ # Verify decimals are preserved
301
+ assert "0.374" in sql_str or "0.3745" in sql_str, f"Decimal points should be preserved in vector: {sql_str}"
302
+ assert "0374" not in sql_str, f"Decimal points should not be removed: {sql_str}"
303
+
304
+ # Verify the vector is properly quoted
305
+ assert "[" in sql_str and "]" in sql_str, f"Vector array brackets should be present: {sql_str}"
306
+
307
+ def test_sqlalchemy_reserved_field_names(self):
308
+ """
309
+ Test that SQLAlchemy reserved field names are avoided.
310
+
311
+ Bug: Using 'metadata' as column name causes:
312
+ "Attribute name 'metadata' is reserved when using the Declarative API"
313
+ """
314
+ from sqlalchemy import Column, Integer, String
315
+ from sqlalchemy.exc import InvalidRequestError
316
+
317
+ Base = declarative_base()
318
+
319
+ # This should NOT raise an error
320
+ class GoodDoc(Base):
321
+ __tablename__ = 'good_doc'
322
+ id = Column(Integer, primary_key=True)
323
+ doc_metadata = Column(String(500)) # Use doc_metadata instead
324
+
325
+ assert hasattr(GoodDoc, 'doc_metadata')
326
+
327
+ # This SHOULD raise an error (reserved name)
328
+ with pytest.raises(InvalidRequestError, match="reserved"):
329
+
330
+ class BadDoc(Base):
331
+ __tablename__ = 'bad_doc'
332
+ id = Column(Integer, primary_key=True)
333
+ metadata = Column(String(500)) # Reserved name!
334
+
335
+ def test_vector_distance_in_filter_with_literal_binds(self):
336
+ """
337
+ Test that vector distance in filter() preserves vector format with literal_binds.
338
+
339
+ Bug: When using literal_binds=True, vector parameters lose proper formatting.
340
+ """
341
+ from matrixone.sqlalchemy_ext import create_vector_column
342
+ from sqlalchemy import Column, Integer, String
343
+
344
+ Base = declarative_base()
345
+
346
+ class TestVec(Base):
347
+ __tablename__ = 'test_vec'
348
+ id = Column(Integer, primary_key=True)
349
+ name = Column(String(100))
350
+ embedding = create_vector_column(4, precision='f32')
351
+
352
+ # Test vector with decimal values
353
+ query_vector = [0.1, 0.2, 0.3, 0.4]
354
+
355
+ # Create filter expression
356
+ filter_expr = TestVec.embedding.l2_distance(query_vector) < 1.0
357
+
358
+ # Compile with literal_binds (this is what orm.py does)
359
+ compiled = filter_expr.compile(compile_kwargs={"literal_binds": True})
360
+ sql_str = str(compiled)
361
+
362
+ # Verify vector format is correct
363
+ assert "[0.1" in sql_str or "[0.1," in sql_str, f"Vector should contain decimal values: {sql_str}"
364
+ assert "l2_distance" in sql_str, f"Function name should be present: {sql_str}"
365
+ assert "< 1.0" in sql_str or "< 1" in sql_str, f"Comparison should be present: {sql_str}"
366
+
367
+ def test_multiple_vector_distance_calls_in_query(self):
368
+ """
369
+ Test that multiple vector distance calls in same query work correctly.
370
+
371
+ Bug: When vector distance is used in both SELECT and WHERE,
372
+ parameters might be duplicated or incorrectly formatted.
373
+ """
374
+ from matrixone.sqlalchemy_ext import create_vector_column
375
+ from sqlalchemy import Column, Integer, String
376
+
377
+ Base = declarative_base()
378
+
379
+ class MultiVec(Base):
380
+ __tablename__ = 'multi_vec'
381
+ id = Column(Integer, primary_key=True)
382
+ embedding = create_vector_column(4, precision='f32')
383
+
384
+ query_vector = [0.5, 0.6, 0.7, 0.8]
385
+
386
+ # Use distance in both select and filter
387
+ dist_expr = MultiVec.embedding.l2_distance(query_vector)
388
+ filter_expr = MultiVec.embedding.l2_distance(query_vector) < 2.0
389
+
390
+ # Compile both expressions
391
+ select_compiled = dist_expr.compile(compile_kwargs={"literal_binds": True})
392
+ filter_compiled = filter_expr.compile(compile_kwargs={"literal_binds": True})
393
+
394
+ select_sql = str(select_compiled)
395
+ filter_sql = str(filter_compiled)
396
+
397
+ # Both should have valid vector format
398
+ for sql_str in [select_sql, filter_sql]:
399
+ assert "[0.5" in sql_str or "[0.5," in sql_str, f"Vector should contain decimal values: {sql_str}"
400
+ assert "l2_distance" in sql_str, f"Function should be present: {sql_str}"
401
+
402
+ def test_vector_distance_with_different_types(self):
403
+ """
404
+ Test vector distance functions with different distance types.
405
+
406
+ Ensures all distance functions (l2, cosine, inner_product) properly
407
+ handle vector parameters.
408
+ """
409
+ from matrixone.sqlalchemy_ext import create_vector_column
410
+ from sqlalchemy import Column, Integer
411
+
412
+ Base = declarative_base()
413
+
414
+ class DistTest(Base):
415
+ __tablename__ = 'dist_test'
416
+ id = Column(Integer, primary_key=True)
417
+ vec = create_vector_column(3, precision='f32')
418
+
419
+ query_vec = [0.1, 0.2, 0.3]
420
+
421
+ # Test all distance functions
422
+ distances = {
423
+ 'l2': DistTest.vec.l2_distance(query_vec),
424
+ 'cosine': DistTest.vec.cosine_distance(query_vec),
425
+ 'inner': DistTest.vec.inner_product(query_vec),
426
+ }
427
+
428
+ for dist_name, dist_expr in distances.items():
429
+ compiled = dist_expr.compile(compile_kwargs={"literal_binds": True})
430
+ sql_str = str(compiled)
431
+
432
+ # Verify vector format
433
+ assert "[0.1" in sql_str or "[0.1," in sql_str, f"{dist_name}: Vector should contain decimals: {sql_str}"
434
+ assert "0.2" in sql_str, f"{dist_name}: All vector elements should be present: {sql_str}"
@@ -0,0 +1,198 @@
1
+ # Copyright 2021 - 2022 Matrix Origin
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Offline tests for VectorTableBuilder functionality.
17
+ """
18
+
19
+ import pytest
20
+ import sys
21
+ from unittest.mock import Mock
22
+ from sqlalchemy import MetaData
23
+ from sqlalchemy.schema import CreateTable
24
+ from matrixone.sqlalchemy_ext import (
25
+ VectorTableBuilder,
26
+ create_vector_table,
27
+ create_vector_index_table,
28
+ Vectorf32,
29
+ Vectorf64,
30
+ )
31
+
32
+ pytestmark = pytest.mark.vector
33
+
34
+ # No longer needed - global mocks have been fixed
35
+
36
+
37
+ class TestVectorTableBuilder:
38
+ """Test VectorTableBuilder functionality."""
39
+
40
+ def test_create_vector_index_table(self):
41
+ """Test creating the exact table from the example."""
42
+ metadata = MetaData()
43
+ builder = create_vector_index_table("vector_index_07", metadata)
44
+ table = builder.build()
45
+
46
+ assert table.name == "vector_index_07"
47
+ assert len(table.columns) == 3
48
+
49
+ # Check columns
50
+ column_names = [col.name for col in table.columns]
51
+ assert "a" in column_names
52
+ assert "b" in column_names
53
+ assert "c" in column_names
54
+
55
+ # Check column types
56
+ for col in table.columns:
57
+ if col.name == "a":
58
+ assert col.primary_key
59
+ assert str(col.type) == "INTEGER"
60
+ elif col.name == "b":
61
+ assert str(col.type) == "vecf32(128)"
62
+ elif col.name == "c":
63
+ assert str(col.type) == "INTEGER"
64
+
65
+ # Check indexes
66
+ assert len(table.indexes) == 1
67
+ index = list(table.indexes)[0]
68
+ assert index.name == "c_k"
69
+ assert len(index.columns) == 1
70
+ assert index.columns[0].name == "c"
71
+
72
+ def test_manual_table_builder(self):
73
+ """Test manual table building."""
74
+ metadata = MetaData()
75
+ builder = VectorTableBuilder("test_table", metadata)
76
+
77
+ builder.add_int_column("id", primary_key=True)
78
+ builder.add_string_column("name", length=100)
79
+ builder.add_vecf32_column("embedding", dimension=128)
80
+ builder.add_index("name")
81
+
82
+ table = builder.build()
83
+
84
+ assert table.name == "test_table"
85
+ assert len(table.columns) == 3
86
+ assert len(table.indexes) == 1
87
+
88
+ # Check column types
89
+ for col in table.columns:
90
+ if col.name == "id":
91
+ assert col.primary_key
92
+ assert str(col.type) == "INTEGER"
93
+ elif col.name == "name":
94
+ assert str(col.type) == "VARCHAR(100)"
95
+ elif col.name == "embedding":
96
+ assert str(col.type) == "vecf32(128)"
97
+
98
+ def test_vector_column_types(self):
99
+ """Test different vector column types."""
100
+ metadata = MetaData()
101
+ builder = VectorTableBuilder("vector_types_test", metadata)
102
+
103
+ builder.add_vecf32_column("vec32", dimension=128)
104
+ builder.add_vecf64_column("vec64", dimension=256)
105
+ builder.add_int_column("id", primary_key=True)
106
+
107
+ table = builder.build()
108
+
109
+ # Check vector column types
110
+ for col in table.columns:
111
+ if col.name == "vec32":
112
+ assert str(col.type) == "vecf32(128)"
113
+ elif col.name == "vec64":
114
+ assert str(col.type) == "vecf64(256)"
115
+
116
+ def test_multiple_indexes(self):
117
+ """Test table with multiple indexes."""
118
+ metadata = MetaData()
119
+ builder = VectorTableBuilder("multi_index_table", metadata)
120
+
121
+ builder.add_int_column("id", primary_key=True)
122
+ builder.add_string_column("name", length=100)
123
+ builder.add_string_column("category", length=50)
124
+ builder.add_int_column("version")
125
+ builder.add_vecf32_column("embedding", dimension=128)
126
+
127
+ # Add multiple indexes
128
+ builder.add_index("name")
129
+ builder.add_index("category")
130
+ builder.add_index(["category", "version"])
131
+
132
+ table = builder.build()
133
+
134
+ assert len(table.indexes) == 3
135
+
136
+ index_names = [idx.name for idx in table.indexes]
137
+ assert any("name" in name for name in index_names)
138
+ assert any("category" in name for name in index_names)
139
+ assert any("category" in name and "version" in name for name in index_names)
140
+
141
+ def test_sql_generation(self):
142
+ """Test SQL generation for vector tables."""
143
+ metadata = MetaData()
144
+ builder = create_vector_index_table("sql_test_table", metadata)
145
+ table = builder.build()
146
+
147
+ # Generate CREATE TABLE SQL using MatrixOne dialect
148
+ from matrixone.sqlalchemy_ext import MatrixOneDialect
149
+
150
+ create_sql = str(CreateTable(table).compile(dialect=MatrixOneDialect(), compile_kwargs={"literal_binds": True}))
151
+
152
+ # Check that SQL contains expected elements
153
+ assert "CREATE TABLE sql_test_table" in create_sql
154
+ assert "a INTEGER NOT NULL" in create_sql
155
+ assert "b vecf32(128)" in create_sql
156
+ assert "c INTEGER" in create_sql
157
+ assert "PRIMARY KEY (a)" in create_sql
158
+
159
+ def test_empty_table_builder(self):
160
+ """Test creating an empty table builder."""
161
+ metadata = MetaData()
162
+ builder = VectorTableBuilder("empty_table", metadata)
163
+ table = builder.build()
164
+
165
+ assert table.name == "empty_table"
166
+ assert len(table.columns) == 0
167
+ assert len(table.indexes) == 0
168
+
169
+ def test_table_builder_fluent_interface(self):
170
+ """Test the fluent interface of table builder."""
171
+ metadata = MetaData()
172
+
173
+ # Chain method calls
174
+ table = (
175
+ VectorTableBuilder("fluent_table", metadata)
176
+ .add_int_column("id", primary_key=True)
177
+ .add_string_column("name", length=100)
178
+ .add_vecf32_column("embedding", dimension=128)
179
+ .add_index("name")
180
+ .build()
181
+ )
182
+
183
+ assert table.name == "fluent_table"
184
+ assert len(table.columns) == 3
185
+ assert len(table.indexes) == 1
186
+
187
+ def test_create_vector_table_convenience(self):
188
+ """Test create_vector_table convenience function."""
189
+ metadata = MetaData()
190
+ builder = create_vector_table("convenience_table", metadata)
191
+
192
+ # Should return a builder, not a table
193
+ assert isinstance(builder, VectorTableBuilder)
194
+ assert builder.table_name == "convenience_table"
195
+
196
+ # Build the table
197
+ table = builder.add_int_column("id", primary_key=True).build()
198
+ assert table.name == "convenience_table"