vastdb 0.1.6__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vastdb/bucket.py CHANGED
@@ -6,7 +6,7 @@ It is possible to list and access VAST snapshots generated over a bucket.
6
6
 
7
7
  import logging
8
8
  from dataclasses import dataclass, field
9
- from typing import TYPE_CHECKING, List, Optional
9
+ from typing import TYPE_CHECKING, Iterable, Optional
10
10
 
11
11
  from . import errors, schema, transaction
12
12
 
@@ -55,7 +55,7 @@ class Bucket:
55
55
 
56
56
  return Bucket(name=f'{self.name}/{expected_name}', tx=self.tx)
57
57
 
58
- def snapshots(self) -> List["Bucket"]:
58
+ def snapshots(self) -> Iterable["Bucket"]:
59
59
  """List bucket's snapshots."""
60
60
  snapshots = []
61
61
  next_key = 0
vastdb/conftest.py CHANGED
@@ -9,11 +9,15 @@ import vastdb
9
9
 
10
10
  def pytest_addoption(parser):
11
11
  parser.addoption("--tabular-bucket-name", help="Name of the S3 bucket with Tabular enabled", default="vastdb")
12
- parser.addoption("--tabular-access-key", help="Access key with Tabular permissions (AWS_ACCESS_KEY_ID)", default=os.environ.get("AWS_ACCESS_KEY_ID", None))
13
- parser.addoption("--tabular-secret-key", help="Secret key with Tabular permissions (AWS_SECRET_ACCESS_KEY)", default=os.environ.get("AWS_SECRET_ACCESS_KEY", None))
12
+ parser.addoption("--tabular-access-key", help="Access key with Tabular permissions (AWS_ACCESS_KEY_ID)",
13
+ default=os.environ.get("AWS_ACCESS_KEY_ID", None))
14
+ parser.addoption("--tabular-secret-key", help="Secret key with Tabular permissions (AWS_SECRET_ACCESS_KEY)",
15
+ default=os.environ.get("AWS_SECRET_ACCESS_KEY", None))
14
16
  parser.addoption("--tabular-endpoint-url", help="Tabular server endpoint", default="http://localhost:9090")
15
17
  parser.addoption("--data-path", help="Data files location", default=None)
16
18
  parser.addoption("--crater-path", help="Save benchmark results in a dedicated location", default=None)
19
+ parser.addoption("--schema-name", help="Name of schema for the test to operate on", default=None)
20
+ parser.addoption("--table-name", help="Name of table for the test to operate on", default=None)
17
21
 
18
22
 
19
23
  @pytest.fixture(scope="session")
@@ -67,3 +71,13 @@ def parquets_path(request):
67
71
  @pytest.fixture(scope="function")
68
72
  def crater_path(request):
69
73
  return request.config.getoption("--crater-path")
74
+
75
+
76
+ @pytest.fixture(scope="function")
77
+ def schema_name(request):
78
+ return request.config.getoption("--schema-name")
79
+
80
+
81
+ @pytest.fixture(scope="function")
82
+ def table_name(request):
83
+ return request.config.getoption("--table-name")
vastdb/errors.py CHANGED
@@ -175,6 +175,12 @@ class NotSupportedVersion(NotSupported):
175
175
  version: str
176
176
 
177
177
 
178
+ @dataclass
179
+ class ConnectionError(Exception):
180
+ cause: Exception
181
+ may_retry: bool
182
+
183
+
178
184
  def handle_unavailable(**kwargs):
179
185
  if kwargs['code'] == 'SlowDown':
180
186
  raise Slowdown(**kwargs)
vastdb/schema.py CHANGED
@@ -6,7 +6,7 @@ It is possible to list and access VAST snapshots generated over a bucket.
6
6
 
7
7
  import logging
8
8
  from dataclasses import dataclass
9
- from typing import TYPE_CHECKING, List, Optional
9
+ from typing import TYPE_CHECKING, Iterable, List, Optional
10
10
 
11
11
  import pyarrow as pa
12
12
 
@@ -62,7 +62,7 @@ class Schema:
62
62
  assert len(names) == 1, f"Expected to receive only a single schema, but got {len(schemas)}: ({schemas})"
63
63
  return schema.Schema(name=self._subschema_full_name(names[0]), bucket=self.bucket)
