vastdb 2.0.1__py3-none-any.whl → 2.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vastdb/_adbc.py +194 -0
- vastdb/_internal.py +101 -12
- vastdb/_table_interface.py +20 -3
- vastdb/conftest.py +23 -1
- vastdb/errors.py +5 -0
- vastdb/schema.py +17 -2
- vastdb/session.py +12 -5
- vastdb/table.py +60 -24
- vastdb/table_metadata.py +58 -34
- vastdb/tests/test_adbc_integration.py +89 -0
- vastdb/tests/test_projections.py +49 -1
- vastdb/tests/test_tables.py +35 -1
- vastdb/tests/test_vector_index.py +162 -0
- vastdb/tests/test_vector_search.py +210 -0
- vastdb/tests/util.py +3 -2
- vastdb/transaction.py +30 -0
- vastdb/vast_flatbuf/tabular/GetTableStatsResponse.py +51 -59
- vastdb/vast_flatbuf/tabular/ObjectDetails.py +36 -59
- vastdb/vast_flatbuf/tabular/VectorIndexMetadata.py +67 -0
- vastdb/vast_flatbuf/tabular/VipRange.py +19 -12
- {vastdb-2.0.1.dist-info → vastdb-2.0.3.dist-info}/METADATA +2 -1
- {vastdb-2.0.1.dist-info → vastdb-2.0.3.dist-info}/RECORD +25 -20
- {vastdb-2.0.1.dist-info → vastdb-2.0.3.dist-info}/WHEEL +0 -0
- {vastdb-2.0.1.dist-info → vastdb-2.0.3.dist-info}/licenses/LICENSE +0 -0
- {vastdb-2.0.1.dist-info → vastdb-2.0.3.dist-info}/top_level.txt +0 -0
vastdb/_adbc.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
import hashlib
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
import urllib.request
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
import pyarrow as pa
|
|
8
|
+
import sqlglot
|
|
9
|
+
from adbc_driver_manager.dbapi import Connection, Cursor, connect
|
|
10
|
+
from sqlglot import exp
|
|
11
|
+
|
|
12
|
+
from vastdb._internal import VectorIndex
|
|
13
|
+
from vastdb._table_interface import IbisPredicate
|
|
14
|
+
from vastdb.table_metadata import TableRef
|
|
15
|
+
|
|
16
|
+
log = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
TXID_OVERRIDE_PROPERTY: str = "vast.db.external_txid"
|
|
20
|
+
VAST_DIST_ALIAS = "vast_pysdk_vector_dist"
|
|
21
|
+
DEFAULT_ADBC_DRIVER_CACHE_DIR: str = "~/.vast/adbc_drivers_cache"
|
|
22
|
+
DEFAULT_ADBC_DRIVER_CACHE_BY_URL_DIR: str = f"{DEFAULT_ADBC_DRIVER_CACHE_DIR}/by_url"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class LocalAdbcDriverNotFound(Exception):
|
|
26
|
+
"""LocalAdbcDriverNotFound."""
|
|
27
|
+
|
|
28
|
+
pass
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class RemoteAdbcDriverDownloadFailed(Exception):
|
|
32
|
+
"""RemoteAdbcDriverDownloadFailed."""
|
|
33
|
+
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class AdbcDriver:
|
|
38
|
+
_local_path: str
|
|
39
|
+
|
|
40
|
+
def __init__(self, local_path: str):
|
|
41
|
+
self._local_path = local_path
|
|
42
|
+
|
|
43
|
+
@staticmethod
|
|
44
|
+
def from_local_path(local_path: str) -> "AdbcDriver":
|
|
45
|
+
"""AdbcDriver from a local_path to shared-library."""
|
|
46
|
+
if not os.path.exists(local_path):
|
|
47
|
+
raise LocalAdbcDriverNotFound(local_path)
|
|
48
|
+
|
|
49
|
+
return AdbcDriver(local_path)
|
|
50
|
+
|
|
51
|
+
@staticmethod
|
|
52
|
+
def from_url(url: str) -> "AdbcDriver":
|
|
53
|
+
"""AdbcDriver to be downloaded by url to shared-library (uses cache if exists)."""
|
|
54
|
+
expected_local_path = AdbcDriver._url_to_local_path(url)
|
|
55
|
+
|
|
56
|
+
if os.path.exists(expected_local_path):
|
|
57
|
+
return AdbcDriver(expected_local_path)
|
|
58
|
+
|
|
59
|
+
AdbcDriver._download_driver(url, expected_local_path)
|
|
60
|
+
return AdbcDriver(expected_local_path)
|
|
61
|
+
|
|
62
|
+
@staticmethod
|
|
63
|
+
def _url_to_local_path(url: str) -> str:
|
|
64
|
+
url_hash = hashlib.sha256(url.encode("utf-8")).hexdigest()
|
|
65
|
+
return os.path.join(DEFAULT_ADBC_DRIVER_CACHE_BY_URL_DIR, url_hash)
|
|
66
|
+
|
|
67
|
+
@staticmethod
|
|
68
|
+
def _download_driver(url: str, target_path: str):
|
|
69
|
+
os.makedirs(os.path.dirname(target_path), exist_ok=True)
|
|
70
|
+
|
|
71
|
+
try:
|
|
72
|
+
log.info(f"Downloading ADBC driver from {url} to {target_path}...")
|
|
73
|
+
urllib.request.urlretrieve(url, target_path)
|
|
74
|
+
log.info(f"Successfully downloaded driver to {target_path}.")
|
|
75
|
+
except Exception as e:
|
|
76
|
+
raise RemoteAdbcDriverDownloadFailed(
|
|
77
|
+
f"Failed to download ADBC driver from {url}: {e}"
|
|
78
|
+
) from e
|
|
79
|
+
|
|
80
|
+
@property
|
|
81
|
+
def local_path(self) -> str:
|
|
82
|
+
return self._local_path
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def _get_adbc_connection(
|
|
86
|
+
adbc_driver_path: str, endpoint: str, access_key: str, secret_key: str, txid: int
|
|
87
|
+
) -> Connection:
|
|
88
|
+
"""Get an adbc connection in transaction."""
|
|
89
|
+
return connect(
|
|
90
|
+
driver=adbc_driver_path,
|
|
91
|
+
db_kwargs={
|
|
92
|
+
"vast.db.endpoint": endpoint,
|
|
93
|
+
"vast.db.access_key": access_key,
|
|
94
|
+
"vast.db.secret_key": secret_key,
|
|
95
|
+
},
|
|
96
|
+
conn_kwargs={TXID_OVERRIDE_PROPERTY: str(txid)},
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def _remove_table_qualification_from_columns(expression: exp.Expression):
|
|
101
|
+
"""Goes over all columns which are fully qualified with "t0" table reference (ibis default table qualification for unbound tables.
|
|
102
|
+
|
|
103
|
+
Note: use only if one table is involved - if two tables exist in the expression columns might become ambiguous.
|
|
104
|
+
"""
|
|
105
|
+
for col in expression.find_all(exp.Column):
|
|
106
|
+
col.set("table", None)
|
|
107
|
+
return expression
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _ibis_to_qe_predicates(predicate: IbisPredicate) -> str:
|
|
111
|
+
ibis_sql = predicate.to_sql()
|
|
112
|
+
parsed = sqlglot.parse_one(ibis_sql)
|
|
113
|
+
|
|
114
|
+
# currently there is a single table
|
|
115
|
+
# removing the
|
|
116
|
+
without_table_qualification = _remove_table_qualification_from_columns(
|
|
117
|
+
parsed.expressions[0].this
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
return without_table_qualification.sql()
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def _vector_search_sql(
|
|
124
|
+
query_vector: list[float],
|
|
125
|
+
vector_index: VectorIndex,
|
|
126
|
+
table_ref: TableRef,
|
|
127
|
+
columns: list[str],
|
|
128
|
+
limit: int,
|
|
129
|
+
predicate: Optional[IbisPredicate] = None,
|
|
130
|
+
) -> str:
|
|
131
|
+
query_vector_dim = len(query_vector)
|
|
132
|
+
|
|
133
|
+
query_vector_literal = f"{query_vector}::FLOAT[{query_vector_dim}]"
|
|
134
|
+
dist_func = f"{vector_index.sql_distance_function}({vector_index.column}::FLOAT[{query_vector_dim}], {query_vector_literal})"
|
|
135
|
+
dist_alias = f"{dist_func} as {VAST_DIST_ALIAS}"
|
|
136
|
+
|
|
137
|
+
projection_str = ",".join(columns + [dist_alias])
|
|
138
|
+
|
|
139
|
+
if predicate is not None:
|
|
140
|
+
where = f"WHERE {_ibis_to_qe_predicates(predicate)}"
|
|
141
|
+
else:
|
|
142
|
+
where = ""
|
|
143
|
+
|
|
144
|
+
return f"""
|
|
145
|
+
SELECT {projection_str}
|
|
146
|
+
FROM {table_ref.query_engine_full_path}
|
|
147
|
+
{where}
|
|
148
|
+
ORDER BY {VAST_DIST_ALIAS}
|
|
149
|
+
LIMIT {limit}"""
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class AdbcConnection:
|
|
153
|
+
def __init__(
|
|
154
|
+
self,
|
|
155
|
+
adbc_driver: AdbcDriver,
|
|
156
|
+
endpoint: str,
|
|
157
|
+
access_key: str,
|
|
158
|
+
secret_key: str,
|
|
159
|
+
txid: int,
|
|
160
|
+
):
|
|
161
|
+
self._adbc_conn = _get_adbc_connection(
|
|
162
|
+
adbc_driver.local_path, endpoint, access_key, secret_key, txid
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
self._cursor = self._adbc_conn.cursor()
|
|
166
|
+
|
|
167
|
+
@property
|
|
168
|
+
def cursor(self) -> Cursor:
|
|
169
|
+
return self._cursor
|
|
170
|
+
|
|
171
|
+
def close(self):
|
|
172
|
+
self._cursor.close()
|
|
173
|
+
|
|
174
|
+
def vector_search(
|
|
175
|
+
self,
|
|
176
|
+
query_vector: list[float],
|
|
177
|
+
vector_index: VectorIndex,
|
|
178
|
+
table_ref: TableRef,
|
|
179
|
+
columns: list[str],
|
|
180
|
+
limit: int,
|
|
181
|
+
predicate: Optional[IbisPredicate] = None,
|
|
182
|
+
) -> pa.RecordBatchReader:
|
|
183
|
+
"""Top-n on vector-column."""
|
|
184
|
+
sql = _vector_search_sql(
|
|
185
|
+
query_vector=query_vector,
|
|
186
|
+
vector_index=vector_index,
|
|
187
|
+
table_ref=table_ref,
|
|
188
|
+
columns=columns,
|
|
189
|
+
limit=limit,
|
|
190
|
+
predicate=predicate,
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
self._cursor.execute(sql)
|
|
194
|
+
return self._cursor.fetch_record_batch()
|
vastdb/_internal.py
CHANGED
|
@@ -6,6 +6,7 @@ import struct
|
|
|
6
6
|
import time
|
|
7
7
|
import urllib.parse
|
|
8
8
|
from collections import defaultdict, namedtuple
|
|
9
|
+
from dataclasses import dataclass
|
|
9
10
|
from enum import Enum
|
|
10
11
|
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast
|
|
11
12
|
|
|
@@ -116,6 +117,7 @@ from vastdb.vast_flatbuf.tabular.ListSchemasResponse import (
|
|
|
116
117
|
from vastdb.vast_flatbuf.tabular.ListTablesResponse import (
|
|
117
118
|
ListTablesResponse as list_tables,
|
|
118
119
|
)
|
|
120
|
+
from vastdb.vast_flatbuf.tabular.VectorIndexMetadata import VectorIndexMetadata
|
|
119
121
|
|
|
120
122
|
from . import errors, util
|
|
121
123
|
from .config import BackoffConfig
|
|
@@ -803,10 +805,40 @@ def _parse_table_info(obj, parse_properties):
|
|
|
803
805
|
sorting_score, write_amplification, acummulative_row_insertion_count, sorting_done)
|
|
804
806
|
|
|
805
807
|
|
|
806
|
-
|
|
808
|
+
@dataclass
|
|
809
|
+
class VectorIndexSpec:
|
|
810
|
+
"""
|
|
811
|
+
Vector Index Specification when creating a table.
|
|
812
|
+
"""
|
|
813
|
+
column: str
|
|
814
|
+
distance_metric: str
|
|
815
|
+
|
|
816
|
+
|
|
817
|
+
@dataclass
|
|
818
|
+
class VectorIndex:
|
|
819
|
+
column: str
|
|
820
|
+
distance_metric: str
|
|
821
|
+
sql_distance_function: str
|
|
807
822
|
|
|
823
|
+
def to_vector_index_spec(self) -> VectorIndexSpec:
|
|
824
|
+
return VectorIndexSpec(self.column,
|
|
825
|
+
self.distance_metric)
|
|
808
826
|
|
|
809
|
-
|
|
827
|
+
|
|
828
|
+
@dataclass
|
|
829
|
+
class TableStats:
|
|
830
|
+
"""Table-related information."""
|
|
831
|
+
|
|
832
|
+
num_rows: int
|
|
833
|
+
size_in_bytes: int
|
|
834
|
+
sorting_score: int
|
|
835
|
+
write_amplification: int
|
|
836
|
+
acummulative_row_inserition_count: int
|
|
837
|
+
is_external_rowid_alloc: bool = False
|
|
838
|
+
sorting_key_enabled: bool = False
|
|
839
|
+
sorting_done: bool = False
|
|
840
|
+
endpoints: Tuple[str, ...] = ()
|
|
841
|
+
vector_index: Optional[VectorIndex] = None
|
|
810
842
|
|
|
811
843
|
|
|
812
844
|
_RETRIABLE_EXCEPTIONS = (
|
|
@@ -1128,11 +1160,31 @@ class VastdbApi:
|
|
|
1128
1160
|
def create_table(self, bucket, schema, name, arrow_schema=None,
|
|
1129
1161
|
txid=0, client_tags=[], expected_retvals=[],
|
|
1130
1162
|
create_imports_table=False, use_external_row_ids_allocation=False, table_props=None,
|
|
1131
|
-
sorting_key=[]):
|
|
1163
|
+
sorting_key=[], vector_index: Optional[VectorIndexSpec] = None):
|
|
1164
|
+
"""
|
|
1165
|
+
Create a table in the specified bucket and schema.
|
|
1166
|
+
|
|
1167
|
+
Args:
|
|
1168
|
+
bucket: Name of the bucket
|
|
1169
|
+
schema: Name of the schema
|
|
1170
|
+
name: Name of the table
|
|
1171
|
+
arrow_schema: PyArrow schema defining the table columns
|
|
1172
|
+
txid: Transaction ID
|
|
1173
|
+
client_tags: Client tags for the request
|
|
1174
|
+
expected_retvals: Expected return values
|
|
1175
|
+
create_imports_table: Whether this is an imports table
|
|
1176
|
+
use_external_row_ids_allocation: Whether to use external row ID allocation
|
|
1177
|
+
table_props: Table properties
|
|
1178
|
+
sorting_key: List of column indices to sort by (for Elysium tables)
|
|
1179
|
+
vector_index: Optional vector index
|
|
1180
|
+
"""
|
|
1132
1181
|
self._create_table_internal(bucket=bucket, schema=schema, name=name, arrow_schema=arrow_schema,
|
|
1133
|
-
txid=txid, client_tags=client_tags,
|
|
1134
|
-
|
|
1135
|
-
|
|
1182
|
+
txid=txid, client_tags=client_tags,
|
|
1183
|
+
expected_retvals=expected_retvals,
|
|
1184
|
+
create_imports_table=create_imports_table,
|
|
1185
|
+
use_external_row_ids_allocation=use_external_row_ids_allocation,
|
|
1186
|
+
table_props=table_props, sorting_key=sorting_key,
|
|
1187
|
+
vector_index=vector_index)
|
|
1136
1188
|
|
|
1137
1189
|
def create_topic(self, bucket, name, topic_partitions, expected_retvals=[],
|
|
1138
1190
|
message_timestamp_type=None, retention_ms=None, message_timestamp_after_max_ms=None,
|
|
@@ -1149,8 +1201,9 @@ class VastdbApi:
|
|
|
1149
1201
|
|
|
1150
1202
|
def _create_table_internal(self, bucket, schema, name, arrow_schema=None,
|
|
1151
1203
|
txid=0, client_tags=[], expected_retvals=[], topic_partitions=0,
|
|
1152
|
-
create_imports_table=False, use_external_row_ids_allocation=False,
|
|
1153
|
-
sorting_key=[]
|
|
1204
|
+
create_imports_table=False, use_external_row_ids_allocation=False,
|
|
1205
|
+
table_props=None, sorting_key=[],
|
|
1206
|
+
vector_index: Optional[VectorIndexSpec] = None):
|
|
1154
1207
|
"""
|
|
1155
1208
|
Create a table, use the following request
|
|
1156
1209
|
POST /bucket/schema/table?table HTTP/1.1
|
|
@@ -1176,6 +1229,10 @@ class VastdbApi:
|
|
|
1176
1229
|
if use_external_row_ids_allocation:
|
|
1177
1230
|
headers['use-external-row-ids-alloc'] = str(use_external_row_ids_allocation)
|
|
1178
1231
|
|
|
1232
|
+
if vector_index is not None:
|
|
1233
|
+
headers['tabular-vector-index-column'] = vector_index.column
|
|
1234
|
+
headers['tabular-vector-index-distance-metric'] = vector_index.distance_metric
|
|
1235
|
+
|
|
1179
1236
|
url_params = {'topic_partitions': str(topic_partitions)} if topic_partitions else {}
|
|
1180
1237
|
if create_imports_table:
|
|
1181
1238
|
url_params['sub-table'] = IMPORTED_OBJECTS_TABLE_NAME
|
|
@@ -1188,10 +1245,10 @@ class VastdbApi:
|
|
|
1188
1245
|
url=self._url(bucket=bucket, schema=schema, table=name, command="table", url_params=url_params),
|
|
1189
1246
|
data=serialized_schema, headers=headers)
|
|
1190
1247
|
|
|
1191
|
-
def get_topic_stats(self, bucket, name, expected_retvals=[]):
|
|
1248
|
+
def get_topic_stats(self, bucket, name, expected_retvals=[]) -> TableStats:
|
|
1192
1249
|
return self.get_table_stats(bucket=bucket, schema=KAFKA_TOPICS_SCHEMA_NAME, name=name, expected_retvals=expected_retvals)
|
|
1193
1250
|
|
|
1194
|
-
def get_table_stats(self, bucket, schema, name, txid=0, client_tags=[], expected_retvals=[], imports_table_stats=False):
|
|
1251
|
+
def get_table_stats(self, bucket, schema, name, txid=0, client_tags=[], expected_retvals=[], imports_table_stats=False) -> TableStats:
|
|
1195
1252
|
"""
|
|
1196
1253
|
GET /mybucket/myschema/mytable?stats HTTP/1.1
|
|
1197
1254
|
tabular-txid: TransactionId
|
|
@@ -1218,8 +1275,40 @@ class VastdbApi:
|
|
|
1218
1275
|
sorting_score = sorting_score_raw & ((1 << 63) - 1)
|
|
1219
1276
|
sorting_done = bool(sorting_score_raw >> 63)
|
|
1220
1277
|
|
|
1278
|
+
vector_index_metadata: Optional[VectorIndexMetadata] = stats.VectorIndexMetadata()
|
|
1279
|
+
|
|
1280
|
+
if vector_index_metadata is not None:
|
|
1281
|
+
column_name = vector_index_metadata.ColumnName()
|
|
1282
|
+
distance_metric = vector_index_metadata.DistanceMetric()
|
|
1283
|
+
sql_distance_function = vector_index_metadata.SqlFunctionName()
|
|
1284
|
+
|
|
1285
|
+
if (column_name is None or
|
|
1286
|
+
distance_metric is None or
|
|
1287
|
+
sql_distance_function is None):
|
|
1288
|
+
raise errors.ApiResponseError(
|
|
1289
|
+
"VectorIndexMetadata properties (column_name, distance_metric, sql_function_name) must all be set."
|
|
1290
|
+
)
|
|
1291
|
+
|
|
1292
|
+
vector_index = VectorIndex(
|
|
1293
|
+
column=column_name.decode('utf-8'),
|
|
1294
|
+
distance_metric=distance_metric.decode('utf-8'),
|
|
1295
|
+
sql_distance_function=sql_distance_function.decode('utf-8'))
|
|
1296
|
+
else:
|
|
1297
|
+
vector_index = None
|
|
1298
|
+
|
|
1221
1299
|
endpoints = [self.url] # we cannot replace the host by a VIP address in HTTPS-based URLs
|
|
1222
|
-
|
|
1300
|
+
|
|
1301
|
+
return TableStats(
|
|
1302
|
+
num_rows=num_rows,
|
|
1303
|
+
size_in_bytes=size_in_bytes,
|
|
1304
|
+
sorting_score=sorting_score,
|
|
1305
|
+
write_amplification=write_amplification,
|
|
1306
|
+
acummulative_row_inserition_count=acummulative_row_inserition_count,
|
|
1307
|
+
is_external_rowid_alloc=is_external_rowid_alloc,
|
|
1308
|
+
sorting_key_enabled=sorting_key_enabled,
|
|
1309
|
+
sorting_done=sorting_done,
|
|
1310
|
+
endpoints=tuple(endpoints),
|
|
1311
|
+
vector_index=vector_index)
|
|
1223
1312
|
|
|
1224
1313
|
def alter_topic(self, bucket, name,
|
|
1225
1314
|
new_name="", expected_retvals=[],
|
|
@@ -2339,7 +2428,7 @@ def build_field(builder: flatbuffers.Builder, f: pa.Field, include_name=True):
|
|
|
2339
2428
|
class QueryDataRequest:
|
|
2340
2429
|
def __init__(self, serialized, response_schema, response_parser):
|
|
2341
2430
|
self.serialized = serialized
|
|
2342
|
-
self.response_schema = response_schema
|
|
2431
|
+
self.response_schema: pa.Schema = response_schema
|
|
2343
2432
|
self.response_parser = response_parser
|
|
2344
2433
|
|
|
2345
2434
|
|
vastdb/_table_interface.py
CHANGED
|
@@ -1,15 +1,18 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
|
-
from typing import TYPE_CHECKING, Iterable, Optional, Union
|
|
2
|
+
from typing import TYPE_CHECKING, Iterable, Optional, TypeAlias, Union
|
|
3
3
|
|
|
4
4
|
import ibis
|
|
5
5
|
import pyarrow as pa
|
|
6
6
|
|
|
7
|
+
from ._internal import VectorIndex
|
|
7
8
|
from .config import ImportConfig, QueryConfig
|
|
8
9
|
from .table_metadata import TableRef
|
|
9
10
|
|
|
10
11
|
if TYPE_CHECKING:
|
|
11
12
|
from .table import Projection
|
|
12
13
|
|
|
14
|
+
IbisPredicate: TypeAlias = Union[ibis.expr.types.BooleanColumn, ibis.common.deferred.Deferred]
|
|
15
|
+
|
|
13
16
|
|
|
14
17
|
class ITable(ABC):
|
|
15
18
|
"""Interface for VAST Table operations."""
|
|
@@ -71,8 +74,7 @@ class ITable(ABC):
|
|
|
71
74
|
@abstractmethod
|
|
72
75
|
def select(self,
|
|
73
76
|
columns: Optional[list[str]] = None,
|
|
74
|
-
predicate:
|
|
75
|
-
ibis.common.deferred.Deferred] = None,
|
|
77
|
+
predicate: Optional[IbisPredicate] = None,
|
|
76
78
|
config: Optional[QueryConfig] = None,
|
|
77
79
|
*,
|
|
78
80
|
internal_row_id: bool = False,
|
|
@@ -134,3 +136,18 @@ class ITable(ABC):
|
|
|
134
136
|
It is useful for constructing expressions for predicate pushdown in `ITable.select()` method.
|
|
135
137
|
"""
|
|
136
138
|
pass
|
|
139
|
+
|
|
140
|
+
@property
|
|
141
|
+
@abstractmethod
|
|
142
|
+
def vector_index(self) -> Optional[VectorIndex]:
|
|
143
|
+
"""Table's Vector Index if exists."""
|
|
144
|
+
pass
|
|
145
|
+
|
|
146
|
+
@abstractmethod
|
|
147
|
+
def vector_search(self,
|
|
148
|
+
vec: list[float],
|
|
149
|
+
columns: list[str],
|
|
150
|
+
limit: int,
|
|
151
|
+
predicate: Optional[IbisPredicate] = None) -> pa.RecordBatchReader:
|
|
152
|
+
"""Top-n on vector-column."""
|
|
153
|
+
pass
|
vastdb/conftest.py
CHANGED
|
@@ -1,13 +1,14 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import sqlite3
|
|
3
3
|
from pathlib import Path
|
|
4
|
-
from typing import Iterable
|
|
4
|
+
from typing import Iterable, Protocol
|
|
5
5
|
|
|
6
6
|
import boto3
|
|
7
7
|
import pytest
|
|
8
8
|
|
|
9
9
|
import vastdb
|
|
10
10
|
import vastdb.errors
|
|
11
|
+
from vastdb._adbc import AdbcDriver
|
|
11
12
|
from vastdb.schema import Schema
|
|
12
13
|
from vastdb.session import Session
|
|
13
14
|
|
|
@@ -45,6 +46,10 @@ def pytest_addoption(parser):
|
|
|
45
46
|
parser.addoption("--num-workers", help="Number of concurrent workers", default=1)
|
|
46
47
|
|
|
47
48
|
|
|
49
|
+
def _get_adbc_driver_url(pipeline: str) -> str:
|
|
50
|
+
return f"https://artifactory.vastdata.com/artifactory/files/vastdb-native-client/{pipeline}/libadbc_driver_vastdb.so"
|
|
51
|
+
|
|
52
|
+
|
|
48
53
|
@pytest.fixture(scope="session")
|
|
49
54
|
def session_kwargs(request: pytest.FixtureRequest, tabular_endpoint_urls):
|
|
50
55
|
return dict(
|
|
@@ -59,6 +64,23 @@ def session(session_kwargs):
|
|
|
59
64
|
return vastdb.connect(**session_kwargs)
|
|
60
65
|
|
|
61
66
|
|
|
67
|
+
class SessionFactory(Protocol):
|
|
68
|
+
def __call__(self, *, with_adbc: bool) -> Session: ...
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@pytest.fixture(scope="session")
|
|
72
|
+
def session_factory(session_kwargs) -> SessionFactory:
|
|
73
|
+
def create_session(with_adbc: bool = False) -> Session:
|
|
74
|
+
|
|
75
|
+
if with_adbc:
|
|
76
|
+
# TODO use other not hard coded driver
|
|
77
|
+
session_kwargs['adbc_driver'] = AdbcDriver.from_url(url=_get_adbc_driver_url("2103686"))
|
|
78
|
+
|
|
79
|
+
return vastdb.connect(**session_kwargs)
|
|
80
|
+
|
|
81
|
+
return create_session
|
|
82
|
+
|
|
83
|
+
|
|
62
84
|
@pytest.fixture(scope="session")
|
|
63
85
|
def num_workers(request: pytest.FixtureRequest):
|
|
64
86
|
return int(request.config.getoption("--num-workers"))
|
vastdb/errors.py
CHANGED
|
@@ -231,6 +231,11 @@ class ConnectionError(Exception):
|
|
|
231
231
|
self.args = [vars(self)]
|
|
232
232
|
|
|
233
233
|
|
|
234
|
+
class ApiResponseError(Exception):
|
|
235
|
+
"""Indicates a logically invalid or inconsistent server response."""
|
|
236
|
+
pass
|
|
237
|
+
|
|
238
|
+
|
|
234
239
|
def handle_unavailable(**kwargs):
|
|
235
240
|
if kwargs['code'] == 'SlowDown':
|
|
236
241
|
raise Slowdown(**kwargs)
|
vastdb/schema.py
CHANGED
|
@@ -14,6 +14,7 @@ from vastdb.table_metadata import TableMetadata, TableRef, TableType
|
|
|
14
14
|
|
|
15
15
|
from . import bucket, errors, schema, table
|
|
16
16
|
from ._ibis_support import validate_ibis_support_schema
|
|
17
|
+
from ._internal import VectorIndexSpec
|
|
17
18
|
|
|
18
19
|
if TYPE_CHECKING:
|
|
19
20
|
from .table import Table
|
|
@@ -80,11 +81,24 @@ class Schema:
|
|
|
80
81
|
return result
|
|
81
82
|
|
|
82
83
|
def create_table(self, table_name: str, columns: pa.Schema, fail_if_exists=True,
|
|
83
|
-
use_external_row_ids_allocation=False, sorting_key=[]
|
|
84
|
+
use_external_row_ids_allocation=False, sorting_key=[],
|
|
85
|
+
vector_index: Optional[VectorIndexSpec] = None) -> "Table":
|
|
84
86
|
"""Create a new table under this schema.
|
|
85
87
|
|
|
86
88
|
A virtual `vastdb_rowid` column (of `int64` type) can be created to access and filter by internal VAST row IDs.
|
|
87
89
|
See https://support.vastdata.com/s/article/UUID-48d0a8cf-5786-5ef3-3fa3-9c64e63a0967 for more details.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
table_name: Name of the table to create
|
|
93
|
+
columns: PyArrow schema defining the table columns
|
|
94
|
+
fail_if_exists: Whether to fail if the table already exists
|
|
95
|
+
use_external_row_ids_allocation: Whether to use external row ID allocation
|
|
96
|
+
sorting_key: List of column names to use as sorting key (for Elysium tables)
|
|
97
|
+
vector_index: Optional vector index.
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
The created table
|
|
101
|
+
|
|
88
102
|
"""
|
|
89
103
|
if current := self.table(table_name, fail_if_missing=False):
|
|
90
104
|
if fail_if_exists:
|
|
@@ -97,7 +111,8 @@ class Schema:
|
|
|
97
111
|
validate_ibis_support_schema(columns)
|
|
98
112
|
self.tx._rpc.api.create_table(self.bucket.name, self.name, table_name, columns, txid=self.tx.txid,
|
|
99
113
|
use_external_row_ids_allocation=use_external_row_ids_allocation,
|
|
100
|
-
sorting_key=sorting_key
|
|
114
|
+
sorting_key=sorting_key,
|
|
115
|
+
vector_index=vector_index)
|
|
101
116
|
log.info("Created table: %s", table_name)
|
|
102
117
|
return self.table(table_name) # type: ignore[return-value]
|
|
103
118
|
|
vastdb/session.py
CHANGED
|
@@ -10,6 +10,7 @@ For more details see:
|
|
|
10
10
|
import os
|
|
11
11
|
from typing import TYPE_CHECKING, Optional
|
|
12
12
|
|
|
13
|
+
from vastdb._adbc import AdbcDriver
|
|
13
14
|
from vastdb.transaction import Transaction
|
|
14
15
|
|
|
15
16
|
if TYPE_CHECKING:
|
|
@@ -23,7 +24,8 @@ class Session:
|
|
|
23
24
|
*,
|
|
24
25
|
ssl_verify=True,
|
|
25
26
|
timeout=None,
|
|
26
|
-
backoff_config: Optional["BackoffConfig"] = None
|
|
27
|
+
backoff_config: Optional["BackoffConfig"] = None,
|
|
28
|
+
adbc_driver: Optional[AdbcDriver] = None):
|
|
27
29
|
"""Connect to a VAST Database endpoint, using specified credentials."""
|
|
28
30
|
from . import _internal, features
|
|
29
31
|
|
|
@@ -34,14 +36,19 @@ class Session:
|
|
|
34
36
|
if endpoint is None:
|
|
35
37
|
endpoint = os.environ['AWS_S3_ENDPOINT_URL']
|
|
36
38
|
|
|
39
|
+
self.endpoint = endpoint
|
|
40
|
+
self.access = access
|
|
41
|
+
self.secret = secret
|
|
42
|
+
|
|
37
43
|
self.api = _internal.VastdbApi(
|
|
38
|
-
endpoint=endpoint,
|
|
39
|
-
access_key=access,
|
|
40
|
-
secret_key=secret,
|
|
44
|
+
endpoint=self.endpoint,
|
|
45
|
+
access_key=self.access,
|
|
46
|
+
secret_key=self.secret,
|
|
41
47
|
ssl_verify=ssl_verify,
|
|
42
48
|
timeout=timeout,
|
|
43
49
|
backoff_config=backoff_config)
|
|
44
50
|
self.features = features.Features(self.api.vast_version)
|
|
51
|
+
self.adbc_driver: Optional[AdbcDriver] = adbc_driver
|
|
45
52
|
|
|
46
53
|
def __repr__(self):
|
|
47
54
|
"""Don't show the secret key."""
|
|
@@ -56,4 +63,4 @@ class Session:
|
|
|
56
63
|
tx.bucket("bucket").create_schema("schema")
|
|
57
64
|
"""
|
|
58
65
|
from . import transaction
|
|
59
|
-
return transaction.Transaction(self)
|
|
66
|
+
return transaction.Transaction(self, _adbc_driver=self.adbc_driver)
|