vastdb 1.3.9__py3-none-any.whl → 1.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vastdb/_internal.py CHANGED
@@ -69,6 +69,7 @@ import vastdb.vast_flatbuf.org.apache.arrow.flatbuf.Date as fb_date
69
69
  import vastdb.vast_flatbuf.org.apache.arrow.flatbuf.Decimal as fb_decimal
70
70
  import vastdb.vast_flatbuf.org.apache.arrow.flatbuf.Field as fb_field
71
71
  import vastdb.vast_flatbuf.org.apache.arrow.flatbuf.FixedSizeBinary as fb_fixed_size_binary
72
+ import vastdb.vast_flatbuf.org.apache.arrow.flatbuf.FixedSizeList as fb_fixed_size_list
72
73
  import vastdb.vast_flatbuf.org.apache.arrow.flatbuf.FloatingPoint as fb_floating_point
73
74
  import vastdb.vast_flatbuf.org.apache.arrow.flatbuf.Int as fb_int
74
75
  import vastdb.vast_flatbuf.org.apache.arrow.flatbuf.List as fb_list
@@ -497,7 +498,13 @@ class Predicate:
497
498
  fb_bool.Start(self.builder)
498
499
  field_type = fb_bool.End(self.builder)
499
500
 
500
- value = True if value == 'true' else False # not cover all cases
501
+ # Handle both boolean values and string representations
502
+ if isinstance(value, bool):
503
+ value = value
504
+ elif isinstance(value, str):
505
+ value = value.lower() == 'true'
506
+ else:
507
+ value = bool(value)
501
508
  elif isinstance(field.type, pa.Decimal128Type):
502
509
  literal_type = fb_decimal_lit
503
510
  literal_impl = LiteralImpl.DecimalLiteral
@@ -608,7 +615,7 @@ class FieldNode:
608
615
  self.debug = debug
609
616
  if isinstance(self.type, pa.StructType):
610
617
  self.children = [FieldNode(field, index_iter, parent=self) for field in self.type]
611
- elif isinstance(self.type, pa.ListType):
618
+ elif pa.types.is_list(self.type) or pa.types.is_fixed_size_list(self.type):
612
619
  self.children = [FieldNode(self.type.value_field, index_iter, parent=self)]
613
620
  elif isinstance(self.type, pa.MapType):
614
621
  # Map is represented as List<Struct<K, V>> in Arrow
@@ -752,7 +759,7 @@ def _iter_nested_arrays(column: pa.Array) -> Iterator[pa.Array]:
752
759
  if not column.type.num_fields == 1: # Note: VAST serializes only a single struct field at a time
753
760
  raise ValueError(f'column.type.num_fields: {column.type.num_fields} not eq to 1')
754
761
  yield from _iter_nested_arrays(column.field(0))
755
- elif isinstance(column.type, pa.ListType):
762
+ elif pa.types.is_list(column.type) or pa.types.is_fixed_size_list(column.type):
756
763
  yield from _iter_nested_arrays(column.values) # Note: Map is serialized in VAST as a List<Struct<K, V>>
757
764
 
758
765
 
@@ -853,10 +860,11 @@ class VastdbApi:
853
860
  VAST_VERSION_REGEX = re.compile(r'^vast (\d+\.\d+\.\d+\.\d+)$')
854
861
 
855
862
  def __init__(self, endpoint, access_key, secret_key,
856
- *,
857
- ssl_verify=True,
858
- timeout=None,
859
- backoff_config: Optional[BackoffConfig] = None):
863
+ *,
864
+ ssl_verify=True,
865
+ timeout=None,
866
+ backoff_config: Optional[BackoffConfig] = None,
867
+ version_check=True):
860
868
 
861
869
  from . import version # import lazily here (to avoid circular dependencies)
862
870
  self.client_sdk_version = f"VAST Database Python SDK {version()} - 2024 (c)"
@@ -896,29 +904,30 @@ class VastdbApi:
896
904
  aws_region='',
897
905
  aws_service='s3')
898
906
 
