vastdb 2.0.2__py3-none-any.whl → 2.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vastdb/table.py CHANGED
@@ -22,11 +22,12 @@ import ibis
22
22
  import pyarrow as pa
23
23
  import urllib3
24
24
 
25
- from vastdb._table_interface import ITable
25
+ from vastdb._table_interface import IbisPredicate, ITable
26
26
  from vastdb.table_metadata import TableMetadata, TableRef, TableStats, TableType
27
27
 
28
28
  from . import _internal, errors, util
29
29
  from ._ibis_support import validate_ibis_support_schema
30
+ from ._internal import VectorIndex
30
31
  from .config import ImportConfig, QueryConfig
31
32
 
32
33
  if TYPE_CHECKING:
@@ -213,6 +214,11 @@ class TableInTransaction(ITable):
213
214
  """Reload Sorted Columns."""
214
215
  self._metadata.load_sorted_columns(self._tx)
215
216
 
217
+ @property
218
+ def vector_index(self) -> Optional[VectorIndex]:
219
+ """Table's Vector Index if exists."""
220
+ return self._metadata._vector_index
221
+
216
222
  @property
217
223
  def path(self) -> str:
218
224
  """Return table's path."""
@@ -222,7 +228,7 @@ class TableInTransaction(ITable):
222
228
  def _internal_rowid_field(self) -> pa.Field:
223
229
  return INTERNAL_ROW_ID_SORTED_FIELD if self._is_sorted_table else INTERNAL_ROW_ID_FIELD
224
230
 
225
- def sorted_columns(self) -> list[str]:
231
+ def sorted_columns(self) -> list[pa.Field]:
226
232
  """Return sorted columns' metadata."""
227
233
  return self._metadata.sorted_columns
228
234
 
@@ -631,35 +637,41 @@ class TableInTransaction(ITable):
631
637
  return pa.RecordBatchReader.from_batches(query_data_request.response_schema, batches_iterator())
632
638
 
633
639
  def insert_in_column_batches(self, rows: pa.RecordBatch) -> pa.ChunkedArray:
634
- """Split the RecordBatch into max_columns that can be inserted in single RPC.
640
+ """Split the RecordBatch into an insert + updates.
635
641
 
642
+ This is both to support rows that won't fit into an RPC and for performance for wide rows.
636
643
  Insert first MAX_COLUMN_IN_BATCH columns and get the row_ids. Then loop on the rest of the columns and
637
644
  update in groups of MAX_COLUMN_IN_BATCH.
