vastdb 2.0.2__py3-none-any.whl → 2.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vastdb/_adbc.py +194 -0
- vastdb/_internal.py +101 -12
- vastdb/_table_interface.py +20 -3
- vastdb/conftest.py +23 -1
- vastdb/errors.py +5 -0
- vastdb/schema.py +17 -2
- vastdb/session.py +12 -5
- vastdb/table.py +56 -20
- vastdb/table_metadata.py +58 -34
- vastdb/tests/test_adbc_integration.py +89 -0
- vastdb/tests/test_tables.py +35 -1
- vastdb/tests/test_vector_index.py +162 -0
- vastdb/tests/test_vector_search.py +210 -0
- vastdb/tests/util.py +3 -2
- vastdb/transaction.py +30 -0
- vastdb/vast_flatbuf/tabular/GetTableStatsResponse.py +51 -59
- vastdb/vast_flatbuf/tabular/ObjectDetails.py +36 -59
- vastdb/vast_flatbuf/tabular/VectorIndexMetadata.py +67 -0
- vastdb/vast_flatbuf/tabular/VipRange.py +19 -12
- {vastdb-2.0.2.dist-info → vastdb-2.0.3.dist-info}/METADATA +2 -1
- {vastdb-2.0.2.dist-info → vastdb-2.0.3.dist-info}/RECORD +24 -19
- {vastdb-2.0.2.dist-info → vastdb-2.0.3.dist-info}/WHEEL +0 -0
- {vastdb-2.0.2.dist-info → vastdb-2.0.3.dist-info}/licenses/LICENSE +0 -0
- {vastdb-2.0.2.dist-info → vastdb-2.0.3.dist-info}/top_level.txt +0 -0
vastdb/table.py
CHANGED
|
@@ -22,11 +22,12 @@ import ibis
|
|
|
22
22
|
import pyarrow as pa
|
|
23
23
|
import urllib3
|
|
24
24
|
|
|
25
|
-
from vastdb._table_interface import ITable
|
|
25
|
+
from vastdb._table_interface import IbisPredicate, ITable
|
|
26
26
|
from vastdb.table_metadata import TableMetadata, TableRef, TableStats, TableType
|
|
27
27
|
|
|
28
28
|
from . import _internal, errors, util
|
|
29
29
|
from ._ibis_support import validate_ibis_support_schema
|
|
30
|
+
from ._internal import VectorIndex
|
|
30
31
|
from .config import ImportConfig, QueryConfig
|
|
31
32
|
|
|
32
33
|
if TYPE_CHECKING:
|
|
@@ -213,6 +214,11 @@ class TableInTransaction(ITable):
|
|
|
213
214
|
"""Reload Sorted Columns."""
|
|
214
215
|
self._metadata.load_sorted_columns(self._tx)
|
|
215
216
|
|
|
217
|
+
@property
|
|
218
|
+
def vector_index(self) -> Optional[VectorIndex]:
|
|
219
|
+
"""Table's Vector Index if exists."""
|
|
220
|
+
return self._metadata._vector_index
|
|
221
|
+
|
|
216
222
|
@property
|
|
217
223
|
def path(self) -> str:
|
|
218
224
|
"""Return table's path."""
|
|
@@ -222,7 +228,7 @@ class TableInTransaction(ITable):
|
|
|
222
228
|
def _internal_rowid_field(self) -> pa.Field:
|
|
223
229
|
return INTERNAL_ROW_ID_SORTED_FIELD if self._is_sorted_table else INTERNAL_ROW_ID_FIELD
|
|
224
230
|
|
|
225
|
-
def sorted_columns(self) -> list[
|
|
231
|
+
def sorted_columns(self) -> list[pa.Field]:
|
|
226
232
|
"""Return sorted columns' metadata."""
|
|
227
233
|
return self._metadata.sorted_columns
|
|
228
234
|
|
|
@@ -631,35 +637,41 @@ class TableInTransaction(ITable):
|
|
|
631
637
|
return pa.RecordBatchReader.from_batches(query_data_request.response_schema, batches_iterator())
|
|
632
638
|
|
|
633
639
|
def insert_in_column_batches(self, rows: pa.RecordBatch) -> pa.ChunkedArray:
|
|
634
|
-
"""Split the RecordBatch into
|
|
640
|
+
"""Split the RecordBatch into an insert + updates.
|
|
635
641
|
|
|
642
|
+
This is both to support rows that won't fit into an RPC and for performance for wide rows.
|
|
636
643
|
Insert first MAX_COLUMN_IN_BATCH columns and get the row_ids. Then loop on the rest of the columns and
|
|
637
644
|
update in groups of MAX_COLUMN_IN_BATCH.
|
|
638
645
|
"""
|
|
639
|
-
column_record_batch = pa.RecordBatch.from_arrays([_combine_chunks(rows.column(i)) for i in range(0, MAX_COLUMN_IN_BATCH)],
|
|
640
|
-
schema=pa.schema([rows.schema.field(i) for i in range(0, MAX_COLUMN_IN_BATCH)]))
|
|
641
|
-
row_ids = self.insert(rows=column_record_batch) # type: ignore
|
|
642
|
-
|
|
643
646
|
columns_names = [field.name for field in rows.schema]
|
|
644
|
-
columns
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
647
|
+
# Sorted columns must be in the first insert as those can't be updated later.
|
|
648
|
+
if self._is_sorted_table:
|
|
649
|
+
sorted_columns_names = [field.name for field in self.sorted_columns()]
|
|
650
|
+
columns_names = sorted_columns_names + [column_name for column_name in columns_names if column_name not in sorted_columns_names]
|
|
651
|
+
columns = [rows.schema.field(column_name) for column_name in columns_names]
|
|
652
|
+
|
|
653
|
+
arrays = [_combine_chunks(rows.column(column_name)) for column_name in columns_names]
|
|
654
|
+
for start in range(0, len(rows.schema), MAX_COLUMN_IN_BATCH):
|
|
648
655
|
end = start + MAX_COLUMN_IN_BATCH if start + \
|
|
649
656
|
MAX_COLUMN_IN_BATCH < len(rows.schema) else len(rows.schema)
|
|
650
657
|
columns_name_chunk = columns_names[start:end]
|
|
651
658
|
columns_chunks = columns[start:end]
|
|
652
659
|
arrays_chunks = arrays[start:end]
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
660
|
+
if start == 0:
|
|
661
|
+
column_record_batch = pa.RecordBatch.from_arrays(
|
|
662
|
+
arrays_chunks, schema=pa.schema(columns_chunks))
|
|
663
|
+
row_ids = self.insert(rows=column_record_batch, by_columns=False) # type: ignore
|
|
664
|
+
else:
|
|
665
|
+
columns_chunks.append(self._internal_rowid_field)
|
|
666
|
+
arrays_chunks.append(row_ids.to_pylist())
|
|
667
|
+
column_record_batch = pa.RecordBatch.from_arrays(
|
|
668
|
+
arrays_chunks, schema=pa.schema(columns_chunks))
|
|
669
|
+
self.update(rows=column_record_batch, columns=columns_name_chunk)
|
|
658
670
|
return row_ids
|
|
659
671
|
|
|
660
672
|
def insert(self,
|
|
661
673
|
rows: Union[pa.RecordBatch, pa.Table],
|
|
662
|
-
by_columns: bool =
|
|
674
|
+
by_columns: bool = True) -> pa.ChunkedArray:
|
|
663
675
|
"""Insert a RecordBatch into this table."""
|
|
664
676
|
self._assert_not_imports_table()
|
|
665
677
|
|
|
@@ -667,9 +679,14 @@ class TableInTransaction(ITable):
|
|
|
667
679
|
log.debug("Ignoring empty insert into %s", self.ref)
|
|
668
680
|
return pa.chunked_array([], type=self._internal_rowid_field.type)
|
|
669
681
|
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
682
|
+
# inserting by columns is faster, so default to doing that
|
|
683
|
+
# if the cluster supports it and there are actually columns in the rows
|
|
684
|
+
if by_columns and len(rows.schema):
|
|
685
|
+
try:
|
|
686
|
+
self._tx._rpc.features.check_return_row_ids()
|
|
687
|
+
return self.insert_in_column_batches(rows)
|
|
688
|
+
except errors.NotSupportedVersion:
|
|
689
|
+
pass
|
|
673
690
|
|
|
674
691
|
try:
|
|
675
692
|
row_ids = []
|
|
@@ -802,6 +819,25 @@ class TableInTransaction(ITable):
|
|
|
802
819
|
def _is_sorted_table(self) -> bool:
|
|
803
820
|
return self._metadata.table_type is TableType.Elysium
|
|
804
821
|
|
|
822
|
+
def vector_search(
|
|
823
|
+
self,
|
|
824
|
+
vec: list[float],
|
|
825
|
+
columns: list[str],
|
|
826
|
+
limit: int,
|
|
827
|
+
predicate: Optional[IbisPredicate] = None,
|
|
828
|
+
) -> pa.RecordBatchReader:
|
|
829
|
+
"""Vector Search over vector indexed columns."""
|
|
830
|
+
assert self.vector_index is not None, "Table is either not vector indexed. (maybe try reloading the TableMetadata)"
|
|
831
|
+
|
|
832
|
+
return self._tx.adbc_conn.vector_search(
|
|
833
|
+
vec,
|
|
834
|
+
self.vector_index,
|
|
835
|
+
self.ref,
|
|
836
|
+
columns,
|
|
837
|
+
limit,
|
|
838
|
+
predicate=predicate,
|
|
839
|
+
)
|
|
840
|
+
|
|
805
841
|
|
|
806
842
|
class Table(TableInTransaction):
|
|
807
843
|
"""Vast Interactive Table."""
|
vastdb/table_metadata.py
CHANGED
|
@@ -4,17 +4,19 @@ import logging
|
|
|
4
4
|
from copy import deepcopy
|
|
5
5
|
from dataclasses import dataclass
|
|
6
6
|
from enum import Enum
|
|
7
|
-
from typing import TYPE_CHECKING, Optional
|
|
7
|
+
from typing import TYPE_CHECKING, Optional
|
|
8
8
|
|
|
9
9
|
import ibis
|
|
10
10
|
import pyarrow as pa
|
|
11
11
|
|
|
12
12
|
from vastdb import errors
|
|
13
13
|
from vastdb._ibis_support import validate_ibis_support_schema
|
|
14
|
+
from vastdb._internal import TableStats, VectorIndex
|
|
14
15
|
|
|
15
16
|
if TYPE_CHECKING:
|
|
16
17
|
from .transaction import Transaction
|
|
17
18
|
|
|
19
|
+
|
|
18
20
|
log = logging.getLogger(__name__)
|
|
19
21
|
|
|
20
22
|
|
|
@@ -39,26 +41,16 @@ class TableRef:
|
|
|
39
41
|
"""Table full path."""
|
|
40
42
|
return f"{self.bucket}/{self.schema}/{self.table}"
|
|
41
43
|
|
|
44
|
+
@property
|
|
45
|
+
def query_engine_full_path(self) -> str:
|
|
46
|
+
"""Table full path for VastDB Query Engine."""
|
|
47
|
+
return f'"{self.bucket}/{self.schema}".{self.table}'
|
|
48
|
+
|
|
42
49
|
def __str__(self) -> str:
|
|
43
50
|
"""Table full path."""
|
|
44
51
|
return self.full_path
|
|
45
52
|
|
|
46
53
|
|
|
47
|
-
@dataclass
|
|
48
|
-
class TableStats:
|
|
49
|
-
"""Table-related information."""
|
|
50
|
-
|
|
51
|
-
num_rows: int
|
|
52
|
-
size_in_bytes: int
|
|
53
|
-
sorting_score: int
|
|
54
|
-
write_amplification: int
|
|
55
|
-
acummulative_row_inserition_count: int
|
|
56
|
-
is_external_rowid_alloc: bool = False
|
|
57
|
-
sorting_key_enabled: bool = False
|
|
58
|
-
sorting_done: bool = False
|
|
59
|
-
endpoints: Tuple[str, ...] = ()
|
|
60
|
-
|
|
61
|
-
|
|
62
54
|
class TableMetadata:
|
|
63
55
|
"""Table Metadata."""
|
|
64
56
|
|
|
@@ -67,25 +59,29 @@ class TableMetadata:
|
|
|
67
59
|
_sorted_columns: Optional[list[str]]
|
|
68
60
|
_ibis_table: ibis.Table
|
|
69
61
|
_stats: Optional[TableStats]
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
62
|
+
_vector_index: Optional[VectorIndex]
|
|
63
|
+
|
|
64
|
+
def __init__(
|
|
65
|
+
self,
|
|
66
|
+
ref: TableRef,
|
|
67
|
+
arrow_schema: Optional[pa.Schema] = None,
|
|
68
|
+
table_type: Optional[TableType] = None,
|
|
69
|
+
vector_index: Optional[VectorIndex] = None,
|
|
70
|
+
):
|
|
75
71
|
"""Table Metadata."""
|
|
76
72
|
self._ref = deepcopy(ref)
|
|
77
73
|
self._table_type = table_type
|
|
78
74
|
self.arrow_schema = deepcopy(arrow_schema)
|
|
79
75
|
self._sorted_columns = None
|
|
80
76
|
self._stats = None
|
|
77
|
+
self._vector_index = vector_index
|
|
81
78
|
|
|
82
79
|
def __eq__(self, other: object) -> bool:
|
|
83
80
|
"""TableMetadata Equal."""
|
|
84
81
|
if not isinstance(other, TableMetadata):
|
|
85
82
|
return False
|
|
86
83
|
|
|
87
|
-
return
|
|
88
|
-
self._table_type == other._table_type)
|
|
84
|
+
return self._ref == other._ref and self._table_type == other._table_type
|
|
89
85
|
|
|
90
86
|
def rename_table(self, name: str) -> None:
|
|
91
87
|
"""Rename table metadata's table name."""
|
|
@@ -110,7 +106,8 @@ class TableMetadata:
|
|
|
110
106
|
table=self.ref.table,
|
|
111
107
|
next_key=next_key,
|
|
112
108
|
txid=tx.active_txid,
|
|
113
|
-
list_imports_table=self.is_imports_table
|
|
109
|
+
list_imports_table=self.is_imports_table,
|
|
110
|
+
)
|
|
114
111
|
fields.extend(cur_columns)
|
|
115
112
|
if not is_truncated:
|
|
116
113
|
break
|
|
@@ -123,9 +120,16 @@ class TableMetadata:
|
|
|
123
120
|
try:
|
|
124
121
|
next_key = 0
|
|
125
122
|
while True:
|
|
126
|
-
cur_columns, next_key, is_truncated, _count =
|
|
127
|
-
|
|
128
|
-
|
|
123
|
+
cur_columns, next_key, is_truncated, _count = (
|
|
124
|
+
tx._rpc.api.list_sorted_columns(
|
|
125
|
+
bucket=self.ref.bucket,
|
|
126
|
+
schema=self.ref.schema,
|
|
127
|
+
table=self.ref.table,
|
|
128
|
+
next_key=next_key,
|
|
129
|
+
txid=tx.active_txid,
|
|
130
|
+
list_imports_table=self.is_imports_table,
|
|
131
|
+
)
|
|
132
|
+
)
|
|
129
133
|
fields.extend(cur_columns)
|
|
130
134
|
if not is_truncated:
|
|
131
135
|
break
|
|
@@ -133,7 +137,9 @@ class TableMetadata:
|
|
|
133
137
|
raise
|
|
134
138
|
except errors.InternalServerError as ise:
|
|
135
139
|
log.warning(
|
|
136
|
-
"Failed to get the sorted columns Elysium might not be supported: %s",
|
|
140
|
+
"Failed to get the sorted columns Elysium might not be supported: %s",
|
|
141
|
+
ise,
|
|
142
|
+
)
|
|
137
143
|
raise
|
|
138
144
|
except errors.NotSupportedVersion:
|
|
139
145
|
log.warning("Failed to get the sorted columns, Elysium not supported")
|
|
@@ -143,10 +149,13 @@ class TableMetadata:
|
|
|
143
149
|
|
|
144
150
|
def load_stats(self, tx: "Transaction") -> None:
|
|
145
151
|
"""Load/Reload table stats."""
|
|
146
|
-
|
|
147
|
-
bucket=self.ref.bucket,
|
|
148
|
-
|
|
149
|
-
|
|
152
|
+
self._stats = tx._rpc.api.get_table_stats(
|
|
153
|
+
bucket=self.ref.bucket,
|
|
154
|
+
schema=self.ref.schema,
|
|
155
|
+
name=self.ref.table,
|
|
156
|
+
txid=tx.active_txid,
|
|
157
|
+
imports_table_stats=self.is_imports_table,
|
|
158
|
+
)
|
|
150
159
|
|
|
151
160
|
is_elysium_table = self._stats.sorting_key_enabled
|
|
152
161
|
|
|
@@ -161,6 +170,18 @@ class TableMetadata:
|
|
|
161
170
|
"Actual table is sorted (TableType.Elysium), was not inited as TableType.Elysium"
|
|
162
171
|
)
|
|
163
172
|
|
|
173
|
+
self._parse_stats_vector_index()
|
|
174
|
+
|
|
175
|
+
def _parse_stats_vector_index(self):
|
|
176
|
+
vector_index_is_set = self._vector_index is not None
|
|
177
|
+
|
|
178
|
+
if vector_index_is_set and self._stats.vector_index != self._vector_index:
|
|
179
|
+
raise ValueError(
|
|
180
|
+
f"Table has index {self._stats.vector_index}, but was initialized as {self._vector_index}"
|
|
181
|
+
)
|
|
182
|
+
else:
|
|
183
|
+
self._vector_index = self._stats.vector_index
|
|
184
|
+
|
|
164
185
|
def _set_sorted_table(self, tx: "Transaction"):
|
|
165
186
|
self._table_type = TableType.Elysium
|
|
166
187
|
tx._rpc.features.check_elysium()
|
|
@@ -184,7 +205,9 @@ class TableMetadata:
|
|
|
184
205
|
if arrow_schema:
|
|
185
206
|
validate_ibis_support_schema(arrow_schema)
|
|
186
207
|
self._arrow_schema = arrow_schema
|
|
187
|
-
self._ibis_table = ibis.table(
|
|
208
|
+
self._ibis_table = ibis.table(
|
|
209
|
+
ibis.Schema.from_pyarrow(arrow_schema), self._ref.full_path
|
|
210
|
+
)
|
|
188
211
|
else:
|
|
189
212
|
self._arrow_schema = None
|
|
190
213
|
self._ibis_table = None
|
|
@@ -211,7 +234,8 @@ class TableMetadata:
|
|
|
211
234
|
"""Table's type."""
|
|
212
235
|
if self._table_type is None:
|
|
213
236
|
raise ValueError(
|
|
214
|
-
"TableType was not loaded. load using TableMetadata.load_stats"
|
|
237
|
+
"TableType was not loaded. load using TableMetadata.load_stats"
|
|
238
|
+
)
|
|
215
239
|
|
|
216
240
|
return self._table_type
|
|
217
241
|
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
import pyarrow as pa
|
|
2
|
+
import pytest
|
|
3
|
+
|
|
4
|
+
from vastdb.table_metadata import TableRef
|
|
5
|
+
from vastdb.transaction import NoAdbcConnectionError
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def test_sanity(session_factory, clean_bucket_name: str):
|
|
9
|
+
session = session_factory(with_adbc=True)
|
|
10
|
+
|
|
11
|
+
arrow_schema = pa.schema([("n", pa.int32())])
|
|
12
|
+
|
|
13
|
+
ref = TableRef(clean_bucket_name, "s", "t")
|
|
14
|
+
data_table = pa.table(schema=arrow_schema, data=[[1, 2, 3, 4, 5]])
|
|
15
|
+
|
|
16
|
+
with session.transaction() as tx:
|
|
17
|
+
table = (
|
|
18
|
+
tx.bucket(clean_bucket_name)
|
|
19
|
+
.create_schema("s")
|
|
20
|
+
.create_table("t", arrow_schema)
|
|
21
|
+
)
|
|
22
|
+
table.insert(data_table)
|
|
23
|
+
|
|
24
|
+
with session.transaction() as tx:
|
|
25
|
+
tx.adbc_conn.cursor.execute(f"SELECT * FROM {ref.query_engine_full_path}")
|
|
26
|
+
res = tx.adbc_conn.cursor.fetchall()
|
|
27
|
+
|
|
28
|
+
assert res == [(1,), (2,), (3,), (4,), (5,)]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def test_adbc_shares_tx(session_factory, clean_bucket_name: str):
|
|
32
|
+
session = session_factory(with_adbc=True)
|
|
33
|
+
|
|
34
|
+
arrow_schema = pa.schema([("n", pa.int32())])
|
|
35
|
+
|
|
36
|
+
data_table = pa.table(schema=arrow_schema, data=[[1, 2, 3, 4, 5]])
|
|
37
|
+
|
|
38
|
+
with session.transaction() as tx:
|
|
39
|
+
table = (
|
|
40
|
+
tx.bucket(clean_bucket_name)
|
|
41
|
+
.create_schema("s")
|
|
42
|
+
.create_table("t", arrow_schema)
|
|
43
|
+
)
|
|
44
|
+
table.insert(data_table)
|
|
45
|
+
|
|
46
|
+
# expecting adbc execute to "see" table if it shares the transaction with the pysdk
|
|
47
|
+
tx.adbc_conn.cursor.execute(f"SELECT * FROM {table.ref.query_engine_full_path}")
|
|
48
|
+
assert tx.adbc_conn.cursor.fetchall() == [(1,), (2,), (3,), (4,), (5,)]
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def test_adbc_conn_unreachable_tx_close(session_factory):
|
|
52
|
+
session = session_factory(with_adbc=True)
|
|
53
|
+
|
|
54
|
+
with session.transaction() as tx:
|
|
55
|
+
assert tx.adbc_conn is not None
|
|
56
|
+
|
|
57
|
+
# adbc conn should not be reachable after tx close
|
|
58
|
+
with pytest.raises(NoAdbcConnectionError):
|
|
59
|
+
tx.adbc_conn
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def test_two_simulatnious_txs_with_adbc(session_factory, clean_bucket_name: str):
|
|
63
|
+
session = session_factory(with_adbc=True)
|
|
64
|
+
|
|
65
|
+
arrow_schema = pa.schema([("n", pa.int32())])
|
|
66
|
+
|
|
67
|
+
data_table = pa.table(schema=arrow_schema, data=[[1, 2, 3, 4, 5]])
|
|
68
|
+
|
|
69
|
+
with session.transaction() as tx:
|
|
70
|
+
table = (
|
|
71
|
+
tx.bucket(clean_bucket_name)
|
|
72
|
+
.create_schema("s")
|
|
73
|
+
.create_table("t1", arrow_schema)
|
|
74
|
+
)
|
|
75
|
+
table.insert(data_table)
|
|
76
|
+
|
|
77
|
+
# expecting adbc execute to "see" table if it shares the transaction with the pysdk
|
|
78
|
+
tx.adbc_conn.cursor.execute(f"SELECT * FROM {table.ref.query_engine_full_path}")
|
|
79
|
+
assert tx.adbc_conn.cursor.fetchall() == [(1,), (2,), (3,), (4,), (5,)]
|
|
80
|
+
|
|
81
|
+
with session.transaction() as tx:
|
|
82
|
+
table = (
|
|
83
|
+
tx.bucket(clean_bucket_name).schema("s").create_table("t2", arrow_schema)
|
|
84
|
+
)
|
|
85
|
+
table.insert(data_table)
|
|
86
|
+
|
|
87
|
+
# expecting adbc execute to "see" table if it shares the transaction with the pysdk
|
|
88
|
+
tx.adbc_conn.cursor.execute(f"SELECT * FROM {table.ref.query_engine_full_path}")
|
|
89
|
+
assert tx.adbc_conn.cursor.fetchall() == [(1,), (2,), (3,), (4,), (5,)]
|
vastdb/tests/test_tables.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import datetime as dt
|
|
2
2
|
import decimal
|
|
3
|
+
import itertools
|
|
3
4
|
import logging
|
|
4
5
|
import random
|
|
5
6
|
import threading
|
|
@@ -17,7 +18,7 @@ from requests.exceptions import HTTPError
|
|
|
17
18
|
|
|
18
19
|
from vastdb import errors
|
|
19
20
|
from vastdb.session import Session
|
|
20
|
-
from vastdb.table import INTERNAL_ROW_ID, QueryConfig
|
|
21
|
+
from vastdb.table import INTERNAL_ROW_ID, MAX_COLUMN_IN_BATCH, QueryConfig
|
|
21
22
|
|
|
22
23
|
from .util import assert_row_ids_ascending_on_first_insertion_to_table, prepare_data
|
|
23
24
|
|
|
@@ -95,6 +96,39 @@ def test_insert_wide_row(session, clean_bucket_name):
|
|
|
95
96
|
assert actual == expected
|
|
96
97
|
|
|
97
98
|
|
|
99
|
+
@pytest.mark.parametrize("num_columns,insert_by_columns", itertools.product([
|
|
100
|
+
MAX_COLUMN_IN_BATCH // 2,
|
|
101
|
+
MAX_COLUMN_IN_BATCH - 1,
|
|
102
|
+
MAX_COLUMN_IN_BATCH,
|
|
103
|
+
MAX_COLUMN_IN_BATCH + 1,
|
|
104
|
+
MAX_COLUMN_IN_BATCH * 2,
|
|
105
|
+
MAX_COLUMN_IN_BATCH * 10,
|
|
106
|
+
],
|
|
107
|
+
[False, True]
|
|
108
|
+
)
|
|
109
|
+
)
|
|
110
|
+
def test_insert_by_columns_variations(session, clean_bucket_name, num_columns, insert_by_columns):
|
|
111
|
+
columns = pa.schema([pa.field(f'i{i}', pa.int64()) for i in range(num_columns)])
|
|
112
|
+
data = [[i] for i in range(num_columns)]
|
|
113
|
+
expected = pa.table(schema=columns, data=data)
|
|
114
|
+
|
|
115
|
+
with prepare_data(session, clean_bucket_name, 's', 't', expected, insert_by_columns=insert_by_columns) as t:
|
|
116
|
+
actual = t.select().read_all()
|
|
117
|
+
assert actual == expected
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@pytest.mark.parametrize("sorting_key", [0, 40, 80, 120])
|
|
121
|
+
def test_insert_by_columns_sorted(session, clean_bucket_name, sorting_key):
|
|
122
|
+
num_columns = 160
|
|
123
|
+
columns = pa.schema([pa.field(f'i{i}', pa.int64()) for i in range(num_columns)])
|
|
124
|
+
data = [[i] for i in range(num_columns)]
|
|
125
|
+
expected = pa.table(schema=columns, data=data)
|
|
126
|
+
|
|
127
|
+
with prepare_data(session, clean_bucket_name, 's', 't', expected, sorting_key=[sorting_key], insert_by_columns=True) as t:
|
|
128
|
+
actual = t.select().read_all()
|
|
129
|
+
assert actual == expected
|
|
130
|
+
|
|
131
|
+
|
|
98
132
|
def test_multi_batch_table(session, clean_bucket_name):
|
|
99
133
|
columns = pa.schema([pa.field('s', pa.utf8())])
|
|
100
134
|
expected = pa.Table.from_batches([
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
"""Tests for vector index functionality."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
import pyarrow as pa
|
|
7
|
+
import pytest
|
|
8
|
+
|
|
9
|
+
from vastdb import errors
|
|
10
|
+
from vastdb._internal import VectorIndexSpec
|
|
11
|
+
from vastdb.session import Session
|
|
12
|
+
|
|
13
|
+
log = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@pytest.mark.parametrize("table_name,vector_index", [
|
|
17
|
+
# Test 1: Table without vector index
|
|
18
|
+
("table_without_index", None),
|
|
19
|
+
# Test 2: Table with L2 vector index
|
|
20
|
+
("table_with_l2_index", VectorIndexSpec("embedding", "l2sq")),
|
|
21
|
+
# Test 3: Table with inner product vector index
|
|
22
|
+
("table_with_ip_index", VectorIndexSpec("embedding", "ip")),
|
|
23
|
+
])
|
|
24
|
+
def test_create_table_with_vector_index_metadata(session: Session,
|
|
25
|
+
clean_bucket_name: str,
|
|
26
|
+
table_name: str,
|
|
27
|
+
vector_index: Optional[VectorIndexSpec]):
|
|
28
|
+
"""Test that table creation and stats retrieval work correctly with vector index metadata."""
|
|
29
|
+
schema_name = "schema1"
|
|
30
|
+
|
|
31
|
+
with session.transaction() as tx:
|
|
32
|
+
log.info(f"Testing table '{table_name}' with {vector_index}")
|
|
33
|
+
|
|
34
|
+
# Create schema
|
|
35
|
+
bucket = tx.bucket(clean_bucket_name)
|
|
36
|
+
schema = bucket.create_schema(schema_name)
|
|
37
|
+
|
|
38
|
+
# Create the appropriate schema based on whether vector index is needed
|
|
39
|
+
if vector_index is None:
|
|
40
|
+
# Simple table without vector index
|
|
41
|
+
arrow_schema = pa.schema([
|
|
42
|
+
('id', pa.int64()),
|
|
43
|
+
('data', pa.string())
|
|
44
|
+
])
|
|
45
|
+
else:
|
|
46
|
+
# Table with vector column
|
|
47
|
+
vector_dimension = 128 # Fixed-size vector dimension
|
|
48
|
+
vec_type = pa.list_(pa.field('', pa.float32(), False), vector_dimension)
|
|
49
|
+
arrow_schema = pa.schema([
|
|
50
|
+
('id', pa.int64()),
|
|
51
|
+
('embedding', vec_type) # Fixed-size vector column
|
|
52
|
+
])
|
|
53
|
+
|
|
54
|
+
# Create table with or without vector index
|
|
55
|
+
log.info(f"Creating table: {table_name}")
|
|
56
|
+
table = schema.create_table(
|
|
57
|
+
table_name=table_name,
|
|
58
|
+
columns=arrow_schema,
|
|
59
|
+
vector_index=vector_index
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
# Reload stats to ensure we get the vector index metadata
|
|
63
|
+
table.reload_stats()
|
|
64
|
+
|
|
65
|
+
# Get vector index metadata
|
|
66
|
+
result_vector_index = table._metadata._vector_index
|
|
67
|
+
|
|
68
|
+
log.info(f"Vector index metadata: {result_vector_index}")
|
|
69
|
+
|
|
70
|
+
# Assert expected values (should match input parameters)
|
|
71
|
+
result_vector_index_spec = (
|
|
72
|
+
None
|
|
73
|
+
if result_vector_index is None
|
|
74
|
+
else result_vector_index.to_vector_index_spec()
|
|
75
|
+
)
|
|
76
|
+
assert result_vector_index_spec == vector_index
|
|
77
|
+
|
|
78
|
+
log.info(f"✓ Test passed for table '{table_name}'")
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@pytest.mark.parametrize("table_name,vector_index,expected_error", [
|
|
82
|
+
# Test 1: Invalid column name (column doesn't exist in schema)
|
|
83
|
+
("table_invalid_column", VectorIndexSpec("nonexistent_column", "l2sq"), "invalid vector indexed column name nonexistent_column"),
|
|
84
|
+
# Test 2: Invalid distance metric
|
|
85
|
+
("table_invalid_metric", VectorIndexSpec("embedding", "invalid_metric"), "invalid vector index distance metric invalid_metric, supported metrics: 'l2sq', 'ip'"),
|
|
86
|
+
])
|
|
87
|
+
def test_create_table_with_invalid_vector_index(session: Session,
|
|
88
|
+
clean_bucket_name: str,
|
|
89
|
+
table_name: str,
|
|
90
|
+
vector_index: VectorIndexSpec,
|
|
91
|
+
expected_error: str):
|
|
92
|
+
"""Test that table creation fails with appropriate error messages for invalid vector index parameters."""
|
|
93
|
+
schema_name = "schema1"
|
|
94
|
+
|
|
95
|
+
with session.transaction() as tx:
|
|
96
|
+
log.info(f"Testing invalid table '{table_name}' with vector_index={vector_index}, expected_error={expected_error}")
|
|
97
|
+
|
|
98
|
+
# Create schema
|
|
99
|
+
bucket = tx.bucket(clean_bucket_name)
|
|
100
|
+
schema = bucket.create_schema(schema_name)
|
|
101
|
+
|
|
102
|
+
# Table with vector column
|
|
103
|
+
vector_dimension = 128 # Fixed-size vector dimension
|
|
104
|
+
vec_type = pa.list_(pa.field('', pa.float32(), False), vector_dimension)
|
|
105
|
+
arrow_schema = pa.schema([
|
|
106
|
+
('id', pa.int64()),
|
|
107
|
+
('embedding', vec_type) # Fixed-size vector column
|
|
108
|
+
])
|
|
109
|
+
|
|
110
|
+
# Attempt to create table with invalid parameters - should raise an error
|
|
111
|
+
log.info(f"Attempting to create invalid table: {table_name}")
|
|
112
|
+
with pytest.raises((errors.BadRequest)) as exc_info:
|
|
113
|
+
schema.create_table(
|
|
114
|
+
table_name=table_name,
|
|
115
|
+
columns=arrow_schema,
|
|
116
|
+
vector_index=vector_index
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
# Verify the error message contains the expected error text
|
|
120
|
+
assert expected_error in str(exc_info.value), \
|
|
121
|
+
f"Expected error message to contain '{expected_error}', got '{str(exc_info.value)}'"
|
|
122
|
+
|
|
123
|
+
log.info(f"✓ Test passed for invalid table '{table_name}'")
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def test_vector_index_metadata_from_stats(session: Session, clean_bucket_name: str):
|
|
127
|
+
"""Test that vector index metadata is correctly retrieved from table stats."""
|
|
128
|
+
schema_name = "schema1"
|
|
129
|
+
table_name = "vector_table"
|
|
130
|
+
|
|
131
|
+
with session.transaction() as tx:
|
|
132
|
+
# Create schema
|
|
133
|
+
bucket = tx.bucket(clean_bucket_name)
|
|
134
|
+
schema = bucket.create_schema(schema_name)
|
|
135
|
+
|
|
136
|
+
# Create table with vector index
|
|
137
|
+
vector_dimension = 128
|
|
138
|
+
vec_type = pa.list_(pa.field('', pa.float32(), False), vector_dimension)
|
|
139
|
+
arrow_schema = pa.schema([
|
|
140
|
+
('id', pa.int64()),
|
|
141
|
+
('embedding', vec_type)
|
|
142
|
+
])
|
|
143
|
+
|
|
144
|
+
table = schema.create_table(
|
|
145
|
+
table_name=table_name,
|
|
146
|
+
columns=arrow_schema,
|
|
147
|
+
vector_index=VectorIndexSpec("embedding", "l2sq")
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
# Check stats object directly
|
|
151
|
+
stats = table.stats
|
|
152
|
+
assert stats is not None
|
|
153
|
+
assert stats.vector_index is not None
|
|
154
|
+
assert stats.vector_index.column == "embedding"
|
|
155
|
+
assert stats.vector_index.distance_metric == "l2sq"
|
|
156
|
+
|
|
157
|
+
# Check via the table methods
|
|
158
|
+
assert table._metadata._vector_index is not None
|
|
159
|
+
assert table._metadata._vector_index.column == "embedding"
|
|
160
|
+
assert table._metadata._vector_index.distance_metric == "l2sq"
|
|
161
|
+
|
|
162
|
+
log.info("✓ Vector index metadata correctly retrieved from stats")
|