64
64
 
65
- def schemas(self, batch_size=None) -> List["Schema"]:
65
+ def schemas(self, batch_size=None) -> Iterable["Schema"]:
66
66
  """List child schemas."""
67
67
  next_key = 0
68
68
  if not batch_size:
@@ -76,14 +76,18 @@ class Schema:
76
76
  break
77
77
  return result
78
78
 
79
- def create_table(self, table_name: str, columns: pa.Schema, fail_if_exists=True) -> "Table":
79
+ def create_table(self, table_name: str, columns: pa.Schema, fail_if_exists=True, use_external_row_ids_allocation=False) -> "Table":
80
80
  """Create a new table under this schema."""
81
81
  if current := self.table(table_name, fail_if_missing=False):
82
82
  if fail_if_exists:
83
83
  raise errors.TableExists(self.bucket.name, self.name, table_name)
84
84
  else:
85
85
  return current
86
- self.tx._rpc.api.create_table(self.bucket.name, self.name, table_name, columns, txid=self.tx.txid)
86
+ if use_external_row_ids_allocation:
87
+ self.tx._rpc.features.check_external_row_ids_allocation()
88
+
89
+ self.tx._rpc.api.create_table(self.bucket.name, self.name, table_name, columns, txid=self.tx.txid,
90
+ use_external_row_ids_allocation=use_external_row_ids_allocation)
87
91
  log.info("Created table: %s", table_name)
88
92
  return self.table(table_name) # type: ignore[return-value]
89
93
 
vastdb/session.py CHANGED
@@ -9,10 +9,12 @@ For more details see:
9
9
 
10
10
  import logging
11
11
  import os
12
+ from typing import Optional
12
13
 
13
14
  import boto3
14
15
 
15
- from . import errors, internal_commands, transaction
16
+ from . import _internal, errors, transaction
17
+ from ._internal import BackoffConfig
16
18
 
17
19
  log = logging.getLogger()
18
20
 
@@ -36,6 +38,10 @@ class Features:
36
38
  "Semi-sorted projection enforcement requires 5.1+ VAST release",
37
39
  vast_version >= (5, 1))
38
40
 
41
+ self.check_external_row_ids_allocation = self._check(
42
+ "External row IDs allocation requires 5.1+ VAST release",
43
+ vast_version >= (5, 1))
44
+
39
45
  def _check(self, msg, supported):
40
46
  log.debug("%s (current version is %s): supported=%s", msg, self.vast_version, supported)
41
47
  if not supported:
@@ -51,7 +57,10 @@ class Features:
51
57
  class Session:
52
58
  """VAST database session."""
53
59
 
54
- def __init__(self, access=None, secret=None, endpoint=None, ssl_verify=True):
60
+ def __init__(self, access=None, secret=None, endpoint=None,
61
+ *,
62
+ ssl_verify=True,
63
+ backoff_config: Optional[BackoffConfig] = None):
55
64
  """Connect to a VAST Database endpoint, using specified credentials."""
56
65
  if access is None:
57
66
  access = os.environ['AWS_ACCESS_KEY_ID']
@@ -60,9 +69,13 @@ class Session:
60
69
  if endpoint is None:
61
70
  endpoint = os.environ['AWS_S3_ENDPOINT_URL']
62
71
 