638
645
  """
639
- column_record_batch = pa.RecordBatch.from_arrays([_combine_chunks(rows.column(i)) for i in range(0, MAX_COLUMN_IN_BATCH)],
640
- schema=pa.schema([rows.schema.field(i) for i in range(0, MAX_COLUMN_IN_BATCH)]))
641
- row_ids = self.insert(rows=column_record_batch) # type: ignore
642
-
643
646
  columns_names = [field.name for field in rows.schema]
644
- columns = list(rows.schema)
645
- arrays = [_combine_chunks(rows.column(i))
646
- for i in range(len(rows.schema))]
647
- for start in range(MAX_COLUMN_IN_BATCH, len(rows.schema), MAX_COLUMN_IN_BATCH):
647
+ # Sorted columns must be in the first insert as those can't be updated later.
648
+ if self._is_sorted_table:
649
+ sorted_columns_names = [field.name for field in self.sorted_columns()]
650
+ columns_names = sorted_columns_names + [column_name for column_name in columns_names if column_name not in sorted_columns_names]
651
+ columns = [rows.schema.field(column_name) for column_name in columns_names]
652
+
653
+ arrays = [_combine_chunks(rows.column(column_name)) for column_name in columns_names]
654
+ for start in range(0, len(rows.schema), MAX_COLUMN_IN_BATCH):
648
655
  end = start + MAX_COLUMN_IN_BATCH if start + \
649
656
  MAX_COLUMN_IN_BATCH < len(rows.schema) else len(rows.schema)
650
657
  columns_name_chunk = columns_names[start:end]
651
658
  columns_chunks = columns[start:end]
652
659
  arrays_chunks = arrays[start:end]
653
- columns_chunks.append(self._internal_rowid_field)
654
- arrays_chunks.append(row_ids.to_pylist())
655
- column_record_batch = pa.RecordBatch.from_arrays(
656
- arrays_chunks, schema=pa.schema(columns_chunks))
657
- self.update(rows=column_record_batch, columns=columns_name_chunk)
660
+ if start == 0:
661
+ column_record_batch = pa.RecordBatch.from_arrays(
662
+ arrays_chunks, schema=pa.schema(columns_chunks))
663
+ row_ids = self.insert(rows=column_record_batch, by_columns=False) # type: ignore
664
+ else:
665
+ columns_chunks.append(self._internal_rowid_field)
666
+ arrays_chunks.append(row_ids.to_pylist())
667
+ column_record_batch = pa.RecordBatch.from_arrays(
668
+ arrays_chunks, schema=pa.schema(columns_chunks))
669
+ self.update(rows=column_record_batch, columns=columns_name_chunk)
658
670
  return row_ids
659
671
 
660
672
  def insert(self,
661
673
  rows: Union[pa.RecordBatch, pa.Table],
662
- by_columns: bool = False) -> pa.ChunkedArray:
674
+ by_columns: bool = True) -> pa.ChunkedArray:
663
675
  """Insert a RecordBatch into this table."""
664
676
  self._assert_not_imports_table()
665
677
 
@@ -667,9 +679,14 @@ class TableInTransaction(ITable):
667
679
  log.debug("Ignoring empty insert into %s", self.ref)
668
680
  return pa.chunked_array([], type=self._internal_rowid_field.type)
669
681
 
670
- if by_columns:
671
- self._tx._rpc.features.check_return_row_ids()
672
- return self.insert_in_column_batches(rows)
682
+ # inserting by columns is faster, so default to doing that
683
+ # if the cluster supports it and there are actually columns in the rows
684
+ if by_columns and len(rows.schema):
685
+ try:
686
+ self._tx._rpc.features.check_return_row_ids()
687
+ return self.insert_in_column_batches(rows)
688
+ except errors.NotSupportedVersion:
689
+ pass
673
690
 
674
691
  try:
675
692
  row_ids = []
@@ -802,6 +819,25 @@ class TableInTransaction(ITable):
802
819
  def _is_sorted_table(self) -> bool:
803
820
  return self._metadata.table_type is TableType.Elysium
804
821
 
822
+ def vector_search(
823
+ self,
824
+ vec: list[float],
825
+ columns: list[str],
826
+ limit: int,
827
+ predicate: Optional[IbisPredicate] = None,
828
+ ) -> pa.RecordBatchReader:
829
+ """Vector Search over vector indexed columns."""
830
+ assert self.vector_index is not None, "Table is either not vector indexed. (maybe try reloading the TableMetadata)"
831
+
832
+ return self._tx.adbc_conn.vector_search(
833
+ vec,
834
+ self.vector_index,
835
+ self.ref,
836
+ columns,
837
+ limit,
838
+ predicate=predicate,
839
+ )
840
+
805
841
 
806
842
  class Table(TableInTransaction):
807
843
  """Vast Interactive Table."""
vastdb/table_metadata.py CHANGED
@@ -4,17 +4,19 @@ import logging
4
4
  from copy import deepcopy
5
5
  from dataclasses import dataclass
6
6
  from enum import Enum
7
- from typing import TYPE_CHECKING, Optional, Tuple
7
+ from typing import TYPE_CHECKING, Optional
8
8
 
9
9
  import ibis
10
10
  import pyarrow as pa
11
11
 
12
12
  from vastdb import errors
13
13
  from vastdb._ibis_support import validate_ibis_support_schema
14
+ from vastdb._internal import TableStats, VectorIndex
14
15
 
15
16
  if TYPE_CHECKING:
16
17
  from .transaction import Transaction
17
18
 
19
+
18
20
  log = logging.getLogger(__name__)
19
21
 
20
22
 
@@ -39,26 +41,16 @@ class TableRef:
39
41
  """Table full path."""
40
42
  return f"{self.bucket}/{self.schema}/{self.table}"
41
43
 
44
+ @property
45
+ def query_engine_full_path(self) -> str:
46
+ """Table full path for VastDB Query Engine."""
47
+ return f'"{self.bucket}/{self.schema}".{self.table}'
48
+
42
49
  def __str__(self) -> str:
43
50
  """Table full path."""
44
51
  return self.full_path
45
52
 
46
53
 
47
- @dataclass
48
- class TableStats:
49
- """Table-related information."""
50
-
51
- num_rows: int
52
- size_in_bytes: int
53
- sorting_score: int
54
- write_amplification: int
55
- acummulative_row_inserition_count: int
56
- is_external_rowid_alloc: bool = False
57
- sorting_key_enabled: bool = False
58
- sorting_done: bool = False
59
- endpoints: Tuple[str, ...] = ()
60
-
61
-
62
54
  class TableMetadata:
63
55
  """Table Metadata."""
64
56
 
@@ -67,25 +59,29 @@ class TableMetadata:
67
59
  _sorted_columns: Optional[list[str]]
68
60
  _ibis_table: ibis.Table
69
61
  _stats: Optional[TableStats]
70
-
71
- def __init__(self,
72
- ref: TableRef,
73
- arrow_schema: Optional[pa.Schema] = None,
74
- table_type: Optional[TableType] = None):
62
+ _vector_index: Optional[VectorIndex]
63
+
64
+ def __init__(
65
+ self,
66
+ ref: TableRef,
67
+ arrow_schema: Optional[pa.Schema] = None,
68
+ table_type: Optional[TableType] = None,
69
+ vector_index: Optional[VectorIndex] = None,
70
+ ):
75
71
  """Table Metadata."""
76
72
  self._ref = deepcopy(ref)
77
73
  self._table_type = table_type
78
74
  self.arrow_schema = deepcopy(arrow_schema)
79
75
  self._sorted_columns = None
80
76
  self._stats = None
77
+ self._vector_index = vector_index
81
78
 
82
79
  def __eq__(self, other: object) -> bool:
83
80
  """TableMetadata Equal."""
84
81
  if not isinstance(other, TableMetadata):
85
82
  return False
86
83
 
87
- return (self._ref == other._ref and
88
- self._table_type == other._table_type)
84
+ return self._ref == other._ref and self._table_type == other._table_type
89
85
 
90
86
  def rename_table(self, name: str) -> None:
91
87
  """Rename table metadata's table name."""