899
- # probe the cluster for its version
900
- res = self._request(method="GET", url=self._url(command="transaction"), skip_status_check=True) # used only for the response headers
901
- _logger.debug("headers=%s code=%s content=%s", res.headers, res.status_code, res.content)
902
- server_header = res.headers.get("Server")
903
- if server_header is None:
904
- _logger.error("Response doesn't contain 'Server' header")
905
- else:
906
- if not server_header.startswith(self.VAST_SERVER_PREFIX):
907
- raise UnsupportedServer(f'{self.url} is not a VAST DB server endpoint ("{server_header}")')
908
-
909
- if m := self.VAST_VERSION_REGEX.match(server_header):
910
- self.vast_version: Tuple[int, ...] = tuple(int(v) for v in m.group(1).split("."))
911
- return
907
+ if version_check:
908
+ # probe the cluster for its version
909
+ res = self._request(method="GET", url=self._url(command="transaction"), skip_status_check=True) # used only for the response headers
910
+ _logger.debug("headers=%s code=%s content=%s", res.headers, res.status_code, res.content)
911
+ server_header = res.headers.get("Server")
912
+ if server_header is None:
913
+ _logger.error("Response doesn't contain 'Server' header")
912
914
  else:
913
- _logger.error("'Server' header '%s' doesn't match the expected pattern", server_header)
915
+ if not server_header.startswith(self.VAST_SERVER_PREFIX):
916
+ raise UnsupportedServer(f'{self.url} is not a VAST DB server endpoint ("{server_header}")')
914
917
 
915
- msg = (
916
- f'Please use `vastdb` <= 0.0.5.x with current VAST cluster version ("{server_header or "N/A"}"). '
917
- 'To use the latest SDK, please upgrade your cluster to the latest service pack. '
918
- 'Please contact customer.support@vastdata.com for more details.'
919
- )
920
- _logger.critical(msg)
921
- raise NotImplementedError(msg)
918
+ if m := self.VAST_VERSION_REGEX.match(server_header):
919
+ self.vast_version: Tuple[int, ...] = tuple(int(v) for v in m.group(1).split("."))
920
+ return
921
+ else:
922
+ _logger.error("'Server' header '%s' doesn't match the expected pattern", server_header)
923
+
924
+ msg = (
925
+ f'Please use `vastdb` <= 0.0.5.x with current VAST cluster version ("{server_header or "N/A"}"). '
926
+ 'To use the latest SDK, please upgrade your cluster to the latest service pack. '
927
+ 'Please contact customer.support@vastdata.com for more details.'
928
+ )
929
+ _logger.critical(msg)
930
+ raise NotImplementedError(msg)
922
931
 
923
932
  def __enter__(self):
924
933
  """Allow using this session as a context manager."""
@@ -935,7 +944,8 @@ class VastdbApi:
935
944
  secret_key=self.secret_key,
936
945
  ssl_verify=self._session.verify,
937
946
  timeout=self.timeout,
938
- backoff_config=self.backoff_config)
947
+ backoff_config=self.backoff_config,
948
+ version_check=False)
939
949
 
940
950
  def _single_request(self, *, method, url, skip_status_check=False, **kwargs):
941
951
  _logger.debug("Sending request: %s %s %s timeout=%s", method, url, kwargs, self.timeout)
@@ -1349,12 +1359,12 @@ class VastdbApi:
1349
1359
  lists = list_tables.GetRootAs(res.content)
1350
1360
  tables_length = lists.TablesLength()
1351
1361
  count = int(res_headers['tabular-list-count']) if 'tabular-list-count' in res_headers else tables_length
1352
- return lists, is_truncated, count
1362
+ return lists, next_key, is_truncated, count
1353
1363
 
1354
1364
  def _list_tables_internal(self, bucket, schema, parse_properties, txid=0, client_tags=[], max_keys=1000, next_key=0, name_prefix="",
1355
1365
  exact_match=False, expected_retvals=[], include_list_stats=False, count_only=False):
1356
1366
  tables = []