63
- self.api = internal_commands.VastdbApi(endpoint, access, secret, ssl_verify=ssl_verify)
64
- version_tuple = tuple(int(part) for part in self.api.vast_version.split('.'))
65
- self.features = Features(version_tuple)
72
+ self.api = _internal.VastdbApi(
73
+ endpoint=endpoint,
74
+ access_key=access,
75
+ secret_key=secret,
76
+ ssl_verify=ssl_verify,
77
+ backoff_config=backoff_config)
78
+ self.features = Features(self.api.vast_version)
66
79
  self.s3 = boto3.client('s3',
67
80
  aws_access_key_id=access,
68
81
  aws_secret_access_key=secret,
vastdb/table.py CHANGED
@@ -7,14 +7,13 @@ import queue
7
7
  from dataclasses import dataclass, field
8
8
  from math import ceil
9
9
  from threading import Event
10
- from typing import Any, Dict, List, Optional, Tuple, Union
10
+ from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
11
11
 
12
- import backoff
13
12
  import ibis
14
13
  import pyarrow as pa
15
- import requests
14
+ import urllib3
16
15
 
17
- from . import errors, internal_commands, schema, util
16
+ from . import _internal, errors, schema, util
18
17
 
19
18
  log = logging.getLogger(__name__)
20
19
 
@@ -40,12 +39,6 @@ class TableStats:
40
39
  endpoints: Tuple[str, ...] = ()
41
40
 
42
41
 
43
- RETRIABLE_ERRORS = (
44
- errors.Slowdown,
45
- requests.exceptions.ConnectionError,
46
- )
47
-
48
-
49
42
  @dataclass
50
43
  class QueryConfig:
51
44
  """Query execution configiration."""
@@ -80,8 +73,10 @@ class QueryConfig:
80
73
  # used for worker threads' naming
81
74
  query_id: str = ""
82
75
 
83
- # allows retrying QueryData when the server is overloaded
84
- backoff_func: Any = field(default=backoff.on_exception(backoff.expo, RETRIABLE_ERRORS, max_tries=10))
76
+ # non-negative integer, used for server-side prioritization of queued requests:
77
+ # - requests with lower values will be served before requests with higher values.
78
+ # - if unset, the request will be added to the queue's end.
79
+ queue_priority: Optional[int] = None
85
80
 
86
81
 
87
82
  @dataclass
@@ -102,42 +97,58 @@ class SelectSplitState:
102
97
  self.query_data_request = query_data_request
103
98
  self.table = table
104
99
 
105
- def batches(self, api: internal_commands.VastdbApi):
106
- """Execute QueryData request, and yield parsed RecordBatch objects.
100
+ def process_split(self, api: _internal.VastdbApi, record_batches_queue: queue.Queue[pa.RecordBatch], check_stop: Callable):
101
+ """Execute a sequence of QueryData requests, and queue the parsed RecordBatch objects.
107
102
 
108
- Can be called repeatedly, to allow pagination.
103
+ Can be called repeatedly, to support resuming the query after a disconnection / retriable error.
109
104
  """
110
- while not self.done:
111
- query_with_backoff = self.config.backoff_func(api.query_data)
112
- response = query_with_backoff(
113
- bucket=self.table.bucket.name,
114
- schema=self.table.schema.name,
115
- table=self.table.name,
116
- params=self.query_data_request.serialized,
117
- split=(self.split_id, self.config.num_splits, self.config.num_row_groups_per_sub_split),
118
- num_sub_splits=self.config.num_sub_splits,
119
- response_row_id=False,
120
- txid=self.table.tx.txid,
121
- limit_rows=self.config.limit_rows_per_sub_split,
122
- sub_split_start_row_ids=self.subsplits_state.items(),
123
- enable_sorted_projections=self.config.use_semi_sorted_projections,
124
- query_imports_table=self.table._imports_table,
125
- projection=self.config.semi_sorted_projection_name)
126
- pages_iter = internal_commands.parse_query_data_response(
127
- conn=response.raw,
128
- schema=self.query_data_request.response_schema,
129
- start_row_ids=self.subsplits_state,
130
- parser=self.query_data_request.response_parser)
131
-
132
- for page in pages_iter:
133
- for batch in page.to_batches():
134
- if len(batch) > 0:
135
- yield batch
105
+ try:
106
+ # contains RecordBatch parts received from the server, must be re-created in case of a retry
107
+ while not self.done:
108
+ # raises if request parsing fails or throttled due to server load, and will be externally retried
109
+ response = api.query_data(
110
+ bucket=self.table.bucket.name,
111
+ schema=self.table.schema.name,
112
+ table=self.table.name,
113
+ params=self.query_data_request.serialized,
114
+ split=(self.split_id, self.config.num_splits, self.config.num_row_groups_per_sub_split),
115
+ num_sub_splits=self.config.num_sub_splits,
116
+ response_row_id=False,
117
+ txid=self.table.tx.txid,
118
+ limit_rows=self.config.limit_rows_per_sub_split,
119
+ sub_split_start_row_ids=self.subsplits_state.items(),
120
+ schedule_id=self.config.queue_priority,
121
+ enable_sorted_projections=self.config.use_semi_sorted_projections,
122
+ query_imports_table=self.table._imports_table,
123
+ projection=self.config.semi_sorted_projection_name)
124
+
125
+ # can raise during response parsing (e.g. due to disconnections), and will be externally retried
126
+ # the pagination state is stored in `self.subsplits_state` and must be correct in case of a reconnection
127
+ # the partial RecordBatch chunks are managed internally in `parse_query_data_response`
128
+ response_iter = _internal.parse_query_data_response(
129
+ conn=response.raw,
130
+ schema=self.query_data_request.response_schema,
131
+ parser=self.query_data_request.response_parser)
132
+
133
+ for stream_id, next_row_id, table_chunk in response_iter:
134
+ # in case of I/O error, `response_iter` will be closed and an appropriate exception will be thrown.
135
+ self.subsplits_state[stream_id] = next_row_id
136
+ # we have parsed a pyarrow.Table successfully, self.subsplits_state is now correctly updated
137
+ # if the below loop fails, the query is not retried
138
+ for batch in table_chunk.to_batches():
139
+ check_stop() # may raise StoppedException to early-exit the query (without retries)
140
+ if batch:
141
+ record_batches_queue.put(batch)
142
+ except urllib3.exceptions.ProtocolError as err:
143
+ log.warning("Failed parsing QueryData response table=%r split=%s/%s offsets=%s cause=%s",
144
+ self.table, self.split_id, self.config.num_splits, self.subsplits_state, err)
145
+ # since this is a read-only idempotent operation, it is safe to retry
146
+ raise errors.ConnectionError(cause=err, may_retry=True)
136
147
 
137
148
  @property
138
149
  def done(self):
139
150
  """Returns true iff the pagination over."""
140
- return all(row_id == internal_commands.TABULAR_INVALID_ROW_ID for row_id in self.subsplits_state.values())
151
+ return all(row_id == _internal.TABULAR_INVALID_ROW_ID for row_id in self.subsplits_state.values())
141
152
 
142
153
 
143
154
  @dataclass
@@ -187,14 +198,14 @@ class Table:
187
198
  """Get a specific semi-sorted projection of this table."""
188
199
  if self._imports_table:
189
200
  raise errors.NotSupportedCommand(self.bucket.name, self.schema.name, self.name)
190
- projs = self.projections(projection_name=name)
201
+ projs = tuple(self.projections(projection_name=name))
191
202
  if not projs:
192
203
  raise errors.MissingProjection(self.bucket.name, self.schema.name, self.name, name)
193
204
  assert len(projs) == 1, f"Expected to receive only a single projection, but got: {len(projs)}. projections: {projs}"
194
205
  log.debug("Found projection: %s", projs[0])
195
206
  return projs[0]
196
207
 
197
- def projections(self, projection_name=None) -> List["Projection"]:
208
+ def projections(self, projection_name=None) -> Iterable["Projection"]:
198
209
  """List all semi-sorted projections of this table."""
199
210
  if self._imports_table:
200
211
  raise errors.NotSupportedCommand(self.bucket.name, self.schema.name, self.name)
@@ -214,7 +225,7 @@ class Table:
214
225
  break
215
226
  return [_parse_projection_info(projection, self) for projection in projections]
216
227
 
217
- def import_files(self, files_to_import: List[str], config: Optional[ImportConfig] = None) -> None:
228
+ def import_files(self, files_to_import: Iterable[str], config: Optional[ImportConfig] = None) -> None:
218
229
  """Import a list of Parquet files into this table.
219
230
 
220
231
  The files must be on VAST S3 server and be accessible using current credentials.
@@ -283,7 +294,7 @@ class Table:
283
294
  max_workers=config.import_concurrency, thread_name_prefix='import_thread') as pool:
284
295
  try:
285
296
  for endpoint in endpoints:
286
- session = internal_commands.VastdbApi(endpoint, self.tx._rpc.api.access_key, self.tx._rpc.api.secret_key)
297
+ session = _internal.VastdbApi(endpoint, self.tx._rpc.api.access_key, self.tx._rpc.api.secret_key)
287
298
  futures.append(pool.submit(import_worker, files_queue, session))
288
299
 
289
300
  log.debug("Waiting for import workers to finish")
@@ -316,10 +327,15 @@ class Table:
316
327
  if config is None:
317
328
  config = QueryConfig()
318
329
 
319
- # Take a snapshot of enpoints
320
- stats = self.get_stats()
321
- log.debug("stats: %s", stats)
322
- endpoints = stats.endpoints if config.data_endpoints is None else config.data_endpoints
330
+ # Retrieve snapshots only if needed
331
+ if config.data_endpoints is None or config.num_splits is None:
332
+ stats = self.get_stats()
333
+ log.debug("stats: %s", stats)
334
+
335
+ if config.data_endpoints is None:
336
+ endpoints = stats.endpoints
337
+ else:
338
+ endpoints = tuple(config.data_endpoints)
323
339
  log.debug("endpoints: %s", endpoints)
324
340
 
325
341
  if config.num_splits is None:
@@ -342,13 +358,13 @@ class Table:
342
358
  if predicate is True:
343
359
  predicate = None
344
360
  if predicate is False:
345
- response_schema = internal_commands.get_response_schema(schema=query_schema, field_names=columns)
361
+ response_schema = _internal.get_response_schema(schema=query_schema, field_names=columns)
346
362
  return pa.RecordBatchReader.from_batches(response_schema, [])
347
363
 
348
364
  if isinstance(predicate, ibis.common.deferred.Deferred):
349
365
  predicate = predicate.resolve(self._ibis_table) # may raise if the predicate is invalid (e.g. wrong types / missing column)
350
366
 
351
- query_data_request = internal_commands.build_query_data_request(
367
+ query_data_request = _internal.build_query_data_request(
352
368
  schema=query_schema,
353
369
  predicate=predicate,
354
370
  field_names=columns)
@@ -376,7 +392,8 @@ class Table:
376
392
 
377
393
  def single_endpoint_worker(endpoint: str):
378
394
  try:
379
- host_api = internal_commands.VastdbApi(endpoint=endpoint, access_key=self.tx._rpc.api.access_key, secret_key=self.tx._rpc.api.secret_key)
395
+ host_api = _internal.VastdbApi(endpoint=endpoint, access_key=self.tx._rpc.api.access_key, secret_key=self.tx._rpc.api.secret_key)
396
+ backoff_decorator = self.tx._rpc.api._backoff_decorator
380
397
  while True:
381
398
  check_stop()
382
399
  try:
@@ -390,9 +407,9 @@ class Table:
390
407
  split_id=split,
391
408
  config=config)
392
409
 
393
- for batch in split_state.batches(host_api):
394
- check_stop()
395
- record_batches_queue.put(batch)
410
+ process_with_retries = backoff_decorator(split_state.process_split)
411
+ process_with_retries(host_api, record_batches_queue, check_stop)
412
+
396
413
  except StoppedException:
397
414
  log.debug("stop signal.", exc_info=True)
398
415
  return
@@ -473,7 +490,7 @@ class Table:
473
490
  for slice in serialized_slices:
474
491
  res = self.tx._rpc.api.insert_rows(self.bucket.name, self.schema.name, self.name, record_batch=slice,
475
492
  txid=self.tx.txid)
476
- (batch,) = pa.RecordBatchStreamReader(res.raw)
493
+ (batch,) = pa.RecordBatchStreamReader(res.content)
477
494
  row_ids.append(batch[INTERNAL_ROW_ID])
478
495
  try:
479
496
  self.tx._rpc.features.check_return_row_ids()
@@ -509,6 +526,8 @@ class Table:
509
526
  else:
510
527
  update_rows_rb = rows
511
528
 
529
+ update_rows_rb = util.sort_record_batch_if_needed(update_rows_rb, INTERNAL_ROW_ID)
530
+
512
531
  serialized_slices = util.iter_serialized_slices(update_rows_rb, MAX_ROWS_PER_BATCH)
513
532
  for slice in serialized_slices:
514
533
  self.tx._rpc.api.update_rows(self.bucket.name, self.schema.name, self.name, record_batch=slice,
@@ -528,6 +547,8 @@ class Table:
528
547
  delete_rows_rb = pa.record_batch(schema=pa.schema([(INTERNAL_ROW_ID, pa.uint64())]),
529
548
  data=[_combine_chunks(rows_chunk)])
530
549
 
550
+ delete_rows_rb = util.sort_record_batch_if_needed(delete_rows_rb, INTERNAL_ROW_ID)
551
+
531
552
  serialized_slices = util.iter_serialized_slices(delete_rows_rb, MAX_ROWS_PER_BATCH)
532
553
  for slice in serialized_slices:
533
554
  self.tx._rpc.api.delete_rows(self.bucket.name, self.schema.name, self.name, record_batch=slice,
@@ -593,7 +614,7 @@ class Table:
593
614
  return self.imports_table() # type: ignore[return-value]
594
615
 
595
616
  def imports_table(self) -> Optional["Table"]:
596
- """Get the imports table under of this table."""
617
+ """Get the imports table of this table."""
597
618
  self.tx._rpc.features.check_imports_table()
598
619
  return Table(name=self.name, schema=self.schema, handle=int(self.handle), stats=self.stats, _imports_table=True)
599
620
 
@@ -56,6 +56,6 @@ def test_closed_tx(session, clean_bucket_name):
56
56
  res = conn.execute('SELECT a FROM batches')
57
57
  log.debug("closing tx=%s after first batch=%s", t.tx, first)
58
58
 
59
- # transaction is closed, collecting the result should fail
60
- with pytest.raises(duckdb.InvalidInputException, match="Detail: Python exception: MissingTransaction"):
59
+ # transaction is closed, collecting the result should fail internally in DuckDB
60
+ with pytest.raises(duckdb.InvalidInputException):
61
61
  res.arrow()
@@ -105,7 +105,11 @@ def test_query_data_with_projection(session, clean_bucket_name):
105
105
  t = s.table(table_name)
106
106
  projection_actual = pa.Table.from_batches(t.select(columns=['a', 'b', 's'], predicate=(t['b'] < 5), config=config))
107
107
  # no projection supply - need to be with p1 projeciton
108
- assert expected_projection_p1 == projection_actual
108
+ # doing this since we also run this test against production clusters
109
+ if expected_projection_p1 != projection_actual:
110
+ config.num_row_groups_per_sub_split = 8
111
+ projection_actual = pa.Table.from_batches(t.select(columns=['a', 'b', 's'], predicate=(t['b'] < 5), config=config))
112
+ assert expected_projection_p1 == projection_actual
109
113
 
110
114
  config.semi_sorted_projection_name = 'p1'
111
115
  projection_actual = pa.Table.from_batches(t.select(columns=['a', 'b', 's'], predicate=(t['b'] < 5), config=config))
@@ -5,9 +5,8 @@ from http.server import BaseHTTPRequestHandler, HTTPServer
5
5
  from itertools import cycle
6
6
 
7
7
  import pytest
8
- import requests
9
8
 
10
- import vastdb
9
+ import vastdb.errors
11
10
 
12
11
  log = logging.getLogger(__name__)
13
12
 
@@ -25,8 +24,9 @@ def test_bad_credentials(session):
25
24
 
26
25
 
27
26
  def test_bad_endpoint(session):
28
- with pytest.raises(requests.exceptions.ConnectionError):
29
- vastdb.connect(access='BAD', secret='BAD', endpoint='http://invalid-host-name-for-tests:12345')
27
+ backoff_config = vastdb.session.BackoffConfig(max_tries=3)
28
+ with pytest.raises(vastdb.errors.ConnectionError):
29
+ vastdb.connect(access='BAD', secret='BAD', endpoint='http://invalid-host-name-for-tests:12345', backoff_config=backoff_config)
30
30
 
31
31
 
32
32
  def test_version_extraction():
@@ -36,7 +36,7 @@ def test_version_extraction():
36
36
  ("5", None), # major
37
37
  ("5.2", None), # major.minor
38
38
  ("5.2.0", None), # major.minor.patch
39
- ("5.2.0.10", "5.2.0.10"), # major.minor.patch.protocol
39
+ ("5.2.0.10", (5, 2, 0, 10)), # major.minor.patch.protocol
40
40
  ("5.2.0.10 some other things", None), # suffix
41
41
  ("5.2.0.10.20", None), # extra version
42
42
  ]
@@ -58,7 +58,7 @@ def test_tables(session, clean_bucket_name):
58
58
  }