@@ -110,7 +106,8 @@ class TableMetadata:
110
106
  table=self.ref.table,
111
107
  next_key=next_key,
112
108
  txid=tx.active_txid,
113
- list_imports_table=self.is_imports_table)
109
+ list_imports_table=self.is_imports_table,
110
+ )
114
111
  fields.extend(cur_columns)
115
112
  if not is_truncated:
116
113
  break
@@ -123,9 +120,16 @@ class TableMetadata:
123
120
  try:
124
121
  next_key = 0
125
122
  while True:
126
- cur_columns, next_key, is_truncated, _count = tx._rpc.api.list_sorted_columns(
127
- bucket=self.ref.bucket, schema=self.ref.schema, table=self.ref.table,
128
- next_key=next_key, txid=tx.active_txid, list_imports_table=self.is_imports_table)
123
+ cur_columns, next_key, is_truncated, _count = (
124
+ tx._rpc.api.list_sorted_columns(
125
+ bucket=self.ref.bucket,
126
+ schema=self.ref.schema,
127
+ table=self.ref.table,
128
+ next_key=next_key,
129
+ txid=tx.active_txid,
130
+ list_imports_table=self.is_imports_table,
131
+ )
132
+ )
129
133
  fields.extend(cur_columns)
130
134
  if not is_truncated:
131
135
  break
@@ -133,7 +137,9 @@ class TableMetadata:
133
137
  raise
134
138
  except errors.InternalServerError as ise:
135
139
  log.warning(
136
- "Failed to get the sorted columns Elysium might not be supported: %s", ise)
140
+ "Failed to get the sorted columns Elysium might not be supported: %s",
141
+ ise,
142
+ )
137
143
  raise
138
144
  except errors.NotSupportedVersion:
139
145
  log.warning("Failed to get the sorted columns, Elysium not supported")
@@ -143,10 +149,13 @@ class TableMetadata:
143
149
 
144
150
  def load_stats(self, tx: "Transaction") -> None:
145
151
  """Load/Reload table stats."""
146
- stats_tuple = tx._rpc.api.get_table_stats(
147
- bucket=self.ref.bucket, schema=self.ref.schema, name=self.ref.table, txid=tx.active_txid,
148
- imports_table_stats=self.is_imports_table)
149
- self._stats = TableStats(**stats_tuple._asdict())
152
+ self._stats = tx._rpc.api.get_table_stats(
153
+ bucket=self.ref.bucket,
154
+ schema=self.ref.schema,
155
+ name=self.ref.table,
156
+ txid=tx.active_txid,
157
+ imports_table_stats=self.is_imports_table,
158
+ )
150
159
 
151
160
  is_elysium_table = self._stats.sorting_key_enabled
152
161
 
@@ -161,6 +170,18 @@ class TableMetadata:
161
170
  "Actual table is sorted (TableType.Elysium), was not inited as TableType.Elysium"
162
171
  )
163
172
 
173
+ self._parse_stats_vector_index()
174
+
175
+ def _parse_stats_vector_index(self):
176
+ vector_index_is_set = self._vector_index is not None
177
+
178
+ if vector_index_is_set and self._stats.vector_index != self._vector_index:
179
+ raise ValueError(
180
+ f"Table has index {self._stats.vector_index}, but was initialized as {self._vector_index}"
181
+ )
182
+ else:
183
+ self._vector_index = self._stats.vector_index
184
+
164
185
  def _set_sorted_table(self, tx: "Transaction"):
165
186
  self._table_type = TableType.Elysium
166
187
  tx._rpc.features.check_elysium()
@@ -184,7 +205,9 @@ class TableMetadata:
184
205
  if arrow_schema:
185
206
  validate_ibis_support_schema(arrow_schema)
186
207
  self._arrow_schema = arrow_schema
187
- self._ibis_table = ibis.table(ibis.Schema.from_pyarrow(arrow_schema), self._ref.full_path)
208
+ self._ibis_table = ibis.table(
209
+ ibis.Schema.from_pyarrow(arrow_schema), self._ref.full_path
210
+ )
188
211
  else:
189
212
  self._arrow_schema = None
190
213
  self._ibis_table = None
@@ -211,7 +234,8 @@ class TableMetadata:
211
234
  """Table's type."""
212
235
  if self._table_type is None:
213
236
  raise ValueError(
214
- "TableType was not loaded. load using TableMetadata.load_stats")
237
+ "TableType was not loaded. load using TableMetadata.load_stats"
238
+ )
215
239
 
216
240
  return self._table_type
217
241
 