1357
- lists, is_truncated, count = self._list_tables_raw(bucket, schema, txid=txid, client_tags=client_tags, max_keys=max_keys,
1367
+ lists, next_key, is_truncated, count = self._list_tables_raw(bucket, schema, txid=txid, client_tags=client_tags, max_keys=max_keys,
1358
1368
  next_key=next_key, name_prefix=name_prefix, exact_match=exact_match, expected_retvals=expected_retvals,
1359
1369
  include_list_stats=include_list_stats, count_only=count_only)
1360
1370
  bucket_name = lists.BucketName().decode()
@@ -1368,7 +1378,7 @@ class VastdbApi:
1368
1378
  return bucket_name, schema_name, tables, next_key, is_truncated, count
1369
1379
 
1370
1380
  def raw_sorting_score(self, bucket, schema, txid, name):
1371
- lists, _, _ = self._list_tables_raw(bucket, schema, txid=txid, exact_match=True, name_prefix=name, include_list_stats=True)
1381
+ lists, _, _, _ = self._list_tables_raw(bucket, schema, txid=txid, exact_match=True, name_prefix=name, include_list_stats=True)
1372
1382
  bucket_name = lists.BucketName().decode()
1373
1383
  if not bucket.startswith(bucket_name): # ignore snapshot name
1374
1384
  raise ValueError(f'bucket: {bucket} did not start from {bucket_name}')
@@ -2267,11 +2277,17 @@ def get_field_type(builder: flatbuffers.Builder, field: pa.Field):
2267
2277
  fb_struct.Start(builder)
2268
2278
  field_type = fb_struct.End(builder)
2269
2279
 
2270
- elif isinstance(field.type, pa.ListType):
2280
+ elif pa.types.is_list(field.type):
2271
2281
  field_type_type = Type.List
2272
2282
  fb_list.Start(builder)
2273
2283
  field_type = fb_list.End(builder)
2274
2284
 
2285
+ elif pa.types.is_fixed_size_list(field.type):
2286
+ field_type_type = Type.FixedSizeList
2287
+ fb_fixed_size_list.Start(builder)
2288
+ fb_fixed_size_list.AddListSize(builder, field.type.list_size)
2289
+ field_type = fb_fixed_size_list.End(builder)
2290
+
2275
2291
  elif isinstance(field.type, pa.MapType):
2276
2292
  field_type_type = Type.Map
2277
2293
  fb_map.Start(builder)
@@ -2293,7 +2309,7 @@ def build_field(builder: flatbuffers.Builder, f: pa.Field, name: str):
2293
2309
  children = None
2294
2310
  if isinstance(f.type, pa.StructType):
2295
2311
  children = [build_field(builder, child, child.name) for child in list(f.type)]
2296
- if isinstance(f.type, pa.ListType):
2312
+ if pa.types.is_list(f.type) or pa.types.is_fixed_size_list(f.type):
2297
2313
  children = [build_field(builder, f.type.value_field, "item")]
2298
2314
  if isinstance(f.type, pa.MapType):
2299
2315
  children = [
vastdb/bench/test_perf.py CHANGED
@@ -1,3 +1,4 @@
1
+ import datetime as dt
1
2
  import logging
2
3
  import time
3
4
 
@@ -5,6 +6,7 @@ import pytest
5
6
 
6
7
  from vastdb import util
7
8
  from vastdb.table import ImportConfig, QueryConfig
9
+ from vastdb.tests.util import compare_pyarrow_tables
8
10
 
9
11
  log = logging.getLogger(__name__)
10
12
 
@@ -12,17 +14,74 @@ log = logging.getLogger(__name__)
12
14
  @pytest.mark.benchmark
13
15
  def test_bench(session, test_bucket_name, parquets_path, crater_path):
14
16
  files = [str(parquets_path / f) for f in (parquets_path.glob('**/*.pq'))]
17
+ stats = None
15
18
 
16
19
  with session.transaction() as tx:
17
20
  b = tx.bucket(test_bucket_name)
18
21
  s = b.create_schema('s1')
19
- t = util.create_table_from_files(s, 't1', files, config=ImportConfig(import_concurrency=8))
22
+ util.create_table_from_files(s, 't1', files, config=ImportConfig(import_concurrency=8))
23
+ t2 = util.create_table_from_files(s, 't2', files, config=ImportConfig(import_concurrency=8))
24
+ # Enabling Elysium with 4 sorting keys - ts, sid, ask_open, ask_close
25
+ t2.add_sorting_key([2, 0, 3, 4])
26
+ stats = t2.get_stats()
27
+ log.info("Added sorting keys")
28
+
29
+ assert stats
30
+ # Waiting up to 2 hours for sorting to complete.
31
+ start_time = time.time()
32
+ while not stats.sorting_done:
33
+ if time.time() - start_time > 7200:
34
+ raise TimeoutError("Sorting did not complete after waiting for 2 hours.")
35
+ time.sleep(30)
36
+ with session.transaction() as tx:
37
+ table = tx.bucket(test_bucket_name).schema('s1').table('t2')
38
+ stats = table.get_stats()
39
+ log.info("Sorting completed")
40
+
41
+ queries = [
42
+ {'query_str': "select sid from {t} where sid = 10033007".format, 'columns': ['sid'],
43
+ 'predicate': lambda t: t['sid'] == 10033007},
44
+ {'query_str': "select last_trade_price from {t} where ts between "
45
+ "TIMESTAMP'2018-01-04 20:30:00' AND TIMESTAMP'2018-01-05 20:30:00'".format,
46
+ 'columns': ['last_trade_price'], 'predicate': lambda t: (t['ts'].between(
47
+ dt.datetime(2018, 1, 4, 20, 30, 00, 00), dt.datetime(2018, 1, 5, 20, 30, 00, 00)))},
48
+ {'query_str': "select ts,ask_close,ask_open from {t} where bid_qty = 684000 and ask_close > 1".format,
49
+ 'columns': ['ts', 'ask_close', 'ask_open'],
50
+ 'predicate': lambda t: ((t['bid_qty'] == 684000) & (t['ask_close'] > 1))},
51
+ {'query_str': "select ts,ticker from {t} where "
52
+ "ask_open between 4374 and 4375 OR ask_open between 380 and 381".format,
53
+ 'columns': ['ts', 'ticker'],
54
+ 'predicate': lambda t: ((t['ask_open'].between(4374, 4375)) | (t['ask_open'].between(380, 381)))},
55
+ {
56
+ 'query_str': "select trade_close, trade_high, trade_low, trade_open from {t} where ticker in ('BANR', 'KELYB')".format,
57
+ 'columns': ['trade_close', 'trade_high', 'trade_low', 'trade_open'],
58
+ 'predicate': lambda t: (t['ticker'].isin(['BANR', 'KELYB']))}
59
+ ]
60
+
61
+ log.info("Starting to run queries")
62
+ with session.transaction() as tx:
63
+ schema = tx.bucket(test_bucket_name).schema('s1')
64
+ t1 = schema.table("t1")
65
+ t2 = schema.table("t2")
66
+
20
67
  config = QueryConfig(num_splits=8, num_sub_splits=4)
21
- s = time.time()
22
- pa_table = t.select(columns=['sid'], predicate=t['sid'] == 10033007, config=config).read_all()
23
- e = time.time()
24
- log.info("'SELECT sid from TABLE WHERE sid = 10033007' returned in %s seconds.", e - s)
25
- if crater_path:
26
- with open(f'{crater_path}/bench_results', 'a') as f:
27
- f.write(f"'SELECT sid FROM TABLE WHERE sid = 10033007' returned in {e - s} seconds")
28
- assert pa_table.num_rows == 255_075
68
+
69
+ for q in queries:
70
+ normal_table_res, els_table_res = None, None
71
+ for table in [t1, t2]:
72
+ log.info("Starting query: %s", q['query_str'](t=table.name))
73
+ s = time.time()
74
+ res = table.select(columns=q['columns'], predicate=q['predicate'](table), config=config).read_all()
75
+ e = time.time()
76
+ if table == t1:
77
+ normal_table_res = res
78
+ else:
79
+ els_table_res = res
80
+ log.info("Query %s returned in %s seconds.", q['query_str'](t=table.name), e - s)
81
+ if crater_path:
82
+ with open(f'{crater_path}/bench_results', 'a') as f:
83
+ f.write(f"Query '{q['query_str'](t=table)}' returned in {e - s} seconds")
84
+
85
+ assert normal_table_res, f"missing result for {t1} table"
86
+ assert els_table_res, f"missing result for {t2} table"
87
+ assert compare_pyarrow_tables(normal_table_res, els_table_res)
vastdb/conftest.py CHANGED
@@ -6,6 +6,7 @@ import boto3
6
6
  import pytest
7
7
 
8
8
  import vastdb
9
+ import vastdb.errors
9
10
 
10
11
 
11
12
  def pytest_addoption(parser):
@@ -65,8 +66,14 @@ def clean_bucket_name(request, test_bucket_name, session):
65
66
  b = tx.bucket(test_bucket_name)
66
67
  for top_schema in b.schemas():
67
68
  for s in iter_schemas(top_schema):
68
- for t in s.tables():
69
- t.drop()
69
+ for t_name in s.tablenames():
70
+ try:
71
+ t = s.table(t_name)
72
+ t.drop()
73
+ except vastdb.errors.NotSupportedSchema:
74
+ # Use internal API to drop the table in case unsupported schema prevents creating a table
75
+ # object.
76
+ tx._rpc.api.drop_table(b.name, s.name, t_name, txid=tx.txid)
70
77
  s.drop()
71
78
  return test_bucket_name
72
79
 
vastdb/errors.py CHANGED
@@ -2,7 +2,9 @@ import logging
2
2
  import xml.etree.ElementTree
3
3
  from dataclasses import dataclass
4
4
  from enum import Enum
5
+ from typing import Optional
5
6
 
7
+ import pyarrow as pa
6
8
  import requests
7
9
 
8
10
 
@@ -89,6 +91,9 @@ class ImportFilesError(Exception):
89
91
  message: str
90
92
  error_dict: dict
91
93
 
94
+ def __post_init__(self):
95
+ self.args = [vars(self)]
96
+
92
97
 
93
98
  class InvalidArgument(Exception):
94
99
  pass
@@ -122,18 +127,27 @@ class NotSupported(Exception):
122
127
  class MissingBucket(Missing):
123
128
  bucket: str
124
129
 
130
+ def __post_init__(self):
131
+ self.args = [vars(self)]
132
+
125
133
 
126
134
  @dataclass
127
135
  class MissingSnapshot(Missing):
128
136
  bucket: str
129
137
  snapshot: str
130
138
 
139
+ def __post_init__(self):
140
+ self.args = [vars(self)]
141
+
131
142
 
132
143
  @dataclass
133
144
  class MissingSchema(Missing):
134
145
  bucket: str
135
146
  schema: str
136
147
 
148
+ def __post_init__(self):
149
+ self.args = [vars(self)]
150
+
137
151
 
138
152
  @dataclass
139
153
  class MissingTable(Missing):
@@ -141,6 +155,9 @@ class MissingTable(Missing):
141
155
  schema: str
142
156
  table: str
143
157
 
158
+ def __post_init__(self):
159
+ self.args = [vars(self)]
160
+
144
161
 
145
162
  @dataclass
146
163
  class MissingProjection(Missing):
@@ -149,6 +166,9 @@ class MissingProjection(Missing):
149
166
  table: str
150
167
  projection: str
151
168
 
169
+ def __post_init__(self):
170
+ self.args = [vars(self)]
171
+
152
172
 
153
173
  class Exists(Exception):
154
174
  pass
@@ -159,6 +179,9 @@ class SchemaExists(Exists):
159
179
  bucket: str
160
180
  schema: str
161
181
 
182
+ def __post_init__(self):
183
+ self.args = [vars(self)]
184
+
162
185
 
163
186
  @dataclass
164
187
  class TableExists(Exists):
@@ -166,6 +189,9 @@ class TableExists(Exists):
166
189
  schema: str
167
190
  table: str
168
191
 
192
+ def __post_init__(self):
193
+ self.args = [vars(self)]
194
+
169
195
 
170
196
  @dataclass
171
197
  class NotSupportedCommand(NotSupported):
@@ -173,18 +199,37 @@ class NotSupportedCommand(NotSupported):
173
199
  schema: str
174
200
  table: str
175
201
 
202
+ def __post_init__(self):
203
+ self.args = [vars(self)]
204
+
176
205
 
177
206
  @dataclass
178
207
  class NotSupportedVersion(NotSupported):
179
208
  err_msg: str
180
209
  version: str
181
210
 
211
+ def __post_init__(self):
212
+ self.args = [vars(self)]
213
+
214
+
215
+ @dataclass
216
+ class NotSupportedSchema(NotSupported):
217
+ message: Optional[str] = None
218
+ schema: Optional[pa.Schema] = None
219
+ cause: Optional[Exception] = None
220
+
221
+ def __post_init__(self):
222
+ self.args = [vars(self)]
223
+
182
224
 
183
225
  @dataclass
184
226
  class ConnectionError(Exception):
185
227
  cause: Exception
186
228
  may_retry: bool
187
229
 
230
+ def __post_init__(self):
231
+ self.args = [vars(self)]
232
+
188
233
 
189
234
  def handle_unavailable(**kwargs):
190
235
  if kwargs['code'] == 'SlowDown':
@@ -192,7 +237,7 @@ def handle_unavailable(**kwargs):
192
237
  raise ServiceUnavailable(**kwargs)
193
238
 
194
239
 
195
- ERROR_TYPES_MAP = {
240
+ HTTP_ERROR_TYPES_MAP = {
196
241
  HttpStatus.BAD_REQUEST: BadRequest,
197
242
  HttpStatus.FOBIDDEN: Forbidden,
198
243
  HttpStatus.NOT_FOUND: NotFound,
@@ -205,6 +250,10 @@ ERROR_TYPES_MAP = {
205
250
  HttpStatus.INSUFFICIENT_CAPACITY: InsufficientCapacity,
206
251
  }
207
252
 
253
+ SPECIFIC_ERROR_TYPES_MAP = {
254
+ 'TabularUnsupportedColumnType': NotSupportedSchema,
255
+ }
256
+
208
257
 
209
258
  def from_response(res: requests.Response):
210
259
  if res.status_code == HttpStatus.SUCCESS.value:
@@ -234,5 +283,10 @@ def from_response(res: requests.Response):
234
283
  )
235
284
  log.warning("RPC failed: %s", kwargs)
236
285
  status = HttpStatus(res.status_code)
237
- error_type = ERROR_TYPES_MAP.get(status, UnexpectedError)
238
- return error_type(**kwargs) # type: ignore
286
+ http_error_type = HTTP_ERROR_TYPES_MAP.get(status, UnexpectedError)
287
+ http_error = http_error_type(**kwargs) # type: ignore
288
+ # Wrap specific error types if applicable
289
+ if code_str in SPECIFIC_ERROR_TYPES_MAP:
290
+ error_type = SPECIFIC_ERROR_TYPES_MAP[code_str]
291
+ return error_type(message=message_str, cause=http_error)
292
+ return http_error
vastdb/features.py CHANGED
@@ -4,7 +4,7 @@ import logging
4
4
 
5
5
  from .errors import NotSupportedVersion
6
6
 
7
- log = logging.getLogger()
7
+ log = logging.getLogger(__name__)
8
8
 
9
9
 
10
10
  class Features:
@@ -39,6 +39,10 @@ class Features:
39
39
  "Zip import requires 5.3.1+ VAST release",
40
40
  vast_version >= (5, 3, 1))
41
41
 
42
+ self.check_timezone = self._check(
43
+ "Timezone support requires 5.4+ Vast release",
44
+ vast_version >= (5, 4))
45
+
42
46
  def _check(self, msg, supported):
43
47
  log.debug("%s (current version is %s): supported=%s", msg, self.vast_version, supported)
44
48
  if not supported:
vastdb/schema.py CHANGED
@@ -91,6 +91,7 @@ class Schema:
91
91
  if use_external_row_ids_allocation:
92
92
  self.tx._rpc.features.check_external_row_ids_allocation()
93
93
 
94
+ table.Table.validate_ibis_support_schema(columns)
94
95
  self.tx._rpc.api.create_table(self.bucket.name, self.name, table_name, columns, txid=self.tx.txid,
95
96
  use_external_row_ids_allocation=use_external_row_ids_allocation,
96
97
  sorting_key=sorting_key)
@@ -109,14 +110,14 @@ class Schema:
109
110
  log.debug("Found table: %s", t[0])
110
111
  return t[0]
111
112
 
112
- def _iter_tables(self, table_name=None):
113
+ def _iter_tables(self, table_name=None, page_size=1000):
113
114
  next_key = 0
114
115
  name_prefix = table_name if table_name else ""
115
116
  exact_match = bool(table_name)
116
117
  while True:
117
118
  _bucket_name, _schema_name, curr_tables, next_key, is_truncated, _ = \
118
119
  self.tx._rpc.api.list_tables(
119
- bucket=self.bucket.name, schema=self.name, next_key=next_key, txid=self.tx.txid,
120
+ bucket=self.bucket.name, schema=self.name, next_key=next_key, max_keys=page_size, txid=self.tx.txid,
120
121
  exact_match=exact_match, name_prefix=name_prefix, include_list_stats=exact_match)
121
122
  if not curr_tables:
122
123
  break
@@ -124,19 +125,19 @@ class Schema:
124
125
  if not is_truncated:
125
126
  break
126
127
 
127
- def tables(self, table_name: str = "") -> List["Table"]:
128
+ def tables(self, table_name: str = "", page_size=1000) -> List["Table"]:
128
129
  """List all tables under this schema if `table_name` is empty.
129
130
 
130
131
  Otherwise, list only the specific table (if exists).
131
132
  """
132
133
  return [
133
134
  _parse_table_info(table_info, self)
134
- for table_info in self._iter_tables(table_name=table_name)
135
+ for table_info in self._iter_tables(table_name=table_name, page_size=page_size)
135
136
  ]
136
137
 
137
- def tablenames(self) -> List[str]:
138
+ def tablenames(self, page_size=1000) -> List[str]:
138
139
  """List all table names under this schema."""
139
- return [table_info.name for table_info in self._iter_tables()]
140
+ return [table_info.name for table_info in self._iter_tables(page_size=page_size)]
140
141
 
141
142
  def drop(self) -> None:
142
143
  """Delete this schema."""
vastdb/table.py CHANGED
@@ -1,9 +1,11 @@
1
1
  """VAST Database table."""
2
2
 
3
3
  import concurrent.futures
4
+ import copy
4
5
  import logging
5
6
  import os
6
7
  import queue
8
+ import sys
7
9
  from dataclasses import dataclass, field
8
10
  from math import ceil
9
11
  from threading import Event
@@ -124,11 +126,35 @@ class Table:
124
126
  _imports_table: bool
125
127
  sorted_table: bool
126
128
 
129
+ @staticmethod
130
+ def validate_ibis_support_schema(arrow_schema: pa.Schema):
131
+ """Validate that the provided Arrow schema is compatible with Ibis.
132
+
133
+ Raises NotSupportedSchema if the schema contains unsupported fields.
134
+ """
135
+ unsupported_fields = []
136
+ first_exception = None
137
+ for f in arrow_schema:
138
+ try:
139
+ ibis.Schema.from_pyarrow(pa.schema([f]))
140
+ except Exception as e:
141
+ if first_exception is None:
142
+ first_exception = e
143
+ unsupported_fields.append(f)
144
+
145
+ if unsupported_fields:
146
+ raise errors.NotSupportedSchema(
147
+ message=f"Ibis does not support the schema {unsupported_fields=}",
148
+ schema=arrow_schema,
149
+ cause=first_exception
150
+ )
151
+
127
152
  def __post_init__(self):
128
153
  """Also, load columns' metadata."""
129
154
  self.arrow_schema = self.columns()
130
155
 
131
156
  self._table_path = f'{self.schema.bucket.name}/{self.schema.name}/{self.name}'
157
+ self.validate_ibis_support_schema(self.arrow_schema)
132
158
  self._ibis_table = ibis.table(ibis.Schema.from_pyarrow(self.arrow_schema), self._table_path)
133
159
 
134
160
  @property
@@ -333,7 +359,8 @@ class Table:
333
359
  predicate: Union[ibis.expr.types.BooleanColumn, ibis.common.deferred.Deferred] = None,
334
360
  config: Optional[QueryConfig] = None,
335
361
  *,
336
- internal_row_id: bool = False) -> pa.RecordBatchReader:
362
+ internal_row_id: bool = False,
363
+ limit_rows: Optional[int] = None) -> pa.RecordBatchReader:
337
364
  """Execute a query over this table.
338
365
 
339
366
  To read a subset of the columns, specify their names via `columns` argument. Otherwise, all columns will be read.
@@ -342,15 +369,13 @@ class Table:
342
369
 
343
370
  Query-execution configuration options can be specified via the optional `config` argument.
344
371
  """