59
59
 
60
60
  columns_to_delete = pa.schema([(INTERNAL_ROW_ID, pa.uint64())])
61
- rb = pa.record_batch(schema=columns_to_delete, data=[[0]]) # delete rows 0,1
61
+ rb = pa.record_batch(schema=columns_to_delete, data=[[0]]) # delete row 0
62
62
  t.delete(rb)
63
63
 
64
64
  selected_rows = t.select(columns=['b'], predicate=(t['a'] == 222), internal_row_id=True).read_all()
@@ -81,6 +81,19 @@ def test_insert_wide_row(session, clean_bucket_name):
81
81
  assert actual == expected
82
82
 
83
83
 
84
+ def test_insert_empty(session, clean_bucket_name):
85
+ columns = pa.schema([('a', pa.int8()), ('b', pa.float32())])
86
+ data = [[None] * 5, [None] * 5]
87
+ all_nulls = pa.table(schema=columns, data=data)
88
+ no_columns = all_nulls.select([])
89
+
90
+ with session.transaction() as tx:
91
+ t = tx.bucket(clean_bucket_name).create_schema('s').create_table('t', columns)
92
+ t.insert(all_nulls)
93
+ with pytest.raises(errors.NotImplemented):
94
+ t.insert(no_columns)
95
+
96
+
84
97
  def test_exists(session, clean_bucket_name):
85
98
  with session.transaction() as tx:
86
99
  s = tx.bucket(clean_bucket_name).create_schema('s1')
@@ -156,6 +169,27 @@ def test_update_table(session, clean_bucket_name):
156
169
  'b': [0.5, 1.5, 2.5]
157
170
  }
158
171
 
172
+ # test update for not sorted rows:
173
+ rb = pa.record_batch(schema=columns_to_update, data=[
174
+ [2, 0], # update rows 0,2
175
+ [231, 235]
176
+ ])
177
+ t.update(rb)
178
+ actual = t.select(columns=['a', 'b']).read_all()
179
+ assert actual.to_pydict() == {
180
+ 'a': [235, 2222, 231],
181
+ 'b': [0.5, 1.5, 2.5]
182
+ }
183
+
184
+ # test delete for not sorted rows:
185
+ rb = pa.record_batch(schema=pa.schema([(INTERNAL_ROW_ID, pa.uint64())]), data=[[2, 0]])
186
+ t.delete(rb)
187
+ actual = t.select(columns=['a', 'b']).read_all()
188
+ assert actual.to_pydict() == {
189
+ 'a': [2222],
190
+ 'b': [1.5]
191
+ }
192
+
159
193
 
160
194
  def test_select_with_multisplits(session, clean_bucket_name):