@@ -0,0 +1,89 @@
1
+ import pyarrow as pa
2
+ import pytest
3
+
4
+ from vastdb.table_metadata import TableRef
5
+ from vastdb.transaction import NoAdbcConnectionError
6
+
7
+
8
+ def test_sanity(session_factory, clean_bucket_name: str):
9
+ session = session_factory(with_adbc=True)
10
+
11
+ arrow_schema = pa.schema([("n", pa.int32())])
12
+
13
+ ref = TableRef(clean_bucket_name, "s", "t")
14
+ data_table = pa.table(schema=arrow_schema, data=[[1, 2, 3, 4, 5]])
15
+
16
+ with session.transaction() as tx:
17
+ table = (
18
+ tx.bucket(clean_bucket_name)
19
+ .create_schema("s")
20
+ .create_table("t", arrow_schema)
21
+ )
22
+ table.insert(data_table)
23
+
24
+ with session.transaction() as tx:
25
+ tx.adbc_conn.cursor.execute(f"SELECT * FROM {ref.query_engine_full_path}")
26
+ res = tx.adbc_conn.cursor.fetchall()
27
+
28
+ assert res == [(1,), (2,), (3,), (4,), (5,)]
29
+
30
+
31
+ def test_adbc_shares_tx(session_factory, clean_bucket_name: str):
32
+ session = session_factory(with_adbc=True)
33
+
34
+ arrow_schema = pa.schema([("n", pa.int32())])
35
+
36
+ data_table = pa.table(schema=arrow_schema, data=[[1, 2, 3, 4, 5]])
37
+
38
+ with session.transaction() as tx:
39
+ table = (
40
+ tx.bucket(clean_bucket_name)
41
+ .create_schema("s")
42
+ .create_table("t", arrow_schema)
43
+ )
44
+ table.insert(data_table)
45
+
46
+ # expecting adbc execute to "see" table if it shares the transaction with the pysdk
47
+ tx.adbc_conn.cursor.execute(f"SELECT * FROM {table.ref.query_engine_full_path}")
48
+ assert tx.adbc_conn.cursor.fetchall() == [(1,), (2,), (3,), (4,), (5,)]
49
+
50
+
51
+ def test_adbc_conn_unreachable_tx_close(session_factory):
52
+ session = session_factory(with_adbc=True)
53
+
54
+ with session.transaction() as tx:
55
+ assert tx.adbc_conn is not None
56
+
57
+ # adbc conn should not be reachable after tx close
58
+ with pytest.raises(NoAdbcConnectionError):
59
+ tx.adbc_conn
60
+
61
+
62
+ def test_two_simulatnious_txs_with_adbc(session_factory, clean_bucket_name: str):
63
+ session = session_factory(with_adbc=True)
64
+
65
+ arrow_schema = pa.schema([("n", pa.int32())])
66
+
67
+ data_table = pa.table(schema=arrow_schema, data=[[1, 2, 3, 4, 5]])
68
+
69
+ with session.transaction() as tx:
70
+ table = (
71
+ tx.bucket(clean_bucket_name)
72
+ .create_schema("s")
73
+ .create_table("t1", arrow_schema)
74
+ )
75
+ table.insert(data_table)
76
+
77
+ # expecting adbc execute to "see" table if it shares the transaction with the pysdk
78
+ tx.adbc_conn.cursor.execute(f"SELECT * FROM {table.ref.query_engine_full_path}")
79
+ assert tx.adbc_conn.cursor.fetchall() == [(1,), (2,), (3,), (4,), (5,)]
80
+
81
+ with session.transaction() as tx:
82
+ table = (
83
+ tx.bucket(clean_bucket_name).schema("s").create_table("t2", arrow_schema)
84
+ )
85
+ table.insert(data_table)
86
+
87
+ # expecting adbc execute to "see" table if it shares the transaction with the pysdk
88
+ tx.adbc_conn.cursor.execute(f"SELECT * FROM {table.ref.query_engine_full_path}")
89
+ assert tx.adbc_conn.cursor.fetchall() == [(1,), (2,), (3,), (4,), (5,)]
@@ -1,5 +1,6 @@
1
1
  import datetime as dt
2
2
  import decimal
3
+ import itertools
3
4
  import logging
4
5
  import random
5
6
  import threading
@@ -17,7 +18,7 @@ from requests.exceptions import HTTPError
17
18
 
18
19
  from vastdb import errors
19
20
  from vastdb.session import Session
20
- from vastdb.table import INTERNAL_ROW_ID, QueryConfig
21
+ from vastdb.table import INTERNAL_ROW_ID, MAX_COLUMN_IN_BATCH, QueryConfig
21
22
 
22
23
  from .util import assert_row_ids_ascending_on_first_insertion_to_table, prepare_data
23
24
 
@@ -95,6 +96,39 @@ def test_insert_wide_row(session, clean_bucket_name):
95
96
  assert actual == expected
96
97
 
97
98
 