345
- if config is None:
346
- config = QueryConfig()
372
+ config = copy.deepcopy(config) if config else QueryConfig()
373
+
374
+ if limit_rows:
375
+ config.limit_rows_per_sub_split = limit_rows
347
376
 
348
- stats = None
349
- # Retrieve snapshots only if needed
350
377
  if config.data_endpoints is None:
351
- stats = self.get_stats()
352
- log.debug("stats: %s", stats)
353
- endpoints = stats.endpoints
378
+ endpoints = tuple([self.tx._rpc.api.url])
354
379
  else:
355
380
  endpoints = tuple(config.data_endpoints)
356
381
  log.debug("endpoints: %s", endpoints)
@@ -380,8 +405,7 @@ class Table:
380
405
  num_rows = self._get_row_estimate(columns, predicate, query_schema)
381
406
  log.debug(f'sorted estimate: {num_rows}')
382
407
  if num_rows == 0:
383
- if stats is None:
384
- stats = self.get_stats()
408
+ stats = self.get_stats()
385
409
  num_rows = stats.num_rows
386
410
 
387
411
  config.num_splits = max(1, num_rows // config.rows_per_split)
@@ -402,7 +426,7 @@ class Table:
402
426
  for split in range(config.num_splits):
403
427
  splits_queue.put(split)
404
428
 
405
- # this queue shouldn't be large it is marely a pipe through which the results
429
+ # this queue shouldn't be large it is merely a pipe through which the results
406
430
  # are sent to the main thread. Most of the pages actually held in the
407
431
  # threads that fetch the pages.
408
432
  record_batches_queue: queue.Queue[pa.RecordBatch] = queue.Queue(maxsize=2)
@@ -458,8 +482,9 @@ class Table:
458
482
  if config.query_id:
459
483
  threads_prefix = threads_prefix + "-" + config.query_id
460
484
 
485
+ total_num_rows = limit_rows if limit_rows else sys.maxsize
461
486
  with concurrent.futures.ThreadPoolExecutor(max_workers=len(endpoints), thread_name_prefix=threads_prefix) as tp: # TODO: concurrency == enpoints is just a heuristic
462
- futures = [tp.submit(single_endpoint_worker, endpoint) for endpoint in endpoints]
487
+ futures = [tp.submit(single_endpoint_worker, endpoint) for endpoint in endpoints[:config.num_splits]]
463
488
  tasks_running = len(futures)
464
489
  try:
465
490
  while tasks_running > 0:
@@ -467,7 +492,14 @@ class Table:
467
492
 
468
493
  batch = record_batches_queue.get()
469
494
  if batch is not None:
470
- yield batch
495
+ if batch.num_rows < total_num_rows:
496
+ yield batch
497
+ total_num_rows -= batch.num_rows
498
+ else:
499
+ yield batch.slice(length=total_num_rows)
500
+ log.info("reached limit rows per query: %d - stop query", limit_rows)
501
+ stop_event.set()
502
+ break
471
503
  else:
472
504
  tasks_running -= 1
473
505
  log.debug("one worker thread finished, remaining: %d", tasks_running)
@@ -510,6 +542,9 @@ class Table:
510
542
  """Insert a RecordBatch into this table."""
511
543
  if self._imports_table:
512
544
  raise errors.NotSupportedCommand(self.bucket.name, self.schema.name, self.name)
545
+ if 0 == rows.num_rows:
546
+ log.debug("Ignoring empty insert into %s", self.name)
547
+ return pa.chunked_array([], type=INTERNAL_ROW_ID_FIELD.type)
513
548
  try:
514
549
  row_ids = []
515
550
  serialized_slices = util.iter_serialized_slices(rows, MAX_INSERT_ROWS_PER_PATCH)
@@ -522,7 +557,7 @@ class Table:
522
557
  self.tx._rpc.features.check_return_row_ids()
523
558
  except errors.NotSupportedVersion:
524
559
  return # type: ignore
525
- return pa.chunked_array(row_ids)
560
+ return pa.chunked_array(row_ids, type=INTERNAL_ROW_ID_FIELD.type)
526
561
  except errors.TooWideRow:
527
562
  self.tx._rpc.features.check_return_row_ids()
528
563
  return self.insert_in_column_batches(rows)
@@ -596,7 +631,7 @@ class Table:
596
631
  self.name = new_name
597
632
 
598
633
  def add_sorting_key(self, sorting_key: list) -> None:
599
- """Ads a sorting key to a table that doesn't have any."""
634
+ """Add a sorting key to a table that doesn't have any."""
600
635
  self.tx._rpc.features.check_elysium()
601
636
  self.tx._rpc.api.alter_table(
602
637
  self.bucket.name, self.schema.name, self.name, txid=self.tx.txid, sorting_key=sorting_key)
@@ -606,6 +641,7 @@ class Table:
606
641
  """Add a new column."""
607
642
  if self._imports_table:
608
643
  raise errors.NotSupportedCommand(self.bucket.name, self.schema.name, self.name)
644
+ self.validate_ibis_support_schema(new_column)
609
645
  self.tx._rpc.api.add_columns(self.bucket.name, self.schema.name, self.name, new_column, txid=self.tx.txid)
610
646
  log.info("Added column(s): %s", new_column)
611
647
  self.arrow_schema = self.columns()