161
195
  columns = pa.schema([
@@ -174,6 +208,25 @@ def test_select_with_multisplits(session, clean_bucket_name):
174
208
  assert actual == expected
175
209
 
176
210
 
211
+ def test_select_with_priority(session, clean_bucket_name):
212
+ columns = pa.schema([
213
+ ('a', pa.int32())
214
+ ])
215
+ expected = pa.table(schema=columns, data=[range(100)])
216
+ with prepare_data(session, clean_bucket_name, 's', 't', expected) as t:
217
+ config = QueryConfig()
218
+
219
+ config.queue_priority = 0
220
+ assert t.select(config=config).read_all() == expected
221
+
222
+ config.queue_priority = 12345
223
+ assert t.select(config=config).read_all() == expected
224
+
225
+ config.queue_priority = -1
226
+ with pytest.raises(errors.BadRequest):
227
+ t.select(config=config).read_all()
228
+
229
+
177
230
  def test_types(session, clean_bucket_name):
178
231
  columns = pa.schema([
179
232
  ('tb', pa.bool_()),
vastdb/tests/test_util.py CHANGED
@@ -33,6 +33,12 @@ def test_wide_row():
33
33
  list(util.iter_serialized_slices(t))
34
34
 
35
35
 
36
+ def test_expand_ip_ranges():
37
+ endpoints = ["http://172.19.101.1-3"]
38
+ expected = ["http://172.19.101.1", "http://172.19.101.2", "http://172.19.101.3"]
39
+ assert util.expand_ip_ranges(endpoints) == expected
40
+
41
+
36
42
  def _parse(bufs):
37
43
  for buf in bufs:
38
44
  with pa.ipc.open_stream(buf) as reader:
vastdb/transaction.py CHANGED
@@ -8,7 +8,7 @@ A transcation is used as a context manager, since every Database-related operati
8
8
 
9
9
  import logging
10
10
  from dataclasses import dataclass
11
- from typing import TYPE_CHECKING, List, Optional
11
+ from typing import TYPE_CHECKING, Iterable, Optional
12
12
 
13
13
  import botocore
14
14
 
@@ -72,7 +72,7 @@ class Transaction:
72
72
  raise
73
73
  return bucket.Bucket(name, self)
74
74
 
75
- def catalog_snapshots(self) -> List["Bucket"]:
75
+ def catalog_snapshots(self) -> Iterable["Bucket"]:
76
76
  """Return VAST Catalog bucket snapshots."""
77
77
  return bucket.Bucket(VAST_CATALOG_BUCKET_NAME, self).snapshots()
78
78
 
vastdb/util.py CHANGED
@@ -1,7 +1,9 @@
1
1
  import logging
2
+ import re
2
3
  from typing import TYPE_CHECKING, Callable, List, Optional, Union
3
4
 
4
5
  import pyarrow as pa
6
+ import pyarrow.compute as pc
5
7
  import pyarrow.parquet as pq
6
8
 
7
9
  from .errors import InvalidArgument, TooWideRow
@@ -88,8 +90,11 @@ MAX_QUERY_DATA_REQUEST_SIZE = int(0.9 * MAX_TABULAR_REQUEST_SIZE)
88
90
 
89
91
  def iter_serialized_slices(batch: Union[pa.RecordBatch, pa.Table], max_rows_per_slice=None):
90
92
  """Iterate over a list of record batch slices."""
93
+ if batch.nbytes:
94
+ rows_per_slice = int(0.9 * len(batch) * MAX_RECORD_BATCH_SLICE_SIZE / batch.nbytes)
95
+ else:
96
+ rows_per_slice = len(batch) # if the batch has no buffers (no rows/columns)
91
97
 
92
- rows_per_slice = int(0.9 * len(batch) * MAX_RECORD_BATCH_SLICE_SIZE / batch.nbytes)
93
98
  if max_rows_per_slice is not None:
94
99
  rows_per_slice = min(rows_per_slice, max_rows_per_slice)
95
100
 
@@ -113,3 +118,37 @@ def serialize_record_batch(batch: Union[pa.RecordBatch, pa.Table]):
113
118
  with pa.ipc.new_stream(sink, batch.schema) as writer:
114
119
  writer.write(batch)
115
120
  return sink.getvalue()
121
+
122
+
123
+ def expand_ip_ranges(endpoints):
124
+ """Expands endpoint strings that include an IP range in the format 'http://172.19.101.1-16'."""
125
+ expanded_endpoints = []
126
+ pattern = re.compile(r"(http://\d+\.\d+\.\d+)\.(\d+)-(\d+)")
127
+
128
+ for endpoint in endpoints:
129
+ match = pattern.match(endpoint)
130
+ if match:
131
+ base_url = match.group(1)
132
+ start_ip = int(match.group(2))
133
+ end_ip = int(match.group(3))
134
+ if start_ip > end_ip:
135
+ raise ValueError("Start IP cannot be greater than end IP in the range.")
136
+ expanded_endpoints.extend(f"{base_url}.{ip}" for ip in range(start_ip, end_ip + 1))
137
+ else:
138
+ expanded_endpoints.append(endpoint)
139
+ return expanded_endpoints
140
+
141
+
142
+ def is_sorted(arr):
143
+ """Check if the array is sorted."""
144
+ return pc.all(pc.greater(arr[1:], arr[:-1])).as_py()
145
+
146
+
147
+ def sort_record_batch_if_needed(record_batch, sort_column):
148
+ """Sort the RecordBatch by the specified column if it is not already sorted."""
149
+ column_data = record_batch[sort_column]
150
+
151
+ if not is_sorted(column_data):
152
+ return record_batch.sort_by(sort_column)
153
+ else:
154
+ return record_batch
File without changes