99
+ @pytest.mark.parametrize("num_columns,insert_by_columns", itertools.product([
100
+ MAX_COLUMN_IN_BATCH // 2,
101
+ MAX_COLUMN_IN_BATCH - 1,
102
+ MAX_COLUMN_IN_BATCH,
103
+ MAX_COLUMN_IN_BATCH + 1,
104
+ MAX_COLUMN_IN_BATCH * 2,
105
+ MAX_COLUMN_IN_BATCH * 10,
106
+ ],
107
+ [False, True]
108
+ )
109
+ )
110
+ def test_insert_by_columns_variations(session, clean_bucket_name, num_columns, insert_by_columns):
111
+ columns = pa.schema([pa.field(f'i{i}', pa.int64()) for i in range(num_columns)])
112
+ data = [[i] for i in range(num_columns)]
113
+ expected = pa.table(schema=columns, data=data)
114
+
115
+ with prepare_data(session, clean_bucket_name, 's', 't', expected, insert_by_columns=insert_by_columns) as t:
116
+ actual = t.select().read_all()
117
+ assert actual == expected
118
+
119
+
120
+ @pytest.mark.parametrize("sorting_key", [0, 40, 80, 120])
121
+ def test_insert_by_columns_sorted(session, clean_bucket_name, sorting_key):
122
+ num_columns = 160
123
+ columns = pa.schema([pa.field(f'i{i}', pa.int64()) for i in range(num_columns)])
124
+ data = [[i] for i in range(num_columns)]
125
+ expected = pa.table(schema=columns, data=data)
126
+
127
+ with prepare_data(session, clean_bucket_name, 's', 't', expected, sorting_key=[sorting_key], insert_by_columns=True) as t:
128
+ actual = t.select().read_all()
129
+ assert actual == expected
130
+
131
+
98
132
  def test_multi_batch_table(session, clean_bucket_name):
99
133
  columns = pa.schema([pa.field('s', pa.utf8())])
