vastdb 2.0.2__py3-none-any.whl → 2.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vastdb/_adbc.py +205 -0
- vastdb/_internal.py +106 -17
- vastdb/_table_interface.py +20 -3
- vastdb/conftest.py +23 -1
- vastdb/errors.py +5 -0
- vastdb/schema.py +19 -2
- vastdb/session.py +14 -5
- vastdb/table.py +57 -22
- vastdb/table_metadata.py +58 -34
- vastdb/tests/test_adbc_integration.py +129 -0
- vastdb/tests/test_tables.py +35 -1
- vastdb/tests/test_vector_index.py +162 -0
- vastdb/tests/test_vector_search.py +211 -0
- vastdb/tests/util.py +3 -2
- vastdb/transaction.py +32 -0
- vastdb/vast_flatbuf/tabular/GetTableStatsResponse.py +51 -59
- vastdb/vast_flatbuf/tabular/ObjectDetails.py +36 -59
- vastdb/vast_flatbuf/tabular/VectorIndexMetadata.py +67 -0
- vastdb/vast_flatbuf/tabular/VipRange.py +19 -12
- {vastdb-2.0.2.dist-info → vastdb-2.0.5.dist-info}/METADATA +2 -1
- {vastdb-2.0.2.dist-info → vastdb-2.0.5.dist-info}/RECORD +24 -19
- {vastdb-2.0.2.dist-info → vastdb-2.0.5.dist-info}/WHEEL +0 -0
- {vastdb-2.0.2.dist-info → vastdb-2.0.5.dist-info}/licenses/LICENSE +0 -0
- {vastdb-2.0.2.dist-info → vastdb-2.0.5.dist-info}/top_level.txt +0 -0
vastdb/table.py
CHANGED
|
@@ -22,11 +22,12 @@ import ibis
|
|
|
22
22
|
import pyarrow as pa
|
|
23
23
|
import urllib3
|
|
24
24
|
|
|
25
|
-
from vastdb._table_interface import ITable
|
|
25
|
+
from vastdb._table_interface import IbisPredicate, ITable
|
|
26
26
|
from vastdb.table_metadata import TableMetadata, TableRef, TableStats, TableType
|
|
27
27
|
|
|
28
28
|
from . import _internal, errors, util
|
|
29
29
|
from ._ibis_support import validate_ibis_support_schema
|
|
30
|
+
from ._internal import VectorIndex
|
|
30
31
|
from .config import ImportConfig, QueryConfig
|
|
31
32
|
|
|
32
33
|
if TYPE_CHECKING:
|
|
@@ -213,6 +214,11 @@ class TableInTransaction(ITable):
|
|
|
213
214
|
"""Reload Sorted Columns."""
|
|
214
215
|
self._metadata.load_sorted_columns(self._tx)
|
|
215
216
|
|
|
217
|
+
@property
|
|
218
|
+
def vector_index(self) -> Optional[VectorIndex]:
|
|
219
|
+
"""Table's Vector Index if exists."""
|
|
220
|
+
return self._metadata._vector_index
|
|
221
|
+
|
|
216
222
|
@property
|
|
217
223
|
def path(self) -> str:
|
|
218
224
|
"""Return table's path."""
|
|
@@ -222,7 +228,7 @@ class TableInTransaction(ITable):
|
|
|
222
228
|
def _internal_rowid_field(self) -> pa.Field:
|
|
223
229
|
return INTERNAL_ROW_ID_SORTED_FIELD if self._is_sorted_table else INTERNAL_ROW_ID_FIELD
|
|
224
230
|
|
|
225
|
-
def sorted_columns(self) -> list[
|
|
231
|
+
def sorted_columns(self) -> list[pa.Field]:
|
|
226
232
|
"""Return sorted columns' metadata."""
|
|
227
233
|
return self._metadata.sorted_columns
|
|
228
234
|
|
|
@@ -620,46 +626,51 @@ class TableInTransaction(ITable):
|
|
|
620
626
|
log.debug(
|
|
621
627
|
"one worker thread finished, remaining: %d", tasks_running)
|
|
622
628
|
|
|
623
|
-
# all host threads ended - wait for all futures to complete
|
|
624
|
-
propagate_first_exception(futures, block=True)
|
|
625
629
|
finally:
|
|
626
630
|
stop_event.set()
|
|
627
631
|
while tasks_running > 0:
|
|
628
632
|
if record_batches_queue.get() is None:
|
|
629
633
|
tasks_running -= 1
|
|
634
|
+
propagate_first_exception(futures, block=True)
|
|
630
635
|
|
|
631
636
|
return pa.RecordBatchReader.from_batches(query_data_request.response_schema, batches_iterator())
|
|
632
637
|
|
|
633
638
|
def insert_in_column_batches(self, rows: pa.RecordBatch) -> pa.ChunkedArray:
|
|
634
|
-
"""Split the RecordBatch into
|
|
639
|
+
"""Split the RecordBatch into an insert + updates.
|
|
635
640
|
|
|
641
|
+
This is both to support rows that won't fit into an RPC and for performance for wide rows.
|
|
636
642
|
Insert first MAX_COLUMN_IN_BATCH columns and get the row_ids. Then loop on the rest of the columns and
|
|
637
643
|
update in groups of MAX_COLUMN_IN_BATCH.
|
|
638
644
|
"""
|
|
639
|
-
column_record_batch = pa.RecordBatch.from_arrays([_combine_chunks(rows.column(i)) for i in range(0, MAX_COLUMN_IN_BATCH)],
|
|
640
|
-
schema=pa.schema([rows.schema.field(i) for i in range(0, MAX_COLUMN_IN_BATCH)]))
|
|
641
|
-
row_ids = self.insert(rows=column_record_batch) # type: ignore
|
|
642
|
-
|
|
643
645
|
columns_names = [field.name for field in rows.schema]
|
|
644
|
-
columns
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
646
|
+
# Sorted columns must be in the first insert as those can't be updated later.
|
|
647
|
+
if self._is_sorted_table:
|
|
648
|
+
sorted_columns_names = [field.name for field in self.sorted_columns()]
|
|
649
|
+
columns_names = sorted_columns_names + [column_name for column_name in columns_names if column_name not in sorted_columns_names]
|
|
650
|
+
columns = [rows.schema.field(column_name) for column_name in columns_names]
|
|
651
|
+
|
|
652
|
+
arrays = [_combine_chunks(rows.column(column_name)) for column_name in columns_names]
|
|
653
|
+
for start in range(0, len(rows.schema), MAX_COLUMN_IN_BATCH):
|
|
648
654
|
end = start + MAX_COLUMN_IN_BATCH if start + \
|
|
649
655
|
MAX_COLUMN_IN_BATCH < len(rows.schema) else len(rows.schema)
|
|
650
656
|
columns_name_chunk = columns_names[start:end]
|
|
651
657
|
columns_chunks = columns[start:end]
|
|
652
658
|
arrays_chunks = arrays[start:end]
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
659
|
+
if start == 0:
|
|
660
|
+
column_record_batch = pa.RecordBatch.from_arrays(
|
|
661
|
+
arrays_chunks, schema=pa.schema(columns_chunks))
|
|
662
|
+
row_ids = self.insert(rows=column_record_batch, by_columns=False) # type: ignore
|
|
663
|
+
else:
|
|
664
|
+
columns_chunks.append(self._internal_rowid_field)
|
|
665
|
+
arrays_chunks.append(row_ids.to_pylist())
|
|
666
|
+
column_record_batch = pa.RecordBatch.from_arrays(
|
|
667
|
+
arrays_chunks, schema=pa.schema(columns_chunks))
|
|
668
|
+
self.update(rows=column_record_batch, columns=columns_name_chunk)
|
|
658
669
|
return row_ids
|
|
659
670
|
|
|
660
671
|
def insert(self,
|
|
661
672
|
rows: Union[pa.RecordBatch, pa.Table],
|
|
662
|
-
by_columns: bool =
|
|
673
|
+
by_columns: bool = True) -> pa.ChunkedArray:
|
|
663
674
|
"""Insert a RecordBatch into this table."""
|
|
664
675
|
self._assert_not_imports_table()
|
|
665
676
|
|
|
@@ -667,9 +678,14 @@ class TableInTransaction(ITable):
|
|
|
667
678
|
log.debug("Ignoring empty insert into %s", self.ref)
|
|
668
679
|
return pa.chunked_array([], type=self._internal_rowid_field.type)
|
|
669
680
|
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
681
|
+
# inserting by columns is faster, so default to doing that
|
|
682
|
+
# if the cluster supports it and there are actually columns in the rows
|
|
683
|
+
if by_columns and len(rows.schema):
|
|
684
|
+
try:
|
|
685
|
+
self._tx._rpc.features.check_return_row_ids()
|
|
686
|
+
return self.insert_in_column_batches(rows)
|
|
687
|
+
except errors.NotSupportedVersion:
|
|
688
|
+
pass
|
|
673
689
|
|
|
674
690
|
try:
|
|
675
691
|
row_ids = []
|
|
@@ -802,6 +818,25 @@ class TableInTransaction(ITable):
|
|
|
802
818
|
def _is_sorted_table(self) -> bool:
|
|
803
819
|
return self._metadata.table_type is TableType.Elysium
|
|
804
820
|
|
|
821
|
+
def vector_search(
|
|
822
|
+
self,
|
|
823
|
+
vec: list[float],
|
|
824
|
+
columns: list[str],
|
|
825
|
+
limit: int,
|
|
826
|
+
predicate: Optional[IbisPredicate] = None,
|
|
827
|
+
) -> pa.RecordBatchReader:
|
|
828
|
+
"""Vector Search over vector indexed columns."""
|
|
829
|
+
assert self.vector_index is not None, "Table is either not vector indexed. (maybe try reloading the TableMetadata)"
|
|
830
|
+
|
|
831
|
+
return self._tx.adbc_conn.vector_search(
|
|
832
|
+
vec,
|
|
833
|
+
self.vector_index,
|
|
834
|
+
self.ref,
|
|
835
|
+
columns,
|
|
836
|
+
limit,
|
|
837
|
+
predicate=predicate,
|
|
838
|
+
)
|
|
839
|
+
|
|
805
840
|
|
|
806
841
|
class Table(TableInTransaction):
|
|
807
842
|
"""Vast Interactive Table."""
|
vastdb/table_metadata.py
CHANGED
|
@@ -4,17 +4,19 @@ import logging
|
|
|
4
4
|
from copy import deepcopy
|
|
5
5
|
from dataclasses import dataclass
|
|
6
6
|
from enum import Enum
|
|
7
|
-
from typing import TYPE_CHECKING, Optional
|
|
7
|
+
from typing import TYPE_CHECKING, Optional
|
|
8
8
|
|
|
9
9
|
import ibis
|
|
10
10
|
import pyarrow as pa
|
|
11
11
|
|
|
12
12
|
from vastdb import errors
|
|
13
13
|
from vastdb._ibis_support import validate_ibis_support_schema
|
|
14
|
+
from vastdb._internal import TableStats, VectorIndex
|
|
14
15
|
|
|
15
16
|
if TYPE_CHECKING:
|
|
16
17
|
from .transaction import Transaction
|
|
17
18
|
|
|
19
|
+
|
|
18
20
|
log = logging.getLogger(__name__)
|
|
19
21
|
|
|
20
22
|
|
|
@@ -39,26 +41,16 @@ class TableRef:
|
|
|
39
41
|
"""Table full path."""
|
|
40
42
|
return f"{self.bucket}/{self.schema}/{self.table}"
|
|
41
43
|
|
|
44
|
+
@property
|
|
45
|
+
def query_engine_full_path(self) -> str:
|
|
46
|
+
"""Table full path for VastDB Query Engine."""
|
|
47
|
+
return f'"{self.bucket}/{self.schema}".{self.table}'
|
|
48
|
+
|
|
42
49
|
def __str__(self) -> str:
|
|
43
50
|
"""Table full path."""
|
|
44
51
|
return self.full_path
|
|
45
52
|
|
|
46
53
|
|
|
47
|
-
@dataclass
|
|
48
|
-
class TableStats:
|
|
49
|
-
"""Table-related information."""
|
|
50
|
-
|
|
51
|
-
num_rows: int
|
|
52
|
-
size_in_bytes: int
|
|
53
|
-
sorting_score: int
|
|
54
|
-
write_amplification: int
|
|
55
|
-
acummulative_row_inserition_count: int
|
|
56
|
-
is_external_rowid_alloc: bool = False
|
|
57
|
-
sorting_key_enabled: bool = False
|
|
58
|
-
sorting_done: bool = False
|
|
59
|
-
endpoints: Tuple[str, ...] = ()
|
|
60
|
-
|
|
61
|
-
|
|
62
54
|
class TableMetadata:
|
|
63
55
|
"""Table Metadata."""
|
|
64
56
|
|
|
@@ -67,25 +59,29 @@ class TableMetadata:
|
|
|
67
59
|
_sorted_columns: Optional[list[str]]
|
|
68
60
|
_ibis_table: ibis.Table
|
|
69
61
|
_stats: Optional[TableStats]
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
62
|
+
_vector_index: Optional[VectorIndex]
|
|
63
|
+
|
|
64
|
+
def __init__(
|
|
65
|
+
self,
|
|
66
|
+
ref: TableRef,
|
|
67
|
+
arrow_schema: Optional[pa.Schema] = None,
|
|
68
|
+
table_type: Optional[TableType] = None,
|
|
69
|
+
vector_index: Optional[VectorIndex] = None,
|
|
70
|
+
):
|
|
75
71
|
"""Table Metadata."""
|
|
76
72
|
self._ref = deepcopy(ref)
|
|
77
73
|
self._table_type = table_type
|
|
78
74
|
self.arrow_schema = deepcopy(arrow_schema)
|
|
79
75
|
self._sorted_columns = None
|
|
80
76
|
self._stats = None
|
|
77
|
+
self._vector_index = vector_index
|
|
81
78
|
|
|
82
79
|
def __eq__(self, other: object) -> bool:
|
|
83
80
|
"""TableMetadata Equal."""
|
|
84
81
|
if not isinstance(other, TableMetadata):
|
|
85
82
|
return False
|
|
86
83
|
|
|
87
|
-
return
|
|
88
|
-
self._table_type == other._table_type)
|
|
84
|
+
return self._ref == other._ref and self._table_type == other._table_type
|
|
89
85
|
|
|
90
86
|
def rename_table(self, name: str) -> None:
|
|
91
87
|
"""Rename table metadata's table name."""
|
|
@@ -110,7 +106,8 @@ class TableMetadata:
|
|
|
110
106
|
table=self.ref.table,
|
|
111
107
|
next_key=next_key,
|
|
112
108
|
txid=tx.active_txid,
|
|
113
|
-
list_imports_table=self.is_imports_table
|
|
109
|
+
list_imports_table=self.is_imports_table,
|
|
110
|
+
)
|
|
114
111
|
fields.extend(cur_columns)
|
|
115
112
|
if not is_truncated:
|
|
116
113
|
break
|
|
@@ -123,9 +120,16 @@ class TableMetadata:
|
|
|
123
120
|
try:
|
|
124
121
|
next_key = 0
|
|
125
122
|
while True:
|
|
126
|
-
cur_columns, next_key, is_truncated, _count =
|
|
127
|
-
|
|
128
|
-
|
|
123
|
+
cur_columns, next_key, is_truncated, _count = (
|
|
124
|
+
tx._rpc.api.list_sorted_columns(
|
|
125
|
+
bucket=self.ref.bucket,
|
|
126
|
+
schema=self.ref.schema,
|
|
127
|
+
table=self.ref.table,
|
|
128
|
+
next_key=next_key,
|
|
129
|
+
txid=tx.active_txid,
|
|
130
|
+
list_imports_table=self.is_imports_table,
|
|
131
|
+
)
|
|
132
|
+
)
|
|
129
133
|
fields.extend(cur_columns)
|
|
130
134
|
if not is_truncated:
|
|
131
135
|
break
|
|
@@ -133,7 +137,9 @@ class TableMetadata:
|
|
|
133
137
|
raise
|
|
134
138
|
except errors.InternalServerError as ise:
|
|
135
139
|
log.warning(
|
|
136
|
-
"Failed to get the sorted columns Elysium might not be supported: %s",
|
|
140
|
+
"Failed to get the sorted columns Elysium might not be supported: %s",
|
|
141
|
+
ise,
|
|
142
|
+
)
|
|
137
143
|
raise
|
|
138
144
|
except errors.NotSupportedVersion:
|
|
139
145
|
log.warning("Failed to get the sorted columns, Elysium not supported")
|
|
@@ -143,10 +149,13 @@ class TableMetadata:
|
|
|
143
149
|
|
|
144
150
|
def load_stats(self, tx: "Transaction") -> None:
|
|
145
151
|
"""Load/Reload table stats."""
|
|
146
|
-
|
|
147
|
-
bucket=self.ref.bucket,
|
|
148
|
-
|
|
149
|
-
|
|
152
|
+
self._stats = tx._rpc.api.get_table_stats(
|
|
153
|
+
bucket=self.ref.bucket,
|
|
154
|
+
schema=self.ref.schema,
|
|
155
|
+
name=self.ref.table,
|
|
156
|
+
txid=tx.active_txid,
|
|
157
|
+
imports_table_stats=self.is_imports_table,
|
|
158
|
+
)
|
|
150
159
|
|
|
151
160
|
is_elysium_table = self._stats.sorting_key_enabled
|
|
152
161
|
|
|
@@ -161,6 +170,18 @@ class TableMetadata:
|
|
|
161
170
|
"Actual table is sorted (TableType.Elysium), was not inited as TableType.Elysium"
|
|
162
171
|
)
|
|
163
172
|
|
|
173
|
+
self._parse_stats_vector_index()
|
|
174
|
+
|
|
175
|
+
def _parse_stats_vector_index(self):
|
|
176
|
+
vector_index_is_set = self._vector_index is not None
|
|
177
|
+
|
|
178
|
+
if vector_index_is_set and self._stats.vector_index != self._vector_index:
|
|
179
|
+
raise ValueError(
|
|
180
|
+
f"Table has index {self._stats.vector_index}, but was initialized as {self._vector_index}"
|
|
181
|
+
)
|
|
182
|
+
else:
|
|
183
|
+
self._vector_index = self._stats.vector_index
|
|
184
|
+
|
|
164
185
|
def _set_sorted_table(self, tx: "Transaction"):
|
|
165
186
|
self._table_type = TableType.Elysium
|
|
166
187
|
tx._rpc.features.check_elysium()
|
|
@@ -184,7 +205,9 @@ class TableMetadata:
|
|
|
184
205
|
if arrow_schema:
|
|
185
206
|
validate_ibis_support_schema(arrow_schema)
|
|
186
207
|
self._arrow_schema = arrow_schema
|
|
187
|
-
self._ibis_table = ibis.table(
|
|
208
|
+
self._ibis_table = ibis.table(
|
|
209
|
+
ibis.Schema.from_pyarrow(arrow_schema), self._ref.full_path
|
|
210
|
+
)
|
|
188
211
|
else:
|
|
189
212
|
self._arrow_schema = None
|
|
190
213
|
self._ibis_table = None
|
|
@@ -211,7 +234,8 @@ class TableMetadata:
|
|
|
211
234
|
"""Table's type."""
|
|
212
235
|
if self._table_type is None:
|
|
213
236
|
raise ValueError(
|
|
214
|
-
"TableType was not loaded. load using TableMetadata.load_stats"
|
|
237
|
+
"TableType was not loaded. load using TableMetadata.load_stats"
|
|
238
|
+
)
|
|
215
239
|
|
|
216
240
|
return self._table_type
|
|
217
241
|
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
from unittest.mock import MagicMock, patch
|
|
3
|
+
|
|
4
|
+
import pyarrow as pa
|
|
5
|
+
import pytest
|
|
6
|
+
|
|
7
|
+
from vastdb._adbc import END_USER_PROPERTY, AdbcDriver
|
|
8
|
+
from vastdb.session import Session
|
|
9
|
+
from vastdb.table_metadata import TableRef
|
|
10
|
+
from vastdb.transaction import NoAdbcConnectionError
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def test_sanity(session_factory, clean_bucket_name: str):
|
|
14
|
+
session = session_factory(with_adbc=True)
|
|
15
|
+
|
|
16
|
+
arrow_schema = pa.schema([("n", pa.int32())])
|
|
17
|
+
|
|
18
|
+
ref = TableRef(clean_bucket_name, "s", "t")
|
|
19
|
+
data_table = pa.table(schema=arrow_schema, data=[[1, 2, 3, 4, 5]])
|
|
20
|
+
|
|
21
|
+
with session.transaction() as tx:
|
|
22
|
+
table = (
|
|
23
|
+
tx.bucket(clean_bucket_name)
|
|
24
|
+
.create_schema("s")
|
|
25
|
+
.create_table("t", arrow_schema)
|
|
26
|
+
)
|
|
27
|
+
table.insert(data_table)
|
|
28
|
+
|
|
29
|
+
with session.transaction() as tx:
|
|
30
|
+
tx.adbc_conn.cursor.execute(f"SELECT * FROM {ref.query_engine_full_path}")
|
|
31
|
+
res = tx.adbc_conn.cursor.fetchall()
|
|
32
|
+
|
|
33
|
+
assert res == [(1,), (2,), (3,), (4,), (5,)]
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def test_adbc_shares_tx(session_factory, clean_bucket_name: str):
|
|
37
|
+
session = session_factory(with_adbc=True)
|
|
38
|
+
|
|
39
|
+
arrow_schema = pa.schema([("n", pa.int32())])
|
|
40
|
+
|
|
41
|
+
data_table = pa.table(schema=arrow_schema, data=[[1, 2, 3, 4, 5]])
|
|
42
|
+
|
|
43
|
+
with session.transaction() as tx:
|
|
44
|
+
table = (
|
|
45
|
+
tx.bucket(clean_bucket_name)
|
|
46
|
+
.create_schema("s")
|
|
47
|
+
.create_table("t", arrow_schema)
|
|
48
|
+
)
|
|
49
|
+
table.insert(data_table)
|
|
50
|
+
|
|
51
|
+
# expecting adbc execute to "see" table if it shares the transaction with the pysdk
|
|
52
|
+
tx.adbc_conn.cursor.execute(f"SELECT * FROM {table.ref.query_engine_full_path}")
|
|
53
|
+
assert tx.adbc_conn.cursor.fetchall() == [(1,), (2,), (3,), (4,), (5,)]
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def test_adbc_conn_unreachable_tx_close(session_factory):
|
|
57
|
+
session = session_factory(with_adbc=True)
|
|
58
|
+
|
|
59
|
+
with session.transaction() as tx:
|
|
60
|
+
assert tx.adbc_conn is not None
|
|
61
|
+
|
|
62
|
+
# adbc conn should not be reachable after tx close
|
|
63
|
+
with pytest.raises(NoAdbcConnectionError):
|
|
64
|
+
tx.adbc_conn
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def test_two_simulatnious_txs_with_adbc(session_factory, clean_bucket_name: str):
|
|
68
|
+
session = session_factory(with_adbc=True)
|
|
69
|
+
|
|
70
|
+
arrow_schema = pa.schema([("n", pa.int32())])
|
|
71
|
+
|
|
72
|
+
data_table = pa.table(schema=arrow_schema, data=[[1, 2, 3, 4, 5]])
|
|
73
|
+
|
|
74
|
+
with session.transaction() as tx:
|
|
75
|
+
table = (
|
|
76
|
+
tx.bucket(clean_bucket_name)
|
|
77
|
+
.create_schema("s")
|
|
78
|
+
.create_table("t1", arrow_schema)
|
|
79
|
+
)
|
|
80
|
+
table.insert(data_table)
|
|
81
|
+
|
|
82
|
+
# expecting adbc execute to "see" table if it shares the transaction with the pysdk
|
|
83
|
+
tx.adbc_conn.cursor.execute(f"SELECT * FROM {table.ref.query_engine_full_path}")
|
|
84
|
+
assert tx.adbc_conn.cursor.fetchall() == [(1,), (2,), (3,), (4,), (5,)]
|
|
85
|
+
|
|
86
|
+
with session.transaction() as tx:
|
|
87
|
+
table = (
|
|
88
|
+
tx.bucket(clean_bucket_name).schema("s").create_table("t2", arrow_schema)
|
|
89
|
+
)
|
|
90
|
+
table.insert(data_table)
|
|
91
|
+
|
|
92
|
+
# expecting adbc execute to "see" table if it shares the transaction with the pysdk
|
|
93
|
+
tx.adbc_conn.cursor.execute(f"SELECT * FROM {table.ref.query_engine_full_path}")
|
|
94
|
+
assert tx.adbc_conn.cursor.fetchall() == [(1,), (2,), (3,), (4,), (5,)]
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
@pytest.mark.parametrize("end_user", [("mock-end-user",), (None,)])
|
|
98
|
+
def test_end_user_passed_to_adbc_connect(end_user: Optional[str]):
|
|
99
|
+
mock_driver = MagicMock(spec=AdbcDriver)
|
|
100
|
+
mock_driver.local_path = "/mock/driver/path"
|
|
101
|
+
|
|
102
|
+
with (
|
|
103
|
+
patch("vastdb._adbc.connect") as mock_connect,
|
|
104
|
+
patch("vastdb._internal.VastdbApi") as MockVastdbApi,
|
|
105
|
+
):
|
|
106
|
+
mock_api_instance = MockVastdbApi.return_value
|
|
107
|
+
mock_api_instance.begin_transaction.return_value.headers = {
|
|
108
|
+
"tabular-txid": "12345"
|
|
109
|
+
}
|
|
110
|
+
# A version that supports everything needed.
|
|
111
|
+
mock_api_instance.vast_version = (5, 4, 0, 0)
|
|
112
|
+
|
|
113
|
+
session = Session(
|
|
114
|
+
access="test_access",
|
|
115
|
+
secret="test_secret",
|
|
116
|
+
endpoint="http://localhost:9090",
|
|
117
|
+
adbc_driver=mock_driver,
|
|
118
|
+
end_user=end_user,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
with session.transaction():
|
|
122
|
+
# The ADBC connection is established when the transaction starts
|
|
123
|
+
pass
|
|
124
|
+
|
|
125
|
+
mock_connect.assert_called_once()
|
|
126
|
+
call_kwargs = mock_connect.call_args.kwargs
|
|
127
|
+
conn_kwargs = call_kwargs.get("conn_kwargs", {})
|
|
128
|
+
assert (end_user is None) ^ (END_USER_PROPERTY in conn_kwargs)
|
|
129
|
+
assert (end_user is None) ^ (conn_kwargs.get(END_USER_PROPERTY) == end_user)
|
vastdb/tests/test_tables.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import datetime as dt
|
|
2
2
|
import decimal
|
|
3
|
+
import itertools
|
|
3
4
|
import logging
|
|
4
5
|
import random
|
|
5
6
|
import threading
|
|
@@ -17,7 +18,7 @@ from requests.exceptions import HTTPError
|
|
|
17
18
|
|
|
18
19
|
from vastdb import errors
|
|
19
20
|
from vastdb.session import Session
|
|
20
|
-
from vastdb.table import INTERNAL_ROW_ID, QueryConfig
|
|
21
|
+
from vastdb.table import INTERNAL_ROW_ID, MAX_COLUMN_IN_BATCH, QueryConfig
|
|
21
22
|
|
|
22
23
|
from .util import assert_row_ids_ascending_on_first_insertion_to_table, prepare_data
|
|
23
24
|
|
|
@@ -95,6 +96,39 @@ def test_insert_wide_row(session, clean_bucket_name):
|
|
|
95
96
|
assert actual == expected
|
|
96
97
|
|
|
97
98
|
|
|
99
|
+
@pytest.mark.parametrize("num_columns,insert_by_columns", itertools.product([
|
|
100
|
+
MAX_COLUMN_IN_BATCH // 2,
|
|
101
|
+
MAX_COLUMN_IN_BATCH - 1,
|
|
102
|
+
MAX_COLUMN_IN_BATCH,
|
|
103
|
+
MAX_COLUMN_IN_BATCH + 1,
|
|
104
|
+
MAX_COLUMN_IN_BATCH * 2,
|
|
105
|
+
MAX_COLUMN_IN_BATCH * 10,
|
|
106
|
+
],
|
|
107
|
+
[False, True]
|
|
108
|
+
)
|
|
109
|
+
)
|
|
110
|
+
def test_insert_by_columns_variations(session, clean_bucket_name, num_columns, insert_by_columns):
|
|
111
|
+
columns = pa.schema([pa.field(f'i{i}', pa.int64()) for i in range(num_columns)])
|
|
112
|
+
data = [[i] for i in range(num_columns)]
|
|
113
|
+
expected = pa.table(schema=columns, data=data)
|
|
114
|
+
|
|
115
|
+
with prepare_data(session, clean_bucket_name, 's', 't', expected, insert_by_columns=insert_by_columns) as t:
|
|
116
|
+
actual = t.select().read_all()
|
|
117
|
+
assert actual == expected
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@pytest.mark.parametrize("sorting_key", [0, 40, 80, 120])
|
|
121
|
+
def test_insert_by_columns_sorted(session, clean_bucket_name, sorting_key):
|
|
122
|
+
num_columns = 160
|
|
123
|
+
columns = pa.schema([pa.field(f'i{i}', pa.int64()) for i in range(num_columns)])
|
|
124
|
+
data = [[i] for i in range(num_columns)]
|
|
125
|
+
expected = pa.table(schema=columns, data=data)
|
|
126
|
+
|
|
127
|
+
with prepare_data(session, clean_bucket_name, 's', 't', expected, sorting_key=[sorting_key], insert_by_columns=True) as t:
|
|
128
|
+
actual = t.select().read_all()
|
|
129
|
+
assert actual == expected
|
|
130
|
+
|
|
131
|
+
|
|
98
132
|
def test_multi_batch_table(session, clean_bucket_name):
|
|
99
133
|
columns = pa.schema([pa.field('s', pa.utf8())])
|
|
100
134
|
expected = pa.Table.from_batches([
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
"""Tests for vector index functionality."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
import pyarrow as pa
|
|
7
|
+
import pytest
|
|
8
|
+
|
|
9
|
+
from vastdb import errors
|
|
10
|
+
from vastdb._internal import VectorIndexSpec
|
|
11
|
+
from vastdb.session import Session
|
|
12
|
+
|
|
13
|
+
log = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@pytest.mark.parametrize("table_name,vector_index", [
|
|
17
|
+
# Test 1: Table without vector index
|
|
18
|
+
("table_without_index", None),
|
|
19
|
+
# Test 2: Table with L2 vector index
|
|
20
|
+
("table_with_l2_index", VectorIndexSpec("embedding", "l2sq")),
|
|
21
|
+
# Test 3: Table with inner product vector index
|
|
22
|
+
("table_with_ip_index", VectorIndexSpec("embedding", "ip")),
|
|
23
|
+
])
|
|
24
|
+
def test_create_table_with_vector_index_metadata(session: Session,
|
|
25
|
+
clean_bucket_name: str,
|
|
26
|
+
table_name: str,
|
|
27
|
+
vector_index: Optional[VectorIndexSpec]):
|
|
28
|
+
"""Test that table creation and stats retrieval work correctly with vector index metadata."""
|
|
29
|
+
schema_name = "schema1"
|
|
30
|
+
|
|
31
|
+
with session.transaction() as tx:
|
|
32
|
+
log.info(f"Testing table '{table_name}' with {vector_index}")
|
|
33
|
+
|
|
34
|
+
# Create schema
|
|
35
|
+
bucket = tx.bucket(clean_bucket_name)
|
|
36
|
+
schema = bucket.create_schema(schema_name)
|
|
37
|
+
|
|
38
|
+
# Create the appropriate schema based on whether vector index is needed
|
|
39
|
+
if vector_index is None:
|
|
40
|
+
# Simple table without vector index
|
|
41
|
+
arrow_schema = pa.schema([
|
|
42
|
+
('id', pa.int64()),
|
|
43
|
+
('data', pa.string())
|
|
44
|
+
])
|
|
45
|
+
else:
|
|
46
|
+
# Table with vector column
|
|
47
|
+
vector_dimension = 128 # Fixed-size vector dimension
|
|
48
|
+
vec_type = pa.list_(pa.field('', pa.float32(), False), vector_dimension)
|
|
49
|
+
arrow_schema = pa.schema([
|
|
50
|
+
('id', pa.int64()),
|
|
51
|
+
('embedding', vec_type) # Fixed-size vector column
|
|
52
|
+
])
|
|
53
|
+
|
|
54
|
+
# Create table with or without vector index
|
|
55
|
+
log.info(f"Creating table: {table_name}")
|
|
56
|
+
table = schema.create_table(
|
|
57
|
+
table_name=table_name,
|
|
58
|
+
columns=arrow_schema,
|
|
59
|
+
vector_index=vector_index
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
# Reload stats to ensure we get the vector index metadata
|
|
63
|
+
table.reload_stats()
|
|
64
|
+
|
|
65
|
+
# Get vector index metadata
|
|
66
|
+
result_vector_index = table._metadata._vector_index
|
|
67
|
+
|
|
68
|
+
log.info(f"Vector index metadata: {result_vector_index}")
|
|
69
|
+
|
|
70
|
+
# Assert expected values (should match input parameters)
|
|
71
|
+
result_vector_index_spec = (
|
|
72
|
+
None
|
|
73
|
+
if result_vector_index is None
|
|
74
|
+
else result_vector_index.to_vector_index_spec()
|
|
75
|
+
)
|
|
76
|
+
assert result_vector_index_spec == vector_index
|
|
77
|
+
|
|
78
|
+
log.info(f"✓ Test passed for table '{table_name}'")
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@pytest.mark.parametrize("table_name,vector_index,expected_error", [
|
|
82
|
+
# Test 1: Invalid column name (column doesn't exist in schema)
|
|
83
|
+
("table_invalid_column", VectorIndexSpec("nonexistent_column", "l2sq"), "invalid vector indexed column name nonexistent_column"),
|
|
84
|
+
# Test 2: Invalid distance metric
|
|
85
|
+
("table_invalid_metric", VectorIndexSpec("embedding", "invalid_metric"), "invalid vector index distance metric invalid_metric, supported metrics: 'l2sq', 'ip'"),
|
|
86
|
+
])
|
|
87
|
+
def test_create_table_with_invalid_vector_index(session: Session,
|
|
88
|
+
clean_bucket_name: str,
|
|
89
|
+
table_name: str,
|
|
90
|
+
vector_index: VectorIndexSpec,
|
|
91
|
+
expected_error: str):
|
|
92
|
+
"""Test that table creation fails with appropriate error messages for invalid vector index parameters."""
|
|
93
|
+
schema_name = "schema1"
|
|
94
|
+
|
|
95
|
+
with session.transaction() as tx:
|
|
96
|
+
log.info(f"Testing invalid table '{table_name}' with vector_index={vector_index}, expected_error={expected_error}")
|
|
97
|
+
|
|
98
|
+
# Create schema
|
|
99
|
+
bucket = tx.bucket(clean_bucket_name)
|
|
100
|
+
schema = bucket.create_schema(schema_name)
|
|
101
|
+
|
|
102
|
+
# Table with vector column
|
|
103
|
+
vector_dimension = 128 # Fixed-size vector dimension
|
|
104
|
+
vec_type = pa.list_(pa.field('', pa.float32(), False), vector_dimension)
|
|
105
|
+
arrow_schema = pa.schema([
|
|
106
|
+
('id', pa.int64()),
|
|
107
|
+
('embedding', vec_type) # Fixed-size vector column
|
|
108
|
+
])
|
|
109
|
+
|
|
110
|
+
# Attempt to create table with invalid parameters - should raise an error
|
|
111
|
+
log.info(f"Attempting to create invalid table: {table_name}")
|
|
112
|
+
with pytest.raises((errors.BadRequest)) as exc_info:
|
|
113
|
+
schema.create_table(
|
|
114
|
+
table_name=table_name,
|
|
115
|
+
columns=arrow_schema,
|
|
116
|
+
vector_index=vector_index
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
# Verify the error message contains the expected error text
|
|
120
|
+
assert expected_error in str(exc_info.value), \
|
|
121
|
+
f"Expected error message to contain '{expected_error}', got '{str(exc_info.value)}'"
|
|
122
|
+
|
|
123
|
+
log.info(f"✓ Test passed for invalid table '{table_name}'")
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def test_vector_index_metadata_from_stats(session: Session, clean_bucket_name: str):
|
|
127
|
+
"""Test that vector index metadata is correctly retrieved from table stats."""
|
|
128
|
+
schema_name = "schema1"
|
|
129
|
+
table_name = "vector_table"
|
|
130
|
+
|
|
131
|
+
with session.transaction() as tx:
|
|
132
|
+
# Create schema
|
|
133
|
+
bucket = tx.bucket(clean_bucket_name)
|
|
134
|
+
schema = bucket.create_schema(schema_name)
|
|
135
|
+
|
|
136
|
+
# Create table with vector index
|
|
137
|
+
vector_dimension = 128
|
|
138
|
+
vec_type = pa.list_(pa.field('', pa.float32(), False), vector_dimension)
|
|
139
|
+
arrow_schema = pa.schema([
|
|
140
|
+
('id', pa.int64()),
|
|
141
|
+
('embedding', vec_type)
|
|
142
|
+
])
|
|
143
|
+
|
|
144
|
+
table = schema.create_table(
|
|
145
|
+
table_name=table_name,
|
|
146
|
+
columns=arrow_schema,
|
|
147
|
+
vector_index=VectorIndexSpec("embedding", "l2sq")
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
# Check stats object directly
|
|
151
|
+
stats = table.stats
|
|
152
|
+
assert stats is not None
|
|
153
|
+
assert stats.vector_index is not None
|
|
154
|
+
assert stats.vector_index.column == "embedding"
|
|
155
|
+
assert stats.vector_index.distance_metric == "l2sq"
|
|
156
|
+
|
|
157
|
+
# Check via the table methods
|
|
158
|
+
assert table._metadata._vector_index is not None
|
|
159
|
+
assert table._metadata._vector_index.column == "embedding"
|
|
160
|
+
assert table._metadata._vector_index.distance_metric == "l2sq"
|
|
161
|
+
|
|
162
|
+
log.info("✓ Vector index metadata correctly retrieved from stats")
|