100
134
  expected = pa.Table.from_batches([
@@ -0,0 +1,162 @@
1
+ """Tests for vector index functionality."""
2
+
3
+ import logging
4
+ from typing import Optional
5
+
6
+ import pyarrow as pa
7
+ import pytest
8
+
9
+ from vastdb import errors
10
+ from vastdb._internal import VectorIndexSpec
11
+ from vastdb.session import Session
12
+
13
+ log = logging.getLogger(__name__)
14
+
15
+
16
+ @pytest.mark.parametrize("table_name,vector_index", [
17
+ # Test 1: Table without vector index
18
+ ("table_without_index", None),
19
+ # Test 2: Table with L2 vector index
20
+ ("table_with_l2_index", VectorIndexSpec("embedding", "l2sq")),
21
+ # Test 3: Table with inner product vector index
22
+ ("table_with_ip_index", VectorIndexSpec("embedding", "ip")),
23
+ ])
24
+ def test_create_table_with_vector_index_metadata(session: Session,
25
+ clean_bucket_name: str,
26
+ table_name: str,
27
+ vector_index: Optional[VectorIndexSpec]):
28
+ """Test that table creation and stats retrieval work correctly with vector index metadata."""
29
+ schema_name = "schema1"
30
+
31
+ with session.transaction() as tx:
32
+ log.info(f"Testing table '{table_name}' with {vector_index}")
33
+
34
+ # Create schema
35
+ bucket = tx.bucket(clean_bucket_name)
36
+ schema = bucket.create_schema(schema_name)
37
+
38
+ # Create the appropriate schema based on whether vector index is needed
39
+ if vector_index is None:
40
+ # Simple table without vector index
41
+ arrow_schema = pa.schema([
42
+ ('id', pa.int64()),
43
+ ('data', pa.string())
44
+ ])
45
+ else:
46
+ # Table with vector column
47
+ vector_dimension = 128 # Fixed-size vector dimension
48
+ vec_type = pa.list_(pa.field('', pa.float32(), False), vector_dimension)
49
+ arrow_schema = pa.schema([
50
+ ('id', pa.int64()),
51
+ ('embedding', vec_type) # Fixed-size vector column
52
+ ])
53
+
54
+ # Create table with or without vector index
55
+ log.info(f"Creating table: {table_name}")
56
+ table = schema.create_table(
57
+ table_name=table_name,
58
+ columns=arrow_schema,
59
+ vector_index=vector_index
60
+ )
61
+
62
+ # Reload stats to ensure we get the vector index metadata
63
+ table.reload_stats()
64
+
65
+ # Get vector index metadata
66
+ result_vector_index = table._metadata._vector_index
67
+
68
+ log.info(f"Vector index metadata: {result_vector_index}")
69
+
70
+ # Assert expected values (should match input parameters)
71
+ result_vector_index_spec = (
72
+ None
73
+ if result_vector_index is None
74
+ else result_vector_index.to_vector_index_spec()
75
+ )
76
+ assert result_vector_index_spec == vector_index
77
+
78
+ log.info(f"✓ Test passed for table '{table_name}'")
79
+
80
+
81
+ @pytest.mark.parametrize("table_name,vector_index,expected_error", [
82
+ # Test 1: Invalid column name (column doesn't exist in schema)
83
+ ("table_invalid_column", VectorIndexSpec("nonexistent_column", "l2sq"), "invalid vector indexed column name nonexistent_column"),
84
+ # Test 2: Invalid distance metric
85
+ ("table_invalid_metric", VectorIndexSpec("embedding", "invalid_metric"), "invalid vector index distance metric invalid_metric, supported metrics: 'l2sq', 'ip'"),
86
+ ])
87
+ def test_create_table_with_invalid_vector_index(session: Session,
88
+ clean_bucket_name: str,
89
+ table_name: str,
90
+ vector_index: VectorIndexSpec,
91
+ expected_error: str):
92
+ """Test that table creation fails with appropriate error messages for invalid vector index parameters."""
93
+ schema_name = "schema1"
94
+
95
+ with session.transaction() as tx:
96
+ log.info(f"Testing invalid table '{table_name}' with vector_index={vector_index}, expected_error={expected_error}")
97
+
98
+ # Create schema
99
+ bucket = tx.bucket(clean_bucket_name)
100
+ schema = bucket.create_schema(schema_name)
101
+
102
+ # Table with vector column
103
+ vector_dimension = 128 # Fixed-size vector dimension
104
+ vec_type = pa.list_(pa.field('', pa.float32(), False), vector_dimension)
105
+ arrow_schema = pa.schema([
106
+ ('id', pa.int64()),
107
+ ('embedding', vec_type) # Fixed-size vector column
108
+ ])
109
+
110
+ # Attempt to create table with invalid parameters - should raise an error
111
+ log.info(f"Attempting to create invalid table: {table_name}")
112
+ with pytest.raises((errors.BadRequest)) as exc_info:
113
+ schema.create_table(
114
+ table_name=table_name,
115
+ columns=arrow_schema,
116
+ vector_index=vector_index
117
+ )
118
+
119
+ # Verify the error message contains the expected error text
120
+ assert expected_error in str(exc_info.value), \
121
+ f"Expected error message to contain '{expected_error}', got '{str(exc_info.value)}'"
122
+
123
+ log.info(f"✓ Test passed for invalid table '{table_name}'")
124
+
125
+
126
+ def test_vector_index_metadata_from_stats(session: Session, clean_bucket_name: str):
127
+ """Test that vector index metadata is correctly retrieved from table stats."""
128
+ schema_name = "schema1"
129
+ table_name = "vector_table"
130
+
131
+ with session.transaction() as tx:
132
+ # Create schema
133
+ bucket = tx.bucket(clean_bucket_name)
134
+ schema = bucket.create_schema(schema_name)
135
+
136
+ # Create table with vector index
137
+ vector_dimension = 128
138
+ vec_type = pa.list_(pa.field('', pa.float32(), False), vector_dimension)
139
+ arrow_schema = pa.schema([
140
+ ('id', pa.int64()),
141
+ ('embedding', vec_type)
142
+ ])
143
+
144
+ table = schema.create_table(
145
+ table_name=table_name,
146
+ columns=arrow_schema,
147
+ vector_index=VectorIndexSpec("embedding", "l2sq")
148
+ )
149
+
150
+ # Check stats object directly
151
+ stats = table.stats
152
+ assert stats is not None
153
+ assert stats.vector_index is not None
154
+ assert stats.vector_index.column == "embedding"
155
+ assert stats.vector_index.distance_metric == "l2sq"
156
+
157
+ # Check via the table methods
158
+ assert table._metadata._vector_index is not None
159
+ assert table._metadata._vector_index.column == "embedding"
160
+ assert table._metadata._vector_index.distance_metric == "l2sq"
161
+
162
+ log.info("✓ Vector index metadata correctly retrieved from stats")